Source file src/database/sql/fakedb_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"database/sql/driver"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"reflect"
    15  	"slices"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  // fakeDriver is a fake database that implements Go's driver.Driver
    24  // interface, just for testing.
    25  //
    26  // It speaks a query language that's semantically similar to but
    27  // syntactically different and simpler than SQL.  The syntax is as
    28  // follows:
    29  //
    30  //	WIPE
    31  //	CREATE|<tablename>|<col>=<type>,<col>=<type>,...
    32  //	  where types are: "string", [u]int{8,16,32,64}, "bool"
    33  //	INSERT|<tablename>|col=val,col2=val2,col3=?
    34  //	SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
    35  //	SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
    36  //
    37  // Any of these can be preceded by PANIC|<method>|, to cause the
    38  // named method on fakeStmt to panic.
    39  //
    40  // Any of these can be proceeded by WAIT|<duration>|, to cause the
    41  // named method on fakeStmt to sleep for the specified duration.
    42  //
    43  // Multiple of these can be combined when separated with a semicolon.
    44  //
    45  // When opening a fakeDriver's database, it starts empty with no
    46  // tables. All tables and data are stored in memory only.
    47  type fakeDriver struct {
    48  	mu         sync.Mutex // guards 3 following fields
    49  	openCount  int        // conn opens
    50  	closeCount int        // conn closes
    51  	waitCh     chan struct{}
    52  	waitingCh  chan struct{}
    53  	dbs        map[string]*fakeDB
    54  }
    55  
    56  type fakeConnector struct {
    57  	name string
    58  
    59  	waiter func(context.Context)
    60  	closed bool
    61  }
    62  
    63  func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
    64  	conn, err := fdriver.Open(c.name)
    65  	conn.(*fakeConn).waiter = c.waiter
    66  	return conn, err
    67  }
    68  
    69  func (c *fakeConnector) Driver() driver.Driver {
    70  	return fdriver
    71  }
    72  
    73  func (c *fakeConnector) Close() error {
    74  	if c.closed {
    75  		return errors.New("fakedb: connector is closed")
    76  	}
    77  	c.closed = true
    78  	return nil
    79  }
    80  
    81  type fakeDriverCtx struct {
    82  	fakeDriver
    83  }
    84  
    85  var _ driver.DriverContext = &fakeDriverCtx{}
    86  
    87  func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
    88  	return &fakeConnector{name: name}, nil
    89  }
    90  
    91  type fakeDB struct {
    92  	name string
    93  
    94  	mu       sync.Mutex
    95  	tables   map[string]*table
    96  	badConn  bool
    97  	allowAny bool
    98  }
    99  
   100  type fakeError struct {
   101  	Message string
   102  	Wrapped error
   103  }
   104  
   105  func (err fakeError) Error() string {
   106  	return err.Message
   107  }
   108  
   109  func (err fakeError) Unwrap() error {
   110  	return err.Wrapped
   111  }
   112  
   113  type table struct {
   114  	mu      sync.Mutex
   115  	colname []string
   116  	coltype []string
   117  	rows    []*row
   118  }
   119  
   120  func (t *table) columnIndex(name string) int {
   121  	return slices.Index(t.colname, name)
   122  }
   123  
   124  type row struct {
   125  	cols []any // must be same size as its table colname + coltype
   126  }
   127  
   128  type memToucher interface {
   129  	// touchMem reads & writes some memory, to help find data races.
   130  	touchMem()
   131  }
   132  
   133  type fakeConn struct {
   134  	db *fakeDB // where to return ourselves to
   135  
   136  	currTx *fakeTx
   137  
   138  	// Every operation writes to line to enable the race detector
   139  	// check for data races.
   140  	line int64
   141  
   142  	// Stats for tests:
   143  	mu          sync.Mutex
   144  	stmtsMade   int
   145  	stmtsClosed int
   146  	numPrepare  int
   147  
   148  	// bad connection tests; see isBad()
   149  	bad       bool
   150  	stickyBad bool
   151  
   152  	skipDirtySession bool // tests that use Conn should set this to true.
   153  
   154  	// dirtySession tests ResetSession, true if a query has executed
   155  	// until ResetSession is called.
   156  	dirtySession bool
   157  
   158  	// The waiter is called before each query. May be used in place of the "WAIT"
   159  	// directive.
   160  	waiter func(context.Context)
   161  }
   162  
   163  func (c *fakeConn) touchMem() {
   164  	c.line++
   165  }
   166  
   167  func (c *fakeConn) incrStat(v *int) {
   168  	c.mu.Lock()
   169  	*v++
   170  	c.mu.Unlock()
   171  }
   172  
   173  type fakeTx struct {
   174  	c *fakeConn
   175  }
   176  
   177  type boundCol struct {
   178  	Column      string
   179  	Placeholder string
   180  	Ordinal     int
   181  }
   182  
   183  type fakeStmt struct {
   184  	memToucher
   185  	c *fakeConn
   186  	q string // just for debugging
   187  
   188  	cmd   string
   189  	table string
   190  	panic string
   191  	wait  time.Duration
   192  
   193  	next *fakeStmt // used for returning multiple results.
   194  
   195  	closed bool
   196  
   197  	colName      []string // used by CREATE, INSERT, SELECT (selected columns)
   198  	colType      []string // used by CREATE
   199  	colValue     []any    // used by INSERT (mix of strings and "?" for bound params)
   200  	placeholders int      // used by INSERT/SELECT: number of ? params
   201  
   202  	whereCol []boundCol // used by SELECT (all placeholders)
   203  
   204  	placeholderConverter []driver.ValueConverter // used by INSERT
   205  }
   206  
   207  var fdriver driver.Driver = &fakeDriver{}
   208  
   209  func init() {
   210  	Register("test", fdriver)
   211  }
   212  
   213  type Dummy struct {
   214  	driver.Driver
   215  }
   216  
   217  func TestDrivers(t *testing.T) {
   218  	unregisterAllDrivers()
   219  	Register("test", fdriver)
   220  	Register("invalid", Dummy{})
   221  	all := Drivers()
   222  	if len(all) < 2 || !slices.IsSorted(all) || !slices.Contains(all, "test") || !slices.Contains(all, "invalid") {
   223  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
   224  	}
   225  }
   226  
   227  // hook to simulate connection failures
   228  var hookOpenErr struct {
   229  	sync.Mutex
   230  	fn func() error
   231  }
   232  
   233  func setHookOpenErr(fn func() error) {
   234  	hookOpenErr.Lock()
   235  	defer hookOpenErr.Unlock()
   236  	hookOpenErr.fn = fn
   237  }
   238  
   239  // Supports dsn forms:
   240  //
   241  //	<dbname>
   242  //	<dbname>;<opts>  (only currently supported option is `badConn`,
   243  //	                  which causes driver.ErrBadConn to be returned on
   244  //	                  every other conn.Begin())
   245  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
   246  	hookOpenErr.Lock()
   247  	fn := hookOpenErr.fn
   248  	hookOpenErr.Unlock()
   249  	if fn != nil {
   250  		if err := fn(); err != nil {
   251  			return nil, err
   252  		}
   253  	}
   254  	parts := strings.Split(dsn, ";")
   255  	if len(parts) < 1 {
   256  		return nil, errors.New("fakedb: no database name")
   257  	}
   258  	name := parts[0]
   259  
   260  	db := d.getDB(name)
   261  
   262  	d.mu.Lock()
   263  	d.openCount++
   264  	d.mu.Unlock()
   265  	conn := &fakeConn{db: db}
   266  
   267  	if len(parts) >= 2 && parts[1] == "badConn" {
   268  		conn.bad = true
   269  	}
   270  	if d.waitCh != nil {
   271  		d.waitingCh <- struct{}{}
   272  		<-d.waitCh
   273  		d.waitCh = nil
   274  		d.waitingCh = nil
   275  	}
   276  	return conn, nil
   277  }
   278  
   279  func (d *fakeDriver) getDB(name string) *fakeDB {
   280  	d.mu.Lock()
   281  	defer d.mu.Unlock()
   282  	if d.dbs == nil {
   283  		d.dbs = make(map[string]*fakeDB)
   284  	}
   285  	db, ok := d.dbs[name]
   286  	if !ok {
   287  		db = &fakeDB{name: name}
   288  		d.dbs[name] = db
   289  	}
   290  	return db
   291  }
   292  
   293  func (db *fakeDB) wipe() {
   294  	db.mu.Lock()
   295  	defer db.mu.Unlock()
   296  	db.tables = nil
   297  }
   298  
   299  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
   300  	db.mu.Lock()
   301  	defer db.mu.Unlock()
   302  	if db.tables == nil {
   303  		db.tables = make(map[string]*table)
   304  	}
   305  	if _, exist := db.tables[name]; exist {
   306  		return fmt.Errorf("fakedb: table %q already exists", name)
   307  	}
   308  	if len(columnNames) != len(columnTypes) {
   309  		return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
   310  			name, len(columnNames), len(columnTypes))
   311  	}
   312  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
   313  	return nil
   314  }
   315  
   316  // must be called with db.mu lock held
   317  func (db *fakeDB) table(table string) (*table, bool) {
   318  	if db.tables == nil {
   319  		return nil, false
   320  	}
   321  	t, ok := db.tables[table]
   322  	return t, ok
   323  }
   324  
   325  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
   326  	db.mu.Lock()
   327  	defer db.mu.Unlock()
   328  	t, ok := db.table(table)
   329  	if !ok {
   330  		return
   331  	}
   332  	if i := slices.Index(t.colname, column); i != -1 {
   333  		return t.coltype[i], true
   334  	}
   335  	return "", false
   336  }
   337  
   338  func (c *fakeConn) isBad() bool {
   339  	if c.stickyBad {
   340  		return true
   341  	} else if c.bad {
   342  		if c.db == nil {
   343  			return false
   344  		}
   345  		// alternate between bad conn and not bad conn
   346  		c.db.badConn = !c.db.badConn
   347  		return c.db.badConn
   348  	} else {
   349  		return false
   350  	}
   351  }
   352  
   353  func (c *fakeConn) isDirtyAndMark() bool {
   354  	if c.skipDirtySession {
   355  		return false
   356  	}
   357  	if c.currTx != nil {
   358  		c.dirtySession = true
   359  		return false
   360  	}
   361  	if c.dirtySession {
   362  		return true
   363  	}
   364  	c.dirtySession = true
   365  	return false
   366  }
   367  
   368  func (c *fakeConn) Begin() (driver.Tx, error) {
   369  	if c.isBad() {
   370  		return nil, fakeError{Wrapped: driver.ErrBadConn}
   371  	}
   372  	if c.currTx != nil {
   373  		return nil, errors.New("fakedb: already in a transaction")
   374  	}
   375  	c.touchMem()
   376  	c.currTx = &fakeTx{c: c}
   377  	return c.currTx, nil
   378  }
   379  
   380  var hookPostCloseConn struct {
   381  	sync.Mutex
   382  	fn func(*fakeConn, error)
   383  }
   384  
   385  func setHookpostCloseConn(fn func(*fakeConn, error)) {
   386  	hookPostCloseConn.Lock()
   387  	defer hookPostCloseConn.Unlock()
   388  	hookPostCloseConn.fn = fn
   389  }
   390  
   391  var testStrictClose *testing.T
   392  
   393  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
   394  // fails to close. If nil, the check is disabled.
   395  func setStrictFakeConnClose(t *testing.T) {
   396  	testStrictClose = t
   397  }
   398  
   399  func (c *fakeConn) ResetSession(ctx context.Context) error {
   400  	c.dirtySession = false
   401  	c.currTx = nil
   402  	if c.isBad() {
   403  		return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
   404  	}
   405  	return nil
   406  }
   407  
   408  var _ driver.Validator = (*fakeConn)(nil)
   409  
   410  func (c *fakeConn) IsValid() bool {
   411  	return !c.isBad()
   412  }
   413  
   414  func (c *fakeConn) Close() (err error) {
   415  	drv := fdriver.(*fakeDriver)
   416  	defer func() {
   417  		if err != nil && testStrictClose != nil {
   418  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
   419  		}
   420  		hookPostCloseConn.Lock()
   421  		fn := hookPostCloseConn.fn
   422  		hookPostCloseConn.Unlock()
   423  		if fn != nil {
   424  			fn(c, err)
   425  		}
   426  		if err == nil {
   427  			drv.mu.Lock()
   428  			drv.closeCount++
   429  			drv.mu.Unlock()
   430  		}
   431  	}()
   432  	c.touchMem()
   433  	if c.currTx != nil {
   434  		return errors.New("fakedb: can't close fakeConn; in a Transaction")
   435  	}
   436  	if c.db == nil {
   437  		return errors.New("fakedb: can't close fakeConn; already closed")
   438  	}
   439  	if c.stmtsMade > c.stmtsClosed {
   440  		return errors.New("fakedb: can't close; dangling statement(s)")
   441  	}
   442  	c.db = nil
   443  	return nil
   444  }
   445  
   446  func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
   447  	for _, arg := range args {
   448  		switch arg.Value.(type) {
   449  		case int64, float64, bool, nil, []byte, string, time.Time:
   450  		default:
   451  			if !allowAny {
   452  				return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
   453  			}
   454  		}
   455  	}
   456  	return nil
   457  }
   458  
   459  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   460  	// Ensure that ExecContext is called if available.
   461  	panic("ExecContext was not called.")
   462  }
   463  
   464  func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   465  	// This is an optional interface, but it's implemented here
   466  	// just to check that all the args are of the proper types.
   467  	// ErrSkip is returned so the caller acts as if we didn't
   468  	// implement this at all.
   469  	err := checkSubsetTypes(c.db.allowAny, args)
   470  	if err != nil {
   471  		return nil, err
   472  	}
   473  	return nil, driver.ErrSkip
   474  }
   475  
   476  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   477  	// Ensure that ExecContext is called if available.
   478  	panic("QueryContext was not called.")
   479  }
   480  
   481  func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   482  	// This is an optional interface, but it's implemented here
   483  	// just to check that all the args are of the proper types.
   484  	// ErrSkip is returned so the caller acts as if we didn't
   485  	// implement this at all.
   486  	err := checkSubsetTypes(c.db.allowAny, args)
   487  	if err != nil {
   488  		return nil, err
   489  	}
   490  	return nil, driver.ErrSkip
   491  }
   492  
   493  func errf(msg string, args ...any) error {
   494  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
   495  }
   496  
   497  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
   498  // (note that where columns must always contain ? marks,
   499  // just a limitation for fakedb)
   500  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   501  	if len(parts) != 3 {
   502  		stmt.Close()
   503  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
   504  	}
   505  	stmt.table = parts[0]
   506  
   507  	stmt.colName = strings.Split(parts[1], ",")
   508  	for n, colspec := range strings.Split(parts[2], ",") {
   509  		if colspec == "" {
   510  			continue
   511  		}
   512  		nameVal := strings.Split(colspec, "=")
   513  		if len(nameVal) != 2 {
   514  			stmt.Close()
   515  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   516  		}
   517  		column, value := nameVal[0], nameVal[1]
   518  		_, ok := c.db.columnType(stmt.table, column)
   519  		if !ok {
   520  			stmt.Close()
   521  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
   522  		}
   523  		if !strings.HasPrefix(value, "?") {
   524  			stmt.Close()
   525  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
   526  				stmt.table, column)
   527  		}
   528  		stmt.placeholders++
   529  		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
   530  	}
   531  	return stmt, nil
   532  }
   533  
   534  // parts are table|col=type,col2=type2
   535  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   536  	if len(parts) != 2 {
   537  		stmt.Close()
   538  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
   539  	}
   540  	stmt.table = parts[0]
   541  	for n, colspec := range strings.Split(parts[1], ",") {
   542  		nameType := strings.Split(colspec, "=")
   543  		if len(nameType) != 2 {
   544  			stmt.Close()
   545  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   546  		}
   547  		stmt.colName = append(stmt.colName, nameType[0])
   548  		stmt.colType = append(stmt.colType, nameType[1])
   549  	}
   550  	return stmt, nil
   551  }
   552  
   553  // parts are table|col=?,col2=val
   554  func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   555  	if len(parts) != 2 {
   556  		stmt.Close()
   557  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
   558  	}
   559  	stmt.table = parts[0]
   560  	for n, colspec := range strings.Split(parts[1], ",") {
   561  		nameVal := strings.Split(colspec, "=")
   562  		if len(nameVal) != 2 {
   563  			stmt.Close()
   564  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   565  		}
   566  		column, value := nameVal[0], nameVal[1]
   567  		ctype, ok := c.db.columnType(stmt.table, column)
   568  		if !ok {
   569  			stmt.Close()
   570  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
   571  		}
   572  		stmt.colName = append(stmt.colName, column)
   573  
   574  		if !strings.HasPrefix(value, "?") {
   575  			var subsetVal any
   576  			// Convert to driver subset type
   577  			switch ctype {
   578  			case "string":
   579  				subsetVal = []byte(value)
   580  			case "blob":
   581  				subsetVal = []byte(value)
   582  			case "int32":
   583  				i, err := strconv.Atoi(value)
   584  				if err != nil {
   585  					stmt.Close()
   586  					return nil, errf("invalid conversion to int32 from %q", value)
   587  				}
   588  				subsetVal = int64(i) // int64 is a subset type, but not int32
   589  			case "table": // For testing cursor reads.
   590  				c.skipDirtySession = true
   591  				vparts := strings.Split(value, "!")
   592  
   593  				substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
   594  				if err != nil {
   595  					return nil, err
   596  				}
   597  				cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
   598  				substmt.Close()
   599  				if err != nil {
   600  					return nil, err
   601  				}
   602  				subsetVal = cursor
   603  			default:
   604  				stmt.Close()
   605  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
   606  			}
   607  			stmt.colValue = append(stmt.colValue, subsetVal)
   608  		} else {
   609  			stmt.placeholders++
   610  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
   611  			stmt.colValue = append(stmt.colValue, value)
   612  		}
   613  	}
   614  	return stmt, nil
   615  }
   616  
   617  // hook to simulate broken connections
   618  var hookPrepareBadConn func() bool
   619  
   620  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
   621  	panic("use PrepareContext")
   622  }
   623  
   624  func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   625  	c.numPrepare++
   626  	if c.db == nil {
   627  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
   628  	}
   629  
   630  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
   631  		return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn}
   632  	}
   633  
   634  	c.touchMem()
   635  	var firstStmt, prev *fakeStmt
   636  	for _, query := range strings.Split(query, ";") {
   637  		parts := strings.Split(query, "|")
   638  		if len(parts) < 1 {
   639  			return nil, errf("empty query")
   640  		}
   641  		stmt := &fakeStmt{q: query, c: c, memToucher: c}
   642  		if firstStmt == nil {
   643  			firstStmt = stmt
   644  		}
   645  		if len(parts) >= 3 {
   646  			switch parts[0] {
   647  			case "PANIC":
   648  				stmt.panic = parts[1]
   649  				parts = parts[2:]
   650  			case "WAIT":
   651  				wait, err := time.ParseDuration(parts[1])
   652  				if err != nil {
   653  					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
   654  				}
   655  				parts = parts[2:]
   656  				stmt.wait = wait
   657  			}
   658  		}
   659  		cmd := parts[0]
   660  		stmt.cmd = cmd
   661  		parts = parts[1:]
   662  
   663  		if c.waiter != nil {
   664  			c.waiter(ctx)
   665  			if err := ctx.Err(); err != nil {
   666  				return nil, err
   667  			}
   668  		}
   669  
   670  		if stmt.wait > 0 {
   671  			wait := time.NewTimer(stmt.wait)
   672  			select {
   673  			case <-wait.C:
   674  			case <-ctx.Done():
   675  				wait.Stop()
   676  				return nil, ctx.Err()
   677  			}
   678  		}
   679  
   680  		c.incrStat(&c.stmtsMade)
   681  		var err error
   682  		switch cmd {
   683  		case "WIPE":
   684  			// Nothing
   685  		case "SELECT":
   686  			stmt, err = c.prepareSelect(stmt, parts)
   687  		case "CREATE":
   688  			stmt, err = c.prepareCreate(stmt, parts)
   689  		case "INSERT":
   690  			stmt, err = c.prepareInsert(ctx, stmt, parts)
   691  		case "NOSERT":
   692  			// Do all the prep-work like for an INSERT but don't actually insert the row.
   693  			// Used for some of the concurrent tests.
   694  			stmt, err = c.prepareInsert(ctx, stmt, parts)
   695  		default:
   696  			stmt.Close()
   697  			return nil, errf("unsupported command type %q", cmd)
   698  		}
   699  		if err != nil {
   700  			return nil, err
   701  		}
   702  		if prev != nil {
   703  			prev.next = stmt
   704  		}
   705  		prev = stmt
   706  	}
   707  	return firstStmt, nil
   708  }
   709  
   710  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
   711  	if s.panic == "ColumnConverter" {
   712  		panic(s.panic)
   713  	}
   714  	if len(s.placeholderConverter) == 0 {
   715  		return driver.DefaultParameterConverter
   716  	}
   717  	return s.placeholderConverter[idx]
   718  }
   719  
   720  func (s *fakeStmt) Close() error {
   721  	if s.panic == "Close" {
   722  		panic(s.panic)
   723  	}
   724  	if s.c == nil {
   725  		panic("nil conn in fakeStmt.Close")
   726  	}
   727  	if s.c.db == nil {
   728  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
   729  	}
   730  	s.touchMem()
   731  	if !s.closed {
   732  		s.c.incrStat(&s.c.stmtsClosed)
   733  		s.closed = true
   734  	}
   735  	if s.next != nil {
   736  		s.next.Close()
   737  	}
   738  	return nil
   739  }
   740  
   741  var errClosed = errors.New("fakedb: statement has been closed")
   742  
   743  // hook to simulate broken connections
   744  var hookExecBadConn func() bool
   745  
   746  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
   747  	panic("Using ExecContext")
   748  }
   749  
   750  var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
   751  
   752  func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   753  	if s.panic == "Exec" {
   754  		panic(s.panic)
   755  	}
   756  	if s.closed {
   757  		return nil, errClosed
   758  	}
   759  
   760  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
   761  		return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
   762  	}
   763  	if s.c.isDirtyAndMark() {
   764  		return nil, errFakeConnSessionDirty
   765  	}
   766  
   767  	err := checkSubsetTypes(s.c.db.allowAny, args)
   768  	if err != nil {
   769  		return nil, err
   770  	}
   771  	s.touchMem()
   772  
   773  	if s.wait > 0 {
   774  		time.Sleep(s.wait)
   775  	}
   776  
   777  	select {
   778  	default:
   779  	case <-ctx.Done():
   780  		return nil, ctx.Err()
   781  	}
   782  
   783  	db := s.c.db
   784  	switch s.cmd {
   785  	case "WIPE":
   786  		db.wipe()
   787  		return driver.ResultNoRows, nil
   788  	case "CREATE":
   789  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
   790  			return nil, err
   791  		}
   792  		return driver.ResultNoRows, nil
   793  	case "INSERT":
   794  		return s.execInsert(args, true)
   795  	case "NOSERT":
   796  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   797  		// Used for some of the concurrent tests.
   798  		return s.execInsert(args, false)
   799  	}
   800  	return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
   801  }
   802  
   803  func valueFromPlaceholderName(args []driver.NamedValue, name string) driver.Value {
   804  	for i := range args {
   805  		if args[i].Name == name {
   806  			return args[i].Value
   807  		}
   808  	}
   809  	return nil
   810  }
   811  
   812  // When doInsert is true, add the row to the table.
   813  // When doInsert is false do prep-work and error checking, but don't
   814  // actually add the row to the table.
   815  func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
   816  	db := s.c.db
   817  	if len(args) != s.placeholders {
   818  		panic("error in pkg db; should only get here if size is correct")
   819  	}
   820  	db.mu.Lock()
   821  	t, ok := db.table(s.table)
   822  	db.mu.Unlock()
   823  	if !ok {
   824  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   825  	}
   826  
   827  	t.mu.Lock()
   828  	defer t.mu.Unlock()
   829  
   830  	var cols []any
   831  	if doInsert {
   832  		cols = make([]any, len(t.colname))
   833  	}
   834  	argPos := 0
   835  	for n, colname := range s.colName {
   836  		colidx := t.columnIndex(colname)
   837  		if colidx == -1 {
   838  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
   839  		}
   840  		var val any
   841  		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
   842  			if strvalue == "?" {
   843  				val = args[argPos].Value
   844  			} else {
   845  				// Assign value from argument placeholder name.
   846  				if v := valueFromPlaceholderName(args, strvalue[1:]); v != nil {
   847  					val = v
   848  				}
   849  			}
   850  			argPos++
   851  		} else {
   852  			val = s.colValue[n]
   853  		}
   854  		if doInsert {
   855  			cols[colidx] = val
   856  		}
   857  	}
   858  
   859  	if doInsert {
   860  		t.rows = append(t.rows, &row{cols: cols})
   861  	}
   862  	return driver.RowsAffected(1), nil
   863  }
   864  
   865  // hook to simulate broken connections
   866  var hookQueryBadConn func() bool
   867  
   868  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
   869  	panic("Use QueryContext")
   870  }
   871  
   872  func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   873  	if s.panic == "Query" {
   874  		panic(s.panic)
   875  	}
   876  	if s.closed {
   877  		return nil, errClosed
   878  	}
   879  
   880  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
   881  		return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
   882  	}
   883  	if s.c.isDirtyAndMark() {
   884  		return nil, errFakeConnSessionDirty
   885  	}
   886  
   887  	err := checkSubsetTypes(s.c.db.allowAny, args)
   888  	if err != nil {
   889  		return nil, err
   890  	}
   891  
   892  	s.touchMem()
   893  	db := s.c.db
   894  	if len(args) != s.placeholders {
   895  		panic("error in pkg db; should only get here if size is correct")
   896  	}
   897  
   898  	setMRows := make([][]*row, 0, 1)
   899  	setColumns := make([][]string, 0, 1)
   900  	setColType := make([][]string, 0, 1)
   901  
   902  	for {
   903  		db.mu.Lock()
   904  		t, ok := db.table(s.table)
   905  		db.mu.Unlock()
   906  		if !ok {
   907  			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   908  		}
   909  
   910  		if s.table == "magicquery" {
   911  			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
   912  				if args[0].Value == "sleep" {
   913  					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
   914  				}
   915  			}
   916  		}
   917  		if s.table == "tx_status" && s.colName[0] == "tx_status" {
   918  			txStatus := "autocommit"
   919  			if s.c.currTx != nil {
   920  				txStatus = "transaction"
   921  			}
   922  			cursor := &rowsCursor{
   923  				db:        s.c.db,
   924  				parentMem: s.c,
   925  				posRow:    -1,
   926  				rows: [][]*row{
   927  					{
   928  						{
   929  							cols: []any{
   930  								txStatus,
   931  							},
   932  						},
   933  					},
   934  				},
   935  				cols: [][]string{
   936  					{
   937  						"tx_status",
   938  					},
   939  				},
   940  				colType: [][]string{
   941  					{
   942  						"string",
   943  					},
   944  				},
   945  				errPos: -1,
   946  			}
   947  			return cursor, nil
   948  		}
   949  
   950  		t.mu.Lock()
   951  
   952  		colIdx := make(map[string]int) // select column name -> column index in table
   953  		for _, name := range s.colName {
   954  			idx := t.columnIndex(name)
   955  			if idx == -1 {
   956  				t.mu.Unlock()
   957  				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
   958  			}
   959  			colIdx[name] = idx
   960  		}
   961  
   962  		mrows := []*row{}
   963  	rows:
   964  		for _, trow := range t.rows {
   965  			// Process the where clause, skipping non-match rows. This is lazy
   966  			// and just uses fmt.Sprintf("%v") to test equality. Good enough
   967  			// for test code.
   968  			for _, wcol := range s.whereCol {
   969  				idx := t.columnIndex(wcol.Column)
   970  				if idx == -1 {
   971  					t.mu.Unlock()
   972  					return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
   973  				}
   974  				tcol := trow.cols[idx]
   975  				if bs, ok := tcol.([]byte); ok {
   976  					// lazy hack to avoid sprintf %v on a []byte
   977  					tcol = string(bs)
   978  				}
   979  				var argValue any
   980  				if wcol.Placeholder == "?" {
   981  					argValue = args[wcol.Ordinal-1].Value
   982  				} else {
   983  					if v := valueFromPlaceholderName(args, wcol.Placeholder[1:]); v != nil {
   984  						argValue = v
   985  					}
   986  				}
   987  				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
   988  					continue rows
   989  				}
   990  			}
   991  			mrow := &row{cols: make([]any, len(s.colName))}
   992  			for seli, name := range s.colName {
   993  				mrow.cols[seli] = trow.cols[colIdx[name]]
   994  			}
   995  			mrows = append(mrows, mrow)
   996  		}
   997  
   998  		var colType []string
   999  		for _, column := range s.colName {
  1000  			colType = append(colType, t.coltype[t.columnIndex(column)])
  1001  		}
  1002  
  1003  		t.mu.Unlock()
  1004  
  1005  		setMRows = append(setMRows, mrows)
  1006  		setColumns = append(setColumns, s.colName)
  1007  		setColType = append(setColType, colType)
  1008  
  1009  		if s.next == nil {
  1010  			break
  1011  		}
  1012  		s = s.next
  1013  	}
  1014  
  1015  	cursor := &rowsCursor{
  1016  		db:        s.c.db,
  1017  		parentMem: s.c,
  1018  		posRow:    -1,
  1019  		rows:      setMRows,
  1020  		cols:      setColumns,
  1021  		colType:   setColType,
  1022  		errPos:    -1,
  1023  	}
  1024  	return cursor, nil
  1025  }
  1026  
  1027  func (s *fakeStmt) NumInput() int {
  1028  	if s.panic == "NumInput" {
  1029  		panic(s.panic)
  1030  	}
  1031  	return s.placeholders
  1032  }
  1033  
  1034  // hook to simulate broken connections
  1035  var hookCommitBadConn func() bool
  1036  
  1037  func (tx *fakeTx) Commit() error {
  1038  	tx.c.currTx = nil
  1039  	if hookCommitBadConn != nil && hookCommitBadConn() {
  1040  		return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
  1041  	}
  1042  	tx.c.touchMem()
  1043  	return nil
  1044  }
  1045  
  1046  // hook to simulate broken connections
  1047  var hookRollbackBadConn func() bool
  1048  
  1049  func (tx *fakeTx) Rollback() error {
  1050  	tx.c.currTx = nil
  1051  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
  1052  		return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
  1053  	}
  1054  	tx.c.touchMem()
  1055  	return nil
  1056  }
  1057  
  1058  type rowsCursor struct {
  1059  	db        *fakeDB
  1060  	parentMem memToucher
  1061  	cols      [][]string
  1062  	colType   [][]string
  1063  	posSet    int
  1064  	posRow    int
  1065  	rows      [][]*row
  1066  	closed    bool
  1067  
  1068  	// errPos and err are for making Next return early with error.
  1069  	errPos int
  1070  	err    error
  1071  
  1072  	// Data returned to clients.
  1073  	// We clone and stash it here so it can be invalidated by Close and Next.
  1074  	driverOwnedMemory [][]byte
  1075  
  1076  	// Every operation writes to line to enable the race detector
  1077  	// check for data races.
  1078  	// This is separate from the fakeConn.line to allow for drivers that
  1079  	// can start multiple queries on the same transaction at the same time.
  1080  	line int64
  1081  
  1082  	// closeErr is returned when rowsCursor.Close
  1083  	closeErr error
  1084  }
  1085  
  1086  func (rc *rowsCursor) touchMem() {
  1087  	rc.parentMem.touchMem()
  1088  	rc.line++
  1089  }
  1090  
  1091  func (rc *rowsCursor) invalidateDriverOwnedMemory() {
  1092  	for _, buf := range rc.driverOwnedMemory {
  1093  		for i := range buf {
  1094  			buf[i] = 'x'
  1095  		}
  1096  	}
  1097  	rc.driverOwnedMemory = nil
  1098  }
  1099  
  1100  func (rc *rowsCursor) Close() error {
  1101  	rc.touchMem()
  1102  	rc.parentMem.touchMem()
  1103  	rc.invalidateDriverOwnedMemory()
  1104  	rc.closed = true
  1105  	return rc.closeErr
  1106  }
  1107  
  1108  func (rc *rowsCursor) Columns() []string {
  1109  	return rc.cols[rc.posSet]
  1110  }
  1111  
  1112  func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
  1113  	return colTypeToReflectType(rc.colType[rc.posSet][index])
  1114  }
  1115  
  1116  var rowsCursorNextHook func(dest []driver.Value) error
  1117  
  1118  func (rc *rowsCursor) Next(dest []driver.Value) error {
  1119  	if rowsCursorNextHook != nil {
  1120  		return rowsCursorNextHook(dest)
  1121  	}
  1122  
  1123  	if rc.closed {
  1124  		return errors.New("fakedb: cursor is closed")
  1125  	}
  1126  	rc.touchMem()
  1127  	rc.posRow++
  1128  	if rc.posRow == rc.errPos {
  1129  		return rc.err
  1130  	}
  1131  	if rc.posRow >= len(rc.rows[rc.posSet]) {
  1132  		return io.EOF // per interface spec
  1133  	}
  1134  	// Corrupt any previously returned bytes.
  1135  	rc.invalidateDriverOwnedMemory()
  1136  	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
  1137  		// TODO(bradfitz): convert to subset types? naah, I
  1138  		// think the subset types should only be input to
  1139  		// driver, but the sql package should be able to handle
  1140  		// a wider range of types coming out of drivers. all
  1141  		// for ease of drivers, and to prevent drivers from
  1142  		// messing up conversions or doing them differently.
  1143  		if bs, ok := v.([]byte); ok {
  1144  			// Clone []bytes and stash for later invalidation.
  1145  			bs = bytes.Clone(bs)
  1146  			rc.driverOwnedMemory = append(rc.driverOwnedMemory, bs)
  1147  			v = bs
  1148  		}
  1149  		dest[i] = v
  1150  	}
  1151  	return nil
  1152  }
  1153  
  1154  func (rc *rowsCursor) HasNextResultSet() bool {
  1155  	rc.touchMem()
  1156  	return rc.posSet < len(rc.rows)-1
  1157  }
  1158  
  1159  func (rc *rowsCursor) NextResultSet() error {
  1160  	rc.touchMem()
  1161  	if rc.HasNextResultSet() {
  1162  		rc.posSet++
  1163  		rc.posRow = -1
  1164  		return nil
  1165  	}
  1166  	return io.EOF // Per interface spec.
  1167  }
  1168  
  1169  // fakeDriverString is like driver.String, but indirects pointers like
  1170  // DefaultValueConverter.
  1171  //
  1172  // This could be surprising behavior to retroactively apply to
  1173  // driver.String now that Go1 is out, but this is convenient for
  1174  // our TestPointerParamsAndScans.
  1175  type fakeDriverString struct{}
  1176  
  1177  func (fakeDriverString) ConvertValue(v any) (driver.Value, error) {
  1178  	switch c := v.(type) {
  1179  	case string, []byte:
  1180  		return v, nil
  1181  	case *string:
  1182  		if c == nil {
  1183  			return nil, nil
  1184  		}
  1185  		return *c, nil
  1186  	}
  1187  	return fmt.Sprintf("%v", v), nil
  1188  }
  1189  
  1190  type anyTypeConverter struct{}
  1191  
  1192  func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) {
  1193  	return v, nil
  1194  }
  1195  
  1196  func converterForType(typ string) driver.ValueConverter {
  1197  	switch typ {
  1198  	case "bool":
  1199  		return driver.Bool
  1200  	case "nullbool":
  1201  		return driver.Null{Converter: driver.Bool}
  1202  	case "byte", "int16":
  1203  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1204  	case "int32":
  1205  		return driver.Int32
  1206  	case "nullbyte", "nullint32", "nullint16":
  1207  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1208  	case "string":
  1209  		return driver.NotNull{Converter: fakeDriverString{}}
  1210  	case "nullstring":
  1211  		return driver.Null{Converter: fakeDriverString{}}
  1212  	case "int64":
  1213  		// TODO(coopernurse): add type-specific converter
  1214  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1215  	case "nullint64":
  1216  		// TODO(coopernurse): add type-specific converter
  1217  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1218  	case "float64":
  1219  		// TODO(coopernurse): add type-specific converter
  1220  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1221  	case "nullfloat64":
  1222  		// TODO(coopernurse): add type-specific converter
  1223  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1224  	case "datetime":
  1225  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1226  	case "nulldatetime":
  1227  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1228  	case "any":
  1229  		return anyTypeConverter{}
  1230  	}
  1231  	panic("invalid fakedb column type of " + typ)
  1232  }
  1233  
  1234  func colTypeToReflectType(typ string) reflect.Type {
  1235  	switch typ {
  1236  	case "bool":
  1237  		return reflect.TypeFor[bool]()
  1238  	case "nullbool":
  1239  		return reflect.TypeFor[NullBool]()
  1240  	case "int16":
  1241  		return reflect.TypeFor[int16]()
  1242  	case "nullint16":
  1243  		return reflect.TypeFor[NullInt16]()
  1244  	case "int32":
  1245  		return reflect.TypeFor[int32]()
  1246  	case "nullint32":
  1247  		return reflect.TypeFor[NullInt32]()
  1248  	case "string":
  1249  		return reflect.TypeFor[string]()
  1250  	case "nullstring":
  1251  		return reflect.TypeFor[NullString]()
  1252  	case "int64":
  1253  		return reflect.TypeFor[int64]()
  1254  	case "nullint64":
  1255  		return reflect.TypeFor[NullInt64]()
  1256  	case "float64":
  1257  		return reflect.TypeFor[float64]()
  1258  	case "nullfloat64":
  1259  		return reflect.TypeFor[NullFloat64]()
  1260  	case "datetime":
  1261  		return reflect.TypeFor[time.Time]()
  1262  	case "any":
  1263  		return reflect.TypeFor[any]()
  1264  	}
  1265  	panic("invalid fakedb column type of " + typ)
  1266  }
  1267  

View as plain text