diff --git a/README.md b/README.md
index 9a4bccd..e495137 100644
--- a/README.md
+++ b/README.md
@@ -47,21 +47,26 @@ if err != nil {
### What if I think it's bullshit?
I understand that each case needs to be analyzed separately,
-but I hope that the linter will make you think again -
-is it necessary to use an ambiguous API or is it better to do it using a sentinel error?
+but I hope that the linter will make you think again –
+is it necessary to use an **ambiguous API** or is it better to do it using a sentinel error?
In any case, you can just not enable the linter.
## Configuration
+### CLI
+
```shell
-# command line (see help for full list of types)
+# See help for full list of types.
$ nilnil --checked-types ptr,func ./...
```
+### golangci-lint
+
+https://golangci-lint.run/usage/linters/#nilnil
+
```yaml
-# https://golangci-lint.run/usage/configuration/
nilnil:
checked-types:
- ptr
@@ -69,6 +74,8 @@ nilnil:
- iface
- map
- chan
+ - uintptr
+ - unsafeptr
```
## Examples
@@ -219,20 +226,12 @@ func (r *RateLimiter) Allow() bool {
## Assumptions
-
- Click to expand
-
-
-
-- Linter only checks functions with two return arguments, the last of which has `error` type.
+- Linter only checks functions with two return arguments, the last of which implements `error`.
- Next types are checked:
- * pointers, functions & interfaces (`panic: invalid memory address or nil pointer dereference`);
+ * pointers (including `uinptr` and `unsafe.Pointer`), functions and interfaces (`panic: invalid memory address or nil pointer dereference`);
* maps (`panic: assignment to entry in nil map`);
* channels (`fatal error: all goroutines are asleep - deadlock!`)
-- `uinptr` & `unsafe.Pointer` are not checked as a special case.
-- Supported only explicit `return nil, nil`.
-
-
+- Only explicit `return nil, nil` are supported.
## Check Go 1.22.2 source code
diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go
index 1ea5204..5646ee9 100644
--- a/pkg/analyzer/analyzer.go
+++ b/pkg/analyzer/analyzer.go
@@ -2,7 +2,9 @@ package analyzer
import (
"go/ast"
+ "go/token"
"go/types"
+ "strconv"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
@@ -74,13 +76,32 @@ func (n *nilNil) run(pass *analysis.Pass) (interface{}, error) {
return false
}
- fRes1, fRes2 := ft.Results.List[0], ft.Results.List[1]
- if !(n.isDangerNilField(pass, fRes1) && n.isErrorField(pass, fRes2)) {
+ fRes1Type := pass.TypesInfo.TypeOf(ft.Results.List[0].Type)
+ if fRes1Type == nil {
return false
}
- rRes1, rRes2 := v.Results[0], v.Results[1]
- if isNil(pass, rRes1) && isNil(pass, rRes2) {
+ fRes2Type := pass.TypesInfo.TypeOf(ft.Results.List[1].Type)
+ if fRes2Type == nil {
+ return false
+ }
+
+ ok, zv := n.isDangerNilType(fRes1Type)
+ if !(ok && isErrorType(fRes2Type)) {
+ return false
+ }
+
+ retVal, retErr := v.Results[0], v.Results[1]
+
+ var needWarn bool
+ switch zv {
+ case zeroValueNil:
+ needWarn = isNil(pass, retVal) && isNil(pass, retErr)
+ case zeroValueZero:
+ needWarn = isZero(retVal) && isNil(pass, retErr)
+ }
+
+ if needWarn {
pass.Reportf(v.Pos(), reportMsg)
}
}
@@ -91,41 +112,47 @@ func (n *nilNil) run(pass *analysis.Pass) (interface{}, error) {
return nil, nil //nolint:nilnil
}
-func (n *nilNil) isDangerNilField(pass *analysis.Pass, f *ast.Field) bool {
- return n.isDangerNilType(pass.TypesInfo.TypeOf(f.Type))
-}
+type zeroValue int
-func (n *nilNil) isDangerNilType(t types.Type) bool {
+const (
+ zeroValueNil = iota + 1
+ zeroValueZero
+)
+
+func (n *nilNil) isDangerNilType(t types.Type) (bool, zeroValue) {
switch v := t.(type) {
case *types.Pointer:
- return n.checkedTypes.Contains(ptrType)
+ return n.checkedTypes.Contains(ptrType), zeroValueNil
case *types.Signature:
- return n.checkedTypes.Contains(funcType)
+ return n.checkedTypes.Contains(funcType), zeroValueNil
case *types.Interface:
- return n.checkedTypes.Contains(ifaceType)
+ return n.checkedTypes.Contains(ifaceType), zeroValueNil
case *types.Map:
- return n.checkedTypes.Contains(mapType)
+ return n.checkedTypes.Contains(mapType), zeroValueNil
case *types.Chan:
- return n.checkedTypes.Contains(chanType)
+ return n.checkedTypes.Contains(chanType), zeroValueNil
+
+ case *types.Basic:
+ if v.Kind() == types.Uintptr {
+ return n.checkedTypes.Contains(uintptrType), zeroValueZero
+ }
+ if v.Kind() == types.UnsafePointer {
+ return n.checkedTypes.Contains(unsafeptrType), zeroValueNil
+ }
case *types.Named:
return n.isDangerNilType(v.Underlying())
}
- return false
+ return false, 0
}
var errorIface = types.Universe.Lookup("error").Type().Underlying().(*types.Interface)
-func (n *nilNil) isErrorField(pass *analysis.Pass, f *ast.Field) bool {
- t := pass.TypesInfo.TypeOf(f.Type)
- if t == nil {
- return false
- }
-
+func isErrorType(t types.Type) bool {
_, ok := t.Underlying().(*types.Interface)
return ok && types.Implements(t, errorIface)
}
@@ -139,3 +166,19 @@ func isNil(pass *analysis.Pass, e ast.Expr) bool {
_, ok = pass.TypesInfo.ObjectOf(i).(*types.Nil)
return ok
}
+
+func isZero(e ast.Expr) bool {
+ bl, ok := e.(*ast.BasicLit)
+ if !ok {
+ return false
+ }
+ if bl.Kind != token.INT {
+ return false
+ }
+
+ v, err := strconv.ParseInt(bl.Value, 0, 64)
+ if err != nil {
+ return false
+ }
+ return v == 0
+}
diff --git a/pkg/analyzer/analyzer_test.go b/pkg/analyzer/analyzer_test.go
index a509743..d85c12e 100644
--- a/pkg/analyzer/analyzer_test.go
+++ b/pkg/analyzer/analyzer_test.go
@@ -9,9 +9,22 @@ import (
)
func TestNilNil(t *testing.T) {
+ t.Parallel()
+
pkgs := []string{
"examples",
"strange",
+ "unsafe",
}
analysistest.Run(t, analysistest.TestData(), analyzer.New(), pkgs...)
}
+
+func TestNilNil_Flags(t *testing.T) {
+ t.Parallel()
+
+ anlzr := analyzer.New()
+ if err := anlzr.Flags.Set("checked-types", "ptr"); err != nil {
+ t.Fatal(err)
+ }
+ analysistest.Run(t, analysistest.TestData(), anlzr, "pointers-only")
+}
diff --git a/pkg/analyzer/config.go b/pkg/analyzer/config.go
index 520b813..c9b8e3e 100644
--- a/pkg/analyzer/config.go
+++ b/pkg/analyzer/config.go
@@ -8,11 +8,13 @@ import (
func newDefaultCheckedTypes() checkedTypes {
return checkedTypes{
- ptrType: struct{}{},
- funcType: struct{}{},
- ifaceType: struct{}{},
- mapType: struct{}{},
- chanType: struct{}{},
+ ptrType: {},
+ funcType: {},
+ ifaceType: {},
+ mapType: {},
+ chanType: {},
+ uintptrType: {},
+ unsafeptrType: {},
}
}
@@ -25,15 +27,15 @@ func (t typeName) S() string {
}
const (
- ptrType typeName = "ptr"
- funcType typeName = "func"
- ifaceType typeName = "iface"
- mapType typeName = "map"
- chanType typeName = "chan"
+ ptrType typeName = "ptr"
+ funcType typeName = "func"
+ ifaceType typeName = "iface"
+ mapType typeName = "map"
+ chanType typeName = "chan"
+ uintptrType typeName = "uintptr"
+ unsafeptrType typeName = "unsafeptr"
)
-var knownTypes = []typeName{ptrType, funcType, ifaceType, mapType, chanType}
-
type checkedTypes map[typeName]struct{}
func (c checkedTypes) Contains(t typeName) bool {
@@ -60,7 +62,7 @@ func (c checkedTypes) Set(s string) error {
c.disableAll()
for _, t := range types {
switch tt := typeName(t); tt {
- case ptrType, funcType, ifaceType, mapType, chanType:
+ case ptrType, funcType, ifaceType, mapType, chanType, uintptrType, unsafeptrType:
c[tt] = struct{}{}
default:
return fmt.Errorf("unknown checked type name %q (see help)", t)
diff --git a/pkg/analyzer/config_test.go b/pkg/analyzer/config_test.go
index 82f7c3f..fec9aab 100644
--- a/pkg/analyzer/config_test.go
+++ b/pkg/analyzer/config_test.go
@@ -4,33 +4,18 @@ import "testing"
func TestCheckedTypes(t *testing.T) {
c := newDefaultCheckedTypes()
-
- for _, tt := range knownTypes {
- assertTrue(t, c.Contains(tt))
- }
- assertStringEqual(t, c.String(), "chan,func,iface,map,ptr")
+ assertStringEqual(t, c.String(), "chan,func,iface,map,ptr,uintptr,unsafeptr")
err := c.Set("chan,iface,ptr")
assertNoError(t, err)
assertTrue(t, c.Contains(chanType))
assertTrue(t, c.Contains(ifaceType))
assertTrue(t, c.Contains(ptrType))
+ assertFalse(t, c.Contains(funcType))
+ assertFalse(t, c.Contains(mapType))
+ assertFalse(t, c.Contains(uintptrType))
+ assertFalse(t, c.Contains(unsafeptrType))
assertStringEqual(t, "chan,iface,ptr", c.String())
-
- for _, tt := range knownTypes {
- err := c.Set(tt.S())
- assertNoError(t, err)
-
- for _, tt2 := range knownTypes {
- if tt2 == tt {
- assertTrue(t, c.Contains(tt2))
- } else {
- assertFalse(t, c.Contains(tt2))
- }
- }
-
- assertStringEqual(t, tt.S(), c.String())
- }
}
func TestCheckedTypes_SetError(t *testing.T) {
@@ -46,45 +31,40 @@ func TestCheckedTypes_SetWithoutArg(t *testing.T) {
err := c.Set("")
assertNoError(t, err)
-
- for _, tt := range knownTypes {
- assertTrue(t, c.Contains(tt))
- }
- assertStringEqual(t, c.String(), "chan,func,iface,map,ptr")
+ assertStringEqual(t, c.String(), "chan,func,iface,map,ptr,uintptr,unsafeptr")
}
func assertError(t *testing.T, err error) {
t.Helper()
if err == nil {
- t.FailNow()
+ t.Fatal("must be not nil")
}
}
func assertNoError(t *testing.T, err error) {
t.Helper()
if err != nil {
- t.FailNow()
+ t.Fatalf("must be nil, got %q", err)
}
}
func assertStringEqual(t *testing.T, a, b string) {
t.Helper()
if a != b {
- t.Logf("%q != %q", a, b)
- t.FailNow()
+ t.Fatalf("%q != %q", a, b)
}
}
func assertTrue(t *testing.T, v bool) {
t.Helper()
if !v {
- t.FailNow()
+ t.Fatal("must be true")
}
}
func assertFalse(t *testing.T, v bool) {
t.Helper()
if v {
- t.FailNow()
+ t.Fatal("must be false")
}
}
diff --git a/pkg/analyzer/testdata/src/examples/negative.go b/pkg/analyzer/testdata/src/examples/negative.go
index 6a26c6a..e297e08 100644
--- a/pkg/analyzer/testdata/src/examples/negative.go
+++ b/pkg/analyzer/testdata/src/examples/negative.go
@@ -14,8 +14,24 @@ func withoutError2() (*User, *User) { return nil, nil }
func withoutError3() (*User, *User, *User) { return nil, nil, nil }
func withoutError4() (*User, *User, *User, *User) { return nil, nil, nil, nil }
+func invalidOrder() (error, *User) { return nil, nil }
+func withError3rd() (*User, bool, error) { return nil, false, nil }
+func withError4th() (*User, *User, *User, error) { return nil, nil, nil, nil }
+
+func slice() ([]int, error) { return nil, nil }
+
+func strNil() (string, error) { return "nil", nil }
+func strEmpty() (string, error) { return "", nil }
+
// Valid.
+func primitivePtrTypeValid() (*int, error) {
+ if false {
+ return nil, io.EOF
+ }
+ return new(int), nil
+}
+
func structPtrTypeValid() (*User, error) {
if false {
return nil, io.EOF
@@ -23,11 +39,19 @@ func structPtrTypeValid() (*User, error) {
return new(User), nil
}
-func primitivePtrTypeValid() (*int, error) {
+func unsafePtrValid() (unsafe.Pointer, error) {
if false {
return nil, io.EOF
}
- return new(int), nil
+ var i int
+ return unsafe.Pointer(&i), nil
+}
+
+func uintPtrValid() (uintptr, error) {
+ if false {
+ return 0, io.EOF
+ }
+ return 0xc82000c290, nil
}
func channelTypeValid() (ChannelType, error) {
@@ -55,13 +79,6 @@ func ifaceTypeValid() (io.Reader, error) {
// Unsupported.
-func invalidOrder() (error, *User) { return nil, nil }
-func withError3rd() (*User, bool, error) { return nil, false, nil }
-func withError4th() (*User, *User, *User, error) { return nil, nil, nil, nil }
-func unsafePtr() (unsafe.Pointer, error) { return nil, nil }
-func uintPtr() (uintptr, error) { return 0, nil }
-func slice() ([]int, error) { return nil, nil }
-
func implicitNil1() (*User, error) {
err := (error)(nil)
return nil, err
diff --git a/pkg/analyzer/testdata/src/examples/positive.go b/pkg/analyzer/testdata/src/examples/positive.go
index 99295fb..0ce263d 100644
--- a/pkg/analyzer/testdata/src/examples/positive.go
+++ b/pkg/analyzer/testdata/src/examples/positive.go
@@ -1,5 +1,7 @@
package examples
+import "unsafe"
+
type User struct{}
func primitivePtr() (*int, error) {
@@ -18,6 +20,26 @@ func anonymousStructPtr() (*struct{ ID string }, error) {
return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
}
+func unsafePtr() (unsafe.Pointer, error) {
+ return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
+}
+
+func uintPtr() (uintptr, error) {
+ return 0, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
+}
+
+func uintPtr0b() (uintptr, error) {
+ return 0b0, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
+}
+
+func uintPtr0x() (uintptr, error) {
+ return 0x00, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
+}
+
+func uintPtr0o() (uintptr, error) {
+ return 0o000, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
+}
+
func chBi() (chan int, error) {
return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
}
diff --git a/pkg/analyzer/testdata/src/examples/positive_own_types.go b/pkg/analyzer/testdata/src/examples/positive_own_types.go
index f5f996a..0d94968 100644
--- a/pkg/analyzer/testdata/src/examples/positive_own_types.go
+++ b/pkg/analyzer/testdata/src/examples/positive_own_types.go
@@ -33,3 +33,12 @@ type checkerAlias = Checker
func ifaceTypeAliased() (checkerAlias, error) {
return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
}
+
+type (
+ IntegerType int
+ PtrIntegerType *IntegerType
+)
+
+func ptrIntegerType() (PtrIntegerType, error) {
+ return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
+}
diff --git a/pkg/analyzer/testdata/src/pointers-only/positive.go b/pkg/analyzer/testdata/src/pointers-only/positive.go
new file mode 100644
index 0000000..7be3ddc
--- /dev/null
+++ b/pkg/analyzer/testdata/src/pointers-only/positive.go
@@ -0,0 +1,31 @@
+package examples
+
+type User struct{}
+
+func primitivePtr() (*int, error) {
+ return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
+}
+
+func structPtr() (*User, error) {
+ return nil, nil // want "return both the `nil` error and invalid value: use a sentinel error instead"
+}
+
+func uintPtr0o() (uintptr, error) {
+ return 0o000, nil
+}
+
+func chBi() (chan int, error) {
+ return nil, nil
+}
+
+func fun() (func(), error) {
+ return nil, nil
+}
+
+func anyType() (any, error) {
+ return nil, nil
+}
+
+func m1() (map[int]int, error) {
+ return nil, nil
+}