1
2
3
4
5
6
7 package waitgroup
8
9 import (
10 _ "embed"
11 "go/ast"
12 "reflect"
13
14 "golang.org/x/tools/go/analysis"
15 "golang.org/x/tools/go/analysis/passes/inspect"
16 "golang.org/x/tools/go/analysis/passes/internal/analysisutil"
17 "golang.org/x/tools/go/ast/inspector"
18 "golang.org/x/tools/go/types/typeutil"
19 "golang.org/x/tools/internal/analysisinternal"
20 )
21
22
23 var doc string
24
25 var Analyzer = &analysis.Analyzer{
26 Name: "waitgroup",
27 Doc: analysisutil.MustExtractDoc(doc, "waitgroup"),
28 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/waitgroup",
29 Requires: []*analysis.Analyzer{inspect.Analyzer},
30 Run: run,
31 }
32
33 func run(pass *analysis.Pass) (any, error) {
34 if !analysisinternal.Imports(pass.Pkg, "sync") {
35 return nil, nil
36 }
37
38 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
39 nodeFilter := []ast.Node{
40 (*ast.CallExpr)(nil),
41 }
42
43 inspect.WithStack(nodeFilter, func(n ast.Node, push bool, stack []ast.Node) (proceed bool) {
44 if push {
45 call := n.(*ast.CallExpr)
46 obj := typeutil.Callee(pass.TypesInfo, call)
47 if analysisinternal.IsMethodNamed(obj, "sync", "WaitGroup", "Add") &&
48 hasSuffix(stack, wantSuffix) &&
49 backindex(stack, 1) == backindex(stack, 2).(*ast.BlockStmt).List[0] {
50
51 pass.Reportf(call.Lparen, "WaitGroup.Add called from inside new goroutine")
52 }
53 }
54 return true
55 })
56
57 return nil, nil
58 }
59
60
61
62
63
64 var wantSuffix = []ast.Node{
65 (*ast.GoStmt)(nil),
66 (*ast.CallExpr)(nil),
67 (*ast.FuncLit)(nil),
68 (*ast.BlockStmt)(nil),
69 (*ast.ExprStmt)(nil),
70 (*ast.CallExpr)(nil),
71 }
72
73
74
75 func hasSuffix(stack, suffix []ast.Node) bool {
76
77 if len(stack) < len(suffix) {
78 return false
79 }
80 for i := range len(suffix) {
81 if reflect.TypeOf(backindex(stack, i)) != reflect.TypeOf(backindex(suffix, i)) {
82 return false
83 }
84 }
85 return true
86 }
87
88
89 func backindex[T any](slice []T, i int) T {
90 return slice[len(slice)-1-i]
91 }
92
View as plain text