1
2
3
4
5 package lostcancel
6
7 import (
8 _ "embed"
9 "fmt"
10 "go/ast"
11 "go/types"
12
13 "golang.org/x/tools/go/analysis"
14 "golang.org/x/tools/go/analysis/passes/ctrlflow"
15 "golang.org/x/tools/go/analysis/passes/inspect"
16 "golang.org/x/tools/go/analysis/passes/internal/analysisutil"
17 "golang.org/x/tools/go/ast/inspector"
18 "golang.org/x/tools/go/cfg"
19 "golang.org/x/tools/internal/analysisinternal"
20 "golang.org/x/tools/internal/astutil"
21 )
22
23
24 var doc string
25
26 var Analyzer = &analysis.Analyzer{
27 Name: "lostcancel",
28 Doc: analysisutil.MustExtractDoc(doc, "lostcancel"),
29 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/lostcancel",
30 Run: run,
31 Requires: []*analysis.Analyzer{
32 inspect.Analyzer,
33 ctrlflow.Analyzer,
34 },
35 }
36
37 const debug = false
38
39 var contextPackage = "context"
40
41
42
43
44
45
46
47
48
49
50
51 func run(pass *analysis.Pass) (any, error) {
52
53 if !analysisinternal.Imports(pass.Pkg, contextPackage) {
54 return nil, nil
55 }
56
57
58 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
59 nodeTypes := []ast.Node{
60 (*ast.FuncLit)(nil),
61 (*ast.FuncDecl)(nil),
62 }
63 inspect.Preorder(nodeTypes, func(n ast.Node) {
64 runFunc(pass, n)
65 })
66 return nil, nil
67 }
68
69 func runFunc(pass *analysis.Pass, node ast.Node) {
70
71 var funcScope *types.Scope
72 switch v := node.(type) {
73 case *ast.FuncLit:
74 funcScope = pass.TypesInfo.Scopes[v.Type]
75 case *ast.FuncDecl:
76 funcScope = pass.TypesInfo.Scopes[v.Type]
77 }
78
79
80 cancelvars := make(map[*types.Var]ast.Node)
81
82
83
84
85
86
87 astutil.PreorderStack(node, nil, func(n ast.Node, stack []ast.Node) bool {
88 if _, ok := n.(*ast.FuncLit); ok && len(stack) > 0 {
89 return false
90 }
91
92
93
94
95
96
97
98 if !isContextWithCancel(pass.TypesInfo, n) || !isCall(stack[len(stack)-1]) {
99 return true
100 }
101 var id *ast.Ident
102 stmt := stack[len(stack)-2]
103 switch stmt := stmt.(type) {
104 case *ast.ValueSpec:
105 if len(stmt.Names) > 1 {
106 id = stmt.Names[1]
107 }
108 case *ast.AssignStmt:
109 if len(stmt.Lhs) > 1 {
110 id, _ = stmt.Lhs[1].(*ast.Ident)
111 }
112 }
113 if id != nil {
114 if id.Name == "_" {
115 pass.ReportRangef(id,
116 "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
117 n.(*ast.SelectorExpr).Sel.Name)
118 } else if v, ok := pass.TypesInfo.Uses[id].(*types.Var); ok {
119
120
121 if funcScope.Contains(v.Pos()) {
122 cancelvars[v] = stmt
123 }
124 } else if v, ok := pass.TypesInfo.Defs[id].(*types.Var); ok {
125 cancelvars[v] = stmt
126 }
127 }
128 return true
129 })
130
131 if len(cancelvars) == 0 {
132 return
133 }
134
135
136 cfgs := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs)
137 var g *cfg.CFG
138 var sig *types.Signature
139 switch node := node.(type) {
140 case *ast.FuncDecl:
141 sig, _ = pass.TypesInfo.Defs[node.Name].Type().(*types.Signature)
142 if node.Name.Name == "main" && sig.Recv() == nil && pass.Pkg.Name() == "main" {
143
144
145 return
146 }
147 g = cfgs.FuncDecl(node)
148
149 case *ast.FuncLit:
150 sig, _ = pass.TypesInfo.Types[node.Type].Type.(*types.Signature)
151 g = cfgs.FuncLit(node)
152 }
153 if sig == nil {
154 return
155 }
156
157
158 if debug {
159 fmt.Println(g.Format(pass.Fset))
160 }
161
162
163
164
165 for v, stmt := range cancelvars {
166 if ret := lostCancelPath(pass, g, v, stmt, sig); ret != nil {
167 lineno := pass.Fset.Position(stmt.Pos()).Line
168 pass.ReportRangef(stmt, "the %s function is not used on all paths (possible context leak)", v.Name())
169
170 pos, end := ret.Pos(), ret.End()
171
172
173 if pass.Fset.File(pos) != pass.Fset.File(end) {
174 end = pos
175 }
176 pass.Report(analysis.Diagnostic{
177 Pos: pos,
178 End: end,
179 Message: fmt.Sprintf("this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno),
180 })
181 }
182 }
183 }
184
185 func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
186
187
188
189 func isContextWithCancel(info *types.Info, n ast.Node) bool {
190 sel, ok := n.(*ast.SelectorExpr)
191 if !ok {
192 return false
193 }
194 switch sel.Sel.Name {
195 case "WithCancel", "WithCancelCause",
196 "WithTimeout", "WithTimeoutCause",
197 "WithDeadline", "WithDeadlineCause":
198 default:
199 return false
200 }
201 if x, ok := sel.X.(*ast.Ident); ok {
202 if pkgname, ok := info.Uses[x].(*types.PkgName); ok {
203 return pkgname.Imported().Path() == contextPackage
204 }
205
206
207 return x.Name == "context"
208 }
209 return false
210 }
211
212
213
214
215
216 func lostCancelPath(pass *analysis.Pass, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt {
217 vIsNamedResult := sig != nil && tupleContains(sig.Results(), v)
218
219
220 uses := func(pass *analysis.Pass, v *types.Var, stmts []ast.Node) bool {
221 found := false
222 for _, stmt := range stmts {
223 ast.Inspect(stmt, func(n ast.Node) bool {
224 switch n := n.(type) {
225 case *ast.Ident:
226 if pass.TypesInfo.Uses[n] == v {
227 found = true
228 }
229 case *ast.ReturnStmt:
230
231
232 if n.Results == nil && vIsNamedResult {
233 found = true
234 }
235 }
236 return !found
237 })
238 }
239 return found
240 }
241
242
243 memo := make(map[*cfg.Block]bool)
244 blockUses := func(pass *analysis.Pass, v *types.Var, b *cfg.Block) bool {
245 res, ok := memo[b]
246 if !ok {
247 res = uses(pass, v, b.Nodes)
248 memo[b] = res
249 }
250 return res
251 }
252
253
254
255 var defblock *cfg.Block
256 var rest []ast.Node
257 outer:
258 for _, b := range g.Blocks {
259 for i, n := range b.Nodes {
260 if n == stmt {
261 defblock = b
262 rest = b.Nodes[i+1:]
263 break outer
264 }
265 }
266 }
267 if defblock == nil {
268 panic("internal error: can't find defining block for cancel var")
269 }
270
271
272 if uses(pass, v, rest) {
273 return nil
274 }
275
276
277 if ret := defblock.Return(); ret != nil {
278 return ret
279 }
280
281
282
283 seen := make(map[*cfg.Block]bool)
284 var search func(blocks []*cfg.Block) *ast.ReturnStmt
285 search = func(blocks []*cfg.Block) *ast.ReturnStmt {
286 for _, b := range blocks {
287 if seen[b] {
288 continue
289 }
290 seen[b] = true
291
292
293 if blockUses(pass, v, b) {
294 continue
295 }
296
297
298 if ret := b.Return(); ret != nil {
299 if debug {
300 fmt.Printf("found path to return in block %s\n", b)
301 }
302 return ret
303 }
304
305
306 if ret := search(b.Succs); ret != nil {
307 if debug {
308 fmt.Printf(" from block %s\n", b)
309 }
310 return ret
311 }
312 }
313 return nil
314 }
315 return search(defblock.Succs)
316 }
317
318 func tupleContains(tuple *types.Tuple, v *types.Var) bool {
319 for i := 0; i < tuple.Len(); i++ {
320 if tuple.At(i) == v {
321 return true
322 }
323 }
324 return false
325 }
326
View as plain text