Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/sigchanyzer/sigchanyzer.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package sigchanyzer defines an Analyzer that detects
     6  // misuse of unbuffered signal as argument to signal.Notify.
     7  package sigchanyzer
     8  
     9  import (
    10  	"bytes"
    11  	"slices"
    12  
    13  	_ "embed"
    14  	"go/ast"
    15  	"go/format"
    16  	"go/token"
    17  	"go/types"
    18  
    19  	"golang.org/x/tools/go/analysis"
    20  	"golang.org/x/tools/go/analysis/passes/inspect"
    21  	"golang.org/x/tools/go/analysis/passes/internal/analysisutil"
    22  	"golang.org/x/tools/go/ast/inspector"
    23  	"golang.org/x/tools/internal/analysisinternal"
    24  )
    25  
    26  //go:embed doc.go
    27  var doc string
    28  
    29  // Analyzer describes sigchanyzer analysis function detector.
    30  var Analyzer = &analysis.Analyzer{
    31  	Name:     "sigchanyzer",
    32  	Doc:      analysisutil.MustExtractDoc(doc, "sigchanyzer"),
    33  	URL:      "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/sigchanyzer",
    34  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    35  	Run:      run,
    36  }
    37  
    38  func run(pass *analysis.Pass) (any, error) {
    39  	if !analysisinternal.Imports(pass.Pkg, "os/signal") {
    40  		return nil, nil // doesn't directly import signal
    41  	}
    42  
    43  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    44  
    45  	nodeFilter := []ast.Node{
    46  		(*ast.CallExpr)(nil),
    47  	}
    48  	inspect.Preorder(nodeFilter, func(n ast.Node) {
    49  		call := n.(*ast.CallExpr)
    50  		if !isSignalNotify(pass.TypesInfo, call) {
    51  			return
    52  		}
    53  		var chanDecl *ast.CallExpr
    54  		switch arg := call.Args[0].(type) {
    55  		case *ast.Ident:
    56  			if decl, ok := findDecl(arg).(*ast.CallExpr); ok {
    57  				chanDecl = decl
    58  			}
    59  		case *ast.CallExpr:
    60  			// Only signal.Notify(make(chan os.Signal), os.Interrupt) is safe,
    61  			// conservatively treat others as not safe, see golang/go#45043
    62  			if isBuiltinMake(pass.TypesInfo, arg) {
    63  				return
    64  			}
    65  			chanDecl = arg
    66  		}
    67  		if chanDecl == nil || len(chanDecl.Args) != 1 {
    68  			return
    69  		}
    70  
    71  		// Make a copy of the channel's declaration to avoid
    72  		// mutating the AST. See https://golang.org/issue/46129.
    73  		chanDeclCopy := &ast.CallExpr{}
    74  		*chanDeclCopy = *chanDecl
    75  		chanDeclCopy.Args = slices.Clone(chanDecl.Args)
    76  		chanDeclCopy.Args = append(chanDeclCopy.Args, &ast.BasicLit{
    77  			Kind:  token.INT,
    78  			Value: "1",
    79  		})
    80  
    81  		var buf bytes.Buffer
    82  		if err := format.Node(&buf, token.NewFileSet(), chanDeclCopy); err != nil {
    83  			return
    84  		}
    85  		pass.Report(analysis.Diagnostic{
    86  			Pos:     call.Pos(),
    87  			End:     call.End(),
    88  			Message: "misuse of unbuffered os.Signal channel as argument to signal.Notify",
    89  			SuggestedFixes: []analysis.SuggestedFix{{
    90  				Message: "Change to buffer channel",
    91  				TextEdits: []analysis.TextEdit{{
    92  					Pos:     chanDecl.Pos(),
    93  					End:     chanDecl.End(),
    94  					NewText: buf.Bytes(),
    95  				}},
    96  			}},
    97  		})
    98  	})
    99  	return nil, nil
   100  }
   101  
   102  func isSignalNotify(info *types.Info, call *ast.CallExpr) bool {
   103  	check := func(id *ast.Ident) bool {
   104  		obj := info.ObjectOf(id)
   105  		return obj.Name() == "Notify" && obj.Pkg().Path() == "os/signal"
   106  	}
   107  	switch fun := call.Fun.(type) {
   108  	case *ast.SelectorExpr:
   109  		return check(fun.Sel)
   110  	case *ast.Ident:
   111  		if fun, ok := findDecl(fun).(*ast.SelectorExpr); ok {
   112  			return check(fun.Sel)
   113  		}
   114  		return false
   115  	default:
   116  		return false
   117  	}
   118  }
   119  
   120  func findDecl(arg *ast.Ident) ast.Node {
   121  	if arg.Obj == nil {
   122  		return nil
   123  	}
   124  	switch as := arg.Obj.Decl.(type) {
   125  	case *ast.AssignStmt:
   126  		if len(as.Lhs) != len(as.Rhs) {
   127  			return nil
   128  		}
   129  		for i, lhs := range as.Lhs {
   130  			lid, ok := lhs.(*ast.Ident)
   131  			if !ok {
   132  				continue
   133  			}
   134  			if lid.Obj == arg.Obj {
   135  				return as.Rhs[i]
   136  			}
   137  		}
   138  	case *ast.ValueSpec:
   139  		if len(as.Names) != len(as.Values) {
   140  			return nil
   141  		}
   142  		for i, name := range as.Names {
   143  			if name.Obj == arg.Obj {
   144  				return as.Values[i]
   145  			}
   146  		}
   147  	}
   148  	return nil
   149  }
   150  
   151  func isBuiltinMake(info *types.Info, call *ast.CallExpr) bool {
   152  	typVal := info.Types[call.Fun]
   153  	if !typVal.IsBuiltin() {
   154  		return false
   155  	}
   156  	switch fun := call.Fun.(type) {
   157  	case *ast.Ident:
   158  		return info.ObjectOf(fun).Name() == "make"
   159  	default:
   160  		return false
   161  	}
   162  }
   163  

View as plain text