Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/loopclosure/loopclosure.go

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package loopclosure
     6  
     7  import (
     8  	_ "embed"
     9  	"go/ast"
    10  	"go/types"
    11  
    12  	"golang.org/x/tools/go/analysis"
    13  	"golang.org/x/tools/go/analysis/passes/inspect"
    14  	"golang.org/x/tools/go/analysis/passes/internal/analysisutil"
    15  	"golang.org/x/tools/go/ast/inspector"
    16  	"golang.org/x/tools/go/types/typeutil"
    17  	"golang.org/x/tools/internal/analysisinternal"
    18  	"golang.org/x/tools/internal/typesinternal"
    19  	"golang.org/x/tools/internal/versions"
    20  )
    21  
    22  //go:embed doc.go
    23  var doc string
    24  
    25  var Analyzer = &analysis.Analyzer{
    26  	Name:     "loopclosure",
    27  	Doc:      analysisutil.MustExtractDoc(doc, "loopclosure"),
    28  	URL:      "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/loopclosure",
    29  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    30  	Run:      run,
    31  }
    32  
    33  func run(pass *analysis.Pass) (any, error) {
    34  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    35  
    36  	nodeFilter := []ast.Node{
    37  		(*ast.File)(nil),
    38  		(*ast.RangeStmt)(nil),
    39  		(*ast.ForStmt)(nil),
    40  	}
    41  	inspect.Nodes(nodeFilter, func(n ast.Node, push bool) bool {
    42  		if !push {
    43  			// inspect.Nodes is slightly suboptimal as we only use push=true.
    44  			return true
    45  		}
    46  		// Find the variables updated by the loop statement.
    47  		var vars []types.Object
    48  		addVar := func(expr ast.Expr) {
    49  			if id, _ := expr.(*ast.Ident); id != nil {
    50  				if obj := pass.TypesInfo.ObjectOf(id); obj != nil {
    51  					vars = append(vars, obj)
    52  				}
    53  			}
    54  		}
    55  		var body *ast.BlockStmt
    56  		switch n := n.(type) {
    57  		case *ast.File:
    58  			// Only traverse the file if its goversion is strictly before go1.22.
    59  			goversion := versions.FileVersion(pass.TypesInfo, n)
    60  			return versions.Before(goversion, versions.Go1_22)
    61  		case *ast.RangeStmt:
    62  			body = n.Body
    63  			addVar(n.Key)
    64  			addVar(n.Value)
    65  		case *ast.ForStmt:
    66  			body = n.Body
    67  			switch post := n.Post.(type) {
    68  			case *ast.AssignStmt:
    69  				// e.g. for p = head; p != nil; p = p.next
    70  				for _, lhs := range post.Lhs {
    71  					addVar(lhs)
    72  				}
    73  			case *ast.IncDecStmt:
    74  				// e.g. for i := 0; i < n; i++
    75  				addVar(post.X)
    76  			}
    77  		}
    78  		if vars == nil {
    79  			return true
    80  		}
    81  
    82  		// Inspect statements to find function literals that may be run outside of
    83  		// the current loop iteration.
    84  		//
    85  		// For go, defer, and errgroup.Group.Go, we ignore all but the last
    86  		// statement, because it's hard to prove go isn't followed by wait, or
    87  		// defer by return. "Last" is defined recursively.
    88  		//
    89  		// TODO: consider allowing the "last" go/defer/Go statement to be followed by
    90  		// N "trivial" statements, possibly under a recursive definition of "trivial"
    91  		// so that checker could, for example, conclude that a go statement is
    92  		// followed by an if statement made of only trivial statements and trivial expressions,
    93  		// and hence the go statement could still be checked.
    94  		forEachLastStmt(body.List, func(last ast.Stmt) {
    95  			var stmts []ast.Stmt
    96  			switch s := last.(type) {
    97  			case *ast.GoStmt:
    98  				stmts = litStmts(s.Call.Fun)
    99  			case *ast.DeferStmt:
   100  				stmts = litStmts(s.Call.Fun)
   101  			case *ast.ExprStmt: // check for errgroup.Group.Go
   102  				if call, ok := s.X.(*ast.CallExpr); ok {
   103  					stmts = litStmts(goInvoke(pass.TypesInfo, call))
   104  				}
   105  			}
   106  			for _, stmt := range stmts {
   107  				reportCaptured(pass, vars, stmt)
   108  			}
   109  		})
   110  
   111  		// Also check for testing.T.Run (with T.Parallel).
   112  		// We consider every t.Run statement in the loop body, because there is
   113  		// no commonly used mechanism for synchronizing parallel subtests.
   114  		// It is of course theoretically possible to synchronize parallel subtests,
   115  		// though such a pattern is likely to be exceedingly rare as it would be
   116  		// fighting against the test runner.
   117  		for _, s := range body.List {
   118  			switch s := s.(type) {
   119  			case *ast.ExprStmt:
   120  				if call, ok := s.X.(*ast.CallExpr); ok {
   121  					for _, stmt := range parallelSubtest(pass.TypesInfo, call) {
   122  						reportCaptured(pass, vars, stmt)
   123  					}
   124  
   125  				}
   126  			}
   127  		}
   128  		return true
   129  	})
   130  	return nil, nil
   131  }
   132  
   133  // reportCaptured reports a diagnostic stating a loop variable
   134  // has been captured by a func literal if checkStmt has escaping
   135  // references to vars. vars is expected to be variables updated by a loop statement,
   136  // and checkStmt is expected to be a statements from the body of a func literal in the loop.
   137  func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) {
   138  	ast.Inspect(checkStmt, func(n ast.Node) bool {
   139  		id, ok := n.(*ast.Ident)
   140  		if !ok {
   141  			return true
   142  		}
   143  		obj := pass.TypesInfo.Uses[id]
   144  		if obj == nil {
   145  			return true
   146  		}
   147  		for _, v := range vars {
   148  			if v == obj {
   149  				pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name)
   150  			}
   151  		}
   152  		return true
   153  	})
   154  }
   155  
   156  // forEachLastStmt calls onLast on each "last" statement in a list of statements.
   157  // "Last" is defined recursively so, for example, if the last statement is
   158  // a switch statement, then each switch case is also visited to examine
   159  // its last statements.
   160  func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) {
   161  	if len(stmts) == 0 {
   162  		return
   163  	}
   164  
   165  	s := stmts[len(stmts)-1]
   166  	switch s := s.(type) {
   167  	case *ast.IfStmt:
   168  	loop:
   169  		for {
   170  			forEachLastStmt(s.Body.List, onLast)
   171  			switch e := s.Else.(type) {
   172  			case *ast.BlockStmt:
   173  				forEachLastStmt(e.List, onLast)
   174  				break loop
   175  			case *ast.IfStmt:
   176  				s = e
   177  			case nil:
   178  				break loop
   179  			}
   180  		}
   181  	case *ast.ForStmt:
   182  		forEachLastStmt(s.Body.List, onLast)
   183  	case *ast.RangeStmt:
   184  		forEachLastStmt(s.Body.List, onLast)
   185  	case *ast.SwitchStmt:
   186  		for _, c := range s.Body.List {
   187  			cc := c.(*ast.CaseClause)
   188  			forEachLastStmt(cc.Body, onLast)
   189  		}
   190  	case *ast.TypeSwitchStmt:
   191  		for _, c := range s.Body.List {
   192  			cc := c.(*ast.CaseClause)
   193  			forEachLastStmt(cc.Body, onLast)
   194  		}
   195  	case *ast.SelectStmt:
   196  		for _, c := range s.Body.List {
   197  			cc := c.(*ast.CommClause)
   198  			forEachLastStmt(cc.Body, onLast)
   199  		}
   200  	default:
   201  		onLast(s)
   202  	}
   203  }
   204  
   205  // litStmts returns all statements from the function body of a function
   206  // literal.
   207  //
   208  // If fun is not a function literal, it returns nil.
   209  func litStmts(fun ast.Expr) []ast.Stmt {
   210  	lit, _ := fun.(*ast.FuncLit)
   211  	if lit == nil {
   212  		return nil
   213  	}
   214  	return lit.Body.List
   215  }
   216  
   217  // goInvoke returns a function expression that would be called asynchronously
   218  // (but not awaited) in another goroutine as a consequence of the call.
   219  // For example, given the g.Go call below, it returns the function literal expression.
   220  //
   221  //	import "sync/errgroup"
   222  //	var g errgroup.Group
   223  //	g.Go(func() error { ... })
   224  //
   225  // Currently only "golang.org/x/sync/errgroup.Group()" is considered.
   226  func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr {
   227  	if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") {
   228  		return nil
   229  	}
   230  	return call.Args[0]
   231  }
   232  
   233  // parallelSubtest returns statements that can be easily proven to execute
   234  // concurrently via the go test runner, as t.Run has been invoked with a
   235  // function literal that calls t.Parallel.
   236  //
   237  // In practice, users rely on the fact that statements before the call to
   238  // t.Parallel are synchronous. For example by declaring test := test inside the
   239  // function literal, but before the call to t.Parallel.
   240  //
   241  // Therefore, we only flag references in statements that are obviously
   242  // dominated by a call to t.Parallel. As a simple heuristic, we only consider
   243  // statements following the final labeled statement in the function body, to
   244  // avoid scenarios where a jump would cause either the call to t.Parallel or
   245  // the problematic reference to be skipped.
   246  //
   247  //	import "testing"
   248  //
   249  //	func TestFoo(t *testing.T) {
   250  //		tests := []int{0, 1, 2}
   251  //		for i, test := range tests {
   252  //			t.Run("subtest", func(t *testing.T) {
   253  //				println(i, test) // OK
   254  //		 		t.Parallel()
   255  //				println(i, test) // Not OK
   256  //			})
   257  //		}
   258  //	}
   259  func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt {
   260  	if !isMethodCall(info, call, "testing", "T", "Run") {
   261  		return nil
   262  	}
   263  
   264  	if len(call.Args) != 2 {
   265  		// Ignore calls such as t.Run(fn()).
   266  		return nil
   267  	}
   268  
   269  	lit, _ := call.Args[1].(*ast.FuncLit)
   270  	if lit == nil {
   271  		return nil
   272  	}
   273  
   274  	// Capture the *testing.T object for the first argument to the function
   275  	// literal.
   276  	if len(lit.Type.Params.List[0].Names) == 0 {
   277  		return nil
   278  	}
   279  
   280  	tObj := info.Defs[lit.Type.Params.List[0].Names[0]]
   281  	if tObj == nil {
   282  		return nil
   283  	}
   284  
   285  	// Match statements that occur after a call to t.Parallel following the final
   286  	// labeled statement in the function body.
   287  	//
   288  	// We iterate over lit.Body.List to have a simple, fast and "frequent enough"
   289  	// dominance relationship for t.Parallel(): lit.Body.List[i] dominates
   290  	// lit.Body.List[j] for i < j unless there is a jump.
   291  	var stmts []ast.Stmt
   292  	afterParallel := false
   293  	for _, stmt := range lit.Body.List {
   294  		stmt, labeled := unlabel(stmt)
   295  		if labeled {
   296  			// Reset: naively we don't know if a jump could have caused the
   297  			// previously considered statements to be skipped.
   298  			stmts = nil
   299  			afterParallel = false
   300  		}
   301  
   302  		if afterParallel {
   303  			stmts = append(stmts, stmt)
   304  			continue
   305  		}
   306  
   307  		// Check if stmt is a call to t.Parallel(), for the correct t.
   308  		exprStmt, ok := stmt.(*ast.ExprStmt)
   309  		if !ok {
   310  			continue
   311  		}
   312  		expr := exprStmt.X
   313  		if isMethodCall(info, expr, "testing", "T", "Parallel") {
   314  			call, _ := expr.(*ast.CallExpr)
   315  			if call == nil {
   316  				continue
   317  			}
   318  			x, _ := call.Fun.(*ast.SelectorExpr)
   319  			if x == nil {
   320  				continue
   321  			}
   322  			id, _ := x.X.(*ast.Ident)
   323  			if id == nil {
   324  				continue
   325  			}
   326  			if info.Uses[id] == tObj {
   327  				afterParallel = true
   328  			}
   329  		}
   330  	}
   331  
   332  	return stmts
   333  }
   334  
   335  // unlabel returns the inner statement for the possibly labeled statement stmt,
   336  // stripping any (possibly nested) *ast.LabeledStmt wrapper.
   337  //
   338  // The second result reports whether stmt was an *ast.LabeledStmt.
   339  func unlabel(stmt ast.Stmt) (ast.Stmt, bool) {
   340  	labeled := false
   341  	for {
   342  		labelStmt, ok := stmt.(*ast.LabeledStmt)
   343  		if !ok {
   344  			return stmt, labeled
   345  		}
   346  		labeled = true
   347  		stmt = labelStmt.Stmt
   348  	}
   349  }
   350  
   351  // isMethodCall reports whether expr is a method call of
   352  // <pkgPath>.<typeName>.<method>.
   353  func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool {
   354  	call, ok := expr.(*ast.CallExpr)
   355  	if !ok {
   356  		return false
   357  	}
   358  
   359  	// Check that we are calling a method <method>
   360  	f := typeutil.StaticCallee(info, call)
   361  	if f == nil || f.Name() != method {
   362  		return false
   363  	}
   364  	recv := f.Type().(*types.Signature).Recv()
   365  	if recv == nil {
   366  		return false
   367  	}
   368  
   369  	// Check that the receiver is a <pkgPath>.<typeName> or
   370  	// *<pkgPath>.<typeName>.
   371  	_, named := typesinternal.ReceiverNamed(recv)
   372  	return analysisinternal.IsTypeNamed(named, pkgPath, typeName)
   373  }
   374  

View as plain text