diff --git a/.github/scripts/publish_post_check.sh b/.github/scripts/publish_post_check.sh deleted file mode 100755 index 8311ec25..00000000 --- a/.github/scripts/publish_post_check.sh +++ /dev/null @@ -1,105 +0,0 @@ - -#!/bin/bash - -# Copyright 2020 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -###################################### Outputs ##################################### - -# 1. version: The version of this release including the 'v' prefix (e.g. v1.2.3). -# 2. changelog: Formatted changelog text for this release. - -#################################################################################### - -set -e -set -u - -function echo_info() { - local MESSAGE=$1 - echo "[INFO] ${MESSAGE}" -} - -function echo_warn() { - local MESSAGE=$1 - echo "[WARN] ${MESSAGE}" -} - -function terminate() { - echo "" - echo_warn "--------------------------------------------" - echo_warn "POST CHECK FAILED" - echo_warn "--------------------------------------------" - exit 1 -} - - -echo_info "Starting publish post check..." -echo_info "Git revision : ${GITHUB_SHA}" -echo_info "Git ref : ${GITHUB_REF}" -echo_info "Workflow triggered by : ${GITHUB_ACTOR}" -echo_info "GitHub event : ${GITHUB_EVENT_NAME}" - - -echo_info "" -echo_info "--------------------------------------------" -echo_info "Extracting release version" -echo_info "--------------------------------------------" -echo_info "" - -echo_info "Loading version from: firebase.go" - -readonly RELEASE_VERSION=`grep "const Version" firebase.go | awk '{print $4}' | tr -d \"` || true -if [[ -z "${RELEASE_VERSION}" ]]; then - echo_warn "Failed to extract release version from: firebase.go" - terminate -fi - -if [[ ! "${RELEASE_VERSION}" =~ ^([0-9]*)\.([0-9]*)\.([0-9]*)$ ]]; then - echo_warn "Malformed release version string: ${RELEASE_VERSION}. Exiting." - terminate -fi - -echo_info "Extracted release version: ${RELEASE_VERSION}" -echo "::set-output name=version::v${RELEASE_VERSION}" - - -echo_info "" -echo_info "--------------------------------------------" -echo_info "Generating changelog" -echo_info "--------------------------------------------" -echo_info "" - -echo_info "---< git fetch origin master --prune --unshallow >---" -git fetch origin master --prune --unshallow -echo "" - -echo_info "Generating changelog from history..." -readonly CURRENT_DIR=$(dirname "$0") -readonly CHANGELOG=`${CURRENT_DIR}/generate_changelog.sh` -echo "$CHANGELOG" - -# Parse and preformat the text to handle multi-line output. -# See https://github.community/t5/GitHub-Actions/set-output-Truncates-Multiline-Strings/td-p/37870 -FILTERED_CHANGELOG=`echo "$CHANGELOG" | grep -v "\\[INFO\\]"` -FILTERED_CHANGELOG="${FILTERED_CHANGELOG//'%'/'%25'}" -FILTERED_CHANGELOG="${FILTERED_CHANGELOG//$'\n'/'%0A'}" -FILTERED_CHANGELOG="${FILTERED_CHANGELOG//$'\r'/'%0D'}" -echo "::set-output name=changelog::${FILTERED_CHANGELOG}" - - -echo "" -echo_info "--------------------------------------------" -echo_info "POST CHECK SUCCESSFUL" -echo_info "--------------------------------------------" diff --git a/.github/scripts/publish_preflight_check.sh b/.github/scripts/publish_preflight_check.sh index 9dad71bb..c2be10bf 100755 --- a/.github/scripts/publish_preflight_check.sh +++ b/.github/scripts/publish_preflight_check.sh @@ -64,6 +64,7 @@ if [[ ! "${RELEASE_VERSION}" =~ ^([0-9]*)\.([0-9]*)\.([0-9]*)$ ]]; then fi echo_info "Extracted release version: ${RELEASE_VERSION}" +echo "::set-output name=version::v${RELEASE_VERSION}" echo_info "" @@ -91,6 +92,30 @@ fi echo_info "Tag v${RELEASE_VERSION} does not exist." +echo_info "" +echo_info "--------------------------------------------" +echo_info "Generating changelog" +echo_info "--------------------------------------------" +echo_info "" + +echo_info "---< git fetch origin dev --prune --unshallow >---" +git fetch origin dev --prune --unshallow +echo "" + +echo_info "Generating changelog from history..." +readonly CURRENT_DIR=$(dirname "$0") +readonly CHANGELOG=`${CURRENT_DIR}/generate_changelog.sh` +echo "$CHANGELOG" + +# Parse and preformat the text to handle multi-line output. +# See https://github.community/t5/GitHub-Actions/set-output-Truncates-Multiline-Strings/td-p/37870 +FILTERED_CHANGELOG=`echo "$CHANGELOG" | grep -v "\\[INFO\\]"` +FILTERED_CHANGELOG="${FILTERED_CHANGELOG//'%'/'%25'}" +FILTERED_CHANGELOG="${FILTERED_CHANGELOG//$'\n'/'%0A'}" +FILTERED_CHANGELOG="${FILTERED_CHANGELOG//$'\r'/'%0D'}" +echo "::set-output name=changelog::${FILTERED_CHANGELOG}" + + echo "" echo_info "--------------------------------------------" echo_info "PREFLIGHT SUCCESSFUL" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb8f93ba..6153e812 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ jobs: GOPATH: ${{ github.workspace }}/go strategy: matrix: - go: [1.11, 1.12, 1.13] + go: [1.12, 1.13, 1.14] steps: - name: Set up Go ${{ matrix.go }} @@ -37,7 +37,6 @@ jobs: - name: Run Formatter working-directory: ./go/src/firebase.google.com/go - if: matrix.go != '1.11' run: | if [[ ! -z "$(gofmt -l -s .)" ]]; then echo "Go code is not formatted:" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index 32f6dbe6..00000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: Publish Release - -on: - # Only run the workflow when a PR is merged to the master branch. - pull_request: - branches: master - types: closed - -jobs: - publish_release: - if: github.event.pull_request.merged - - runs-on: ubuntu-latest - - steps: - - name: Checkout source - uses: actions/checkout@v2 - - - name: Publish post check - id: postcheck - run: ./.github/scripts/publish_post_check.sh - - # We pull this action from a custom fork of a contributor until - # https://github.com/actions/create-release/pull/32 is merged. Also note that v1 of - # this action does not support the "body" parameter. - - name: Create release tag - uses: fleskesvor/create-release@1a72e235c178bf2ae6c51a8ae36febc24568c5fe - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ steps.postcheck.outputs.version }} - release_name: Firebase Admin Go SDK ${{ steps.postcheck.outputs.version }} - body: ${{ steps.postcheck.outputs.changelog }} - draft: false - prerelease: false - - # Post to Twitter if explicitly opted-in by adding the label 'release:tweet'. - - name: Post to Twitter - if: success() && - contains(github.event.pull_request.labels.*.name, 'release:tweet') - uses: firebase/firebase-admin-node/.github/actions/send-tweet@master - with: - status: > - ${{ steps.postcheck.outputs.version }} of @Firebase Admin Go SDK is available. - https://github.com/firebase/firebase-admin-go/releases/tag/${{ steps.postcheck.outputs.version }} - consumer-key: ${{ secrets.FIREBASE_TWITTER_CONSUMER_KEY }} - consumer-secret: ${{ secrets.FIREBASE_TWITTER_CONSUMER_SECRET }} - access-token: ${{ secrets.FIREBASE_TWITTER_ACCESS_TOKEN }} - access-token-secret: ${{ secrets.FIREBASE_TWITTER_ACCESS_TOKEN_SECRET }} - continue-on-error: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..3bfccd48 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,138 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Release Candidate + +on: + # Only run the workflow when a PR is updated or when a developer explicitly requests + # a build by sending a 'firebase_build' event. + pull_request: + types: [opened, synchronize, closed] + + repository_dispatch: + types: + - firebase_build + +jobs: + stage_release: + # To publish a release, merge the release PR with the label 'release:publish'. + # To stage a release without publishing it, send a 'firebase_build' event or apply + # the 'release:stage' label to a PR. + if: github.event.action == 'firebase_build' || + contains(github.event.pull_request.labels.*.name, 'release:stage') || + (github.event.pull_request.merged && + contains(github.event.pull_request.labels.*.name, 'release:publish')) + + runs-on: ubuntu-latest + + env: + GOPATH: ${{ github.workspace }}/go + + # When manually triggering the build, the requester can specify a target branch or a tag + # via the 'ref' client parameter. + steps: + - name: Check out code into GOPATH + uses: actions/checkout@v2 + with: + path: go/src/firebase.google.com/go + ref: ${{ github.event.client_payload.ref || github.ref }} + + - name: Set up Go + uses: actions/setup-go@v1 + with: + go-version: 1.11 + + - name: Get dependencies + run: go get -t -v $(go list ./... | grep -v integration) + + - name: Run Linter + run: | + echo + go get golang.org/x/lint/golint + $GOPATH/bin/golint -set_exit_status firebase.google.com/go/... + + - name: Run Tests + working-directory: ./go/src/firebase.google.com/go + run: ./.github/scripts/run_all_tests.sh + env: + FIREBASE_SERVICE_ACCT_KEY: ${{ secrets.FIREBASE_SERVICE_ACCT_KEY }} + FIREBASE_API_KEY: ${{ secrets.FIREBASE_API_KEY }} + + publish_release: + needs: stage_release + + # Check whether the release should be published. We publish only when the trigger PR is + # 1. merged + # 2. to the dev branch + # 3. with the label 'release:publish', and + # 4. the title prefix '[chore] Release '. + if: github.event.pull_request.merged && + github.ref == 'dev' && + contains(github.event.pull_request.labels.*.name, 'release:publish') && + startsWith(github.event.pull_request.title, '[chore] Release ') + + runs-on: ubuntu-latest + + steps: + - name: Checkout source for publish + uses: actions/checkout@v2 + with: + persist-credentials: false + + - name: Publish preflight check + id: preflight + run: ./.github/scripts/publish_preflight_check.sh + + # We authorize this step with an access token that has write access to the master branch. + - name: Merge to master + uses: actions/github-script@0.9.0 + with: + github-token: ${{ secrets.FIREBASE_GITHUB_TOKEN }} + script: | + github.repos.merge({ + owner: context.repo.owner, + repo: context.repo.repo, + base: 'master', + head: 'dev' + }) + + # We pull this action from a custom fork of a contributor until + # https://github.com/actions/create-release/pull/32 is merged. Also note that v1 of + # this action does not support the "body" parameter. + - name: Create release tag + uses: fleskesvor/create-release@1a72e235c178bf2ae6c51a8ae36febc24568c5fe + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.preflight.outputs.version }} + release_name: Firebase Admin Go SDK ${{ steps.preflight.outputs.version }} + body: ${{ steps.preflight.outputs.changelog }} + commitish: master + draft: false + prerelease: false + + # Post to Twitter if explicitly opted-in by adding the label 'release:tweet'. + - name: Post to Twitter + if: success() && + contains(github.event.pull_request.labels.*.name, 'release:tweet') + uses: firebase/firebase-admin-node/.github/actions/send-tweet@master + with: + status: > + ${{ steps.preflight.outputs.version }} of @Firebase Admin Go SDK is available. + https://github.com/firebase/firebase-admin-go/releases/tag/${{ steps.preflight.outputs.version }} + consumer-key: ${{ secrets.FIREBASE_TWITTER_CONSUMER_KEY }} + consumer-secret: ${{ secrets.FIREBASE_TWITTER_CONSUMER_SECRET }} + access-token: ${{ secrets.FIREBASE_TWITTER_ACCESS_TOKEN }} + access-token-secret: ${{ secrets.FIREBASE_TWITTER_ACCESS_TOKEN_SECRET }} + continue-on-error: true diff --git a/.github/workflows/stage.yml b/.github/workflows/stage.yml deleted file mode 100644 index 540d9508..00000000 --- a/.github/workflows/stage.yml +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2020 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: Stage Release - -on: - # Only run the workflow when a PR is updated or when a developer explicitly requests - # a build by sending a 'firebase_build' event. - pull_request: - types: [opened, synchronize] - - repository_dispatch: - types: - - firebase_build - -jobs: - stage_release: - # To stage a release without publishing it, send a 'firebase_build' event or apply - # the 'release:stage' label to a PR. PRs targetting the master branch are always - # staged. - if: github.event.action == 'firebase_build' || - contains(github.event.pull_request.labels.*.name, 'release:stage') || - github.event.pull_request.base.ref == 'master' - - runs-on: ubuntu-latest - - env: - GOPATH: ${{ github.workspace }}/go - - # When manually triggering the build, the requester can specify a target branch or a tag - # via the 'ref' client parameter. - steps: - - name: Check out code into GOPATH - uses: actions/checkout@v2 - with: - path: go/src/firebase.google.com/go - ref: ${{ github.event.client_payload.ref || github.ref }} - - - name: Set up Go - uses: actions/setup-go@v1 - with: - go-version: 1.11 - - - name: Get dependencies - run: go get -t -v $(go list ./... | grep -v integration) - - - name: Run Linter - run: | - echo - go get golang.org/x/lint/golint - $GOPATH/bin/golint -set_exit_status firebase.google.com/go/... - - - name: Run Tests - working-directory: ./go/src/firebase.google.com/go - run: ./.github/scripts/run_all_tests.sh - env: - FIREBASE_SERVICE_ACCT_KEY: ${{ secrets.FIREBASE_SERVICE_ACCT_KEY }} - FIREBASE_API_KEY: ${{ secrets.FIREBASE_API_KEY }} - - # If triggered by a PR against the master branch, run additional checks. - - name: Publish preflight check - if: github.event.pull_request.base.ref == 'master' - working-directory: ./go/src/firebase.google.com/go - run: ./.github/scripts/publish_preflight_check.sh diff --git a/README.md b/README.md index 899a88d4..6a8fb4f0 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![Build Status](https://github.com/firebase/firebase-admin-go/workflows/Continuous%20Integration/badge.svg)](https://github.com/firebase/firebase-admin-go/actions) +[![Build Status](https://github.com/firebase/firebase-admin-go/workflows/Continuous%20Integration/badge.svg?branch=dev)](https://github.com/firebase/firebase-admin-go/actions) [![GoDoc](https://godoc.org/firebase.google.com/go?status.svg)](https://godoc.org/firebase.google.com/go) [![Go Report Card](https://goreportcard.com/badge/github.com/firebase/firebase-admin-go)](https://goreportcard.com/report/github.com/firebase/firebase-admin-go) diff --git a/auth/auth.go b/auth/auth.go index b5088806..8a86e583 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -44,8 +44,6 @@ var reservedClaims = []string{ type Client struct { *baseClient TenantManager *TenantManager - signer cryptoSigner - clock internal.Clock } // NewClient creates a new instance of the Firebase Auth Client. @@ -116,11 +114,11 @@ func NewClient(ctx context.Context, conf *internal.AuthConfig) (*Client, error) httpClient: hc, idTokenVerifier: idTokenVerifier, cookieVerifier: cookieVerifier, + signer: signer, + clock: internal.SystemClock, } return &Client{ baseClient: base, - signer: signer, - clock: internal.SystemClock, TenantManager: newTenantManager(hc, conf, base), }, nil } @@ -144,13 +142,13 @@ func NewClient(ctx context.Context, conf *internal.AuthConfig) (*Client, error) // conjunction with the IAM service to sign tokens remotely. // // CustomToken returns an error the SDK fails to discover a viable mechanism for signing tokens. -func (c *Client) CustomToken(ctx context.Context, uid string) (string, error) { +func (c *baseClient) CustomToken(ctx context.Context, uid string) (string, error) { return c.CustomTokenWithClaims(ctx, uid, nil) } // CustomTokenWithClaims is similar to CustomToken, but in addition to the user ID, it also encodes // all the key-value pairs in the provided map as claims in the resulting JWT. -func (c *Client) CustomTokenWithClaims(ctx context.Context, uid string, devClaims map[string]interface{}) (string, error) { +func (c *baseClient) CustomTokenWithClaims(ctx context.Context, uid string, devClaims map[string]interface{}) (string, error) { iss, err := c.signer.Email(ctx) if err != nil { return "", err @@ -176,13 +174,14 @@ func (c *Client) CustomTokenWithClaims(ctx context.Context, uid string, devClaim info := &jwtInfo{ header: jwtHeader{Algorithm: "RS256", Type: "JWT"}, payload: &customToken{ - Iss: iss, - Sub: iss, - Aud: firebaseAudience, - UID: uid, - Iat: now, - Exp: now + oneHourInSeconds, - Claims: devClaims, + Iss: iss, + Sub: iss, + Aud: firebaseAudience, + UID: uid, + Iat: now, + Exp: now + oneHourInSeconds, + TenantID: c.tenantID, + Claims: devClaims, }, } return info.Token(ctx, c.signer) @@ -235,6 +234,8 @@ type baseClient struct { httpClient *internal.HTTPClient idTokenVerifier *tokenVerifier cookieVerifier *tokenVerifier + signer cryptoSigner + clock internal.Clock } func (c *baseClient) withTenantID(tenantID string) *baseClient { diff --git a/auth/auth_test.go b/auth/auth_test.go index 1e8e1098..f42df871 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -282,20 +282,26 @@ func TestNewClientExplicitNoAuth(t *testing.T) { func TestCustomToken(t *testing.T) { client := &Client{ - signer: testSigner, - clock: testClock, + baseClient: &baseClient{ + signer: testSigner, + clock: testClock, + }, } token, err := client.CustomToken(context.Background(), "user1") if err != nil { t.Fatal(err) } - verifyCustomToken(context.Background(), token, nil, t) + if err := verifyCustomToken(context.Background(), token, nil, ""); err != nil { + t.Fatal(err) + } } func TestCustomTokenWithClaims(t *testing.T) { client := &Client{ - signer: testSigner, - clock: testClock, + baseClient: &baseClient{ + signer: testSigner, + clock: testClock, + }, } claims := map[string]interface{}{ "foo": "bar", @@ -306,19 +312,46 @@ func TestCustomTokenWithClaims(t *testing.T) { if err != nil { t.Fatal(err) } - verifyCustomToken(context.Background(), token, claims, t) + if err := verifyCustomToken(context.Background(), token, claims, ""); err != nil { + t.Fatal(err) + } } func TestCustomTokenWithNilClaims(t *testing.T) { client := &Client{ - signer: testSigner, - clock: testClock, + baseClient: &baseClient{ + signer: testSigner, + clock: testClock, + }, } token, err := client.CustomTokenWithClaims(context.Background(), "user1", nil) if err != nil { t.Fatal(err) } - verifyCustomToken(context.Background(), token, nil, t) + if err := verifyCustomToken(context.Background(), token, nil, ""); err != nil { + t.Fatal(err) + } +} + +func TestCustomTokenForTenant(t *testing.T) { + client := &Client{ + baseClient: &baseClient{ + tenantID: "tenantID", + signer: testSigner, + clock: testClock, + }, + } + claims := map[string]interface{}{ + "foo": "bar", + "premium": true, + } + token, err := client.CustomTokenWithClaims(context.Background(), "user1", claims) + if err != nil { + t.Fatal(err) + } + if err := verifyCustomToken(context.Background(), token, claims, "tenantID"); err != nil { + t.Fatal(err) + } } func TestCustomTokenError(t *testing.T) { @@ -333,7 +366,7 @@ func TestCustomTokenError(t *testing.T) { {"ReservedClaims", "uid", map[string]interface{}{"sub": "1234", "aud": "foo"}}, } - client := &Client{ + client := &baseClient{ signer: testSigner, clock: testClock, } @@ -628,9 +661,9 @@ func TestCustomTokenVerification(t *testing.T) { client := &Client{ baseClient: &baseClient{ idTokenVerifier: testIDTokenVerifier, + signer: testSigner, + clock: testClock, }, - signer: testSigner, - clock: testClock, } token, err := client.CustomToken(context.Background(), "user1") if err != nil { @@ -1137,52 +1170,61 @@ func checkBaseClient(client *Client, wantProjectID string) error { return nil } -func verifyCustomToken(ctx context.Context, token string, expected map[string]interface{}, t *testing.T) { +func verifyCustomToken( + ctx context.Context, token string, expected map[string]interface{}, tenantID string) error { + if err := testIDTokenVerifier.verifySignature(ctx, token); err != nil { - t.Fatal(err) + return err } + var ( header jwtHeader payload customToken ) segments := strings.Split(token, ".") if err := decode(segments[0], &header); err != nil { - t.Fatal(err) + return err } if err := decode(segments[1], &payload); err != nil { - t.Fatal(err) + return err } email, err := testSigner.Email(ctx) if err != nil { - t.Fatal(err) + return err } if header.Algorithm != "RS256" { - t.Errorf("Algorithm: %q; want: 'RS256'", header.Algorithm) + return fmt.Errorf("Algorithm: %q; want: 'RS256'", header.Algorithm) } else if header.Type != "JWT" { - t.Errorf("Type: %q; want: 'JWT'", header.Type) + return fmt.Errorf("Type: %q; want: 'JWT'", header.Type) } else if payload.Aud != firebaseAudience { - t.Errorf("Audience: %q; want: %q", payload.Aud, firebaseAudience) + return fmt.Errorf("Audience: %q; want: %q", payload.Aud, firebaseAudience) } else if payload.Iss != email { - t.Errorf("Issuer: %q; want: %q", payload.Iss, email) + return fmt.Errorf("Issuer: %q; want: %q", payload.Iss, email) } else if payload.Sub != email { - t.Errorf("Subject: %q; want: %q", payload.Sub, email) + return fmt.Errorf("Subject: %q; want: %q", payload.Sub, email) } now := testClock.Now().Unix() if payload.Exp != now+3600 { - t.Errorf("Exp: %d; want: %d", payload.Exp, now+3600) + return fmt.Errorf("Exp: %d; want: %d", payload.Exp, now+3600) } if payload.Iat != now { - t.Errorf("Iat: %d; want: %d", payload.Iat, now) + return fmt.Errorf("Iat: %d; want: %d", payload.Iat, now) } for k, v := range expected { if payload.Claims[k] != v { - t.Errorf("Claim[%q]: %v; want: %v", k, payload.Claims[k], v) + return fmt.Errorf("Claim[%q]: %v; want: %v", k, payload.Claims[k], v) } } + + if payload.TenantID != tenantID { + return fmt.Errorf("Tenant ID: %q; want: %q", payload.TenantID, tenantID) + } + + return nil } func logFatal(err error) { diff --git a/auth/token_generator.go b/auth/token_generator.go index 963767d1..98acad5a 100644 --- a/auth/token_generator.go +++ b/auth/token_generator.go @@ -41,13 +41,14 @@ type jwtHeader struct { } type customToken struct { - Iss string `json:"iss"` - Aud string `json:"aud"` - Exp int64 `json:"exp"` - Iat int64 `json:"iat"` - Sub string `json:"sub,omitempty"` - UID string `json:"uid,omitempty"` - Claims map[string]interface{} `json:"claims,omitempty"` + Iss string `json:"iss"` + Aud string `json:"aud"` + Exp int64 `json:"exp"` + Iat int64 `json:"iat"` + Sub string `json:"sub,omitempty"` + UID string `json:"uid,omitempty"` + TenantID string `json:"tenant_id,omitempty"` + Claims map[string]interface{} `json:"claims,omitempty"` } type jwtInfo struct { diff --git a/auth/user_mgt.go b/auth/user_mgt.go index e1a18a0f..e601d8bc 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -34,6 +34,12 @@ const ( maxLenPayloadCC = 1000 defaultProviderID = "firebase" idToolkitV1Endpoint = "https://identitytoolkit.googleapis.com/v1" + + // Maximum number of users allowed to batch get at a time. + maxGetAccountsBatchSize = 100 + + // Maximum number of users allowed to batch delete at a time. + maxDeleteAccountsBatchSize = 1000 ) // 'REDACTED', encoded as a base64 string. @@ -57,6 +63,9 @@ type UserInfo struct { type UserMetadata struct { CreationTimestamp int64 LastLogInTimestamp int64 + // The time at which the user was last active (ID token refreshed), or 0 if + // the user was never active. + LastRefreshTimestamp int64 } // UserRecord contains metadata associated with a Firebase user account. @@ -491,6 +500,15 @@ func validatePhone(phone string) error { return nil } +func validateProvider(providerID string, providerUID string) error { + if providerID == "" { + return fmt.Errorf("providerID must be a non-empty string") + } else if providerUID == "" { + return fmt.Errorf("providerUID must be a non-empty string") + } + return nil +} + // End of validators // GetUser gets the user data corresponding to the specified user ID. @@ -545,12 +563,13 @@ func (q *userQuery) build() map[string]interface{} { } } +type getAccountInfoResponse struct { + Users []*userQueryResponse `json:"users"` +} + func (c *baseClient) getUser(ctx context.Context, query *userQuery) (*UserRecord, error) { - var parsed struct { - Users []*userQueryResponse `json:"users"` - } - _, err := c.post(ctx, "/accounts:lookup", query.build(), &parsed) - if err != nil { + var parsed getAccountInfoResponse + if _, err := c.post(ctx, "/accounts:lookup", query.build(), &parsed); err != nil { return nil, err } @@ -561,6 +580,195 @@ func (c *baseClient) getUser(ctx context.Context, query *userQuery) (*UserRecord return parsed.Users[0].makeUserRecord() } +// A UserIdentifier identifies a user to be looked up. +type UserIdentifier interface { + matches(ur *UserRecord) bool + populate(req *getAccountInfoRequest) +} + +// A UIDIdentifier is used for looking up an account by uid. +// +// See GetUsers function. +type UIDIdentifier struct { + UID string +} + +func (id UIDIdentifier) matches(ur *UserRecord) bool { + return id.UID == ur.UID +} + +func (id UIDIdentifier) populate(req *getAccountInfoRequest) { + req.LocalID = append(req.LocalID, id.UID) +} + +// An EmailIdentifier is used for looking up an account by email. +// +// See GetUsers function. +type EmailIdentifier struct { + Email string +} + +func (id EmailIdentifier) matches(ur *UserRecord) bool { + return id.Email == ur.Email +} + +func (id EmailIdentifier) populate(req *getAccountInfoRequest) { + req.Email = append(req.Email, id.Email) +} + +// A PhoneIdentifier is used for looking up an account by phone number. +// +// See GetUsers function. +type PhoneIdentifier struct { + PhoneNumber string +} + +func (id PhoneIdentifier) matches(ur *UserRecord) bool { + return id.PhoneNumber == ur.PhoneNumber +} + +func (id PhoneIdentifier) populate(req *getAccountInfoRequest) { + req.PhoneNumber = append(req.PhoneNumber, id.PhoneNumber) +} + +// A ProviderIdentifier is used for looking up an account by federated provider. +// +// See GetUsers function. +type ProviderIdentifier struct { + ProviderID string + ProviderUID string +} + +func (id ProviderIdentifier) matches(ur *UserRecord) bool { + for _, userInfo := range ur.ProviderUserInfo { + if id.ProviderID == userInfo.ProviderID && id.ProviderUID == userInfo.UID { + return true + } + } + return false +} + +func (id ProviderIdentifier) populate(req *getAccountInfoRequest) { + req.FederatedUserID = append( + req.FederatedUserID, + federatedUserIdentifier{ProviderID: id.ProviderID, RawID: id.ProviderUID}) +} + +// A GetUsersResult represents the result of the GetUsers() API. +type GetUsersResult struct { + // Set of UserRecords corresponding to the set of users that were requested. + // Only users that were found are listed here. The result set is unordered. + Users []*UserRecord + + // Set of UserIdentifiers that were requested, but not found. + NotFound []UserIdentifier +} + +type federatedUserIdentifier struct { + ProviderID string `json:"providerId,omitempty"` + RawID string `json:"rawId,omitempty"` +} + +type getAccountInfoRequest struct { + LocalID []string `json:"localId,omitempty"` + Email []string `json:"email,omitempty"` + PhoneNumber []string `json:"phoneNumber,omitempty"` + FederatedUserID []federatedUserIdentifier `json:"federatedUserId,omitempty"` +} + +func (req *getAccountInfoRequest) validate() error { + for i := range req.LocalID { + if err := validateUID(req.LocalID[i]); err != nil { + return err + } + } + + for i := range req.Email { + if err := validateEmail(req.Email[i]); err != nil { + return err + } + } + + for i := range req.PhoneNumber { + if err := validatePhone(req.PhoneNumber[i]); err != nil { + return err + } + } + + for i := range req.FederatedUserID { + id := &req.FederatedUserID[i] + if err := validateProvider(id.ProviderID, id.RawID); err != nil { + return err + } + } + + return nil +} + +func isUserFound(id UserIdentifier, urs [](*UserRecord)) bool { + for i := range urs { + if id.matches(urs[i]) { + return true + } + } + return false +} + +// GetUsers returns the user data corresponding to the specified identifiers. +// +// There are no ordering guarantees; in particular, the nth entry in the users +// result list is not guaranteed to correspond to the nth entry in the input +// parameters list. +// +// A maximum of 100 identifiers may be supplied. If more than 100 +// identifiers are supplied, this method returns an error. +// +// Returns the corresponding user records. An error is returned instead if any +// of the identifiers are invalid or if more than 100 identifiers are +// specified. +func (c *baseClient) GetUsers( + ctx context.Context, identifiers []UserIdentifier, +) (*GetUsersResult, error) { + if len(identifiers) == 0 { + return &GetUsersResult{[](*UserRecord){}, [](UserIdentifier){}}, nil + } else if len(identifiers) > maxGetAccountsBatchSize { + return nil, fmt.Errorf( + "`identifiers` parameter must have <= %d entries", maxGetAccountsBatchSize) + } + + var request getAccountInfoRequest + for i := range identifiers { + identifiers[i].populate(&request) + } + + if err := request.validate(); err != nil { + return nil, err + } + + var parsed getAccountInfoResponse + if _, err := c.post(ctx, "/accounts:lookup", request, &parsed); err != nil { + return nil, err + } + + var userRecords [](*UserRecord) + for _, user := range parsed.Users { + userRecord, err := user.makeUserRecord() + if err != nil { + return nil, err + } + userRecords = append(userRecords, userRecord) + } + + var notFound []UserIdentifier + for i := range identifiers { + if !isUserFound(identifiers[i], userRecords) { + notFound = append(notFound, identifiers[i]) + } + } + + return &GetUsersResult{userRecords, notFound}, nil +} + type userQueryResponse struct { UID string `json:"localId,omitempty"` DisplayName string `json:"displayName,omitempty"` @@ -569,6 +777,7 @@ type userQueryResponse struct { PhotoURL string `json:"photoUrl,omitempty"` CreationTimestamp int64 `json:"createdAt,string,omitempty"` LastLogInTimestamp int64 `json:"lastLoginAt,string,omitempty"` + LastRefreshAt string `json:"lastRefreshAt,omitempty"` ProviderID string `json:"providerId,omitempty"` CustomAttributes string `json:"customAttributes,omitempty"` Disabled bool `json:"disabled,omitempty"` @@ -592,8 +801,7 @@ func (r *userQueryResponse) makeUserRecord() (*UserRecord, error) { func (r *userQueryResponse) makeExportedUserRecord() (*ExportedUserRecord, error) { var customClaims map[string]interface{} if r.CustomAttributes != "" { - err := json.Unmarshal([]byte(r.CustomAttributes), &customClaims) - if err != nil { + if err := json.Unmarshal([]byte(r.CustomAttributes), &customClaims); err != nil { return nil, err } if len(customClaims) == 0 { @@ -609,6 +817,15 @@ func (r *userQueryResponse) makeExportedUserRecord() (*ExportedUserRecord, error hash = "" } + var lastRefreshTimestamp int64 + if r.LastRefreshAt != "" { + t, err := time.Parse(time.RFC3339, r.LastRefreshAt) + if err != nil { + return nil, err + } + lastRefreshTimestamp = t.Unix() * 1000 + } + return &ExportedUserRecord{ UserRecord: &UserRecord{ UserInfo: &UserInfo{ @@ -626,8 +843,9 @@ func (r *userQueryResponse) makeExportedUserRecord() (*ExportedUserRecord, error TenantID: r.TenantID, TokensValidAfterMillis: r.ValidSinceSeconds * 1000, UserMetadata: &UserMetadata{ - LastLogInTimestamp: r.LastLogInTimestamp, - CreationTimestamp: r.CreationTimestamp, + LastLogInTimestamp: r.LastLogInTimestamp, + CreationTimestamp: r.CreationTimestamp, + LastRefreshTimestamp: lastRefreshTimestamp, }, }, PasswordHash: hash, @@ -728,6 +946,91 @@ func (c *baseClient) DeleteUser(ctx context.Context, uid string) error { return err } +// A DeleteUsersResult represents the result of the DeleteUsers() call. +type DeleteUsersResult struct { + // The number of users that were deleted successfully (possibly zero). Users + // that did not exist prior to calling DeleteUsers() are considered to be + // successfully deleted. + SuccessCount int + + // The number of users that failed to be deleted (possibly zero). + FailureCount int + + // A list of DeleteUsersErrorInfo instances describing the errors that were + // encountered during the deletion. Length of this list is equal to the value + // of FailureCount. + Errors []*DeleteUsersErrorInfo +} + +// DeleteUsersErrorInfo represents an error encountered while deleting a user +// account. +// +// The Index field corresponds to the index of the failed user in the uids +// array that was passed to DeleteUsers(). +type DeleteUsersErrorInfo struct { + Index int `json:"index,omitEmpty"` + Reason string `json:"message,omitEmpty"` +} + +// DeleteUsers deletes the users specified by the given identifiers. +// +// Deleting a non-existing user won't generate an error. (i.e. this method is +// idempotent.) Non-existing users are considered to be successfully +// deleted, and are therefore counted in the DeleteUsersResult.SuccessCount +// value. +// +// A maximum of 1000 identifiers may be supplied. If more than 1000 +// identifiers are supplied, this method returns an error. +// +// This API is currently rate limited at the server to 1 QPS. If you exceed +// this, you may get a quota exceeded error. Therefore, if you want to delete +// more than 1000 users, you may need to add a delay to ensure you don't go +// over this limit. +// +// Returns the total number of successful/failed deletions, as well as the +// array of errors that correspond to the failed deletions. An error is +// returned if any of the identifiers are invalid or if more than 1000 +// identifiers are specified. +func (c *baseClient) DeleteUsers(ctx context.Context, uids []string) (*DeleteUsersResult, error) { + if len(uids) == 0 { + return &DeleteUsersResult{}, nil + } else if len(uids) > maxDeleteAccountsBatchSize { + return nil, fmt.Errorf( + "`uids` parameter must have <= %d entries", maxDeleteAccountsBatchSize) + } + + var payload struct { + LocalIds []string `json:"localIds"` + Force bool `json:"force"` + } + payload.Force = true + + for i := range uids { + if err := validateUID(uids[i]); err != nil { + return nil, err + } + + payload.LocalIds = append(payload.LocalIds, uids[i]) + } + + type batchDeleteAccountsResponse struct { + Errors []*DeleteUsersErrorInfo `json:"errors"` + } + + resp := batchDeleteAccountsResponse{} + if _, err := c.post(ctx, "/accounts:batchDelete", payload, &resp); err != nil { + return nil, err + } + + result := DeleteUsersResult{ + FailureCount: len(resp.Errors), + SuccessCount: len(uids) - len(resp.Errors), + Errors: resp.Errors, + } + + return &result, nil +} + // SessionCookie creates a new Firebase session cookie from the given ID token and expiry // duration. The returned JWT can be set as a server-side session cookie with a custom cookie // policy. Expiry duration must be at least 5 minutes but may not exceed 14 days. diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index b2591f1e..0ee5c678 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -24,6 +24,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "sort" "strconv" "strings" "testing" @@ -157,6 +158,236 @@ func TestInvalidGetUser(t *testing.T) { } } +// Checks to see if the users list contain the given uids. Order is ignored. +// +// Behaviour is undefined if there are duplicate entries in either of the +// slices. +// +// This function is identical to the one in integration/auth/user_mgt_test.go +func sameUsers(users [](*UserRecord), uids []string) bool { + if len(users) != len(uids) { + return false + } + + sort.Slice(users, func(i, j int) bool { + return users[i].UID < users[j].UID + }) + sort.Slice(uids, func(i, j int) bool { + return uids[i] < uids[j] + }) + + for i := range users { + if users[i].UID != uids[i] { + return false + } + } + + return true +} + +func TestGetUsersExceeds100(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + var identifiers [101]UserIdentifier + for i := 0; i < 101; i++ { + identifiers[i] = &UIDIdentifier{UID: fmt.Sprintf("id%d", i)} + } + + getUsersResult, err := client.GetUsers(context.Background(), identifiers[:]) + want := "`identifiers` parameter must have <= 100 entries" + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf( + "GetUsers() = (%v, %q); want = (nil, %q)", + getUsersResult, err, want) + } +} + +func TestGetUsersEmpty(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + getUsersResult, err := client.GetUsers(context.Background(), [](UserIdentifier){}) + if getUsersResult == nil || err != nil { + t.Fatalf("GetUsers([]) = %q", err) + } + + if len(getUsersResult.Users) != 0 { + t.Errorf("len(GetUsers([]).Users) = %d; want 0", len(getUsersResult.Users)) + } + if len(getUsersResult.NotFound) != 0 { + t.Errorf("len(GetUsers([]).NotFound) = %d; want 0", len(getUsersResult.NotFound)) + } +} + +func TestGetUsersAllNonExisting(t *testing.T) { + resp := `{ + "kind" : "identitytoolkit#GetAccountInfoResponse", + "users" : [] + }` + s := echoServer([]byte(resp), t) + defer s.Close() + + notFoundIds := []UserIdentifier{&UIDIdentifier{"id that doesnt exist"}} + getUsersResult, err := s.Client.GetUsers(context.Background(), notFoundIds) + if err != nil { + t.Fatalf("GetUsers() = %q", err) + } + + if len(getUsersResult.Users) != 0 { + t.Errorf( + "len(GetUsers().Users) = %d; want 0", + len(getUsersResult.Users)) + } + if len(getUsersResult.NotFound) != len(notFoundIds) { + t.Errorf("len(GetUsers()).NotFound) = %d; want %d", + len(getUsersResult.NotFound), len(notFoundIds)) + } else { + for i := range notFoundIds { + if getUsersResult.NotFound[i] != notFoundIds[i] { + t.Errorf("GetUsers().NotFound[%d] = %v; want %v", + i, getUsersResult.NotFound[i], notFoundIds[i]) + } + } + } +} + +func TestGetUsersInvalidUid(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + getUsersResult, err := client.GetUsers( + context.Background(), + []UserIdentifier{&UIDIdentifier{"too long " + strings.Repeat(".", 128)}}) + want := "uid string must not be longer than 128 characters" + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) + } +} + +func TestGetUsersInvalidEmail(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + getUsersResult, err := client.GetUsers( + context.Background(), + []UserIdentifier{EmailIdentifier{"invalid email addr"}}) + want := `malformed email string: "invalid email addr"` + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) + } +} + +func TestGetUsersInvalidPhoneNumber(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + getUsersResult, err := client.GetUsers(context.Background(), []UserIdentifier{ + PhoneIdentifier{"invalid phone number"}, + }) + want := "phone number must be a valid, E.164 compliant identifier" + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) + } +} + +func TestGetUsersInvalidProvider(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + getUsersResult, err := client.GetUsers(context.Background(), []UserIdentifier{ + ProviderIdentifier{ProviderID: "", ProviderUID: ""}, + }) + want := "providerID must be a non-empty string" + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) + } +} + +func TestGetUsersSingleBadIdentifier(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + identifiers := []UserIdentifier{ + UIDIdentifier{"valid_id1"}, + UIDIdentifier{"valid_id2"}, + UIDIdentifier{"invalid id; too long. " + strings.Repeat(".", 128)}, + UIDIdentifier{"valid_id3"}, + UIDIdentifier{"valid_id4"}, + } + + getUsersResult, err := client.GetUsers(context.Background(), identifiers) + want := "uid string must not be longer than 128 characters" + if getUsersResult != nil || err == nil || err.Error() != want { + t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) + } +} + +func TestGetUsersMultipleIdentifierTypes(t *testing.T) { + mockUsers := []byte(` + { + "users": [{ + "localId": "uid1", + "email": "user1@example.com", + "phoneNumber": "+15555550001" + }, { + "localId": "uid2", + "email": "user2@example.com", + "phoneNumber": "+15555550002" + }, { + "localId": "uid3", + "email": "user3@example.com", + "phoneNumber": "+15555550003" + }, { + "localId": "uid4", + "email": "user4@example.com", + "phoneNumber": "+15555550004", + "providerUserInfo": [{ + "providerId": "google.com", + "rawId": "google_uid4" + }] + }] + }`) + s := echoServer(mockUsers, t) + defer s.Close() + + identifiers := []UserIdentifier{ + &UIDIdentifier{"uid1"}, + &EmailIdentifier{"user2@example.com"}, + &PhoneIdentifier{"+15555550003"}, + &ProviderIdentifier{ProviderID: "google.com", ProviderUID: "google_uid4"}, + &UIDIdentifier{"this-user-doesnt-exist"}, + } + + getUsersResult, err := s.Client.GetUsers(context.Background(), identifiers) + if err != nil { + t.Fatalf("GetUsers() = %q", err) + } + + if !sameUsers(getUsersResult.Users, []string{"uid1", "uid2", "uid3", "uid4"}) { + t.Errorf("GetUsers() = %v; want = (uids from) %v (in any order)", + getUsersResult.Users, []string{"uid1", "uid2", "uid3", "uid4"}) + } + if len(getUsersResult.NotFound) != 1 { + t.Errorf("GetUsers() = %d; want = 1", len(getUsersResult.NotFound)) + } else { + if id, ok := getUsersResult.NotFound[0].(*UIDIdentifier); !ok { + t.Errorf("GetUsers().NotFound[0] not a UIDIdentifier") + } else { + if id.UID != "this-user-doesnt-exist" { + t.Errorf("GetUsers().NotFound[0].UID = %s; want = 'this-user-doesnt-exist'", id.UID) + } + } + } +} + func TestGetNonExistingUser(t *testing.T) { resp := `{ "kind" : "identitytoolkit#GetAccountInfoResponse", @@ -1079,6 +1310,110 @@ func TestInvalidDeleteUser(t *testing.T) { } } +func TestDeleteUsers(t *testing.T) { + client := &Client{ + baseClient: &baseClient{}, + } + + t.Run("should succeed given an empty list", func(t *testing.T) { + result, err := client.DeleteUsers(context.Background(), []string{}) + + if err != nil { + t.Fatalf("DeleteUsers([]) error %v; want = nil", err) + } + + if result.SuccessCount != 0 { + t.Errorf("DeleteUsers([]).SuccessCount = %d; want = 0", result.SuccessCount) + } + if result.FailureCount != 0 { + t.Errorf("DeleteUsers([]).FailureCount = %d; want = 0", result.FailureCount) + } + if len(result.Errors) != 0 { + t.Errorf("len(DeleteUsers([]).Errors) = %d; want = 0", len(result.Errors)) + } + }) + + t.Run("should be rejected when given more than 1000 identifiers", func(t *testing.T) { + uids := []string{} + for i := 0; i < 1001; i++ { + uids = append(uids, fmt.Sprintf("id%d", i)) + } + + _, err := client.DeleteUsers(context.Background(), uids) + if err == nil { + t.Fatalf("DeleteUsers([too_many_uids]) error nil; want not nil") + } + + if err.Error() != "`uids` parameter must have <= 1000 entries" { + t.Errorf( + "DeleteUsers([too_many_uids]) returned an error of '%s'; "+ + "expected '`uids` parameter must have <= 1000 entries'", + err.Error()) + } + }) + + t.Run("should immediately fail given an invalid id", func(t *testing.T) { + tooLongUID := "too long " + strings.Repeat(".", 128) + _, err := client.DeleteUsers(context.Background(), []string{tooLongUID}) + + if err == nil { + t.Fatalf("DeleteUsers([too_long_uid]) error nil; want not nil") + } + + if err.Error() != "uid string must not be longer than 128 characters" { + t.Errorf( + "DeleteUsers([too_long_uid]) returned an error of '%s'; "+ + "expected 'uid string must not be longer than 128 characters'", + err.Error()) + } + }) + + t.Run("should index errors correctly in result", func(t *testing.T) { + resp := `{ + "errors": [{ + "index": 0, + "localId": "uid1", + "message": "Error Message 1" + }, { + "index": 2, + "localId": "uid3", + "message": "Error Message 2" + }] + }` + s := echoServer([]byte(resp), t) + defer s.Close() + + result, err := s.Client.DeleteUsers(context.Background(), []string{"uid1", "uid2", "uid3", "uid4"}) + + if err != nil { + t.Fatalf("DeleteUsers([...]) error %v; want = nil", err) + } + + if result.SuccessCount != 2 { + t.Errorf("DeleteUsers([...]).SuccessCount = %d; want 2", result.SuccessCount) + } + if result.FailureCount != 2 { + t.Errorf("DeleteUsers([...]).FailureCount = %d; want 2", result.FailureCount) + } + if len(result.Errors) != 2 { + t.Errorf("len(DeleteUsers([...]).Errors) = %d; want 2", len(result.Errors)) + } else { + if result.Errors[0].Index != 0 { + t.Errorf("DeleteUsers([...]).Errors[0].Index = %d; want 0", result.Errors[0].Index) + } + if result.Errors[0].Reason != "Error Message 1" { + t.Errorf("DeleteUsers([...]).Errors[0].Reason = %s; want Error Message 1", result.Errors[0].Reason) + } + if result.Errors[1].Index != 2 { + t.Errorf("DeleteUsers([...]).Errors[1].Index = %d; want 2", result.Errors[1].Index) + } + if result.Errors[1].Reason != "Error Message 2" { + t.Errorf("DeleteUsers([...]).Errors[1].Reason = %s; want Error Message 2", result.Errors[1].Reason) + } + } + }) +} + func TestMakeExportedUser(t *testing.T) { queryResponse := &userQueryResponse{ UID: "testuser", diff --git a/firebase.go b/firebase.go index 0343d15b..f7cfaeec 100644 --- a/firebase.go +++ b/firebase.go @@ -38,7 +38,7 @@ import ( var defaultAuthOverrides = make(map[string]interface{}) // Version of the Firebase Go Admin SDK. -const Version = "3.12.1" +const Version = "3.13.0" // firebaseEnvName is the name of the environment variable with the Config. const firebaseEnvName = "FIREBASE_CONFIG" diff --git a/integration/auth/auth_test.go b/integration/auth/auth_test.go index b5638f6a..a9fa7b3a 100644 --- a/integration/auth/auth_test.go +++ b/integration/auth/auth_test.go @@ -207,10 +207,19 @@ func verifyCustomToken(t *testing.T, ct, uid string) *auth.Token { } func signInWithCustomToken(token string) (string, error) { - req, err := json.Marshal(map[string]interface{}{ + return signInWithCustomTokenForTenant(token, "") +} + +func signInWithCustomTokenForTenant(token string, tenantID string) (string, error) { + payload := map[string]interface{}{ "token": token, "returnSecureToken": true, - }) + } + if tenantID != "" { + payload["tenantId"] = tenantID + } + + req, err := json.Marshal(payload) if err != nil { return "", err } @@ -230,8 +239,9 @@ func signInWithCustomToken(token string) (string, error) { func signInWithPassword(email, password string) (string, error) { req, err := json.Marshal(map[string]interface{}{ - "email": email, - "password": password, + "email": email, + "password": password, + "returnSecureToken": true, }) if err != nil { return "", err diff --git a/integration/auth/tenant_mgt_test.go b/integration/auth/tenant_mgt_test.go index e1f4986b..e803443b 100644 --- a/integration/auth/tenant_mgt_test.go +++ b/integration/auth/tenant_mgt_test.go @@ -97,6 +97,10 @@ func TestTenantManager(t *testing.T) { } }) + t.Run("CustomTokens", func(t *testing.T) { + testTenantAwareCustomToken(t, id) + }) + t.Run("UserManagement", func(t *testing.T) { testTenantAwareUserManagement(t, id) }) @@ -154,6 +158,40 @@ func TestTenantManager(t *testing.T) { }) } +func testTenantAwareCustomToken(t *testing.T, id string) { + tenantClient, err := client.TenantManager.AuthForTenant(id) + if err != nil { + t.Fatalf("AuthForTenant() = %v", err) + } + + uid := randomUID() + ct, err := tenantClient.CustomToken(context.Background(), uid) + if err != nil { + t.Fatal(err) + } + + idToken, err := signInWithCustomTokenForTenant(ct, id) + if err != nil { + t.Fatal(err) + } + + defer func() { + tenantClient.DeleteUser(context.Background(), uid) + }() + + vt, err := tenantClient.VerifyIDToken(context.Background(), idToken) + if err != nil { + t.Fatal(err) + } + + if vt.UID != uid { + t.Errorf("UID = %q; want UID = %q", vt.UID, uid) + } + if vt.Firebase.Tenant != id { + t.Errorf("Tenant = %q; want = %q", vt.Firebase.Tenant, id) + } +} + func testTenantAwareUserManagement(t *testing.T, id string) { tenantClient, err := client.TenantManager.AuthForTenant(id) if err != nil { diff --git a/integration/auth/user_mgt_test.go b/integration/auth/user_mgt_test.go index 1d4ef64e..63419a87 100644 --- a/integration/auth/user_mgt_test.go +++ b/integration/auth/user_mgt_test.go @@ -23,6 +23,7 @@ import ( "math/rand" "net/url" "reflect" + "sort" "strings" "testing" "time" @@ -91,6 +92,179 @@ func TestGetNonExistingUser(t *testing.T) { } } +func TestGetUsers(t *testing.T) { + // Checks to see if the users list contain the given uids. Order is ignored. + // + // Behaviour is undefined if there are duplicate entries in either of the + // slices. + // + // This function is identical to the one in auth/user_mgt_test.go + sameUsers := func(users [](*auth.UserRecord), uids []string) bool { + if len(users) != len(uids) { + return false + } + + sort.Slice(users, func(i, j int) bool { + return users[i].UID < users[j].UID + }) + sort.Slice(uids, func(i, j int) bool { + return uids[i] < uids[j] + }) + + for i := range users { + if users[i].UID != uids[i] { + return false + } + } + + return true + } + + testUser1 := newUserWithParams(t) + defer deleteUser(testUser1.UID) + testUser2 := newUserWithParams(t) + defer deleteUser(testUser2.UID) + testUser3 := newUserWithParams(t) + defer deleteUser(testUser3.UID) + + importUser1UID := randomUID() + importUser1 := (&auth.UserToImport{}). + UID(importUser1UID). + Email(randomEmail(importUser1UID)). + PhoneNumber(randomPhoneNumber()). + ProviderData([](*auth.UserProvider){ + &auth.UserProvider{ + ProviderID: "google.com", + UID: "google_" + importUser1UID, + }, + }) + importUser(t, importUser1UID, importUser1) + defer deleteUser(importUser1UID) + + userRecordsToUIDs := func(users [](*auth.UserRecord)) []string { + results := []string{} + for i := range users { + results = append(results, users[i].UID) + } + return results + } + + t.Run("various identifier types", func(t *testing.T) { + getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ + auth.UIDIdentifier{UID: testUser1.UID}, + auth.EmailIdentifier{Email: testUser2.Email}, + auth.PhoneIdentifier{PhoneNumber: testUser3.PhoneNumber}, + auth.ProviderIdentifier{ProviderID: "google.com", ProviderUID: "google_" + importUser1UID}, + }) + if err != nil { + t.Fatalf("GetUsers() = %q", err) + } + + if !sameUsers(getUsersResult.Users, []string{testUser1.UID, testUser2.UID, testUser3.UID, importUser1UID}) { + t.Errorf("GetUsers() = %v; want = %v (in any order)", + userRecordsToUIDs(getUsersResult.Users), []string{testUser1.UID, testUser2.UID, testUser3.UID, importUser1UID}) + } + }) + + t.Run("mix of existing and non-existing users", func(t *testing.T) { + getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ + auth.UIDIdentifier{UID: testUser1.UID}, + auth.UIDIdentifier{UID: "uid_that_doesnt_exist"}, + auth.UIDIdentifier{UID: testUser3.UID}, + }) + if err != nil { + t.Fatalf("GetUsers() = %q", err) + } + + if !sameUsers(getUsersResult.Users, []string{testUser1.UID, testUser3.UID}) { + t.Errorf("GetUsers() = %v; want = %v (in any order)", + getUsersResult.Users, []string{testUser1.UID, testUser3.UID}) + } + if len(getUsersResult.NotFound) != 1 { + t.Errorf("len(GetUsers().NotFound) = %d; want 1", len(getUsersResult.NotFound)) + } else { + if getUsersResult.NotFound[0].(auth.UIDIdentifier).UID != "uid_that_doesnt_exist" { + t.Errorf("GetUsers().NotFound[0].UID = %s; want 'uid_that_doesnt_exist'", + getUsersResult.NotFound[0].(auth.UIDIdentifier).UID) + } + } + }) + + t.Run("only non-existing users", func(t *testing.T) { + getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ + auth.UIDIdentifier{UID: "non-existing user"}, + }) + if err != nil { + t.Fatalf("GetUsers() = %q", err) + } + + if len(getUsersResult.Users) != 0 { + t.Errorf("len(GetUsers().Users) = %d; want = 0", len(getUsersResult.Users)) + } + if len(getUsersResult.NotFound) != 1 { + t.Errorf("len(GetUsers().NotFound) = %d; want = 1", len(getUsersResult.NotFound)) + } else { + if getUsersResult.NotFound[0].(auth.UIDIdentifier).UID != "non-existing user" { + t.Errorf("GetUsers().NotFound[0].UID = %s; want 'non-existing user'", + getUsersResult.NotFound[0].(auth.UIDIdentifier).UID) + } + } + }) + + t.Run("de-dups duplicate users", func(t *testing.T) { + getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ + auth.UIDIdentifier{UID: testUser1.UID}, + auth.UIDIdentifier{UID: testUser1.UID}, + }) + if err != nil { + t.Fatalf("GetUsers() returned an error: %v", err) + } + + if len(getUsersResult.Users) != 1 { + t.Errorf("len(GetUsers().Users) = %d; want = 1", len(getUsersResult.Users)) + } else { + if getUsersResult.Users[0].UID != testUser1.UID { + t.Errorf("GetUsers().Users[0].UID = %s; want = '%s'", getUsersResult.Users[0].UID, testUser1.UID) + } + } + if len(getUsersResult.NotFound) != 0 { + t.Errorf("len(GetUsers().NotFound) = %d; want = 0", len(getUsersResult.NotFound)) + } + }) +} + +func TestLastRefreshTime(t *testing.T) { + userRecord := newUserWithParams(t) + defer deleteUser(userRecord.UID) + + // New users should not have a LastRefreshTimestamp set. + if userRecord.UserMetadata.LastRefreshTimestamp != 0 { + t.Errorf( + "CreateUser(...).UserMetadata.LastRefreshTimestamp = %d; want = 0", + userRecord.UserMetadata.LastRefreshTimestamp) + } + + // Login to cause the LastRefreshTimestamp to be set + if _, err := signInWithPassword(userRecord.Email, "password"); err != nil { + t.Errorf("signInWithPassword failed: %v", err) + } + + getUsersResult, err := client.GetUser(context.Background(), userRecord.UID) + if err != nil { + t.Fatalf("GetUser(...) failed with error: %v", err) + } + + // Ensure last refresh time is approx now (with tollerance of 10m) + nowMillis := time.Now().Unix() * 1000 + lastRefreshTimestamp := getUsersResult.UserMetadata.LastRefreshTimestamp + if lastRefreshTimestamp < nowMillis-10*60*1000 { + t.Errorf("GetUser(...).UserMetadata.LastRefreshTimestamp = %d; want >= %d", lastRefreshTimestamp, nowMillis-10*60*1000) + } + if nowMillis+10*60*1000 < lastRefreshTimestamp { + t.Errorf("GetUser(...).UserMetadata.LastRefreshTimestamp = %d; want <= %d", lastRefreshTimestamp, nowMillis+10*60*1000) + } +} + func TestUpdateNonExistingUser(t *testing.T) { update := (&auth.UserToUpdate{}).Email("test@example.com") user, err := client.UpdateUser(context.Background(), "non.existing", update) @@ -334,6 +508,119 @@ func TestDeleteUser(t *testing.T) { } } +func TestDeleteUsers(t *testing.T) { + // Deletes users slowly. There's currently a 1qps limitation on this API. + // Without this helper, the integration tests occasionally hit that limit + // and fail. + // + // TODO(rsgowman): Remove this function when/if the 1qps limitation is + // relaxed. + slowDeleteUsers := func(ctx context.Context, uids []string) (*auth.DeleteUsersResult, error) { + time.Sleep(1 * time.Second) + return client.DeleteUsers(ctx, uids) + } + + // Ensures the specified users don't exist. Expected to be called after + // deleting the users to ensure the delete method worked. + ensureUsersNotFound := func(t *testing.T, uids []string) { + identifiers := []auth.UserIdentifier{} + for i := range uids { + identifiers = append(identifiers, auth.UIDIdentifier{UID: uids[i]}) + } + + getUsersResult, err := client.GetUsers(context.Background(), identifiers) + if err != nil { + t.Errorf("GetUsers(notfound_ids) error %v; want nil", err) + return + } + + if len(getUsersResult.NotFound) != len(uids) { + t.Errorf("len(GetUsers(notfound_ids).NotFound) = %d; want %d", len(getUsersResult.NotFound), len(uids)) + return + } + + sort.Strings(uids) + notFoundUids := []string{} + for i := range getUsersResult.NotFound { + notFoundUids = append(notFoundUids, getUsersResult.NotFound[i].(auth.UIDIdentifier).UID) + } + sort.Strings(notFoundUids) + for i := range uids { + if notFoundUids[i] != uids[i] { + t.Errorf("GetUsers(deleted_ids).NotFound[%d] = %s; want %s", i, notFoundUids[i], uids[i]) + } + } + } + + t.Run("deletes users", func(t *testing.T) { + uids := []string{ + newUserWithParams(t).UID, newUserWithParams(t).UID, newUserWithParams(t).UID, + } + + result, err := slowDeleteUsers(context.Background(), uids) + if err != nil { + t.Fatalf("DeleteUsers([valid_ids]) error %v; want nil", err) + } + + if result.SuccessCount != 3 { + t.Errorf("DeleteUsers([valid_ids]).SuccessCount = %d; want 3", result.SuccessCount) + } + if result.FailureCount != 0 { + t.Errorf("DeleteUsers([valid_ids]).FailureCount = %d; want 0", result.FailureCount) + } + if len(result.Errors) != 0 { + t.Errorf("len(DeleteUsers([valid_ids]).Errors) = %d; want 0", len(result.Errors)) + } + + ensureUsersNotFound(t, uids) + }) + + t.Run("deletes users that exist even when non-existing users also specified", func(t *testing.T) { + uids := []string{newUserWithParams(t).UID, "uid-that-doesnt-exist"} + result, err := slowDeleteUsers(context.Background(), uids) + if err != nil { + t.Fatalf("DeleteUsers(uids) error %v; want nil", err) + } + + if result.SuccessCount != 2 { + t.Errorf("DeleteUsers(uids).SuccessCount = %d; want 2", result.SuccessCount) + } + if result.FailureCount != 0 { + t.Errorf("DeleteUsers(uids).FailureCount = %d; want 0", result.FailureCount) + } + if len(result.Errors) != 0 { + t.Errorf("len(DeleteUsers(uids).Errors) = %d; want 0", len(result.Errors)) + } + + ensureUsersNotFound(t, uids) + }) + + t.Run("is idempotent", func(t *testing.T) { + deleteUserAndEnsureSuccess := func(t *testing.T, uids []string) { + result, err := slowDeleteUsers(context.Background(), uids) + if err != nil { + t.Fatalf("DeleteUsers(uids) error %v; want nil", err) + } + + if result.SuccessCount != 1 { + t.Errorf("DeleteUsers(uids).SuccessCount = %d; want 1", result.SuccessCount) + } + if result.FailureCount != 0 { + t.Errorf("DeleteUsers(uids).FailureCount = %d; want 0", result.FailureCount) + } + if len(result.Errors) != 0 { + t.Errorf("len(DeleteUsers(uids).Errors) = %d; want 0", len(result.Errors)) + } + } + + uids := []string{newUserWithParams(t).UID} + deleteUserAndEnsureSuccess(t, uids) + + // Delete the user again, ensuring that everything still counts as a success. + deleteUserAndEnsureSuccess(t, uids) + }) +} + func TestImportUsers(t *testing.T) { uid := randomUID() email := randomEmail(uid) @@ -660,3 +947,26 @@ func newUserWithParams(t *testing.T) *auth.UserRecord { } return user } + +// Helper to import a user and return its UserRecord. Upon error, exits via +// t.Fatalf. `uid` must match the UID set on the `userToImport` parameter. +func importUser(t *testing.T, uid string, userToImport *auth.UserToImport) *auth.UserRecord { + userImportResult, err := client.ImportUsers( + context.Background(), [](*auth.UserToImport){userToImport}) + if err != nil { + t.Fatalf("Unable to import user %v (uid %v): %v", *userToImport, uid, err) + } + + if userImportResult.FailureCount > 0 { + t.Fatalf("Unable to import user %v (uid %v): %v", *userToImport, uid, userImportResult.Errors[0].Reason) + } + if userImportResult.SuccessCount != 1 { + t.Fatalf("Import didn't fail, but it didn't succeed either?") + } + + userRecord, err := client.GetUser(context.Background(), uid) + if err != nil { + t.Fatalf("GetUser(%s) for imported user failed: %v", uid, err) + } + return userRecord +}