1
2
3
4
5
6
7 package sql
8
9 import (
10 "bytes"
11 "database/sql/driver"
12 "errors"
13 "fmt"
14 "reflect"
15 "strconv"
16 "time"
17 "unicode"
18 "unicode/utf8"
19 _ "unsafe"
20 )
21
22 var errNilPtr = errors.New("destination pointer is nil")
23
24 func describeNamedValue(nv *driver.NamedValue) string {
25 if len(nv.Name) == 0 {
26 return fmt.Sprintf("$%d", nv.Ordinal)
27 }
28 return fmt.Sprintf("with name %q", nv.Name)
29 }
30
31 func validateNamedValueName(name string) error {
32 if len(name) == 0 {
33 return nil
34 }
35 r, _ := utf8.DecodeRuneInString(name)
36 if unicode.IsLetter(r) {
37 return nil
38 }
39 return fmt.Errorf("name %q does not begin with a letter", name)
40 }
41
42
43
44
45 type ccChecker struct {
46 cci driver.ColumnConverter
47 want int
48 }
49
50 func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
51 if c.cci == nil {
52 return driver.ErrSkip
53 }
54
55
56
57 index := nv.Ordinal - 1
58 if c.want <= index {
59 return nil
60 }
61
62
63
64
65 if vr, ok := nv.Value.(driver.Valuer); ok {
66 sv, err := callValuerValue(vr)
67 if err != nil {
68 return err
69 }
70 if !driver.IsValue(sv) {
71 return fmt.Errorf("non-subset type %T returned from Value", sv)
72 }
73 nv.Value = sv
74 }
75
76
77
78
79
80
81
82
83 var err error
84 arg := nv.Value
85 nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
86 if err != nil {
87 return err
88 }
89 if !driver.IsValue(nv.Value) {
90 return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
91 }
92 return nil
93 }
94
95
96
97
98 func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
99 nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
100 return err
101 }
102
103
104
105
106
107
108
109 func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
110 nvargs := make([]driver.NamedValue, len(args))
111
112
113
114
115 want := -1
116
117 var si driver.Stmt
118 var cc ccChecker
119 if ds != nil {
120 si = ds.si
121 want = ds.si.NumInput()
122 cc.want = want
123 }
124
125
126
127
128
129 nvc, ok := si.(driver.NamedValueChecker)
130 if !ok {
131 nvc, _ = ci.(driver.NamedValueChecker)
132 }
133 cci, ok := si.(driver.ColumnConverter)
134 if ok {
135 cc.cci = cci
136 }
137
138
139
140
141
142
143 var err error
144 var n int
145 for _, arg := range args {
146 nv := &nvargs[n]
147 if np, ok := arg.(NamedArg); ok {
148 if err = validateNamedValueName(np.Name); err != nil {
149 return nil, err
150 }
151 arg = np.Value
152 nv.Name = np.Name
153 }
154 nv.Ordinal = n + 1
155 nv.Value = arg
156
157
158
159
160
161
162
163
164
165
166
167
168 checker := defaultCheckNamedValue
169 nextCC := false
170 switch {
171 case nvc != nil:
172 nextCC = cci != nil
173 checker = nvc.CheckNamedValue
174 case cci != nil:
175 checker = cc.CheckNamedValue
176 }
177
178 nextCheck:
179 err = checker(nv)
180 switch err {
181 case nil:
182 n++
183 continue
184 case driver.ErrRemoveArgument:
185 nvargs = nvargs[:len(nvargs)-1]
186 continue
187 case driver.ErrSkip:
188 if nextCC {
189 nextCC = false
190 checker = cc.CheckNamedValue
191 } else {
192 checker = defaultCheckNamedValue
193 }
194 goto nextCheck
195 default:
196 return nil, fmt.Errorf("sql: converting argument %s type: %w", describeNamedValue(nv), err)
197 }
198 }
199
200
201
202 if want != -1 && len(nvargs) != want {
203 return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
204 }
205
206 return nvargs, nil
207 }
208
209
210
211
212
213
214
215
216
217
218
219
220
221 func convertAssign(dest, src any) error {
222 return convertAssignRows(dest, src, nil)
223 }
224
225
226
227
228
229
230 func convertAssignRows(dest, src any, rows *Rows) error {
231
232 switch s := src.(type) {
233 case string:
234 switch d := dest.(type) {
235 case *string:
236 if d == nil {
237 return errNilPtr
238 }
239 *d = s
240 return nil
241 case *[]byte:
242 if d == nil {
243 return errNilPtr
244 }
245 *d = []byte(s)
246 return nil
247 case *RawBytes:
248 if d == nil {
249 return errNilPtr
250 }
251 *d = rows.setrawbuf(append(rows.rawbuf(), s...))
252 return nil
253 }
254 case []byte:
255 switch d := dest.(type) {
256 case *string:
257 if d == nil {
258 return errNilPtr
259 }
260 *d = string(s)
261 return nil
262 case *any:
263 if d == nil {
264 return errNilPtr
265 }
266 *d = bytes.Clone(s)
267 return nil
268 case *[]byte:
269 if d == nil {
270 return errNilPtr
271 }
272 *d = bytes.Clone(s)
273 return nil
274 case *RawBytes:
275 if d == nil {
276 return errNilPtr
277 }
278 *d = s
279 return nil
280 }
281 case time.Time:
282 switch d := dest.(type) {
283 case *time.Time:
284 *d = s
285 return nil
286 case *string:
287 *d = s.Format(time.RFC3339Nano)
288 return nil
289 case *[]byte:
290 if d == nil {
291 return errNilPtr
292 }
293 *d = s.AppendFormat(make([]byte, 0, len(time.RFC3339Nano)), time.RFC3339Nano)
294 return nil
295 case *RawBytes:
296 if d == nil {
297 return errNilPtr
298 }
299 *d = rows.setrawbuf(s.AppendFormat(rows.rawbuf(), time.RFC3339Nano))
300 return nil
301 }
302 case decimalDecompose:
303 switch d := dest.(type) {
304 case decimalCompose:
305 return d.Compose(s.Decompose(nil))
306 }
307 case nil:
308 switch d := dest.(type) {
309 case *any:
310 if d == nil {
311 return errNilPtr
312 }
313 *d = nil
314 return nil
315 case *[]byte:
316 if d == nil {
317 return errNilPtr
318 }
319 *d = nil
320 return nil
321 case *RawBytes:
322 if d == nil {
323 return errNilPtr
324 }
325 *d = nil
326 return nil
327 }
328
329 case driver.Rows:
330 switch d := dest.(type) {
331 case *Rows:
332 if d == nil {
333 return errNilPtr
334 }
335 if rows == nil {
336 return errors.New("invalid context to convert cursor rows, missing parent *Rows")
337 }
338 *d = Rows{
339 dc: rows.dc,
340 releaseConn: func(error) {},
341 rowsi: s,
342 }
343
344 parentCancel := rows.cancel
345 rows.cancel = func() {
346
347
348 d.close(rows.lasterr)
349 if parentCancel != nil {
350 parentCancel()
351 }
352 }
353 return nil
354 }
355 }
356
357 var sv reflect.Value
358
359 switch d := dest.(type) {
360 case *string:
361 sv = reflect.ValueOf(src)
362 switch sv.Kind() {
363 case reflect.Bool,
364 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
365 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
366 reflect.Float32, reflect.Float64:
367 *d = asString(src)
368 return nil
369 }
370 case *[]byte:
371 sv = reflect.ValueOf(src)
372 if b, ok := asBytes(nil, sv); ok {
373 *d = b
374 return nil
375 }
376 case *RawBytes:
377 sv = reflect.ValueOf(src)
378 if b, ok := asBytes(rows.rawbuf(), sv); ok {
379 *d = rows.setrawbuf(b)
380 return nil
381 }
382 case *bool:
383 bv, err := driver.Bool.ConvertValue(src)
384 if err == nil {
385 *d = bv.(bool)
386 }
387 return err
388 case *any:
389 *d = src
390 return nil
391 }
392
393 if scanner, ok := dest.(Scanner); ok {
394 return scanner.Scan(src)
395 }
396
397 dpv := reflect.ValueOf(dest)
398 if dpv.Kind() != reflect.Pointer {
399 return errors.New("destination not a pointer")
400 }
401 if dpv.IsNil() {
402 return errNilPtr
403 }
404
405 if !sv.IsValid() {
406 sv = reflect.ValueOf(src)
407 }
408
409 dv := reflect.Indirect(dpv)
410 if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
411 switch b := src.(type) {
412 case []byte:
413 dv.Set(reflect.ValueOf(bytes.Clone(b)))
414 default:
415 dv.Set(sv)
416 }
417 return nil
418 }
419
420 if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
421 dv.Set(sv.Convert(dv.Type()))
422 return nil
423 }
424
425
426
427
428
429
430 switch dv.Kind() {
431 case reflect.Pointer:
432 if src == nil {
433 dv.SetZero()
434 return nil
435 }
436 dv.Set(reflect.New(dv.Type().Elem()))
437 return convertAssignRows(dv.Interface(), src, rows)
438 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
439 if src == nil {
440 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
441 }
442 s := asString(src)
443 i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
444 if err != nil {
445 err = strconvErr(err)
446 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
447 }
448 dv.SetInt(i64)
449 return nil
450 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
451 if src == nil {
452 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
453 }
454 s := asString(src)
455 u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
456 if err != nil {
457 err = strconvErr(err)
458 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
459 }
460 dv.SetUint(u64)
461 return nil
462 case reflect.Float32, reflect.Float64:
463 if src == nil {
464 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
465 }
466 s := asString(src)
467 f64, err := strconv.ParseFloat(s, dv.Type().Bits())
468 if err != nil {
469 err = strconvErr(err)
470 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
471 }
472 dv.SetFloat(f64)
473 return nil
474 case reflect.String:
475 if src == nil {
476 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
477 }
478 switch v := src.(type) {
479 case string:
480 dv.SetString(v)
481 return nil
482 case []byte:
483 dv.SetString(string(v))
484 return nil
485 }
486 }
487
488 return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
489 }
490
491 func strconvErr(err error) error {
492 if ne, ok := err.(*strconv.NumError); ok {
493 return ne.Err
494 }
495 return err
496 }
497
498 func asString(src any) string {
499 switch v := src.(type) {
500 case string:
501 return v
502 case []byte:
503 return string(v)
504 }
505 rv := reflect.ValueOf(src)
506 switch rv.Kind() {
507 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
508 return strconv.FormatInt(rv.Int(), 10)
509 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
510 return strconv.FormatUint(rv.Uint(), 10)
511 case reflect.Float64:
512 return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
513 case reflect.Float32:
514 return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
515 case reflect.Bool:
516 return strconv.FormatBool(rv.Bool())
517 }
518 return fmt.Sprintf("%v", src)
519 }
520
521 func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
522 switch rv.Kind() {
523 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
524 return strconv.AppendInt(buf, rv.Int(), 10), true
525 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
526 return strconv.AppendUint(buf, rv.Uint(), 10), true
527 case reflect.Float32:
528 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
529 case reflect.Float64:
530 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
531 case reflect.Bool:
532 return strconv.AppendBool(buf, rv.Bool()), true
533 case reflect.String:
534 s := rv.String()
535 return append(buf, s...), true
536 }
537 return
538 }
539
540 var valuerReflectType = reflect.TypeFor[driver.Valuer]()
541
542
543
544
545
546
547
548
549
550
551
552
553 func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
554 if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
555 rv.IsNil() &&
556 rv.Type().Elem().Implements(valuerReflectType) {
557 return nil, nil
558 }
559 return vr.Value()
560 }
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583 type decimal interface {
584 decimalDecompose
585 decimalCompose
586 }
587
588 type decimalDecompose interface {
589
590
591
592 Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
593 }
594
595 type decimalCompose interface {
596
597
598 Compose(form byte, negative bool, coefficient []byte, exponent int32) error
599 }
600
View as plain text