diff --git a/README.md b/README.md index 9a4bccd..e495137 100644 --- a/README.md +++ b/README.md @@ -47,21 +47,26 @@ if err != nil { ### What if I think it's bullshit? I understand that each case needs to be analyzed separately, -but I hope that the linter will make you think again - -is it necessary to use an ambiguous API or is it better to do it using a sentinel error? +but I hope that the linter will make you think again – +is it necessary to use an **ambiguous API** or is it better to do it using a sentinel error?
In any case, you can just not enable the linter. ## Configuration +### CLI + ```shell -# command line (see help for full list of types) +# See help for full list of types. $ nilnil --checked-types ptr,func ./... ``` +### golangci-lint + +https://golangci-lint.run/usage/linters/#nilnil + ```yaml -# https://golangci-lint.run/usage/configuration/ nilnil: checked-types: - ptr @@ -69,6 +74,8 @@ nilnil: - iface - map - chan + - uintptr + - unsafeptr ``` ## Examples @@ -219,20 +226,12 @@ func (r *RateLimiter) Allow() bool { ## Assumptions -
- Click to expand - -
- -- Linter only checks functions with two return arguments, the last of which has `error` type. +- Linter only checks functions with two return arguments, the last of which implements `error`. - Next types are checked: - * pointers, functions & interfaces (`panic: invalid memory address or nil pointer dereference`); + * pointers (including `uinptr` and `unsafe.Pointer`), functions and interfaces (`panic: invalid memory address or nil pointer dereference`); * maps (`panic: assignment to entry in nil map`); * channels (`fatal error: all goroutines are asleep - deadlock!`) -- `uinptr` & `unsafe.Pointer` are not checked as a special case. -- Supported only explicit `return nil, nil`. - -
+- Only explicit `return nil, nil` are supported. ## Check Go 1.22.2 source code diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index 1ea5204..5646ee9 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -2,7 +2,9 @@ package analyzer import ( "go/ast" + "go/token" "go/types" + "strconv" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" @@ -74,13 +76,32 @@ func (n *nilNil) run(pass *analysis.Pass) (interface{}, error) { return false } - fRes1, fRes2 := ft.Results.List[0], ft.Results.List[1] - if !(n.isDangerNilField(pass, fRes1) && n.isErrorField(pass, fRes2)) { + fRes1Type := pass.TypesInfo.TypeOf(ft.Results.List[0].Type) + if fRes1Type == nil { return false } - rRes1, rRes2 := v.Results[0], v.Results[1] - if isNil(pass, rRes1) && isNil(pass, rRes2) { + fRes2Type := pass.TypesInfo.TypeOf(ft.Results.List[1].Type) + if fRes2Type == nil { + return false + } + + ok, zv := n.isDangerNilType(fRes1Type) + if !(ok && isErrorType(fRes2Type)) { + return false + } + + retVal, retErr := v.Results[0], v.Results[1] + + var needWarn bool + switch zv { + case zeroValueNil: + needWarn = isNil(pass, retVal) && isNil(pass, retErr) + case zeroValueZero: + needWarn = isZero(retVal) && isNil(pass, retErr) + } + + if needWarn { pass.Reportf(v.Pos(), reportMsg) } } @@ -91,41 +112,47 @@ func (n *nilNil) run(pass *analysis.Pass) (interface{}, error) { return nil, nil //nolint:nilnil } -func (n *nilNil) isDangerNilField(pass *analysis.Pass, f *ast.Field) bool { - return n.isDangerNilType(pass.TypesInfo.TypeOf(f.Type)) -} +type zeroValue int -func (n *nilNil) isDangerNilType(t types.Type) bool { +const ( + zeroValueNil = iota + 1 + zeroValueZero +) + +func (n *nilNil) isDangerNilType(t types.Type) (bool, zeroValue) { switch v := t.(type) { case *types.Pointer: - return n.checkedTypes.Contains(ptrType) + return n.checkedTypes.Contains(ptrType), zeroValueNil case *types.Signature: - return n.checkedTypes.Contains(funcType) + return n.checkedTypes.Contains(funcType), zeroValueNil case *types.Interface: - return n.checkedTypes.Contains(ifaceType) + return n.checkedTypes.Contains(ifaceType), zeroValueNil case *types.Map: - return n.checkedTypes.Contains(mapType) + return n.checkedTypes.Contains(mapType), zeroValueNil case *types.Chan: - return n.checkedTypes.Contains(chanType) + return n.checkedTypes.Contains(chanType), zeroValueNil + + case *types.Basic: + if v.Kind() == types.Uintptr { + return n.checkedTypes.Contains(uintptrType), zeroValueZero + } + if v.Kind() == types.UnsafePointer { + return n.checkedTypes.Contains(unsafeptrType), zeroValueNil + } case *types.Named: return n.isDangerNilType(v.Underlying()) } - return false + return false, 0 } var errorIface = types.Universe.Lookup("error").Type().Underlying().(*types.Interface) -func (n *nilNil) isErrorField(pass *analysis.Pass, f *ast.Field) bool { - t := pass.TypesInfo.TypeOf(f.Type) - if t == nil { - return false - } - +func isErrorType(t types.Type) bool { _, ok := t.Underlying().(*types.Interface) return ok && types.Implements(t, errorIface) } @@ -139,3 +166,19 @@ func isNil(pass *analysis.Pass, e ast.Expr) bool { _, ok = pass.TypesInfo.ObjectOf(i).(*types.Nil) return ok } + +func isZero(e ast.Expr) bool { + bl, ok := e.(*ast.BasicLit) + if !ok { + return false + } + if bl.Kind != token.INT { + return false + } + + v, err := strconv.ParseInt(bl.Value, 0, 64) + if err != nil { + return false + } + return v == 0 +} diff --git a/pkg/analyzer/analyzer_test.go b/pkg/analyzer/analyzer_test.go index a509743..d85c12e 100644 --- a/pkg/analyzer/analyzer_test.go +++ b/pkg/analyzer/analyzer_test.go @@ -9,9 +9,22 @@ import ( ) func TestNilNil(t *testing.T) { + t.Parallel() + pkgs := []string{ "examples", "strange", + "unsafe", } analysistest.Run(t, analysistest.TestData(), analyzer.New(), pkgs...) } + +func TestNilNil_Flags(t *testing.T) { + t.Parallel() + + anlzr := analyzer.New() + if err := anlzr.Flags.Set("checked-types", "ptr"); err != nil { + t.Fatal(err) + } + analysistest.Run(t, analysistest.TestData(), anlzr, "pointers-only") +} diff --git a/pkg/analyzer/config.go b/pkg/analyzer/config.go index 520b813..c9b8e3e 100644 --- a/pkg/analyzer/config.go +++ b/pkg/analyzer/config.go @@ -8,11 +8,13 @@ import ( func newDefaultCheckedTypes() checkedTypes { return checkedTypes{ - ptrType: struct{}{}, - funcType: struct{}{}, - ifaceType: struct{}{}, - mapType: struct{}{}, - chanType: struct{}{}, + ptrType: {}, + funcType: {}, + ifaceType: {}, + mapType: {}, + chanType: {}, + uintptrType: {}, + unsafeptrType: {}, } } @@ -25,15 +27,15 @@ func (t typeName) S() string { } const ( - ptrType typeName = "ptr" - funcType typeName = "func" - ifaceType typeName = "iface" - mapType typeName = "map" - chanType typeName = "chan" + ptrType typeName = "ptr" + funcType typeName = "func" + ifaceType typeName = "iface" + mapType typeName = "map" + chanType typeName = "chan" + uintptrType typeName = "uintptr" + unsafeptrType typeName = "unsafeptr" ) -var knownTypes = []typeName{ptrType, funcType, ifaceType, mapType, chanType} - type checkedTypes map[typeName]struct{} func (c checkedTypes) Contains(t typeName) bool { @@ -60,7 +62,7 @@ func (c checkedTypes) Set(s string) error { c.disableAll() for _, t := range types { switch tt := typeName(t); tt { - case ptrType, funcType, ifaceType, mapType, chanType: + case ptrType, funcType, ifaceType, mapType, chanType, uintptrType, unsafeptrType: c[tt] = struct{}{} default: return fmt.Errorf("unknown checked type name %q (see help)", t) diff --git a/pkg/analyzer/config_test.go b/pkg/analyzer/config_test.go index 82f7c3f..fec9aab 100644 --- a/pkg/analyzer/config_test.go +++ b/pkg/analyzer/config_test.go @@ -4,33 +4,18 @@ import "testing" func TestCheckedTypes(t *testing.T) { c := newDefaultCheckedTypes() - - for _, tt := range knownTypes { - assertTrue(t, c.Contains(tt)) - } - assertStringEqual(t, c.String(), "chan,func,iface,map,ptr") + assertStringEqual(t, c.String(), "chan,func,iface,map,ptr,uintptr,unsafeptr") err := c.Set("chan,iface,ptr") assertNoError(t, err) assertTrue(t, c.Contains(chanType)) assertTrue(t, c.Contains(ifaceType)) assertTrue(t, c.Contains(ptrType)) + assertFalse(t, c.Contains(funcType)) + assertFalse(t, c.Contains(mapType)) + assertFalse(t, c.Contains(uintptrType)) + assertFalse(t, c.Contains(unsafeptrType)) assertStringEqual(t, "chan,iface,ptr", c.String()) - - for _, tt := range knownTypes { - err := c.Set(tt.S()) - assertNoError(t, err) - - for _, tt2 := range knownTypes { - if tt2 == tt { - assertTrue(t, c.Contains(tt2)) - } else { - assertFalse(t, c.Contains(tt2)) - } - } - - assertStringEqual(t, tt.S(), c.String()) - } } func TestCheckedTypes_SetError(t *testing.T) { @@ -46,45 +31,40 @@ func TestCheckedTypes_SetWithoutArg(t *testing.T) { err := c.Set("") assertNoError(t, err) - - for _, tt := range knownTypes { - assertTrue(t, c.Contains(tt)) - } - assertStringEqual(t, c.String(), "chan,func,iface,map,ptr") + assertStringEqual(t, c.String(), "chan,func,iface,map,ptr,uintptr,unsafeptr") } func assertError(t *testing.T, err error) { t.Helper() if err == nil { - t.FailNow() + t.Fatal("must be not nil") } } func assertNoError(t *testing.T, err error) { t.Helper() if err != nil { - t.FailNow() + t.Fatalf("must be nil, got %q", err) } } func assertStringEqual(t *testing.T, a, b string) { t.Helper() if a != b { - t.Logf("%q != %q", a, b) - t.FailNow() + t.Fatalf("%q != %q", a, b) } } func assertTrue(t *testing.T, v bool) { t.Helper() if !v { - t.FailNow() + t.Fatal("must be true") } } func assertFalse(t *testing.T, v bool) { t.Helper() if v { - t.FailNow() + t.Fatal("must be false") } } diff --git a/pkg/analyzer/testdata/src/examples/negative.go b/pkg/analyzer/testdata/src/examples/negative.go index 6a26c6a..e297e08 100644 --- a/pkg/analyzer/testdata/src/examples/negative.go +++ b/pkg/analyzer/testdata/src/examples/negative.go @@ -14,8 +14,24 @@ func withoutError2() (*User, *User) { return nil, nil } func withoutError3() (*User, *User, *User) { return nil, nil, nil } func withoutError4() (*User, *User, *User, *User) { return nil, nil, nil, nil } +func invalidOrder() (error, *User) { return nil, nil } +func withError3rd() (*User, bool, error) { return nil, false, nil } +func withError4th() (*User, *User, *User, error) { return nil, nil, nil, nil } + +func slice() ([]int, error) { return nil, nil } + +func strNil() (string, error) { return "nil", nil } +func strEmpty() (string, error) { return "", nil } + // Valid. +func primitivePtrTypeValid() (*int, error) { + if false { + return nil, io.EOF + } + return new(int), nil +} + func structPtrTypeValid() (*User, error) { if false { return nil, io.EOF @@ -23,11 +39,19 @@ func structPtrTypeValid() (*User, error) { return new(User), nil } -func primitivePtrTypeValid() (*int, error) { +func unsafePtrValid() (unsafe.Pointer, error) { if false { return nil, io.EOF } - return new(int), nil + var i int + return unsafe.Pointer(&i), nil +} + +func uintPtrValid() (uintptr, error) { + if false { + return 0, io.EOF + } + return 0xc82000c290, nil } func channelTypeValid() (ChannelType, error) { @@ -55,13 +79,6 @@ func ifaceTypeValid() (io.Reader, error) { // Unsupported. -func invalidOrder() (error, *User) { return nil, nil } -func withError3rd() (*User, bool, error) { return nil, false, nil } -func withError4th() (*User, *User, *User, error) { return nil, nil, nil, nil } -func unsafePtr() (unsafe.Pointer, error) { return nil, nil } -func uintPtr() (uintptr, error) { return 0, nil } -func slice() ([]int, error) { return nil, nil } - func implicitNil1() (*User, error) { err := (error)(nil) return nil, err diff --git a/pkg/analyzer/testdata/src/examples/positive.go b/pkg/analyzer/testdata/src/examples/positive.go index 99295fb..0ce263d 100644 --- a/pkg/analyzer/testdata/src/examples/positive.go +++ b/pkg/analyzer/testdata/src/examples/positive.go @@ -1,5 +1,7 @@ package examples +import "unsafe" + type User struct{} func primitivePtr() (*int, error) { @@ -18,6 +20,26 @@ func anonymousStructPtr() (*struct{ ID string }, error) { return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" } +func unsafePtr() (unsafe.Pointer, error) { + return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" +} + +func uintPtr() (uintptr, error) { + return 0, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" +} + +func uintPtr0b() (uintptr, error) { + return 0b0, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" +} + +func uintPtr0x() (uintptr, error) { + return 0x00, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" +} + +func uintPtr0o() (uintptr, error) { + return 0o000, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" +} + func chBi() (chan int, error) { return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" } diff --git a/pkg/analyzer/testdata/src/examples/positive_own_types.go b/pkg/analyzer/testdata/src/examples/positive_own_types.go index f5f996a..0d94968 100644 --- a/pkg/analyzer/testdata/src/examples/positive_own_types.go +++ b/pkg/analyzer/testdata/src/examples/positive_own_types.go @@ -33,3 +33,12 @@ type checkerAlias = Checker func ifaceTypeAliased() (checkerAlias, error) { return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" } + +type ( + IntegerType int + PtrIntegerType *IntegerType +) + +func ptrIntegerType() (PtrIntegerType, error) { + return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" +} diff --git a/pkg/analyzer/testdata/src/pointers-only/positive.go b/pkg/analyzer/testdata/src/pointers-only/positive.go new file mode 100644 index 0000000..7be3ddc --- /dev/null +++ b/pkg/analyzer/testdata/src/pointers-only/positive.go @@ -0,0 +1,31 @@ +package examples + +type User struct{} + +func primitivePtr() (*int, error) { + return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" +} + +func structPtr() (*User, error) { + return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead" +} + +func uintPtr0o() (uintptr, error) { + return 0o000, nil +} + +func chBi() (chan int, error) { + return nil, nil +} + +func fun() (func(), error) { + return nil, nil +} + +func anyType() (any, error) { + return nil, nil +} + +func m1() (map[int]int, error) { + return nil, nil +}