1
2
3
4
5 package loopclosure
6
7 import (
8 _ "embed"
9 "go/ast"
10 "go/types"
11
12 "golang.org/x/tools/go/analysis"
13 "golang.org/x/tools/go/analysis/passes/inspect"
14 "golang.org/x/tools/go/analysis/passes/internal/analysisutil"
15 "golang.org/x/tools/go/ast/inspector"
16 "golang.org/x/tools/go/types/typeutil"
17 "golang.org/x/tools/internal/analysisinternal"
18 "golang.org/x/tools/internal/typesinternal"
19 "golang.org/x/tools/internal/versions"
20 )
21
22
23 var doc string
24
25 var Analyzer = &analysis.Analyzer{
26 Name: "loopclosure",
27 Doc: analysisutil.MustExtractDoc(doc, "loopclosure"),
28 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/loopclosure",
29 Requires: []*analysis.Analyzer{inspect.Analyzer},
30 Run: run,
31 }
32
33 func run(pass *analysis.Pass) (any, error) {
34 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
35
36 nodeFilter := []ast.Node{
37 (*ast.File)(nil),
38 (*ast.RangeStmt)(nil),
39 (*ast.ForStmt)(nil),
40 }
41 inspect.Nodes(nodeFilter, func(n ast.Node, push bool) bool {
42 if !push {
43
44 return true
45 }
46
47 var vars []types.Object
48 addVar := func(expr ast.Expr) {
49 if id, _ := expr.(*ast.Ident); id != nil {
50 if obj := pass.TypesInfo.ObjectOf(id); obj != nil {
51 vars = append(vars, obj)
52 }
53 }
54 }
55 var body *ast.BlockStmt
56 switch n := n.(type) {
57 case *ast.File:
58
59 goversion := versions.FileVersion(pass.TypesInfo, n)
60 return versions.Before(goversion, versions.Go1_22)
61 case *ast.RangeStmt:
62 body = n.Body
63 addVar(n.Key)
64 addVar(n.Value)
65 case *ast.ForStmt:
66 body = n.Body
67 switch post := n.Post.(type) {
68 case *ast.AssignStmt:
69
70 for _, lhs := range post.Lhs {
71 addVar(lhs)
72 }
73 case *ast.IncDecStmt:
74
75 addVar(post.X)
76 }
77 }
78 if vars == nil {
79 return true
80 }
81
82
83
84
85
86
87
88
89
90
91
92
93
94 forEachLastStmt(body.List, func(last ast.Stmt) {
95 var stmts []ast.Stmt
96 switch s := last.(type) {
97 case *ast.GoStmt:
98 stmts = litStmts(s.Call.Fun)
99 case *ast.DeferStmt:
100 stmts = litStmts(s.Call.Fun)
101 case *ast.ExprStmt:
102 if call, ok := s.X.(*ast.CallExpr); ok {
103 stmts = litStmts(goInvoke(pass.TypesInfo, call))
104 }
105 }
106 for _, stmt := range stmts {
107 reportCaptured(pass, vars, stmt)
108 }
109 })
110
111
112
113
114
115
116
117 for _, s := range body.List {
118 switch s := s.(type) {
119 case *ast.ExprStmt:
120 if call, ok := s.X.(*ast.CallExpr); ok {
121 for _, stmt := range parallelSubtest(pass.TypesInfo, call) {
122 reportCaptured(pass, vars, stmt)
123 }
124
125 }
126 }
127 }
128 return true
129 })
130 return nil, nil
131 }
132
133
134
135
136
137 func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) {
138 ast.Inspect(checkStmt, func(n ast.Node) bool {
139 id, ok := n.(*ast.Ident)
140 if !ok {
141 return true
142 }
143 obj := pass.TypesInfo.Uses[id]
144 if obj == nil {
145 return true
146 }
147 for _, v := range vars {
148 if v == obj {
149 pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name)
150 }
151 }
152 return true
153 })
154 }
155
156
157
158
159
160 func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) {
161 if len(stmts) == 0 {
162 return
163 }
164
165 s := stmts[len(stmts)-1]
166 switch s := s.(type) {
167 case *ast.IfStmt:
168 loop:
169 for {
170 forEachLastStmt(s.Body.List, onLast)
171 switch e := s.Else.(type) {
172 case *ast.BlockStmt:
173 forEachLastStmt(e.List, onLast)
174 break loop
175 case *ast.IfStmt:
176 s = e
177 case nil:
178 break loop
179 }
180 }
181 case *ast.ForStmt:
182 forEachLastStmt(s.Body.List, onLast)
183 case *ast.RangeStmt:
184 forEachLastStmt(s.Body.List, onLast)
185 case *ast.SwitchStmt:
186 for _, c := range s.Body.List {
187 cc := c.(*ast.CaseClause)
188 forEachLastStmt(cc.Body, onLast)
189 }
190 case *ast.TypeSwitchStmt:
191 for _, c := range s.Body.List {
192 cc := c.(*ast.CaseClause)
193 forEachLastStmt(cc.Body, onLast)
194 }
195 case *ast.SelectStmt:
196 for _, c := range s.Body.List {
197 cc := c.(*ast.CommClause)
198 forEachLastStmt(cc.Body, onLast)
199 }
200 default:
201 onLast(s)
202 }
203 }
204
205
206
207
208
209 func litStmts(fun ast.Expr) []ast.Stmt {
210 lit, _ := fun.(*ast.FuncLit)
211 if lit == nil {
212 return nil
213 }
214 return lit.Body.List
215 }
216
217
218
219
220
221
222
223
224
225
226 func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr {
227 if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") {
228 return nil
229 }
230 return call.Args[0]
231 }
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259 func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt {
260 if !isMethodCall(info, call, "testing", "T", "Run") {
261 return nil
262 }
263
264 if len(call.Args) != 2 {
265
266 return nil
267 }
268
269 lit, _ := call.Args[1].(*ast.FuncLit)
270 if lit == nil {
271 return nil
272 }
273
274
275
276 if len(lit.Type.Params.List[0].Names) == 0 {
277 return nil
278 }
279
280 tObj := info.Defs[lit.Type.Params.List[0].Names[0]]
281 if tObj == nil {
282 return nil
283 }
284
285
286
287
288
289
290
291 var stmts []ast.Stmt
292 afterParallel := false
293 for _, stmt := range lit.Body.List {
294 stmt, labeled := unlabel(stmt)
295 if labeled {
296
297
298 stmts = nil
299 afterParallel = false
300 }
301
302 if afterParallel {
303 stmts = append(stmts, stmt)
304 continue
305 }
306
307
308 exprStmt, ok := stmt.(*ast.ExprStmt)
309 if !ok {
310 continue
311 }
312 expr := exprStmt.X
313 if isMethodCall(info, expr, "testing", "T", "Parallel") {
314 call, _ := expr.(*ast.CallExpr)
315 if call == nil {
316 continue
317 }
318 x, _ := call.Fun.(*ast.SelectorExpr)
319 if x == nil {
320 continue
321 }
322 id, _ := x.X.(*ast.Ident)
323 if id == nil {
324 continue
325 }
326 if info.Uses[id] == tObj {
327 afterParallel = true
328 }
329 }
330 }
331
332 return stmts
333 }
334
335
336
337
338
339 func unlabel(stmt ast.Stmt) (ast.Stmt, bool) {
340 labeled := false
341 for {
342 labelStmt, ok := stmt.(*ast.LabeledStmt)
343 if !ok {
344 return stmt, labeled
345 }
346 labeled = true
347 stmt = labelStmt.Stmt
348 }
349 }
350
351
352
353 func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool {
354 call, ok := expr.(*ast.CallExpr)
355 if !ok {
356 return false
357 }
358
359
360 f := typeutil.StaticCallee(info, call)
361 if f == nil || f.Name() != method {
362 return false
363 }
364 recv := f.Type().(*types.Signature).Recv()
365 if recv == nil {
366 return false
367 }
368
369
370
371 _, named := typesinternal.ReceiverNamed(recv)
372 return analysisinternal.IsTypeNamed(named, pkgPath, typeName)
373 }
374
View as plain text