diff --git a/internal/errors.go b/internal/errors.go index e209d158..4219bfc3 100644 --- a/internal/errors.go +++ b/internal/errors.go @@ -16,6 +16,7 @@ package internal import ( "encoding/json" + "errors" "fmt" "net" "net/http" @@ -97,8 +98,8 @@ func (fe *FirebaseError) Error() string { // HasPlatformErrorCode checks if the given error contains a specific error code. func HasPlatformErrorCode(err error, code ErrorCode) bool { - fe, ok := err.(*FirebaseError) - return ok && fe.ErrorCode == code + var fe *FirebaseError + return errors.As(err, &fe) && fe.ErrorCode == code } var httpStatusToErrorCodes = map[int]ErrorCode{ diff --git a/internal/errors_test.go b/internal/errors_test.go index 3733429e..9e22dc28 100644 --- a/internal/errors_test.go +++ b/internal/errors_test.go @@ -335,3 +335,59 @@ func TestErrorHTTPResponse(t *testing.T) { t.Errorf("Unmarshal(Response.Body) = %v; want = {key: value}", m) } } + +func TestHasPlatformErrorCode(t *testing.T) { + type args struct { + err error + code ErrorCode + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "nil", + args: args{ + err: nil, + code: Aborted, + }, + want: false, + }, + { + name: "not internal", + args: args{ + err: fmt.Errorf("something happened"), + code: Aborted, + }, + want: false, + }, + { + name: "simple", + args: args{ + err: &FirebaseError{ + ErrorCode: Aborted, + }, + code: Aborted, + }, + want: true, + }, + { + name: "wrapped", + args: args{ + err: fmt.Errorf("[prefix] %w", &FirebaseError{ + ErrorCode: Aborted, + }), + code: Aborted, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := HasPlatformErrorCode(tt.args.err, tt.args.code); got != tt.want { + t.Errorf("HasPlatformErrorCode() = %v, want %v", got, tt.want) + } + }) + } +}