Source file
src/database/sql/fakedb_test.go
1
2
3
4
5 package sql
6
7 import (
8 "bytes"
9 "context"
10 "database/sql/driver"
11 "errors"
12 "fmt"
13 "io"
14 "reflect"
15 "slices"
16 "strconv"
17 "strings"
18 "sync"
19 "testing"
20 "time"
21 )
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47 type fakeDriver struct {
48 mu sync.Mutex
49 openCount int
50 closeCount int
51 waitCh chan struct{}
52 waitingCh chan struct{}
53 dbs map[string]*fakeDB
54 }
55
56 type fakeConnector struct {
57 name string
58
59 waiter func(context.Context)
60 closed bool
61 }
62
63 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
64 conn, err := fdriver.Open(c.name)
65 conn.(*fakeConn).waiter = c.waiter
66 return conn, err
67 }
68
69 func (c *fakeConnector) Driver() driver.Driver {
70 return fdriver
71 }
72
73 func (c *fakeConnector) Close() error {
74 if c.closed {
75 return errors.New("fakedb: connector is closed")
76 }
77 c.closed = true
78 return nil
79 }
80
81 type fakeDriverCtx struct {
82 fakeDriver
83 }
84
85 var _ driver.DriverContext = &fakeDriverCtx{}
86
87 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
88 return &fakeConnector{name: name}, nil
89 }
90
91 type fakeDB struct {
92 name string
93
94 mu sync.Mutex
95 tables map[string]*table
96 badConn bool
97 allowAny bool
98 }
99
100 type fakeError struct {
101 Message string
102 Wrapped error
103 }
104
105 func (err fakeError) Error() string {
106 return err.Message
107 }
108
109 func (err fakeError) Unwrap() error {
110 return err.Wrapped
111 }
112
113 type table struct {
114 mu sync.Mutex
115 colname []string
116 coltype []string
117 rows []*row
118 }
119
120 func (t *table) columnIndex(name string) int {
121 return slices.Index(t.colname, name)
122 }
123
124 type row struct {
125 cols []any
126 }
127
128 type memToucher interface {
129
130 touchMem()
131 }
132
133 type fakeConn struct {
134 db *fakeDB
135
136 currTx *fakeTx
137
138
139
140 line int64
141
142
143 mu sync.Mutex
144 stmtsMade int
145 stmtsClosed int
146 numPrepare int
147
148
149 bad bool
150 stickyBad bool
151
152 skipDirtySession bool
153
154
155
156 dirtySession bool
157
158
159
160 waiter func(context.Context)
161 }
162
163 func (c *fakeConn) touchMem() {
164 c.line++
165 }
166
167 func (c *fakeConn) incrStat(v *int) {
168 c.mu.Lock()
169 *v++
170 c.mu.Unlock()
171 }
172
173 type fakeTx struct {
174 c *fakeConn
175 }
176
177 type boundCol struct {
178 Column string
179 Placeholder string
180 Ordinal int
181 }
182
183 type fakeStmt struct {
184 memToucher
185 c *fakeConn
186 q string
187
188 cmd string
189 table string
190 panic string
191 wait time.Duration
192
193 next *fakeStmt
194
195 closed bool
196
197 colName []string
198 colType []string
199 colValue []any
200 placeholders int
201
202 whereCol []boundCol
203
204 placeholderConverter []driver.ValueConverter
205 }
206
207 var fdriver driver.Driver = &fakeDriver{}
208
209 func init() {
210 Register("test", fdriver)
211 }
212
213 type Dummy struct {
214 driver.Driver
215 }
216
217 func TestDrivers(t *testing.T) {
218 unregisterAllDrivers()
219 Register("test", fdriver)
220 Register("invalid", Dummy{})
221 all := Drivers()
222 if len(all) < 2 || !slices.IsSorted(all) || !slices.Contains(all, "test") || !slices.Contains(all, "invalid") {
223 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
224 }
225 }
226
227
228 var hookOpenErr struct {
229 sync.Mutex
230 fn func() error
231 }
232
233 func setHookOpenErr(fn func() error) {
234 hookOpenErr.Lock()
235 defer hookOpenErr.Unlock()
236 hookOpenErr.fn = fn
237 }
238
239
240
241
242
243
244
245 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
246 hookOpenErr.Lock()
247 fn := hookOpenErr.fn
248 hookOpenErr.Unlock()
249 if fn != nil {
250 if err := fn(); err != nil {
251 return nil, err
252 }
253 }
254 parts := strings.Split(dsn, ";")
255 if len(parts) < 1 {
256 return nil, errors.New("fakedb: no database name")
257 }
258 name := parts[0]
259
260 db := d.getDB(name)
261
262 d.mu.Lock()
263 d.openCount++
264 d.mu.Unlock()
265 conn := &fakeConn{db: db}
266
267 if len(parts) >= 2 && parts[1] == "badConn" {
268 conn.bad = true
269 }
270 if d.waitCh != nil {
271 d.waitingCh <- struct{}{}
272 <-d.waitCh
273 d.waitCh = nil
274 d.waitingCh = nil
275 }
276 return conn, nil
277 }
278
279 func (d *fakeDriver) getDB(name string) *fakeDB {
280 d.mu.Lock()
281 defer d.mu.Unlock()
282 if d.dbs == nil {
283 d.dbs = make(map[string]*fakeDB)
284 }
285 db, ok := d.dbs[name]
286 if !ok {
287 db = &fakeDB{name: name}
288 d.dbs[name] = db
289 }
290 return db
291 }
292
293 func (db *fakeDB) wipe() {
294 db.mu.Lock()
295 defer db.mu.Unlock()
296 db.tables = nil
297 }
298
299 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
300 db.mu.Lock()
301 defer db.mu.Unlock()
302 if db.tables == nil {
303 db.tables = make(map[string]*table)
304 }
305 if _, exist := db.tables[name]; exist {
306 return fmt.Errorf("fakedb: table %q already exists", name)
307 }
308 if len(columnNames) != len(columnTypes) {
309 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
310 name, len(columnNames), len(columnTypes))
311 }
312 db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
313 return nil
314 }
315
316
317 func (db *fakeDB) table(table string) (*table, bool) {
318 if db.tables == nil {
319 return nil, false
320 }
321 t, ok := db.tables[table]
322 return t, ok
323 }
324
325 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
326 db.mu.Lock()
327 defer db.mu.Unlock()
328 t, ok := db.table(table)
329 if !ok {
330 return
331 }
332 if i := slices.Index(t.colname, column); i != -1 {
333 return t.coltype[i], true
334 }
335 return "", false
336 }
337
338 func (c *fakeConn) isBad() bool {
339 if c.stickyBad {
340 return true
341 } else if c.bad {
342 if c.db == nil {
343 return false
344 }
345
346 c.db.badConn = !c.db.badConn
347 return c.db.badConn
348 } else {
349 return false
350 }
351 }
352
353 func (c *fakeConn) isDirtyAndMark() bool {
354 if c.skipDirtySession {
355 return false
356 }
357 if c.currTx != nil {
358 c.dirtySession = true
359 return false
360 }
361 if c.dirtySession {
362 return true
363 }
364 c.dirtySession = true
365 return false
366 }
367
368 func (c *fakeConn) Begin() (driver.Tx, error) {
369 if c.isBad() {
370 return nil, fakeError{Wrapped: driver.ErrBadConn}
371 }
372 if c.currTx != nil {
373 return nil, errors.New("fakedb: already in a transaction")
374 }
375 c.touchMem()
376 c.currTx = &fakeTx{c: c}
377 return c.currTx, nil
378 }
379
380 var hookPostCloseConn struct {
381 sync.Mutex
382 fn func(*fakeConn, error)
383 }
384
385 func setHookpostCloseConn(fn func(*fakeConn, error)) {
386 hookPostCloseConn.Lock()
387 defer hookPostCloseConn.Unlock()
388 hookPostCloseConn.fn = fn
389 }
390
391 var testStrictClose *testing.T
392
393
394
395 func setStrictFakeConnClose(t *testing.T) {
396 testStrictClose = t
397 }
398
399 func (c *fakeConn) ResetSession(ctx context.Context) error {
400 c.dirtySession = false
401 c.currTx = nil
402 if c.isBad() {
403 return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
404 }
405 return nil
406 }
407
408 var _ driver.Validator = (*fakeConn)(nil)
409
410 func (c *fakeConn) IsValid() bool {
411 return !c.isBad()
412 }
413
414 func (c *fakeConn) Close() (err error) {
415 drv := fdriver.(*fakeDriver)
416 defer func() {
417 if err != nil && testStrictClose != nil {
418 testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
419 }
420 hookPostCloseConn.Lock()
421 fn := hookPostCloseConn.fn
422 hookPostCloseConn.Unlock()
423 if fn != nil {
424 fn(c, err)
425 }
426 if err == nil {
427 drv.mu.Lock()
428 drv.closeCount++
429 drv.mu.Unlock()
430 }
431 }()
432 c.touchMem()
433 if c.currTx != nil {
434 return errors.New("fakedb: can't close fakeConn; in a Transaction")
435 }
436 if c.db == nil {
437 return errors.New("fakedb: can't close fakeConn; already closed")
438 }
439 if c.stmtsMade > c.stmtsClosed {
440 return errors.New("fakedb: can't close; dangling statement(s)")
441 }
442 c.db = nil
443 return nil
444 }
445
446 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
447 for _, arg := range args {
448 switch arg.Value.(type) {
449 case int64, float64, bool, nil, []byte, string, time.Time:
450 default:
451 if !allowAny {
452 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
453 }
454 }
455 }
456 return nil
457 }
458
459 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
460
461 panic("ExecContext was not called.")
462 }
463
464 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
465
466
467
468
469 err := checkSubsetTypes(c.db.allowAny, args)
470 if err != nil {
471 return nil, err
472 }
473 return nil, driver.ErrSkip
474 }
475
476 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
477
478 panic("QueryContext was not called.")
479 }
480
481 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
482
483
484
485
486 err := checkSubsetTypes(c.db.allowAny, args)
487 if err != nil {
488 return nil, err
489 }
490 return nil, driver.ErrSkip
491 }
492
493 func errf(msg string, args ...any) error {
494 return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
495 }
496
497
498
499
500 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
501 if len(parts) != 3 {
502 stmt.Close()
503 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
504 }
505 stmt.table = parts[0]
506
507 stmt.colName = strings.Split(parts[1], ",")
508 for n, colspec := range strings.Split(parts[2], ",") {
509 if colspec == "" {
510 continue
511 }
512 nameVal := strings.Split(colspec, "=")
513 if len(nameVal) != 2 {
514 stmt.Close()
515 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
516 }
517 column, value := nameVal[0], nameVal[1]
518 _, ok := c.db.columnType(stmt.table, column)
519 if !ok {
520 stmt.Close()
521 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
522 }
523 if !strings.HasPrefix(value, "?") {
524 stmt.Close()
525 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
526 stmt.table, column)
527 }
528 stmt.placeholders++
529 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
530 }
531 return stmt, nil
532 }
533
534
535 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
536 if len(parts) != 2 {
537 stmt.Close()
538 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
539 }
540 stmt.table = parts[0]
541 for n, colspec := range strings.Split(parts[1], ",") {
542 nameType := strings.Split(colspec, "=")
543 if len(nameType) != 2 {
544 stmt.Close()
545 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
546 }
547 stmt.colName = append(stmt.colName, nameType[0])
548 stmt.colType = append(stmt.colType, nameType[1])
549 }
550 return stmt, nil
551 }
552
553
554 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
555 if len(parts) != 2 {
556 stmt.Close()
557 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
558 }
559 stmt.table = parts[0]
560 for n, colspec := range strings.Split(parts[1], ",") {
561 nameVal := strings.Split(colspec, "=")
562 if len(nameVal) != 2 {
563 stmt.Close()
564 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
565 }
566 column, value := nameVal[0], nameVal[1]
567 ctype, ok := c.db.columnType(stmt.table, column)
568 if !ok {
569 stmt.Close()
570 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
571 }
572 stmt.colName = append(stmt.colName, column)
573
574 if !strings.HasPrefix(value, "?") {
575 var subsetVal any
576
577 switch ctype {
578 case "string":
579 subsetVal = []byte(value)
580 case "blob":
581 subsetVal = []byte(value)
582 case "int32":
583 i, err := strconv.Atoi(value)
584 if err != nil {
585 stmt.Close()
586 return nil, errf("invalid conversion to int32 from %q", value)
587 }
588 subsetVal = int64(i)
589 case "table":
590 c.skipDirtySession = true
591 vparts := strings.Split(value, "!")
592
593 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
594 if err != nil {
595 return nil, err
596 }
597 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
598 substmt.Close()
599 if err != nil {
600 return nil, err
601 }
602 subsetVal = cursor
603 default:
604 stmt.Close()
605 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
606 }
607 stmt.colValue = append(stmt.colValue, subsetVal)
608 } else {
609 stmt.placeholders++
610 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
611 stmt.colValue = append(stmt.colValue, value)
612 }
613 }
614 return stmt, nil
615 }
616
617
618 var hookPrepareBadConn func() bool
619
620 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
621 panic("use PrepareContext")
622 }
623
624 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
625 c.numPrepare++
626 if c.db == nil {
627 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
628 }
629
630 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
631 return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn}
632 }
633
634 c.touchMem()
635 var firstStmt, prev *fakeStmt
636 for _, query := range strings.Split(query, ";") {
637 parts := strings.Split(query, "|")
638 if len(parts) < 1 {
639 return nil, errf("empty query")
640 }
641 stmt := &fakeStmt{q: query, c: c, memToucher: c}
642 if firstStmt == nil {
643 firstStmt = stmt
644 }
645 if len(parts) >= 3 {
646 switch parts[0] {
647 case "PANIC":
648 stmt.panic = parts[1]
649 parts = parts[2:]
650 case "WAIT":
651 wait, err := time.ParseDuration(parts[1])
652 if err != nil {
653 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
654 }
655 parts = parts[2:]
656 stmt.wait = wait
657 }
658 }
659 cmd := parts[0]
660 stmt.cmd = cmd
661 parts = parts[1:]
662
663 if c.waiter != nil {
664 c.waiter(ctx)
665 if err := ctx.Err(); err != nil {
666 return nil, err
667 }
668 }
669
670 if stmt.wait > 0 {
671 wait := time.NewTimer(stmt.wait)
672 select {
673 case <-wait.C:
674 case <-ctx.Done():
675 wait.Stop()
676 return nil, ctx.Err()
677 }
678 }
679
680 c.incrStat(&c.stmtsMade)
681 var err error
682 switch cmd {
683 case "WIPE":
684
685 case "SELECT":
686 stmt, err = c.prepareSelect(stmt, parts)
687 case "CREATE":
688 stmt, err = c.prepareCreate(stmt, parts)
689 case "INSERT":
690 stmt, err = c.prepareInsert(ctx, stmt, parts)
691 case "NOSERT":
692
693
694 stmt, err = c.prepareInsert(ctx, stmt, parts)
695 default:
696 stmt.Close()
697 return nil, errf("unsupported command type %q", cmd)
698 }
699 if err != nil {
700 return nil, err
701 }
702 if prev != nil {
703 prev.next = stmt
704 }
705 prev = stmt
706 }
707 return firstStmt, nil
708 }
709
710 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
711 if s.panic == "ColumnConverter" {
712 panic(s.panic)
713 }
714 if len(s.placeholderConverter) == 0 {
715 return driver.DefaultParameterConverter
716 }
717 return s.placeholderConverter[idx]
718 }
719
720 func (s *fakeStmt) Close() error {
721 if s.panic == "Close" {
722 panic(s.panic)
723 }
724 if s.c == nil {
725 panic("nil conn in fakeStmt.Close")
726 }
727 if s.c.db == nil {
728 panic("in fakeStmt.Close, conn's db is nil (already closed)")
729 }
730 s.touchMem()
731 if !s.closed {
732 s.c.incrStat(&s.c.stmtsClosed)
733 s.closed = true
734 }
735 if s.next != nil {
736 s.next.Close()
737 }
738 return nil
739 }
740
741 var errClosed = errors.New("fakedb: statement has been closed")
742
743
744 var hookExecBadConn func() bool
745
746 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
747 panic("Using ExecContext")
748 }
749
750 var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
751
752 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
753 if s.panic == "Exec" {
754 panic(s.panic)
755 }
756 if s.closed {
757 return nil, errClosed
758 }
759
760 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
761 return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
762 }
763 if s.c.isDirtyAndMark() {
764 return nil, errFakeConnSessionDirty
765 }
766
767 err := checkSubsetTypes(s.c.db.allowAny, args)
768 if err != nil {
769 return nil, err
770 }
771 s.touchMem()
772
773 if s.wait > 0 {
774 time.Sleep(s.wait)
775 }
776
777 select {
778 default:
779 case <-ctx.Done():
780 return nil, ctx.Err()
781 }
782
783 db := s.c.db
784 switch s.cmd {
785 case "WIPE":
786 db.wipe()
787 return driver.ResultNoRows, nil
788 case "CREATE":
789 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
790 return nil, err
791 }
792 return driver.ResultNoRows, nil
793 case "INSERT":
794 return s.execInsert(args, true)
795 case "NOSERT":
796
797
798 return s.execInsert(args, false)
799 }
800 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
801 }
802
803 func valueFromPlaceholderName(args []driver.NamedValue, name string) driver.Value {
804 for i := range args {
805 if args[i].Name == name {
806 return args[i].Value
807 }
808 }
809 return nil
810 }
811
812
813
814
815 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
816 db := s.c.db
817 if len(args) != s.placeholders {
818 panic("error in pkg db; should only get here if size is correct")
819 }
820 db.mu.Lock()
821 t, ok := db.table(s.table)
822 db.mu.Unlock()
823 if !ok {
824 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
825 }
826
827 t.mu.Lock()
828 defer t.mu.Unlock()
829
830 var cols []any
831 if doInsert {
832 cols = make([]any, len(t.colname))
833 }
834 argPos := 0
835 for n, colname := range s.colName {
836 colidx := t.columnIndex(colname)
837 if colidx == -1 {
838 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
839 }
840 var val any
841 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
842 if strvalue == "?" {
843 val = args[argPos].Value
844 } else {
845
846 if v := valueFromPlaceholderName(args, strvalue[1:]); v != nil {
847 val = v
848 }
849 }
850 argPos++
851 } else {
852 val = s.colValue[n]
853 }
854 if doInsert {
855 cols[colidx] = val
856 }
857 }
858
859 if doInsert {
860 t.rows = append(t.rows, &row{cols: cols})
861 }
862 return driver.RowsAffected(1), nil
863 }
864
865
866 var hookQueryBadConn func() bool
867
868 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
869 panic("Use QueryContext")
870 }
871
872 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
873 if s.panic == "Query" {
874 panic(s.panic)
875 }
876 if s.closed {
877 return nil, errClosed
878 }
879
880 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
881 return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
882 }
883 if s.c.isDirtyAndMark() {
884 return nil, errFakeConnSessionDirty
885 }
886
887 err := checkSubsetTypes(s.c.db.allowAny, args)
888 if err != nil {
889 return nil, err
890 }
891
892 s.touchMem()
893 db := s.c.db
894 if len(args) != s.placeholders {
895 panic("error in pkg db; should only get here if size is correct")
896 }
897
898 setMRows := make([][]*row, 0, 1)
899 setColumns := make([][]string, 0, 1)
900 setColType := make([][]string, 0, 1)
901
902 for {
903 db.mu.Lock()
904 t, ok := db.table(s.table)
905 db.mu.Unlock()
906 if !ok {
907 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
908 }
909
910 if s.table == "magicquery" {
911 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
912 if args[0].Value == "sleep" {
913 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
914 }
915 }
916 }
917 if s.table == "tx_status" && s.colName[0] == "tx_status" {
918 txStatus := "autocommit"
919 if s.c.currTx != nil {
920 txStatus = "transaction"
921 }
922 cursor := &rowsCursor{
923 db: s.c.db,
924 parentMem: s.c,
925 posRow: -1,
926 rows: [][]*row{
927 {
928 {
929 cols: []any{
930 txStatus,
931 },
932 },
933 },
934 },
935 cols: [][]string{
936 {
937 "tx_status",
938 },
939 },
940 colType: [][]string{
941 {
942 "string",
943 },
944 },
945 errPos: -1,
946 }
947 return cursor, nil
948 }
949
950 t.mu.Lock()
951
952 colIdx := make(map[string]int)
953 for _, name := range s.colName {
954 idx := t.columnIndex(name)
955 if idx == -1 {
956 t.mu.Unlock()
957 return nil, fmt.Errorf("fakedb: unknown column name %q", name)
958 }
959 colIdx[name] = idx
960 }
961
962 mrows := []*row{}
963 rows:
964 for _, trow := range t.rows {
965
966
967
968 for _, wcol := range s.whereCol {
969 idx := t.columnIndex(wcol.Column)
970 if idx == -1 {
971 t.mu.Unlock()
972 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
973 }
974 tcol := trow.cols[idx]
975 if bs, ok := tcol.([]byte); ok {
976
977 tcol = string(bs)
978 }
979 var argValue any
980 if wcol.Placeholder == "?" {
981 argValue = args[wcol.Ordinal-1].Value
982 } else {
983 if v := valueFromPlaceholderName(args, wcol.Placeholder[1:]); v != nil {
984 argValue = v
985 }
986 }
987 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
988 continue rows
989 }
990 }
991 mrow := &row{cols: make([]any, len(s.colName))}
992 for seli, name := range s.colName {
993 mrow.cols[seli] = trow.cols[colIdx[name]]
994 }
995 mrows = append(mrows, mrow)
996 }
997
998 var colType []string
999 for _, column := range s.colName {
1000 colType = append(colType, t.coltype[t.columnIndex(column)])
1001 }
1002
1003 t.mu.Unlock()
1004
1005 setMRows = append(setMRows, mrows)
1006 setColumns = append(setColumns, s.colName)
1007 setColType = append(setColType, colType)
1008
1009 if s.next == nil {
1010 break
1011 }
1012 s = s.next
1013 }
1014
1015 cursor := &rowsCursor{
1016 db: s.c.db,
1017 parentMem: s.c,
1018 posRow: -1,
1019 rows: setMRows,
1020 cols: setColumns,
1021 colType: setColType,
1022 errPos: -1,
1023 }
1024 return cursor, nil
1025 }
1026
1027 func (s *fakeStmt) NumInput() int {
1028 if s.panic == "NumInput" {
1029 panic(s.panic)
1030 }
1031 return s.placeholders
1032 }
1033
1034
1035 var hookCommitBadConn func() bool
1036
1037 func (tx *fakeTx) Commit() error {
1038 tx.c.currTx = nil
1039 if hookCommitBadConn != nil && hookCommitBadConn() {
1040 return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1041 }
1042 tx.c.touchMem()
1043 return nil
1044 }
1045
1046
1047 var hookRollbackBadConn func() bool
1048
1049 func (tx *fakeTx) Rollback() error {
1050 tx.c.currTx = nil
1051 if hookRollbackBadConn != nil && hookRollbackBadConn() {
1052 return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1053 }
1054 tx.c.touchMem()
1055 return nil
1056 }
1057
1058 type rowsCursor struct {
1059 db *fakeDB
1060 parentMem memToucher
1061 cols [][]string
1062 colType [][]string
1063 posSet int
1064 posRow int
1065 rows [][]*row
1066 closed bool
1067
1068
1069 errPos int
1070 err error
1071
1072
1073
1074 driverOwnedMemory [][]byte
1075
1076
1077
1078
1079
1080 line int64
1081
1082
1083 closeErr error
1084 }
1085
1086 func (rc *rowsCursor) touchMem() {
1087 rc.parentMem.touchMem()
1088 rc.line++
1089 }
1090
1091 func (rc *rowsCursor) invalidateDriverOwnedMemory() {
1092 for _, buf := range rc.driverOwnedMemory {
1093 for i := range buf {
1094 buf[i] = 'x'
1095 }
1096 }
1097 rc.driverOwnedMemory = nil
1098 }
1099
1100 func (rc *rowsCursor) Close() error {
1101 rc.touchMem()
1102 rc.parentMem.touchMem()
1103 rc.invalidateDriverOwnedMemory()
1104 rc.closed = true
1105 return rc.closeErr
1106 }
1107
1108 func (rc *rowsCursor) Columns() []string {
1109 return rc.cols[rc.posSet]
1110 }
1111
1112 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
1113 return colTypeToReflectType(rc.colType[rc.posSet][index])
1114 }
1115
1116 var rowsCursorNextHook func(dest []driver.Value) error
1117
1118 func (rc *rowsCursor) Next(dest []driver.Value) error {
1119 if rowsCursorNextHook != nil {
1120 return rowsCursorNextHook(dest)
1121 }
1122
1123 if rc.closed {
1124 return errors.New("fakedb: cursor is closed")
1125 }
1126 rc.touchMem()
1127 rc.posRow++
1128 if rc.posRow == rc.errPos {
1129 return rc.err
1130 }
1131 if rc.posRow >= len(rc.rows[rc.posSet]) {
1132 return io.EOF
1133 }
1134
1135 rc.invalidateDriverOwnedMemory()
1136 for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
1137
1138
1139
1140
1141
1142
1143 if bs, ok := v.([]byte); ok {
1144
1145 bs = bytes.Clone(bs)
1146 rc.driverOwnedMemory = append(rc.driverOwnedMemory, bs)
1147 v = bs
1148 }
1149 dest[i] = v
1150 }
1151 return nil
1152 }
1153
1154 func (rc *rowsCursor) HasNextResultSet() bool {
1155 rc.touchMem()
1156 return rc.posSet < len(rc.rows)-1
1157 }
1158
1159 func (rc *rowsCursor) NextResultSet() error {
1160 rc.touchMem()
1161 if rc.HasNextResultSet() {
1162 rc.posSet++
1163 rc.posRow = -1
1164 return nil
1165 }
1166 return io.EOF
1167 }
1168
1169
1170
1171
1172
1173
1174
1175 type fakeDriverString struct{}
1176
1177 func (fakeDriverString) ConvertValue(v any) (driver.Value, error) {
1178 switch c := v.(type) {
1179 case string, []byte:
1180 return v, nil
1181 case *string:
1182 if c == nil {
1183 return nil, nil
1184 }
1185 return *c, nil
1186 }
1187 return fmt.Sprintf("%v", v), nil
1188 }
1189
1190 type anyTypeConverter struct{}
1191
1192 func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) {
1193 return v, nil
1194 }
1195
1196 func converterForType(typ string) driver.ValueConverter {
1197 switch typ {
1198 case "bool":
1199 return driver.Bool
1200 case "nullbool":
1201 return driver.Null{Converter: driver.Bool}
1202 case "byte", "int16":
1203 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1204 case "int32":
1205 return driver.Int32
1206 case "nullbyte", "nullint32", "nullint16":
1207 return driver.Null{Converter: driver.DefaultParameterConverter}
1208 case "string":
1209 return driver.NotNull{Converter: fakeDriverString{}}
1210 case "nullstring":
1211 return driver.Null{Converter: fakeDriverString{}}
1212 case "int64":
1213
1214 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1215 case "nullint64":
1216
1217 return driver.Null{Converter: driver.DefaultParameterConverter}
1218 case "float64":
1219
1220 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1221 case "nullfloat64":
1222
1223 return driver.Null{Converter: driver.DefaultParameterConverter}
1224 case "datetime":
1225 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1226 case "nulldatetime":
1227 return driver.Null{Converter: driver.DefaultParameterConverter}
1228 case "any":
1229 return anyTypeConverter{}
1230 }
1231 panic("invalid fakedb column type of " + typ)
1232 }
1233
1234 func colTypeToReflectType(typ string) reflect.Type {
1235 switch typ {
1236 case "bool":
1237 return reflect.TypeFor[bool]()
1238 case "nullbool":
1239 return reflect.TypeFor[NullBool]()
1240 case "int16":
1241 return reflect.TypeFor[int16]()
1242 case "nullint16":
1243 return reflect.TypeFor[NullInt16]()
1244 case "int32":
1245 return reflect.TypeFor[int32]()
1246 case "nullint32":
1247 return reflect.TypeFor[NullInt32]()
1248 case "string":
1249 return reflect.TypeFor[string]()
1250 case "nullstring":
1251 return reflect.TypeFor[NullString]()
1252 case "int64":
1253 return reflect.TypeFor[int64]()
1254 case "nullint64":
1255 return reflect.TypeFor[NullInt64]()
1256 case "float64":
1257 return reflect.TypeFor[float64]()
1258 case "nullfloat64":
1259 return reflect.TypeFor[NullFloat64]()
1260 case "datetime":
1261 return reflect.TypeFor[time.Time]()
1262 case "any":
1263 return reflect.TypeFor[any]()
1264 }
1265 panic("invalid fakedb column type of " + typ)
1266 }
1267
View as plain text