1
2
3
4
5
6
7 package sigchanyzer
8
9 import (
10 "bytes"
11 "slices"
12
13 _ "embed"
14 "go/ast"
15 "go/format"
16 "go/token"
17 "go/types"
18
19 "golang.org/x/tools/go/analysis"
20 "golang.org/x/tools/go/analysis/passes/inspect"
21 "golang.org/x/tools/go/analysis/passes/internal/analysisutil"
22 "golang.org/x/tools/go/ast/inspector"
23 "golang.org/x/tools/internal/analysisinternal"
24 )
25
26
27 var doc string
28
29
30 var Analyzer = &analysis.Analyzer{
31 Name: "sigchanyzer",
32 Doc: analysisutil.MustExtractDoc(doc, "sigchanyzer"),
33 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/sigchanyzer",
34 Requires: []*analysis.Analyzer{inspect.Analyzer},
35 Run: run,
36 }
37
38 func run(pass *analysis.Pass) (any, error) {
39 if !analysisinternal.Imports(pass.Pkg, "os/signal") {
40 return nil, nil
41 }
42
43 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
44
45 nodeFilter := []ast.Node{
46 (*ast.CallExpr)(nil),
47 }
48 inspect.Preorder(nodeFilter, func(n ast.Node) {
49 call := n.(*ast.CallExpr)
50 if !isSignalNotify(pass.TypesInfo, call) {
51 return
52 }
53 var chanDecl *ast.CallExpr
54 switch arg := call.Args[0].(type) {
55 case *ast.Ident:
56 if decl, ok := findDecl(arg).(*ast.CallExpr); ok {
57 chanDecl = decl
58 }
59 case *ast.CallExpr:
60
61
62 if isBuiltinMake(pass.TypesInfo, arg) {
63 return
64 }
65 chanDecl = arg
66 }
67 if chanDecl == nil || len(chanDecl.Args) != 1 {
68 return
69 }
70
71
72
73 chanDeclCopy := &ast.CallExpr{}
74 *chanDeclCopy = *chanDecl
75 chanDeclCopy.Args = slices.Clone(chanDecl.Args)
76 chanDeclCopy.Args = append(chanDeclCopy.Args, &ast.BasicLit{
77 Kind: token.INT,
78 Value: "1",
79 })
80
81 var buf bytes.Buffer
82 if err := format.Node(&buf, token.NewFileSet(), chanDeclCopy); err != nil {
83 return
84 }
85 pass.Report(analysis.Diagnostic{
86 Pos: call.Pos(),
87 End: call.End(),
88 Message: "misuse of unbuffered os.Signal channel as argument to signal.Notify",
89 SuggestedFixes: []analysis.SuggestedFix{{
90 Message: "Change to buffer channel",
91 TextEdits: []analysis.TextEdit{{
92 Pos: chanDecl.Pos(),
93 End: chanDecl.End(),
94 NewText: buf.Bytes(),
95 }},
96 }},
97 })
98 })
99 return nil, nil
100 }
101
102 func isSignalNotify(info *types.Info, call *ast.CallExpr) bool {
103 check := func(id *ast.Ident) bool {
104 obj := info.ObjectOf(id)
105 return obj.Name() == "Notify" && obj.Pkg().Path() == "os/signal"
106 }
107 switch fun := call.Fun.(type) {
108 case *ast.SelectorExpr:
109 return check(fun.Sel)
110 case *ast.Ident:
111 if fun, ok := findDecl(fun).(*ast.SelectorExpr); ok {
112 return check(fun.Sel)
113 }
114 return false
115 default:
116 return false
117 }
118 }
119
120 func findDecl(arg *ast.Ident) ast.Node {
121 if arg.Obj == nil {
122 return nil
123 }
124 switch as := arg.Obj.Decl.(type) {
125 case *ast.AssignStmt:
126 if len(as.Lhs) != len(as.Rhs) {
127 return nil
128 }
129 for i, lhs := range as.Lhs {
130 lid, ok := lhs.(*ast.Ident)
131 if !ok {
132 continue
133 }
134 if lid.Obj == arg.Obj {
135 return as.Rhs[i]
136 }
137 }
138 case *ast.ValueSpec:
139 if len(as.Names) != len(as.Values) {
140 return nil
141 }
142 for i, name := range as.Names {
143 if name.Obj == arg.Obj {
144 return as.Values[i]
145 }
146 }
147 }
148 return nil
149 }
150
151 func isBuiltinMake(info *types.Info, call *ast.CallExpr) bool {
152 typVal := info.Types[call.Fun]
153 if !typVal.IsBuiltin() {
154 return false
155 }
156 switch fun := call.Fun.(type) {
157 case *ast.Ident:
158 return info.ObjectOf(fun).Name() == "make"
159 default:
160 return false
161 }
162 }
163
View as plain text