diff --git a/.travis.yml b/.travis.yml index b1d02a82..53f411b4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,30 @@ language: go + +go: + - 1.7.x + - 1.8.x + - 1.9.x + - "1.10.x" + - master + +matrix: + # Build OK if fails on unstable development versions of Go. + allow_failures: + - go: master + # Don't wait for tests to finish on allow_failures. + # Mark the build finished if tests pass on other versions of Go. + fast_finish: true + go_import_path: firebase.google.com/go + before_install: - - go get github.com/golang/lint/golint + - go get github.com/golang/lint/golint # Golint requires Go 1.6 or later. + +install: + # Prior to golang 1.8, this can trigger an error for packages containing only tests. + - go get -t -v $(go list ./... | grep -v integration) + script: - golint -set_exit_status $(go list ./...) - - go test -v -test.short ./... - + - go test -v -race -test.short ./... # Run tests with the race detector. + - go vet -v ./... # Run Go static analyzer. diff --git a/CHANGELOG.md b/CHANGELOG.md index 3060d302..5852ed27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ - +# v2.6.0 + +- [changed] Improved error handling in FCM by mapping more server-side + errors to client-side error codes. +- [added] Added the `db` package for interacting with the Firebase database. + # v2.5.0 - [changed] Import context from `golang.org/x/net` for 1.6 compatibility diff --git a/README.md b/README.md index 441b5fad..3ea293a2 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ [![Build Status](https://travis-ci.org/firebase/firebase-admin-go.svg?branch=master)](https://travis-ci.org/firebase/firebase-admin-go) [![GoDoc](https://godoc.org/firebase.google.com/go?status.svg)](https://godoc.org/firebase.google.com/go) +[![Go Report Card](https://goreportcard.com/badge/github.com/firebase/firebase-admin-go)](https://goreportcard.com/report/github.com/firebase/firebase-admin-go) # Firebase Admin Go SDK @@ -43,6 +44,9 @@ requests, code review feedback, and also pull requests. * [Setup Guide](https://firebase.google.com/docs/admin/setup/) * [Authentication Guide](https://firebase.google.com/docs/auth/admin/) +* [Cloud Firestore](https://firebase.google.com/docs/firestore/) +* [Cloud Messaging Guide](https://firebase.google.com/docs/cloud-messaging/admin/) +* [Storage Guide](https://firebase.google.com/docs/storage/admin/start) * [API Reference](https://godoc.org/firebase.google.com/go) * [Release Notes](https://firebase.google.com/support/release-notes/admin/go) diff --git a/auth/auth.go b/auth/auth.go index f6605c7b..98822fef 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -16,6 +16,7 @@ package auth import ( + "context" "crypto/rsa" "crypto/x509" "encoding/json" @@ -25,7 +26,6 @@ import ( "strings" "firebase.google.com/go/internal" - "golang.org/x/net/context" "google.golang.org/api/identitytoolkit/v3" "google.golang.org/api/transport" ) @@ -78,7 +78,7 @@ type signer interface { // NewClient creates a new instance of the Firebase Auth Client. // // This function can only be invoked from within the SDK. Client applications should access the -// the Auth service through firebase.App. +// Auth service through firebase.App. func NewClient(ctx context.Context, c *internal.AuthConfig) (*Client, error) { var ( err error diff --git a/auth/auth_appengine.go b/auth/auth_appengine.go index 351f61c1..5e05cdb1 100644 --- a/auth/auth_appengine.go +++ b/auth/auth_appengine.go @@ -17,7 +17,7 @@ package auth import ( - "golang.org/x/net/context" + "context" "google.golang.org/appengine" ) diff --git a/auth/auth_std.go b/auth/auth_std.go index 2055af38..f593a7cc 100644 --- a/auth/auth_std.go +++ b/auth/auth_std.go @@ -16,7 +16,7 @@ package auth -import "golang.org/x/net/context" +import "context" func newSigner(ctx context.Context) (signer, error) { return serviceAcctSigner{}, nil diff --git a/auth/auth_test.go b/auth/auth_test.go index 690b5d6f..2676f6c4 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -15,6 +15,7 @@ package auth import ( + "context" "encoding/json" "errors" "fmt" @@ -25,7 +26,6 @@ import ( "testing" "time" - "golang.org/x/net/context" "golang.org/x/oauth2/google" "google.golang.org/api/option" diff --git a/auth/jwt_test.go b/auth/jwt_test.go index 79264ee0..4b0858af 100644 --- a/auth/jwt_test.go +++ b/auth/jwt_test.go @@ -1,3 +1,17 @@ +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package auth import ( diff --git a/auth/user_mgt.go b/auth/user_mgt.go index 423ceece..551753ea 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -15,6 +15,7 @@ package auth import ( + "context" "encoding/json" "fmt" "net/http" @@ -23,7 +24,6 @@ import ( "strings" "time" - "golang.org/x/net/context" "google.golang.org/api/identitytoolkit/v3" "google.golang.org/api/iterator" ) diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index db7b7300..3f298ed6 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -16,6 +16,7 @@ package auth import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" @@ -28,7 +29,6 @@ import ( "firebase.google.com/go/internal" - "golang.org/x/net/context" "golang.org/x/oauth2" "google.golang.org/api/identitytoolkit/v3" "google.golang.org/api/iterator" @@ -167,9 +167,9 @@ func TestListUsers(t *testing.T) { defer s.Close() want := []*ExportedUserRecord{ - &ExportedUserRecord{UserRecord: testUser, PasswordHash: "passwordhash1", PasswordSalt: "salt1"}, - &ExportedUserRecord{UserRecord: testUser, PasswordHash: "passwordhash2", PasswordSalt: "salt2"}, - &ExportedUserRecord{UserRecord: testUser, PasswordHash: "passwordhash3", PasswordSalt: "salt3"}, + {UserRecord: testUser, PasswordHash: "passwordhash1", PasswordSalt: "salt1"}, + {UserRecord: testUser, PasswordHash: "passwordhash2", PasswordSalt: "salt2"}, + {UserRecord: testUser, PasswordHash: "passwordhash3", PasswordSalt: "salt3"}, } testIterator := func(iter *UserIterator, token string, req map[string]interface{}) { @@ -574,9 +574,9 @@ func TestInvalidSetCustomClaims(t *testing.T) { func TestSetCustomClaims(t *testing.T) { cases := []map[string]interface{}{ nil, - map[string]interface{}{}, - map[string]interface{}{"admin": true}, - map[string]interface{}{"admin": true, "package": "gold"}, + {}, + {"admin": true}, + {"admin": true, "package": "gold"}, } resp := `{ diff --git a/db/auth_override_test.go b/db/auth_override_test.go new file mode 100644 index 00000000..86cbeef2 --- /dev/null +++ b/db/auth_override_test.go @@ -0,0 +1,107 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "testing" + + "golang.org/x/net/context" +) + +func TestAuthOverrideGet(t *testing.T) { + mock := &mockServer{Resp: "data"} + srv := mock.Start(aoClient) + defer srv.Close() + + ref := aoClient.NewRef("peter") + var got string + if err := ref.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != "data" { + t.Errorf("Ref(AuthOverride).Get() = %q; want = %q", got, "data") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"auth_variable_override": testAuthOverrides}, + }) +} + +func TestAuthOverrideSet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(aoClient) + defer srv.Close() + + ref := aoClient.NewRef("peter") + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + if err := ref.Set(context.Background(), want); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Body: serialize(want), + Path: "/peter.json", + Query: map[string]string{"auth_variable_override": testAuthOverrides, "print": "silent"}, + }) +} + +func TestAuthOverrideQuery(t *testing.T) { + mock := &mockServer{Resp: "data"} + srv := mock.Start(aoClient) + defer srv.Close() + + ref := aoClient.NewRef("peter") + var got string + if err := ref.OrderByChild("foo").Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != "data" { + t.Errorf("Ref(AuthOverride).OrderByChild() = %q; want = %q", got, "data") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "auth_variable_override": testAuthOverrides, + "orderBy": "\"foo\"", + }, + }) +} + +func TestAuthOverrideRangeQuery(t *testing.T) { + mock := &mockServer{Resp: "data"} + srv := mock.Start(aoClient) + defer srv.Close() + + ref := aoClient.NewRef("peter") + var got string + if err := ref.OrderByChild("foo").StartAt(1).EndAt(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != "data" { + t.Errorf("Ref(AuthOverride).OrderByChild() = %q; want = %q", got, "data") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "auth_variable_override": testAuthOverrides, + "orderBy": "\"foo\"", + "startAt": "1", + "endAt": "10", + }, + }) +} diff --git a/db/db.go b/db/db.go new file mode 100644 index 00000000..6bed3922 --- /dev/null +++ b/db/db.go @@ -0,0 +1,134 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db contains functions for accessing the Firebase Realtime Database. +package db + +import ( + "encoding/json" + "fmt" + "runtime" + "strings" + + "firebase.google.com/go/internal" + + "net/url" + + "golang.org/x/net/context" + "google.golang.org/api/option" + "google.golang.org/api/transport" +) + +const userAgentFormat = "Firebase/HTTP/%s/%s/AdminGo" +const invalidChars = "[].#$" +const authVarOverride = "auth_variable_override" + +// Client is the interface for the Firebase Realtime Database service. +type Client struct { + hc *internal.HTTPClient + url string + authOverride string +} + +// NewClient creates a new instance of the Firebase Database Client. +// +// This function can only be invoked from within the SDK. Client applications should access the +// Database service through firebase.App. +func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) { + opts := append([]option.ClientOption{}, c.Opts...) + ua := fmt.Sprintf(userAgentFormat, c.Version, runtime.Version()) + opts = append(opts, option.WithUserAgent(ua)) + hc, _, err := transport.NewHTTPClient(ctx, opts...) + if err != nil { + return nil, err + } + + p, err := url.ParseRequestURI(c.URL) + if err != nil { + return nil, err + } else if p.Scheme != "https" { + return nil, fmt.Errorf("invalid database URL: %q; want scheme: %q", c.URL, "https") + } else if !strings.HasSuffix(p.Host, ".firebaseio.com") { + return nil, fmt.Errorf("invalid database URL: %q; want host: %q", c.URL, "firebaseio.com") + } + + var ao []byte + if c.AuthOverride == nil || len(c.AuthOverride) > 0 { + ao, err = json.Marshal(c.AuthOverride) + if err != nil { + return nil, err + } + } + + ep := func(b []byte) string { + var p struct { + Error string `json:"error"` + } + if err := json.Unmarshal(b, &p); err != nil { + return "" + } + return p.Error + } + return &Client{ + hc: &internal.HTTPClient{Client: hc, ErrParser: ep}, + url: fmt.Sprintf("https://%s", p.Host), + authOverride: string(ao), + }, nil +} + +// NewRef returns a new database reference representing the node at the specified path. +func (c *Client) NewRef(path string) *Ref { + segs := parsePath(path) + key := "" + if len(segs) > 0 { + key = segs[len(segs)-1] + } + + return &Ref{ + Key: key, + Path: "/" + strings.Join(segs, "/"), + client: c, + segs: segs, + } +} + +func (c *Client) send( + ctx context.Context, + method, path string, + body internal.HTTPEntity, + opts ...internal.HTTPOption) (*internal.Response, error) { + + if strings.ContainsAny(path, invalidChars) { + return nil, fmt.Errorf("invalid path with illegal characters: %q", path) + } + if c.authOverride != "" { + opts = append(opts, internal.WithQueryParam(authVarOverride, c.authOverride)) + } + return c.hc.Do(ctx, &internal.Request{ + Method: method, + URL: fmt.Sprintf("%s%s.json", c.url, path), + Body: body, + Opts: opts, + }) +} + +func parsePath(path string) []string { + var segs []string + for _, s := range strings.Split(path, "/") { + if s != "" { + segs = append(segs, s) + } + } + return segs +} diff --git a/db/db_test.go b/db/db_test.go new file mode 100644 index 00000000..01234504 --- /dev/null +++ b/db/db_test.go @@ -0,0 +1,404 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "fmt" + "log" + "net/http" + "net/http/httptest" + "os" + "runtime" + "testing" + + "golang.org/x/net/context" + "golang.org/x/oauth2" + + "encoding/json" + + "reflect" + + "io/ioutil" + + "net/url" + + "firebase.google.com/go/internal" + "google.golang.org/api/option" +) + +const testURL = "https://test-db.firebaseio.com" + +var testUserAgent string +var testAuthOverrides string +var testOpts = []option.ClientOption{ + option.WithTokenSource(&mockTokenSource{"mock-token"}), +} + +var client *Client +var aoClient *Client +var testref *Ref + +func TestMain(m *testing.M) { + var err error + client, err = NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + Version: "1.2.3", + AuthOverride: map[string]interface{}{}, + }) + if err != nil { + log.Fatalln(err) + } + + ao := map[string]interface{}{"uid": "user1"} + aoClient, err = NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + Version: "1.2.3", + AuthOverride: ao, + }) + if err != nil { + log.Fatalln(err) + } + + b, err := json.Marshal(ao) + if err != nil { + log.Fatalln(err) + } + testAuthOverrides = string(b) + + testref = client.NewRef("peter") + testUserAgent = fmt.Sprintf(userAgentFormat, "1.2.3", runtime.Version()) + os.Exit(m.Run()) +} + +func TestNewClient(t *testing.T) { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + AuthOverride: make(map[string]interface{}), + }) + if err != nil { + t.Fatal(err) + } + if c.url != testURL { + t.Errorf("NewClient().url = %q; want = %q", c.url, testURL) + } + if c.hc == nil { + t.Errorf("NewClient().hc = nil; want non-nil") + } + if c.authOverride != "" { + t.Errorf("NewClient().ao = %q; want = %q", c.authOverride, "") + } +} + +func TestNewClientAuthOverrides(t *testing.T) { + cases := []map[string]interface{}{ + nil, + map[string]interface{}{"uid": "user1"}, + } + for _, tc := range cases { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + AuthOverride: tc, + }) + if err != nil { + t.Fatal(err) + } + if c.url != testURL { + t.Errorf("NewClient(%v).url = %q; want = %q", tc, c.url, testURL) + } + if c.hc == nil { + t.Errorf("NewClient(%v).hc = nil; want non-nil", tc) + } + b, err := json.Marshal(tc) + if err != nil { + t.Fatal(err) + } + if c.authOverride != string(b) { + t.Errorf("NewClient(%v).ao = %q; want = %q", tc, c.authOverride, string(b)) + } + } +} + +func TestInvalidURL(t *testing.T) { + cases := []string{ + "", + "foo", + "http://db.firebaseio.com", + "https://firebase.google.com", + } + for _, tc := range cases { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: tc, + }) + if c != nil || err == nil { + t.Errorf("NewClient(%q) = (%v, %v); want = (nil, error)", tc, c, err) + } + } +} + +func TestInvalidAuthOverride(t *testing.T) { + c, err := NewClient(context.Background(), &internal.DatabaseConfig{ + Opts: testOpts, + URL: testURL, + AuthOverride: map[string]interface{}{"uid": func() {}}, + }) + if c != nil || err == nil { + t.Errorf("NewClient() = (%v, %v); want = (nil, error)", c, err) + } +} + +func TestNewRef(t *testing.T) { + cases := []struct { + Path string + WantPath string + WantKey string + }{ + {"", "/", ""}, + {"/", "/", ""}, + {"foo", "/foo", "foo"}, + {"/foo", "/foo", "foo"}, + {"foo/bar", "/foo/bar", "bar"}, + {"/foo/bar", "/foo/bar", "bar"}, + {"/foo/bar/", "/foo/bar", "bar"}, + } + for _, tc := range cases { + r := client.NewRef(tc.Path) + if r.client == nil { + t.Errorf("NewRef(%q).client = nil; want = %v", tc.Path, r.client) + } + if r.Path != tc.WantPath { + t.Errorf("NewRef(%q).Path = %q; want = %q", tc.Path, r.Path, tc.WantPath) + } + if r.Key != tc.WantKey { + t.Errorf("NewRef(%q).Key = %q; want = %q", tc.Path, r.Key, tc.WantKey) + } + } +} + +func TestParent(t *testing.T) { + cases := []struct { + Path string + HasParent bool + Want string + }{ + {"", false, ""}, + {"/", false, ""}, + {"foo", true, ""}, + {"/foo", true, ""}, + {"foo/bar", true, "foo"}, + {"/foo/bar", true, "foo"}, + {"/foo/bar/", true, "foo"}, + } + for _, tc := range cases { + r := client.NewRef(tc.Path).Parent() + if tc.HasParent { + if r == nil { + t.Fatalf("Parent(%q) = nil; want = Ref(%q)", tc.Path, tc.Want) + } + if r.client == nil { + t.Errorf("Parent(%q).client = nil; want = %v", tc.Path, client) + } + if r.Key != tc.Want { + t.Errorf("Parent(%q).Key = %q; want = %q", tc.Path, r.Key, tc.Want) + } + } else if r != nil { + t.Fatalf("Parent(%q) = %v; want = nil", tc.Path, r) + } + } +} + +func TestChild(t *testing.T) { + r := client.NewRef("/test") + cases := []struct { + Path string + Want string + Parent string + }{ + {"", "/test", "/"}, + {"foo", "/test/foo", "/test"}, + {"/foo", "/test/foo", "/test"}, + {"foo/", "/test/foo", "/test"}, + {"/foo/", "/test/foo", "/test"}, + {"//foo//", "/test/foo", "/test"}, + {"foo/bar", "/test/foo/bar", "/test/foo"}, + {"/foo/bar", "/test/foo/bar", "/test/foo"}, + {"foo/bar/", "/test/foo/bar", "/test/foo"}, + {"/foo/bar/", "/test/foo/bar", "/test/foo"}, + {"//foo/bar", "/test/foo/bar", "/test/foo"}, + {"foo//bar/", "/test/foo/bar", "/test/foo"}, + {"foo/bar//", "/test/foo/bar", "/test/foo"}, + } + for _, tc := range cases { + c := r.Child(tc.Path) + if c.Path != tc.Want { + t.Errorf("Child(%q) = %q; want = %q", tc.Path, c.Path, tc.Want) + } + if c.Parent().Path != tc.Parent { + t.Errorf("Child(%q).Parent() = %q; want = %q", tc.Path, c.Parent().Path, tc.Parent) + } + } +} + +func checkOnlyRequest(t *testing.T, got []*testReq, want *testReq) { + checkAllRequests(t, got, []*testReq{want}) +} + +func checkAllRequests(t *testing.T, got []*testReq, want []*testReq) { + if len(got) != len(want) { + t.Errorf("Request Count = %d; want = %d", len(got), len(want)) + } else { + for i, r := range got { + checkRequest(t, r, want[i]) + } + } +} + +func checkRequest(t *testing.T, got, want *testReq) { + if h := got.Header.Get("Authorization"); h != "Bearer mock-token" { + t.Errorf("Authorization = %q; want = %q", h, "Bearer mock-token") + } + if h := got.Header.Get("User-Agent"); h != testUserAgent { + t.Errorf("User-Agent = %q; want = %q", h, testUserAgent) + } + + if got.Method != want.Method { + t.Errorf("Method = %q; want = %q", got.Method, want.Method) + } + + if got.Path != want.Path { + t.Errorf("Path = %q; want = %q", got.Path, want.Path) + } + if len(want.Query) != len(got.Query) { + t.Errorf("QueryParam = %v; want = %v", got.Query, want.Query) + } + for k, v := range want.Query { + if got.Query[k] != v { + t.Errorf("QueryParam(%v) = %v; want = %v", k, got.Query[k], v) + } + } + for k, v := range want.Header { + if got.Header.Get(k) != v[0] { + t.Errorf("Header(%q) = %q; want = %q", k, got.Header.Get(k), v[0]) + } + } + if want.Body != nil { + if h := got.Header.Get("Content-Type"); h != "application/json" { + t.Errorf("User-Agent = %q; want = %q", h, "application/json") + } + var wi, gi interface{} + if err := json.Unmarshal(want.Body, &wi); err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(got.Body, &gi); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(gi, wi) { + t.Errorf("Body = %v; want = %v", gi, wi) + } + } else if len(got.Body) != 0 { + t.Errorf("Body = %v; want empty", got.Body) + } +} + +type testReq struct { + Method string + Path string + Header http.Header + Body []byte + Query map[string]string +} + +func newTestReq(r *http.Request) (*testReq, error) { + defer r.Body.Close() + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, err + } + + u, err := url.Parse(r.RequestURI) + if err != nil { + return nil, err + } + + query := make(map[string]string) + for k, v := range u.Query() { + query[k] = v[0] + } + return &testReq{ + Method: r.Method, + Path: u.Path, + Header: r.Header, + Body: b, + Query: query, + }, nil +} + +type mockServer struct { + Resp interface{} + Header map[string]string + Status int + Reqs []*testReq + srv *httptest.Server +} + +func (s *mockServer) Start(c *Client) *httptest.Server { + if s.srv != nil { + return s.srv + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tr, _ := newTestReq(r) + s.Reqs = append(s.Reqs, tr) + + for k, v := range s.Header { + w.Header().Set(k, v) + } + + print := r.URL.Query().Get("print") + if s.Status != 0 { + w.WriteHeader(s.Status) + } else if print == "silent" { + w.WriteHeader(http.StatusNoContent) + return + } + b, _ := json.Marshal(s.Resp) + w.Header().Set("Content-Type", "application/json") + w.Write(b) + }) + s.srv = httptest.NewServer(handler) + c.url = s.srv.URL + return s.srv +} + +type mockTokenSource struct { + AccessToken string +} + +func (ts *mockTokenSource) Token() (*oauth2.Token, error) { + return &oauth2.Token{AccessToken: ts.AccessToken}, nil +} + +type person struct { + Name string `json:"name"` + Age int32 `json:"age"` +} + +func serialize(v interface{}) []byte { + b, _ := json.Marshal(v) + return b +} diff --git a/db/query.go b/db/query.go new file mode 100644 index 00000000..c6013483 --- /dev/null +++ b/db/query.go @@ -0,0 +1,423 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "encoding/json" + "fmt" + "net/http" + "sort" + "strconv" + "strings" + + "firebase.google.com/go/internal" + + "golang.org/x/net/context" +) + +// QueryNode represents a data node retrieved from an ordered query. +type QueryNode interface { + Key() string + Unmarshal(v interface{}) error +} + +// Query represents a complex query that can be executed on a Ref. +// +// Complex queries can consist of up to 2 components: a required ordering constraint, and an +// optional filtering constraint. At the server, data is first sorted according to the given +// ordering constraint (e.g. order by child). Then the filtering constraint (e.g. limit, range) is +// applied on the sorted data to produce the final result. Despite the ordering constraint, the +// final result is returned by the server as an unordered collection. Therefore the values read +// from a Query instance are not ordered. +type Query struct { + client *Client + path string + order orderBy + limFirst, limLast int + start, end, equalTo interface{} +} + +// StartAt returns a shallow copy of the Query with v set as a lower bound of a range query. +// +// The resulting Query will only return child nodes with a value greater than or equal to v. +func (q *Query) StartAt(v interface{}) *Query { + q2 := &Query{} + *q2 = *q + q2.start = v + return q2 +} + +// EndAt returns a shallow copy of the Query with v set as a upper bound of a range query. +// +// The resulting Query will only return child nodes with a value less than or equal to v. +func (q *Query) EndAt(v interface{}) *Query { + q2 := &Query{} + *q2 = *q + q2.end = v + return q2 +} + +// EqualTo returns a shallow copy of the Query with v set as an equals constraint. +// +// The resulting Query will only return child nodes whose values equal to v. +func (q *Query) EqualTo(v interface{}) *Query { + q2 := &Query{} + *q2 = *q + q2.equalTo = v + return q2 +} + +// LimitToFirst returns a shallow copy of the Query, which is anchored to the first n +// elements of the window. +func (q *Query) LimitToFirst(n int) *Query { + q2 := &Query{} + *q2 = *q + q2.limFirst = n + return q2 +} + +// LimitToLast returns a shallow copy of the Query, which is anchored to the last n +// elements of the window. +func (q *Query) LimitToLast(n int) *Query { + q2 := &Query{} + *q2 = *q + q2.limLast = n + return q2 +} + +// Get executes the Query and populates v with the results. +// +// Data deserialization is performed using https://golang.org/pkg/encoding/json/#Unmarshal, and +// therefore v has the same requirements as the json package. Specifically, it must be a pointer, +// and must not be nil. +// +// Despite the ordering constraint of the Query, results are not stored in any particular order +// in v. Use GetOrdered() to obtain ordered results. +func (q *Query) Get(ctx context.Context, v interface{}) error { + qp := make(map[string]string) + if err := initQueryParams(q, qp); err != nil { + return err + } + resp, err := q.client.send(ctx, "GET", q.path, nil, internal.WithQueryParams(qp)) + if err != nil { + return err + } + return resp.Unmarshal(http.StatusOK, v) +} + +// GetOrdered executes the Query and returns the results as an ordered slice. +func (q *Query) GetOrdered(ctx context.Context) ([]QueryNode, error) { + var temp interface{} + if err := q.Get(ctx, &temp); err != nil { + return nil, err + } + if temp == nil { + return nil, nil + } + + sn := newSortableNodes(temp, q.order) + sort.Sort(sn) + result := make([]QueryNode, len(sn)) + for i, v := range sn { + result[i] = v + } + return result, nil +} + +// OrderByChild returns a Query that orders data by child values before applying filters. +// +// Returned Query can be used to set additional parameters, and execute complex database queries +// (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query +// will inherit it. +func (r *Ref) OrderByChild(child string) *Query { + return newQuery(r, orderByChild(child)) +} + +// OrderByKey returns a Query that orders data by key before applying filters. +// +// Returned Query can be used to set additional parameters, and execute complex database queries +// (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query +// will inherit it. +func (r *Ref) OrderByKey() *Query { + return newQuery(r, orderByProperty("$key")) +} + +// OrderByValue returns a Query that orders data by value before applying filters. +// +// Returned Query can be used to set additional parameters, and execute complex database queries +// (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query +// will inherit it. +func (r *Ref) OrderByValue() *Query { + return newQuery(r, orderByProperty("$value")) +} + +func newQuery(r *Ref, ob orderBy) *Query { + return &Query{ + client: r.client, + path: r.Path, + order: ob, + } +} + +func initQueryParams(q *Query, qp map[string]string) error { + ob, err := q.order.encode() + if err != nil { + return err + } + qp["orderBy"] = ob + + if q.limFirst > 0 && q.limLast > 0 { + return fmt.Errorf("cannot set both limit parameter: first = %d, last = %d", q.limFirst, q.limLast) + } else if q.limFirst < 0 { + return fmt.Errorf("limit first cannot be negative: %d", q.limFirst) + } else if q.limLast < 0 { + return fmt.Errorf("limit last cannot be negative: %d", q.limLast) + } + + if q.limFirst > 0 { + qp["limitToFirst"] = strconv.Itoa(q.limFirst) + } else if q.limLast > 0 { + qp["limitToLast"] = strconv.Itoa(q.limLast) + } + + if err := encodeFilter("startAt", q.start, qp); err != nil { + return err + } + if err := encodeFilter("endAt", q.end, qp); err != nil { + return err + } + return encodeFilter("equalTo", q.equalTo, qp) +} + +func encodeFilter(key string, val interface{}, m map[string]string) error { + if val == nil { + return nil + } + b, err := json.Marshal(val) + if err != nil { + return err + } + m[key] = string(b) + return nil +} + +type orderBy interface { + encode() (string, error) +} + +type orderByChild string + +func (p orderByChild) encode() (string, error) { + if p == "" { + return "", fmt.Errorf("empty child path") + } else if strings.ContainsAny(string(p), invalidChars) { + return "", fmt.Errorf("invalid child path with illegal characters: %q", p) + } + segs := parsePath(string(p)) + if len(segs) == 0 { + return "", fmt.Errorf("invalid child path: %q", p) + } + b, err := json.Marshal(strings.Join(segs, "/")) + if err != nil { + return "", nil + } + return string(b), nil +} + +type orderByProperty string + +func (p orderByProperty) encode() (string, error) { + b, err := json.Marshal(p) + if err != nil { + return "", err + } + return string(b), nil +} + +// Firebase type ordering: https://firebase.google.com/docs/database/rest/retrieve-data#section-rest-ordered-data +const ( + typeNull = 0 + typeBoolFalse = 1 + typeBoolTrue = 2 + typeNumeric = 3 + typeString = 4 + typeObject = 5 +) + +// comparableKey is a union type of numeric values and strings. +type comparableKey struct { + Num *float64 + Str *string +} + +func (k *comparableKey) Compare(o *comparableKey) int { + if k.Str != nil && o.Str != nil { + return strings.Compare(*k.Str, *o.Str) + } else if k.Num != nil && o.Num != nil { + if *k.Num < *o.Num { + return -1 + } else if *k.Num == *o.Num { + return 0 + } + return 1 + } else if k.Num != nil { + // numeric keys appear before string keys + return -1 + } + return 1 +} + +func newComparableKey(v interface{}) *comparableKey { + if s, ok := v.(string); ok { + return &comparableKey{Str: &s} + } + + // Numeric values could be int (in the case of array indices and type constants), or float64 (if + // the value was received as json). + if i, ok := v.(int); ok { + f := float64(i) + return &comparableKey{Num: &f} + } + + f := v.(float64) + return &comparableKey{Num: &f} +} + +type queryNodeImpl struct { + CompKey *comparableKey + Value interface{} + Index interface{} + IndexType int +} + +func (q *queryNodeImpl) Key() string { + if q.CompKey.Str != nil { + return *q.CompKey.Str + } + // Numeric keys in queryNodeImpl are always array indices, and can be safely coverted into int. + return strconv.Itoa(int(*q.CompKey.Num)) +} + +func (q *queryNodeImpl) Unmarshal(v interface{}) error { + b, err := json.Marshal(q.Value) + if err != nil { + return err + } + return json.Unmarshal(b, v) +} + +func newQueryNode(key, val interface{}, order orderBy) *queryNodeImpl { + var index interface{} + if prop, ok := order.(orderByProperty); ok { + if prop == "$value" { + index = val + } else { + index = key + } + } else { + path := order.(orderByChild) + index = extractChildValue(val, string(path)) + } + return &queryNodeImpl{ + CompKey: newComparableKey(key), + Value: val, + Index: index, + IndexType: getIndexType(index), + } +} + +type sortableNodes []*queryNodeImpl + +func (s sortableNodes) Len() int { + return len(s) +} + +func (s sortableNodes) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s sortableNodes) Less(i, j int) bool { + a, b := s[i], s[j] + var aKey, bKey *comparableKey + if a.IndexType == b.IndexType { + // If the indices have the same type and are comparable (i.e. numeric or string), compare + // them directly. Otherwise, compare the keys. + if (a.IndexType == typeNumeric || a.IndexType == typeString) && a.Index != b.Index { + aKey, bKey = newComparableKey(a.Index), newComparableKey(b.Index) + } else { + aKey, bKey = a.CompKey, b.CompKey + } + } else { + // If the indices are of different types, use the type ordering of Firebase. + aKey, bKey = newComparableKey(a.IndexType), newComparableKey(b.IndexType) + } + + return aKey.Compare(bKey) < 0 +} + +func newSortableNodes(values interface{}, order orderBy) sortableNodes { + var entries sortableNodes + if m, ok := values.(map[string]interface{}); ok { + for key, val := range m { + entries = append(entries, newQueryNode(key, val, order)) + } + } else if l, ok := values.([]interface{}); ok { + for key, val := range l { + entries = append(entries, newQueryNode(key, val, order)) + } + } else { + entries = append(entries, newQueryNode(0, values, order)) + } + return entries +} + +// extractChildValue retrieves the value at path from val. +// +// If the given path does not exist in val, or val does not support child path traversal, +// extractChildValue returns nil. +func extractChildValue(val interface{}, path string) interface{} { + segments := parsePath(path) + curr := val + for _, s := range segments { + if curr == nil { + return nil + } + + currMap, ok := curr.(map[string]interface{}) + if !ok { + return nil + } + if curr, ok = currMap[s]; !ok { + return nil + } + } + return curr +} + +func getIndexType(index interface{}) int { + if index == nil { + return typeNull + } else if b, ok := index.(bool); ok { + if b { + return typeBoolTrue + } + return typeBoolFalse + } else if _, ok := index.(float64); ok { + return typeNumeric + } else if _, ok := index.(string); ok { + return typeString + } + return typeObject +} diff --git a/db/query_test.go b/db/query_test.go new file mode 100644 index 00000000..4473daff --- /dev/null +++ b/db/query_test.go @@ -0,0 +1,774 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package db + +import ( + "fmt" + "reflect" + "testing" + + "golang.org/x/net/context" +) + +var sortableKeysResp = map[string]interface{}{ + "bob": person{Name: "bob", Age: 20}, + "alice": person{Name: "alice", Age: 30}, + "charlie": person{Name: "charlie", Age: 15}, + "dave": person{Name: "dave", Age: 25}, + "ernie": person{Name: "ernie"}, +} + +var sortableValuesResp = []struct { + resp map[string]interface{} + want []interface{} + wantKeys []string +}{ + { + resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 3}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"k1", "k2", "k3"}, + }, + { + resp: map[string]interface{}{"k1": 3, "k2": 2, "k3": 1}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"k3", "k2", "k1"}, + }, + { + resp: map[string]interface{}{"k1": 3, "k2": 1, "k3": 2}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"k2", "k3", "k1"}, + }, + { + resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 1}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"k1", "k3", "k2"}, + }, + { + resp: map[string]interface{}{"k1": 1, "k2": 1, "k3": 2}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"k1", "k2", "k3"}, + }, + { + resp: map[string]interface{}{"k1": 2, "k2": 1, "k3": 1}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"k2", "k3", "k1"}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": "baz"}, + want: []interface{}{"bar", "baz", "foo"}, + wantKeys: []string{"k2", "k3", "k1"}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": 10}, + want: []interface{}{10.0, "bar", "foo"}, + wantKeys: []string{"k3", "k2", "k1"}, + }, + { + resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": nil}, + want: []interface{}{nil, "bar", "foo"}, + wantKeys: []string{"k3", "k2", "k1"}, + }, + { + resp: map[string]interface{}{"k1": 5, "k2": "bar", "k3": nil}, + want: []interface{}{nil, 5.0, "bar"}, + wantKeys: []string{"k3", "k1", "k2"}, + }, + { + resp: map[string]interface{}{ + "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, + "k6": map[string]interface{}{"k1": true}, + }, + want: []interface{}{false, true, 0.0, "foo", "foo", map[string]interface{}{"k1": true}}, + wantKeys: []string{"k5", "k1", "k2", "k3", "k4", "k6"}, + }, + { + resp: map[string]interface{}{ + "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, + "k6": map[string]interface{}{"k1": true}, "k7": nil, + "k8": map[string]interface{}{"k0": true}, + }, + want: []interface{}{ + nil, false, true, 0.0, "foo", "foo", + map[string]interface{}{"k1": true}, map[string]interface{}{"k0": true}, + }, + wantKeys: []string{"k7", "k5", "k1", "k2", "k3", "k4", "k6", "k8"}, + }, +} + +func TestChildQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + cases := []string{ + "messages", "messages/", "/messages", + } + var reqs []*testReq + for _, tc := range cases { + var got map[string]interface{} + if err := testref.OrderByChild(tc).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByChild(%q) = %v; want = %v", tc, got, want) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"messages\""}, + }) + } + + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestNestedChildQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages/ratings").Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByChild(%q) = %v; want = %v", "messages/ratings", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"messages/ratings\""}, + }) +} + +func TestChildQueryWithParams(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q := testref.OrderByChild("messages").StartAt("m4").EndAt("m50").LimitToFirst(10) + var got map[string]interface{} + if err := q.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByChild() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "orderBy": "\"messages\"", + "startAt": "\"m4\"", + "endAt": "\"m50\"", + "limitToFirst": "10", + }, + }) +} + +func TestInvalidOrderByChild(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() + + r := client.NewRef("/") + cases := []string{ + "", "/", "foo$", "foo.", "foo#", "foo]", + "foo[", "$key", "$value", "$priority", + } + for _, tc := range cases { + var got string + if err := r.OrderByChild(tc).Get(context.Background(), &got); got != "" || err == nil { + t.Errorf("OrderByChild(%q) = (%q, %v); want = (%q, error)", tc, got, err, "") + } + } + if len(mock.Reqs) != 0 { + t.Errorf("OrderByChild() = %v; want = empty", mock.Reqs) + } +} + +func TestKeyQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByKey().Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByKey() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$key\""}, + }) +} + +func TestValueQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByValue().Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByValue() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$value\""}, + }) +} + +func TestLimitFirstQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages").LimitToFirst(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("LimitToFirst() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"limitToFirst": "10", "orderBy": "\"messages\""}, + }) +} + +func TestLimitLastQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages").LimitToLast(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"limitToLast": "10", "orderBy": "\"messages\""}, + }) +} + +func TestInvalidLimitQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q := testref.OrderByChild("messages") + cases := []struct { + name string + q *Query + }{ + {"BothLimits", q.LimitToFirst(10).LimitToLast(10)}, + {"NegativeFirst", q.LimitToFirst(-10)}, + {"NegativeLast", q.LimitToLast(-10)}, + } + for _, tc := range cases { + var got map[string]interface{} + if err := tc.q.Get(context.Background(), &got); got != nil || err == nil { + t.Errorf("OrderByChild(%q) = (%v, %v); want = (nil, error)", tc.name, got, err) + } + if len(mock.Reqs) != 0 { + t.Errorf("OrderByChild(%q): %v; want: empty", tc.name, mock.Reqs) + } + } +} + +func TestStartAtQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages").StartAt(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("StartAt() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"startAt": "10", "orderBy": "\"messages\""}, + }) +} + +func TestEndAtQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages").EndAt(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("EndAt() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"endAt": "10", "orderBy": "\"messages\""}, + }) +} + +func TestEqualToQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + if err := testref.OrderByChild("messages").EqualTo(10).Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("EqualTo() = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"equalTo": "10", "orderBy": "\"messages\""}, + }) +} + +func TestInvalidFilterQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q := testref.OrderByChild("messages") + cases := []struct { + name string + q *Query + }{ + {"InvalidStartAt", q.StartAt(func() {})}, + {"InvalidEndAt", q.EndAt(func() {})}, + {"InvalidEqualTo", q.EqualTo(func() {})}, + } + for _, tc := range cases { + var got map[string]interface{} + if err := tc.q.Get(context.Background(), &got); got != nil || err == nil { + t.Errorf("OrderByChild(%q) = (%v, %v); want = (nil, error)", tc.name, got, err) + } + if len(mock.Reqs) != 0 { + t.Errorf("OrdderByChild(%q) = %v; want = empty", tc.name, mock.Reqs) + } + } +} + +func TestAllParamsQuery(t *testing.T) { + want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + q := testref.OrderByChild("messages").LimitToFirst(100).StartAt("bar").EndAt("foo") + var got map[string]interface{} + if err := q.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("OrderByChild(AllParams) = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{ + "limitToFirst": "100", + "startAt": "\"bar\"", + "endAt": "\"foo\"", + "orderBy": "\"messages\"", + }, + }) +} + +func TestChildQueryGetOrdered(t *testing.T) { + mock := &mockServer{Resp: sortableKeysResp} + srv := mock.Start(client) + defer srv.Close() + + cases := []struct { + child string + want []string + }{ + {"name", []string{"alice", "bob", "charlie", "dave", "ernie"}}, + {"age", []string{"ernie", "charlie", "bob", "dave", "alice"}}, + {"nonexisting", []string{"alice", "bob", "charlie", "dave", "ernie"}}, + } + + var reqs []*testReq + for idx, tc := range cases { + result, err := testref.OrderByChild(tc.child).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": fmt.Sprintf("%q", tc.child)}, + }) + + var gotKeys, gotVals []string + for _, r := range result { + var p person + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Name) + } + if !reflect.DeepEqual(tc.want, gotKeys) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, tc.child, gotKeys, tc.want) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, tc.child, gotVals, tc.want) + } + } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestImmediateChildQueryGetOrdered(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + type parsedMap struct { + Child interface{} `json:"child"` + } + + var reqs []*testReq + for idx, tc := range sortableValuesResp { + resp := map[string]interface{}{} + for k, v := range tc.resp { + resp[k] = map[string]interface{}{"child": v} + } + mock.Resp = resp + + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"child\""}, + }) + + var gotKeys []string + var gotVals []interface{} + for _, r := range result { + var p parsedMap + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Child) + } + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child", gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child", gotVals, tc.want) + } + } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestNestedChildQueryGetOrdered(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + type grandChild struct { + GrandChild interface{} `json:"grandchild"` + } + type parsedMap struct { + Child grandChild `json:"child"` + } + + var reqs []*testReq + for idx, tc := range sortableValuesResp { + resp := map[string]interface{}{} + for k, v := range tc.resp { + resp[k] = map[string]interface{}{"child": map[string]interface{}{"grandchild": v}} + } + mock.Resp = resp + + q := testref.OrderByChild("child/grandchild") + result, err := q.GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"child/grandchild\""}, + }) + + var gotKeys []string + var gotVals []interface{} + for _, r := range result { + var p parsedMap + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Child.GrandChild) + } + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child/grandchild", gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child/grandchild", gotVals, tc.want) + } + } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestKeyQueryGetOrdered(t *testing.T) { + mock := &mockServer{Resp: sortableKeysResp} + srv := mock.Start(client) + defer srv.Close() + + result, err := testref.OrderByKey().GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + req := &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$key\""}, + } + + var gotKeys, gotVals []string + for _, r := range result { + var p person + if err := r.Unmarshal(&p); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, p.Name) + } + + want := []string{"alice", "bob", "charlie", "dave", "ernie"} + if !reflect.DeepEqual(want, gotKeys) { + t.Errorf("GetOrdered(key) = %v; want = %v", gotKeys, want) + } + if !reflect.DeepEqual(want, gotVals) { + t.Errorf("GetOrdered(key) = %v; want = %v", gotVals, want) + } + checkOnlyRequest(t, mock.Reqs, req) +} + +func TestValueQueryGetOrdered(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + var reqs []*testReq + for idx, tc := range sortableValuesResp { + mock.Resp = tc.resp + + result, err := testref.OrderByValue().GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$value\""}, + }) + + var gotKeys []string + var gotVals []interface{} + for _, r := range result { + var v interface{} + if err := r.Unmarshal(&v); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, v) + } + + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("[%d] GetOrdered(value) = %v; want = %v", idx, gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("[%d] GetOrdered(value) = %v; want = %v", idx, gotVals, tc.want) + } + } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestValueQueryGetOrderedWithList(t *testing.T) { + cases := []struct { + resp []interface{} + want []interface{} + wantKeys []string + }{ + { + resp: []interface{}{1, 2, 3}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"0", "1", "2"}, + }, + { + resp: []interface{}{3, 2, 1}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"2", "1", "0"}, + }, + { + resp: []interface{}{1, 3, 2}, + want: []interface{}{1.0, 2.0, 3.0}, + wantKeys: []string{"0", "2", "1"}, + }, + { + resp: []interface{}{1, 3, 3}, + want: []interface{}{1.0, 3.0, 3.0}, + wantKeys: []string{"0", "1", "2"}, + }, + { + resp: []interface{}{1, 2, 1}, + want: []interface{}{1.0, 1.0, 2.0}, + wantKeys: []string{"0", "2", "1"}, + }, + { + resp: []interface{}{"foo", "bar", "baz"}, + want: []interface{}{"bar", "baz", "foo"}, + wantKeys: []string{"1", "2", "0"}, + }, + { + resp: []interface{}{"foo", 1, false, nil, 0, true}, + want: []interface{}{nil, false, true, 0.0, 1.0, "foo"}, + wantKeys: []string{"3", "2", "5", "4", "1", "0"}, + }, + } + + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + var reqs []*testReq + for _, tc := range cases { + mock.Resp = tc.resp + + result, err := testref.OrderByValue().GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + reqs = append(reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Query: map[string]string{"orderBy": "\"$value\""}, + }) + + var gotKeys []string + var gotVals []interface{} + for _, r := range result { + var v interface{} + if err := r.Unmarshal(&v); err != nil { + t.Fatal(err) + } + gotKeys = append(gotKeys, r.Key()) + gotVals = append(gotVals, v) + } + + if !reflect.DeepEqual(tc.wantKeys, gotKeys) { + t.Errorf("GetOrdered(value) = %v; want = %v", gotKeys, tc.wantKeys) + } + if !reflect.DeepEqual(tc.want, gotVals) { + t.Errorf("GetOrdered(value) = %v; want = %v", gotVals, tc.want) + } + } + checkAllRequests(t, mock.Reqs, reqs) +} + +func TestGetOrderedWithNilResult(t *testing.T) { + mock := &mockServer{Resp: nil} + srv := mock.Start(client) + defer srv.Close() + + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + if result != nil { + t.Errorf("GetOrdered(value) = %v; want = nil", result) + } +} + +func TestGetOrderedWithLeafNode(t *testing.T) { + mock := &mockServer{Resp: "foo"} + srv := mock.Start(client) + defer srv.Close() + + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(result) != 1 { + t.Fatalf("GetOrdered(chid) = %d; want = 1", len(result)) + } + if result[0].Key() != "0" { + t.Errorf("GetOrdered(value).Key() = %v; want = %q", result[0].Key(), 0) + } + + var v interface{} + if err := result[0].Unmarshal(&v); err != nil { + t.Fatal(err) + } + if v != "foo" { + t.Errorf("GetOrdered(value) = %v; want = %v", v, "foo") + } +} + +func TestQueryHttpError(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} + srv := mock.Start(client) + defer srv.Close() + + want := "http error status: 500; reason: test error" + result, err := testref.OrderByChild("child").GetOrdered(context.Background()) + if err == nil || err.Error() != want { + t.Errorf("GetOrdered() = %v; want = %v", err, want) + } + if result != nil { + t.Errorf("GetOrdered() = %v; want = nil", result) + } +} diff --git a/db/ref.go b/db/ref.go new file mode 100644 index 00000000..8fbadf84 --- /dev/null +++ b/db/ref.go @@ -0,0 +1,262 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "firebase.google.com/go/internal" + + "golang.org/x/net/context" +) + +// txnRetires is the maximum number of times a transaction is retried before giving up. Transaction +// retries are triggered by concurrent conflicting updates to the same database location. +const txnRetries = 25 + +// Ref represents a node in the Firebase Realtime Database. +type Ref struct { + Key string + Path string + + segs []string + client *Client +} + +// TransactionNode represents the value of a node within the scope of a transaction. +type TransactionNode interface { + Unmarshal(v interface{}) error +} + +type transactionNodeImpl struct { + Raw []byte +} + +func (t *transactionNodeImpl) Unmarshal(v interface{}) error { + return json.Unmarshal(t.Raw, v) +} + +// Parent returns a reference to the parent of the current node. +// +// If the current reference points to the root of the database, Parent returns nil. +func (r *Ref) Parent() *Ref { + l := len(r.segs) + if l > 0 { + path := strings.Join(r.segs[:l-1], "/") + return r.client.NewRef(path) + } + return nil +} + +// Child returns a reference to the specified child node. +func (r *Ref) Child(path string) *Ref { + fp := fmt.Sprintf("%s/%s", r.Path, path) + return r.client.NewRef(fp) +} + +// Get retrieves the value at the current database location, and stores it in the value pointed to +// by v. +// +// Data deserialization is performed using https://golang.org/pkg/encoding/json/#Unmarshal, and +// therefore v has the same requirements as the json package. Specifically, it must be a pointer, +// and must not be nil. +func (r *Ref) Get(ctx context.Context, v interface{}) error { + resp, err := r.send(ctx, "GET") + if err != nil { + return err + } + return resp.Unmarshal(http.StatusOK, v) +} + +// GetWithETag retrieves the value at the current database location, along with its ETag. +func (r *Ref) GetWithETag(ctx context.Context, v interface{}) (string, error) { + resp, err := r.send(ctx, "GET", internal.WithHeader("X-Firebase-ETag", "true")) + if err != nil { + return "", err + } else if err := resp.Unmarshal(http.StatusOK, v); err != nil { + return "", err + } + return resp.Header.Get("Etag"), nil +} + +// GetShallow performs a shallow read on the current database location. +// +// Shallow reads do not retrieve the child nodes of the current reference. +func (r *Ref) GetShallow(ctx context.Context, v interface{}) error { + resp, err := r.send(ctx, "GET", internal.WithQueryParam("shallow", "true")) + if err != nil { + return err + } + return resp.Unmarshal(http.StatusOK, v) +} + +// GetIfChanged retrieves the value and ETag of the current database location only if the specified +// ETag does not match. +// +// If the specified ETag does not match, returns true along with the latest ETag of the database +// location. The value of the database location will be stored in v just like a regular Get() call. +// If the etag matches, returns false along with the same ETag passed into the function. No data +// will be stored in v in this case. +func (r *Ref) GetIfChanged(ctx context.Context, etag string, v interface{}) (bool, string, error) { + resp, err := r.send(ctx, "GET", internal.WithHeader("If-None-Match", etag)) + if err != nil { + return false, "", err + } + if resp.Status == http.StatusNotModified { + return false, etag, nil + } + if err := resp.Unmarshal(http.StatusOK, v); err != nil { + return false, "", err + } + return true, resp.Header.Get("ETag"), nil +} + +// Set stores the value v in the current database node. +// +// Set uses https://golang.org/pkg/encoding/json/#Marshal to serialize values into JSON. Therefore +// v has the same requirements as the json package. Values like functions and channels cannot be +// saved into Realtime Database. +func (r *Ref) Set(ctx context.Context, v interface{}) error { + resp, err := r.sendWithBody(ctx, "PUT", v, internal.WithQueryParam("print", "silent")) + if err != nil { + return err + } + return resp.CheckStatus(http.StatusNoContent) +} + +// SetIfUnchanged conditionally sets the data at this location to the given value. +// +// Sets the data at this location to v only if the specified ETag matches. Returns true if the +// value is written. Returns false if no changes are made to the database. +func (r *Ref) SetIfUnchanged(ctx context.Context, etag string, v interface{}) (bool, error) { + resp, err := r.sendWithBody(ctx, "PUT", v, internal.WithHeader("If-Match", etag)) + if err != nil { + return false, err + } + if resp.Status == http.StatusPreconditionFailed { + return false, nil + } + if err := resp.CheckStatus(http.StatusOK); err != nil { + return false, err + } + return true, nil +} + +// Push creates a new child node at the current location, and returns a reference to it. +// +// If v is not nil, it will be set as the initial value of the new child node. If v is nil, the +// new child node will be created with empty string as the value. +func (r *Ref) Push(ctx context.Context, v interface{}) (*Ref, error) { + if v == nil { + v = "" + } + resp, err := r.sendWithBody(ctx, "POST", v) + if err != nil { + return nil, err + } + var d struct { + Name string `json:"name"` + } + if err := resp.Unmarshal(http.StatusOK, &d); err != nil { + return nil, err + } + return r.Child(d.Name), nil +} + +// Update modifies the specified child keys of the current location to the provided values. +func (r *Ref) Update(ctx context.Context, v map[string]interface{}) error { + if len(v) == 0 { + return fmt.Errorf("value argument must be a non-empty map") + } + resp, err := r.sendWithBody(ctx, "PATCH", v, internal.WithQueryParam("print", "silent")) + if err != nil { + return err + } + return resp.CheckStatus(http.StatusNoContent) +} + +// UpdateFn represents a function type that can be passed into Transaction(). +type UpdateFn func(TransactionNode) (interface{}, error) + +// Transaction atomically modifies the data at this location. +// +// Unlike a normal Set(), which just overwrites the data regardless of its previous state, +// Transaction() is used to modify the existing value to a new value, ensuring there are no +// conflicts with other clients simultaneously writing to the same location. +// +// This is accomplished by passing an update function which is used to transform the current value +// of this reference into a new value. If another client writes to this location before the new +// value is successfully saved, the update function is called again with the new current value, and +// the write will be retried. In case of repeated failures, this method will retry the transaction up +// to 25 times before giving up and returning an error. +// +// The update function may also force an early abort by returning an error instead of returning a +// value. +func (r *Ref) Transaction(ctx context.Context, fn UpdateFn) error { + resp, err := r.send(ctx, "GET", internal.WithHeader("X-Firebase-ETag", "true")) + if err != nil { + return err + } else if err := resp.CheckStatus(http.StatusOK); err != nil { + return err + } + etag := resp.Header.Get("Etag") + + for i := 0; i < txnRetries; i++ { + new, err := fn(&transactionNodeImpl{resp.Body}) + if err != nil { + return err + } + resp, err = r.sendWithBody(ctx, "PUT", new, internal.WithHeader("If-Match", etag)) + if err != nil { + return err + } + if resp.Status == http.StatusOK { + return nil + } else if err := resp.CheckStatus(http.StatusPreconditionFailed); err != nil { + return err + } + etag = resp.Header.Get("ETag") + } + return fmt.Errorf("transaction aborted after failed retries") +} + +// Delete removes this node from the database. +func (r *Ref) Delete(ctx context.Context) error { + resp, err := r.send(ctx, "DELETE") + if err != nil { + return err + } + return resp.CheckStatus(http.StatusOK) +} + +func (r *Ref) send( + ctx context.Context, + method string, + opts ...internal.HTTPOption) (*internal.Response, error) { + + return r.client.send(ctx, method, r.Path, nil, opts...) +} + +func (r *Ref) sendWithBody( + ctx context.Context, + method string, + body interface{}, + opts ...internal.HTTPOption) (*internal.Response, error) { + + return r.client.send(ctx, method, r.Path, internal.NewJSONEntity(body), opts...) +} diff --git a/db/ref_test.go b/db/ref_test.go new file mode 100644 index 00000000..93e348d0 --- /dev/null +++ b/db/ref_test.go @@ -0,0 +1,729 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "fmt" + "net/http" + "reflect" + "testing" + + "golang.org/x/net/context" +) + +type refOp func(r *Ref) error + +var testOps = []struct { + name string + resp interface{} + op refOp +}{ + { + "Get()", + "test", + func(r *Ref) error { + var got string + return r.Get(context.Background(), &got) + }, + }, + { + "GetWithETag()", + "test", + func(r *Ref) error { + var got string + _, err := r.GetWithETag(context.Background(), &got) + return err + }, + }, + { + "GetShallow()", + "test", + func(r *Ref) error { + var got string + return r.GetShallow(context.Background(), &got) + }, + }, + { + "GetIfChanged()", + "test", + func(r *Ref) error { + var got string + _, _, err := r.GetIfChanged(context.Background(), "etag", &got) + return err + }, + }, + { + "Set()", + nil, + func(r *Ref) error { + return r.Set(context.Background(), "foo") + }, + }, + { + "SetIfUnchanged()", + nil, + func(r *Ref) error { + _, err := r.SetIfUnchanged(context.Background(), "etag", "foo") + return err + }, + }, + { + "Push()", + map[string]interface{}{"name": "test"}, + func(r *Ref) error { + _, err := r.Push(context.Background(), "foo") + return err + }, + }, + { + "Update()", + nil, + func(r *Ref) error { + return r.Update(context.Background(), map[string]interface{}{"foo": "bar"}) + }, + }, + { + "Delete()", + nil, + func(r *Ref) error { + return r.Delete(context.Background()) + }, + }, + { + "Transaction()", + nil, + func(r *Ref) error { + fn := func(t TransactionNode) (interface{}, error) { + var v interface{} + if err := t.Unmarshal(&v); err != nil { + return nil, err + } + return v, nil + } + return r.Transaction(context.Background(), fn) + }, + }, +} + +func TestGet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + cases := []interface{}{ + nil, float64(1), true, "foo", + map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, + } + var want []*testReq + for _, tc := range cases { + mock.Resp = tc + var got interface{} + if err := testref.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tc, got) { + t.Errorf("Get() = %v; want = %v", got, tc) + } + want = append(want, &testReq{Method: "GET", Path: "/peter.json"}) + } + checkAllRequests(t, mock.Reqs, want) +} + +func TestInvalidGet(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + got := func() {} + if err := testref.Get(context.Background(), &got); err == nil { + t.Errorf("Get(func) = nil; want error") + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + +func TestGetWithStruct(t *testing.T) { + want := person{Name: "Peter Parker", Age: 17} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + var got person + if err := testref.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if want != got { + t.Errorf("Get(struct) = %v; want = %v", got, want) + } + checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) +} + +func TestGetShallow(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + cases := []interface{}{ + nil, float64(1), true, "foo", + map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, + map[string]interface{}{"name": "Peter Parker", "nestedChild": true}, + } + wantQuery := map[string]string{"shallow": "true"} + var want []*testReq + for _, tc := range cases { + mock.Resp = tc + var got interface{} + if err := testref.GetShallow(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tc, got) { + t.Errorf("GetShallow() = %v; want = %v", got, tc) + } + want = append(want, &testReq{Method: "GET", Path: "/peter.json", Query: wantQuery}) + } + checkAllRequests(t, mock.Reqs, want) +} + +func TestGetWithETag(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{ + Resp: want, + Header: map[string]string{"ETag": "mock-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + etag, err := testref.GetWithETag(context.Background(), &got) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("GetWithETag() = %v; want = %v", got, want) + } + if etag != "mock-etag" { + t.Errorf("GetWithETag() = %q; want = %q", etag, "mock-etag") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }) +} + +func TestGetIfChanged(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{ + Resp: want, + Header: map[string]string{"ETag": "new-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var got map[string]interface{} + ok, etag, err := testref.GetIfChanged(context.Background(), "old-etag", &got) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("GetIfChanged() = %v; want = %v", ok, true) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("GetIfChanged() = %v; want = %v", got, want) + } + if etag != "new-etag" { + t.Errorf("GetIfChanged() = %q; want = %q", etag, "new-etag") + } + + mock.Status = http.StatusNotModified + mock.Resp = nil + var got2 map[string]interface{} + ok, etag, err = testref.GetIfChanged(context.Background(), "new-etag", &got2) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("GetIfChanged() = %v; want = %v", ok, false) + } + if got2 != nil { + t.Errorf("GetIfChanged() = %v; want nil", got2) + } + if etag != "new-etag" { + t.Errorf("GetIfChanged() = %q; want = %q", etag, "new-etag") + } + + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"If-None-Match": []string{"old-etag"}}, + }, + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"If-None-Match": []string{"new-etag"}}, + }, + }) +} + +func TestWelformedHttpError(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} + srv := mock.Start(client) + defer srv.Close() + + want := "http error status: 500; reason: test error" + for _, tc := range testOps { + err := tc.op(testref) + if err == nil || err.Error() != want { + t.Errorf("%s = %v; want = %v", tc.name, err, want) + } + } + + if len(mock.Reqs) != len(testOps) { + t.Errorf("Requests = %d; want = %d", len(mock.Reqs), len(testOps)) + } +} + +func TestUnexpectedHttpError(t *testing.T) { + mock := &mockServer{Resp: "unexpected error", Status: 500} + srv := mock.Start(client) + defer srv.Close() + + want := "http error status: 500; reason: \"unexpected error\"" + for _, tc := range testOps { + err := tc.op(testref) + if err == nil || err.Error() != want { + t.Errorf("%s = %v; want = %v", tc.name, err, want) + } + } + + if len(mock.Reqs) != len(testOps) { + t.Errorf("Requests = %d; want = %d", len(mock.Reqs), len(testOps)) + } +} + +func TestInvalidPath(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() + + cases := []string{ + "foo$", "foo.", "foo#", "foo]", "foo[", + } + for _, tc := range cases { + r := client.NewRef(tc) + for _, o := range testOps { + err := o.op(r) + if err == nil { + t.Errorf("%s = nil; want = error", o.name) + } + } + } + + if len(mock.Reqs) != 0 { + t.Errorf("Requests = %v; want = empty", mock.Reqs) + } +} + +func TestInvalidChildPath(t *testing.T) { + mock := &mockServer{Resp: "test"} + srv := mock.Start(client) + defer srv.Close() + + cases := []string{ + "foo$", "foo.", "foo#", "foo]", "foo[", + } + for _, tc := range cases { + r := testref.Child(tc) + for _, o := range testOps { + err := o.op(r) + if err == nil { + t.Errorf("%s = nil; want = error", o.name) + } + } + } + + if len(mock.Reqs) != 0 { + t.Errorf("Requests = %v; want = empty", mock.Reqs) + } +} + +func TestSet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + cases := []interface{}{ + 1, + true, + "foo", + map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, + &person{"Peter Parker", 17}, + } + var want []*testReq + for _, tc := range cases { + if err := testref.Set(context.Background(), tc); err != nil { + t.Fatal(err) + } + want = append(want, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(tc), + Query: map[string]string{"print": "silent"}, + }) + } + checkAllRequests(t, mock.Reqs, want) +} + +func TestInvalidSet(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + cases := []interface{}{ + func() {}, + make(chan int), + } + for _, tc := range cases { + if err := testref.Set(context.Background(), tc); err == nil { + t.Errorf("Set(%v) = nil; want = error", tc) + } + } + if len(mock.Reqs) != 0 { + t.Errorf("Set() = %v; want = empty", mock.Reqs) + } +} + +func TestSetIfUnchanged(t *testing.T) { + mock := &mockServer{} + srv := mock.Start(client) + defer srv.Close() + + want := &person{"Peter Parker", 17} + ok, err := testref.SetIfUnchanged(context.Background(), "mock-etag", &want) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(want), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }) +} + +func TestSetIfUnchangedError(t *testing.T) { + mock := &mockServer{ + Status: http.StatusPreconditionFailed, + Resp: &person{"Tony Stark", 39}, + } + srv := mock.Start(client) + defer srv.Close() + + want := &person{"Peter Parker", 17} + ok, err := testref.SetIfUnchanged(context.Background(), "mock-etag", &want) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, false) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(want), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }) +} + +func TestPush(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"name": "new_key"}} + srv := mock.Start(client) + defer srv.Close() + + child, err := testref.Push(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + if child.Key != "new_key" { + t.Errorf("Push() = %q; want = %q", child.Key, "new_key") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "POST", + Path: "/peter.json", + Body: serialize(""), + }) +} + +func TestPushWithValue(t *testing.T) { + mock := &mockServer{Resp: map[string]string{"name": "new_key"}} + srv := mock.Start(client) + defer srv.Close() + + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + child, err := testref.Push(context.Background(), want) + if err != nil { + t.Fatal(err) + } + + if child.Key != "new_key" { + t.Errorf("Push() = %q; want = %q", child.Key, "new_key") + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "POST", + Path: "/peter.json", + Body: serialize(want), + }) +} + +func TestUpdate(t *testing.T) { + want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} + mock := &mockServer{Resp: want} + srv := mock.Start(client) + defer srv.Close() + + if err := testref.Update(context.Background(), want); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "PATCH", + Path: "/peter.json", + Body: serialize(want), + Query: map[string]string{"print": "silent"}, + }) +} + +func TestInvalidUpdate(t *testing.T) { + cases := []map[string]interface{}{ + nil, + make(map[string]interface{}), + map[string]interface{}{"foo": func() {}}, + } + for _, tc := range cases { + if err := testref.Update(context.Background(), tc); err == nil { + t.Errorf("Update(%v) = nil; want error", tc) + } + } +} + +func TestTransaction(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag"}, + } + srv := mock.Start(client) + defer srv.Close() + + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil + } + if err := testref.Transaction(context.Background(), fn); err != nil { + t.Fatal(err) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag"}}, + }, + }) +} + +func TestTransactionRetry(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag2"} + mock.Resp = &person{"Peter Parker", 19} + } else if cnt == 1 { + mock.Status = http.StatusOK + } + cnt++ + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil + } + if err := testref.Transaction(context.Background(), fn); err != nil { + t.Fatal(err) + } + if cnt != 2 { + t.Errorf("Transaction() retries = %d; want = %d", cnt, 2) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 20, + }), + Header: http.Header{"If-Match": []string{"mock-etag2"}}, + }, + }) +} + +func TestTransactionError(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + want := "user error" + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag2"} + mock.Resp = &person{"Peter Parker", 19} + } else if cnt == 1 { + return nil, fmt.Errorf(want) + } + cnt++ + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil + } + if err := testref.Transaction(context.Background(), fn); err == nil || err.Error() != want { + t.Errorf("Transaction() = %v; want = %q", err, want) + } + if cnt != 1 { + t.Errorf("Transaction() retries = %d; want = %d", cnt, 1) + } + checkAllRequests(t, mock.Reqs, []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }, + }) +} + +func TestTransactionAbort(t *testing.T) { + mock := &mockServer{ + Resp: &person{"Peter Parker", 17}, + Header: map[string]string{"ETag": "mock-etag1"}, + } + srv := mock.Start(client) + defer srv.Close() + + cnt := 0 + var fn UpdateFn = func(t TransactionNode) (interface{}, error) { + if cnt == 0 { + mock.Status = http.StatusPreconditionFailed + mock.Header = map[string]string{"ETag": "mock-etag1"} + } + cnt++ + var p person + if err := t.Unmarshal(&p); err != nil { + return nil, err + } + p.Age++ + return &p, nil + } + err := testref.Transaction(context.Background(), fn) + if err == nil { + t.Errorf("Transaction() = nil; want error") + } + wanted := []*testReq{ + &testReq{ + Method: "GET", + Path: "/peter.json", + Header: http.Header{"X-Firebase-ETag": []string{"true"}}, + }, + } + for i := 0; i < txnRetries; i++ { + wanted = append(wanted, &testReq{ + Method: "PUT", + Path: "/peter.json", + Body: serialize(map[string]interface{}{ + "name": "Peter Parker", + "age": 18, + }), + Header: http.Header{"If-Match": []string{"mock-etag1"}}, + }) + } + checkAllRequests(t, mock.Reqs, wanted) +} + +func TestDelete(t *testing.T) { + mock := &mockServer{Resp: "null"} + srv := mock.Start(client) + defer srv.Close() + + if err := testref.Delete(context.Background()); err != nil { + t.Fatal(err) + } + checkOnlyRequest(t, mock.Reqs, &testReq{ + Method: "DELETE", + Path: "/peter.json", + }) +} diff --git a/firebase.go b/firebase.go index ed09ac6d..0e34c058 100644 --- a/firebase.go +++ b/firebase.go @@ -18,6 +18,7 @@ package firebase import ( + "context" "encoding/json" "errors" "io/ioutil" @@ -26,36 +27,30 @@ import ( "cloud.google.com/go/firestore" "firebase.google.com/go/auth" + "firebase.google.com/go/db" "firebase.google.com/go/iid" "firebase.google.com/go/internal" "firebase.google.com/go/messaging" "firebase.google.com/go/storage" - "golang.org/x/net/context" "golang.org/x/oauth2/google" "google.golang.org/api/option" "google.golang.org/api/transport" ) -var firebaseScopes = []string{ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/datastore", - "https://www.googleapis.com/auth/devstorage.full_control", - "https://www.googleapis.com/auth/firebase", - "https://www.googleapis.com/auth/identitytoolkit", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/firebase.messaging", -} +var defaultAuthOverrides = make(map[string]interface{}) // Version of the Firebase Go Admin SDK. -const Version = "2.5.0" +const Version = "2.6.0" // firebaseEnvName is the name of the environment variable with the Config. const firebaseEnvName = "FIREBASE_CONFIG" // An App holds configuration and state common to all Firebase services that are exposed from the SDK. type App struct { + authOverride map[string]interface{} creds *google.DefaultCredentials + dbURL string projectID string storageBucket string opts []option.ClientOption @@ -63,8 +58,10 @@ type App struct { // Config represents the configuration used to initialize an App. type Config struct { - ProjectID string `json:"projectId"` - StorageBucket string `json:"storageBucket"` + AuthOverride *map[string]interface{} `json:"databaseAuthVariableOverride"` + DatabaseURL string `json:"databaseURL"` + ProjectID string `json:"projectId"` + StorageBucket string `json:"storageBucket"` } // Auth returns an instance of auth.Client. @@ -78,6 +75,17 @@ func (a *App) Auth(ctx context.Context) (*auth.Client, error) { return auth.NewClient(ctx, conf) } +// Database returns an instance of db.Client. +func (a *App) Database(ctx context.Context) (*db.Client, error) { + conf := &internal.DatabaseConfig{ + AuthOverride: a.authOverride, + URL: a.dbURL, + Opts: a.opts, + Version: Version, + } + return db.NewClient(ctx, conf) +} + // Storage returns a new instance of storage.Client. func (a *App) Storage(ctx context.Context) (*storage.Client, error) { conf := &internal.StorageConfig{ @@ -124,7 +132,7 @@ func (a *App) Messaging(ctx context.Context) (*messaging.Client, error) { // `FIREBASE_CONFIG` environment variable. If the value in it starts with a `{` it is parsed as a // JSON object, otherwise it is assumed to be the name of the JSON file containing the options. func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (*App, error) { - o := []option.ClientOption{option.WithScopes(firebaseScopes...)} + o := []option.ClientOption{option.WithScopes(internal.FirebaseScopes...)} o = append(o, opts...) creds, err := transport.Creds(ctx, o...) if err != nil { @@ -145,8 +153,15 @@ func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (* pid = os.Getenv("GCLOUD_PROJECT") } + ao := defaultAuthOverrides + if config.AuthOverride != nil { + ao = *config.AuthOverride + } + return &App{ + authOverride: ao, creds: creds, + dbURL: config.DatabaseURL, projectID: pid, storageBucket: config.StorageBucket, opts: o, @@ -170,6 +185,19 @@ func getConfigDefaults() (*Config, error) { return nil, err } } - err := json.Unmarshal(dat, fbc) - return fbc, err + if err := json.Unmarshal(dat, fbc); err != nil { + return nil, err + } + + // Some special handling necessary for db auth overrides + var m map[string]interface{} + if err := json.Unmarshal(dat, &m); err != nil { + return nil, err + } + if ao, ok := m["databaseAuthVariableOverride"]; ok && ao == nil { + // Auth overrides are explicitly set to null + var nullMap map[string]interface{} + fbc.AuthOverride = &nullMap + } + return fbc, nil } diff --git a/firebase_test.go b/firebase_test.go index fc33ba20..bbf1a4c3 100644 --- a/firebase_test.go +++ b/firebase_test.go @@ -15,12 +15,14 @@ package firebase import ( + "context" "fmt" "io/ioutil" "log" "net/http" "net/http/httptest" "os" + "reflect" "strconv" "strings" "testing" @@ -32,7 +34,6 @@ import ( "encoding/json" - "golang.org/x/net/context" "golang.org/x/oauth2" "google.golang.org/api/option" ) @@ -227,6 +228,48 @@ func TestAuth(t *testing.T) { } } +func TestDatabase(t *testing.T) { + ctx := context.Background() + conf := &Config{DatabaseURL: "https://mock-db.firebaseio.com"} + app, err := NewApp(ctx, conf, option.WithCredentialsFile("testdata/service_account.json")) + if err != nil { + t.Fatal(err) + } + + if app.authOverride == nil || len(app.authOverride) != 0 { + t.Errorf("AuthOverrides = %v; want = empty map", app.authOverride) + } + if c, err := app.Database(ctx); c == nil || err != nil { + t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) + } +} + +func TestDatabaseAuthOverrides(t *testing.T) { + cases := []map[string]interface{}{ + nil, + map[string]interface{}{}, + map[string]interface{}{"uid": "user1"}, + } + for _, tc := range cases { + ctx := context.Background() + conf := &Config{ + AuthOverride: &tc, + DatabaseURL: "https://mock-db.firebaseio.com", + } + app, err := NewApp(ctx, conf, option.WithCredentialsFile("testdata/service_account.json")) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(app.authOverride, tc) { + t.Errorf("AuthOverrides = %v; want = %v", app.authOverride, tc) + } + if c, err := app.Database(ctx); c == nil || err != nil { + t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) + } + } +} + func TestStorage(t *testing.T) { ctx := context.Background() app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/service_account.json")) @@ -360,7 +403,10 @@ func TestVersion(t *testing.T) { } } } + func TestAutoInit(t *testing.T) { + var nullMap map[string]interface{} + uidMap := map[string]interface{}{"uid": "test"} tests := []struct { name string optionsConfig string @@ -378,6 +424,7 @@ func TestAutoInit(t *testing.T) { "testdata/firebase_config.json", nil, &Config{ + DatabaseURL: "https://auto-init.database.url", ProjectID: "auto-init-project-id", StorageBucket: "auto-init.storage.bucket", }, @@ -385,11 +432,13 @@ func TestAutoInit(t *testing.T) { { "", `{ + "databaseURL": "https://auto-init.database.url", "projectId": "auto-init-project-id", "storageBucket": "auto-init.storage.bucket" }`, nil, &Config{ + DatabaseURL: "https://auto-init.database.url", ProjectID: "auto-init-project-id", StorageBucket: "auto-init.storage.bucket", }, @@ -456,6 +505,34 @@ func TestAutoInit(t *testing.T) { StorageBucket: "auto-init.storage.bucket", }, }, + { + "", + `{ + "databaseURL": "https://auto-init.database.url", + "projectId": "auto-init-project-id", + "databaseAuthVariableOverride": null + }`, + nil, + &Config{ + DatabaseURL: "https://auto-init.database.url", + ProjectID: "auto-init-project-id", + AuthOverride: &nullMap, + }, + }, + { + "", + `{ + "databaseURL": "https://auto-init.database.url", + "projectId": "auto-init-project-id", + "databaseAuthVariableOverride": {"uid": "test"} + }`, + nil, + &Config{ + DatabaseURL: "https://auto-init.database.url", + ProjectID: "auto-init-project-id", + AuthOverride: &uidMap, + }, + }, } credOld := overwriteEnv(credEnvVar, "testdata/service_account.json") @@ -523,6 +600,16 @@ func (t *testTokenSource) Token() (*oauth2.Token, error) { } func compareConfig(got *App, want *Config, t *testing.T) { + if got.dbURL != want.DatabaseURL { + t.Errorf("app.dbURL = %q; want = %q", got.dbURL, want.DatabaseURL) + } + if want.AuthOverride != nil { + if !reflect.DeepEqual(got.authOverride, *want.AuthOverride) { + t.Errorf("app.ao = %#v; want = %#v", got.authOverride, *want.AuthOverride) + } + } else if !reflect.DeepEqual(got.authOverride, defaultAuthOverrides) { + t.Errorf("app.ao = %#v; want = nil", got.authOverride) + } if got.projectID != want.ProjectID { t.Errorf("app.projectID = %q; want = %q", got.projectID, want.ProjectID) } @@ -574,7 +661,7 @@ func overwriteEnv(varName, newVal string) string { return oldVal } -// reinstateEnv restores the enviornment variable, will usually be used deferred with overwriteEnv. +// reinstateEnv restores the environment variable, will usually be used deferred with overwriteEnv. func reinstateEnv(varName, oldVal string) { if len(varName) > 0 { os.Setenv(varName, oldVal) diff --git a/iid/iid.go b/iid/iid.go index 980a7bed..b282db40 100644 --- a/iid/iid.go +++ b/iid/iid.go @@ -16,6 +16,7 @@ package iid import ( + "context" "errors" "fmt" "net/http" @@ -23,8 +24,6 @@ import ( "google.golang.org/api/transport" "firebase.google.com/go/internal" - - "golang.org/x/net/context" ) const iidEndpoint = "https://console.firebase.google.com/v1" diff --git a/iid/iid_test.go b/iid/iid_test.go index 6d154650..b3e69638 100644 --- a/iid/iid_test.go +++ b/iid/iid_test.go @@ -15,6 +15,7 @@ package iid import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -23,8 +24,6 @@ import ( "google.golang.org/api/option" "firebase.google.com/go/internal" - - "golang.org/x/net/context" ) var testIIDConfig = &internal.InstanceIDConfig{ diff --git a/integration/auth/auth_test.go b/integration/auth/auth_test.go index 5e027d44..2b8d6bd6 100644 --- a/integration/auth/auth_test.go +++ b/integration/auth/auth_test.go @@ -17,6 +17,7 @@ package auth import ( "bytes" + "context" "encoding/json" "flag" "fmt" @@ -29,8 +30,6 @@ import ( "firebase.google.com/go/auth" "firebase.google.com/go/integration/internal" - - "golang.org/x/net/context" ) const apiURL = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken?key=%s" @@ -45,7 +44,7 @@ func TestMain(m *testing.M) { } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } diff --git a/integration/auth/user_mgt_test.go b/integration/auth/user_mgt_test.go index 3013fbe0..4121d62c 100644 --- a/integration/auth/user_mgt_test.go +++ b/integration/auth/user_mgt_test.go @@ -16,6 +16,7 @@ package auth import ( + "context" "fmt" "reflect" "testing" @@ -24,8 +25,6 @@ import ( "google.golang.org/api/iterator" "firebase.google.com/go/auth" - - "golang.org/x/net/context" ) var testFixtures = struct { @@ -371,6 +370,9 @@ func testRemoveCustomClaims(t *testing.T) { t.Fatal(err) } u, err = client.GetUser(context.Background(), testFixtures.sampleUserBlank.UID) + if err != nil { + t.Fatal(err) + } if u.CustomClaims != nil { t.Errorf("CustomClaims() = %#v; want = nil", u.CustomClaims) } diff --git a/integration/db/db_test.go b/integration/db/db_test.go new file mode 100644 index 00000000..0754d5bf --- /dev/null +++ b/integration/db/db_test.go @@ -0,0 +1,709 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package db contains integration tests for the firebase.google.com/go/db package. +package db + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "log" + "net/http" + "os" + "reflect" + "testing" + + "golang.org/x/net/context" + + "firebase.google.com/go" + "firebase.google.com/go/db" + "firebase.google.com/go/integration/internal" +) + +var client *db.Client +var aoClient *db.Client +var guestClient *db.Client + +var ref *db.Ref +var users *db.Ref +var dinos *db.Ref + +var testData map[string]interface{} +var parsedTestData map[string]Dinosaur + +const permDenied = "http error status: 401; reason: Permission denied" + +func TestMain(m *testing.M) { + flag.Parse() + if testing.Short() { + log.Println("skipping database integration tests in short mode.") + os.Exit(0) + } + + pid, err := internal.ProjectID() + if err != nil { + log.Fatalln(err) + } + + client, err = initClient(pid) + if err != nil { + log.Fatalln(err) + } + + aoClient, err = initOverrideClient(pid) + if err != nil { + log.Fatalln(err) + } + + guestClient, err = initGuestClient(pid) + if err != nil { + log.Fatalln(err) + } + + ref = client.NewRef("_adminsdk/go/dinodb") + dinos = ref.Child("dinosaurs") + users = ref.Parent().Child("users") + + initRules() + initData() + + os.Exit(m.Run()) +} + +func initClient(pid string) (*db.Client, error) { + ctx := context.Background() + app, err := internal.NewTestApp(ctx, &firebase.Config{ + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + }) + if err != nil { + return nil, err + } + + return app.Database(ctx) +} + +func initOverrideClient(pid string) (*db.Client, error) { + ctx := context.Background() + ao := map[string]interface{}{"uid": "user1"} + app, err := internal.NewTestApp(ctx, &firebase.Config{ + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + AuthOverride: &ao, + }) + if err != nil { + return nil, err + } + + return app.Database(ctx) +} + +func initGuestClient(pid string) (*db.Client, error) { + ctx := context.Background() + var nullMap map[string]interface{} + app, err := internal.NewTestApp(ctx, &firebase.Config{ + DatabaseURL: fmt.Sprintf("https://%s.firebaseio.com", pid), + AuthOverride: &nullMap, + }) + if err != nil { + return nil, err + } + + return app.Database(ctx) +} + +func initRules() { + b, err := ioutil.ReadFile(internal.Resource("dinosaurs_index.json")) + if err != nil { + log.Fatalln(err) + } + + pid, err := internal.ProjectID() + if err != nil { + log.Fatalln(err) + } + + url := fmt.Sprintf("https://%s.firebaseio.com/.settings/rules.json", pid) + req, err := http.NewRequest("PUT", url, bytes.NewBuffer(b)) + if err != nil { + log.Fatalln(err) + } + req.Header.Set("Content-Type", "application/json") + + hc, err := internal.NewHTTPClient(context.Background()) + if err != nil { + log.Fatalln(err) + } + resp, err := hc.Do(req) + if err != nil { + log.Fatalln(err) + } + defer resp.Body.Close() + + b, err = ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatalln(err) + } else if resp.StatusCode != http.StatusOK { + log.Fatalln("failed to update rules:", string(b)) + } +} + +func initData() { + b, err := ioutil.ReadFile(internal.Resource("dinosaurs.json")) + if err != nil { + log.Fatalln(err) + } + if err = json.Unmarshal(b, &testData); err != nil { + log.Fatalln(err) + } + + b, err = json.Marshal(testData["dinosaurs"]) + if err != nil { + log.Fatalln(err) + } + if err = json.Unmarshal(b, &parsedTestData); err != nil { + log.Fatalln(err) + } + + if err = ref.Set(context.Background(), testData); err != nil { + log.Fatalln(err) + } +} + +func TestRef(t *testing.T) { + if ref.Key != "dinodb" { + t.Errorf("Key = %q; want = %q", ref.Key, "dinodb") + } + if ref.Path != "/_adminsdk/go/dinodb" { + t.Errorf("Path = %q; want = %q", ref.Path, "/_adminsdk/go/dinodb") + } +} + +func TestChild(t *testing.T) { + c := ref.Child("dinosaurs") + if c.Key != "dinosaurs" { + t.Errorf("Key = %q; want = %q", c.Key, "dinosaurs") + } + if c.Path != "/_adminsdk/go/dinodb/dinosaurs" { + t.Errorf("Path = %q; want = %q", c.Path, "/_adminsdk/go/dinodb/dinosaurs") + } +} + +func TestParent(t *testing.T) { + p := ref.Parent() + if p.Key != "go" { + t.Errorf("Key = %q; want = %q", p.Key, "go") + } + if p.Path != "/_adminsdk/go" { + t.Errorf("Path = %q; want = %q", p.Path, "/_adminsdk/go") + } +} + +func TestGet(t *testing.T) { + var m map[string]interface{} + if err := ref.Get(context.Background(), &m); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("Get() = %v; want = %v", m, testData) + } +} + +func TestGetWithETag(t *testing.T) { + var m map[string]interface{} + etag, err := ref.GetWithETag(context.Background(), &m) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("GetWithETag() = %v; want = %v", m, testData) + } + if etag == "" { + t.Errorf("GetWithETag() = \"\"; want non-empty") + } +} + +func TestGetShallow(t *testing.T) { + var m map[string]interface{} + if err := ref.GetShallow(context.Background(), &m); err != nil { + t.Fatal(err) + } + want := map[string]interface{}{} + for k := range testData { + want[k] = true + } + if !reflect.DeepEqual(want, m) { + t.Errorf("GetShallow() = %v; want = %v", m, want) + } +} + +func TestGetIfChanged(t *testing.T) { + var m map[string]interface{} + ok, etag, err := ref.GetIfChanged(context.Background(), "wrong-etag", &m) + if err != nil { + t.Fatal(err) + } + if !ok || etag == "" { + t.Errorf("GetIfChanged() = (%v, %q); want = (%v, %q)", ok, etag, true, "non-empty") + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("GetWithETag() = %v; want = %v", m, testData) + } + + var m2 map[string]interface{} + ok, etag2, err := ref.GetIfChanged(context.Background(), etag, &m2) + if err != nil { + t.Fatal(err) + } + if ok || etag != etag2 { + t.Errorf("GetIfChanged() = (%v, %q); want = (%v, %q)", ok, etag2, false, etag) + } + if len(m2) != 0 { + t.Errorf("GetWithETag() = %v; want empty", m) + } +} + +func TestGetChildValue(t *testing.T) { + c := ref.Child("dinosaurs") + var m map[string]interface{} + if err := c.Get(context.Background(), &m); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData["dinosaurs"], m) { + t.Errorf("Get() = %v; want = %v", m, testData["dinosaurs"]) + } +} + +func TestGetGrandChildValue(t *testing.T) { + c := ref.Child("dinosaurs/lambeosaurus") + var got Dinosaur + if err := c.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + want := parsedTestData["lambeosaurus"] + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestGetNonExistingChild(t *testing.T) { + c := ref.Child("non_existing") + var i interface{} + if err := c.Get(context.Background(), &i); err != nil { + t.Fatal(err) + } + if i != nil { + t.Errorf("Get() = %v; want nil", i) + } +} + +func TestPush(t *testing.T) { + u, err := users.Push(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if u.Path != "/_adminsdk/go/users/"+u.Key { + t.Errorf("Push() = %q; want = %q", u.Path, "/_adminsdk/go/users/"+u.Key) + } + + var i interface{} + if err := u.Get(context.Background(), &i); err != nil { + t.Fatal(err) + } + if i != "" { + t.Errorf("Get() = %v; want empty string", i) + } +} + +func TestPushWithValue(t *testing.T) { + want := User{"Luis Alvarez", 1911} + u, err := users.Push(context.Background(), &want) + if err != nil { + t.Fatal(err) + } + if u.Path != "/_adminsdk/go/users/"+u.Key { + t.Errorf("Push() = %q; want = %q", u.Path, "/_adminsdk/go/users/"+u.Key) + } + + var got User + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if want != got { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestSetPrimitiveValue(t *testing.T) { + u, err := users.Push(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if err := u.Set(context.Background(), "value"); err != nil { + t.Fatal(err) + } + var got string + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != "value" { + t.Errorf("Get() = %q; want = %q", got, "value") + } +} + +func TestSetComplexValue(t *testing.T) { + u, err := users.Push(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + want := User{"Mary Anning", 1799} + if err := u.Set(context.Background(), &want); err != nil { + t.Fatal(err) + } + var got User + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != want { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestUpdateChildren(t *testing.T) { + u, err := users.Push(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + want := map[string]interface{}{ + "name": "Robert Bakker", + "since": float64(1945), + } + if err := u.Update(context.Background(), want); err != nil { + t.Fatal(err) + } + var got map[string]interface{} + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestUpdateChildrenWithExistingValue(t *testing.T) { + u, err := users.Push(context.Background(), map[string]interface{}{ + "name": "Edwin Colbert", + "since": float64(1900), + }) + if err != nil { + t.Fatal(err) + } + + update := map[string]interface{}{"since": float64(1905)} + if err := u.Update(context.Background(), update); err != nil { + t.Fatal(err) + } + var got map[string]interface{} + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + want := map[string]interface{}{ + "name": "Edwin Colbert", + "since": float64(1905), + } + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestUpdateNestedChildren(t *testing.T) { + edward, err := users.Push(context.Background(), map[string]interface{}{ + "name": "Edward Cope", "since": float64(1800), + }) + if err != nil { + t.Fatal(err) + } + jack, err := users.Push(context.Background(), map[string]interface{}{ + "name": "Jack Horner", "since": float64(1940), + }) + if err != nil { + t.Fatal(err) + } + delta := map[string]interface{}{ + fmt.Sprintf("%s/since", edward.Key): 1840, + fmt.Sprintf("%s/since", jack.Key): 1946, + } + if err := users.Update(context.Background(), delta); err != nil { + t.Fatal(err) + } + var got map[string]interface{} + if err := edward.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + want := map[string]interface{}{"name": "Edward Cope", "since": float64(1840)} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } + + if err := jack.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + want = map[string]interface{}{"name": "Jack Horner", "since": float64(1946)} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestSetIfChanged(t *testing.T) { + edward, err := users.Push(context.Background(), &User{"Edward Cope", 1800}) + if err != nil { + t.Fatal(err) + } + + update := User{"Jack Horner", 1940} + ok, err := edward.SetIfUnchanged(context.Background(), "invalid-etag", &update) + if err != nil { + t.Fatal(err) + } + if ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, false) + } + + var u User + etag, err := edward.GetWithETag(context.Background(), &u) + if err != nil { + t.Fatal(err) + } + ok, err = edward.SetIfUnchanged(context.Background(), etag, &update) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) + } + + if err := edward.Get(context.Background(), &u); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(update, u) { + t.Errorf("Get() = %v; want = %v", u, update) + } +} + +func TestTransaction(t *testing.T) { + u, err := users.Push(context.Background(), &User{Name: "Richard"}) + if err != nil { + t.Fatal(err) + } + fn := func(t db.TransactionNode) (interface{}, error) { + var user User + if err := t.Unmarshal(&user); err != nil { + return nil, err + } + user.Name = "Richard Owen" + user.Since = 1804 + return &user, nil + } + if err := u.Transaction(context.Background(), fn); err != nil { + t.Fatal(err) + } + var got User + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + want := User{"Richard Owen", 1804} + if !reflect.DeepEqual(want, got) { + t.Errorf("Get() = %v; want = %v", got, want) + } +} + +func TestTransactionScalar(t *testing.T) { + cnt := users.Child("count") + if err := cnt.Set(context.Background(), 42); err != nil { + t.Fatal(err) + } + fn := func(t db.TransactionNode) (interface{}, error) { + var snap float64 + if err := t.Unmarshal(&snap); err != nil { + return nil, err + } + return snap + 1, nil + } + if err := cnt.Transaction(context.Background(), fn); err != nil { + t.Fatal(err) + } + var got float64 + if err := cnt.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != 43.0 { + t.Errorf("Get() = %v; want = %v", got, 43.0) + } +} + +func TestDelete(t *testing.T) { + u, err := users.Push(context.Background(), "foo") + if err != nil { + t.Fatal(err) + } + var got string + if err := u.Get(context.Background(), &got); err != nil { + t.Fatal(err) + } + if got != "foo" { + t.Errorf("Get() = %q; want = %q", got, "foo") + } + if err := u.Delete(context.Background()); err != nil { + t.Fatal(err) + } + + var got2 string + if err := u.Get(context.Background(), &got2); err != nil { + t.Fatal(err) + } + if got2 != "" { + t.Errorf("Get() = %q; want = %q", got2, "") + } +} + +func TestNoAccess(t *testing.T) { + r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/admin")) + var got string + if err := r.Get(context.Background(), &got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + if err := r.Set(context.Background(), "update"); err == nil { + t.Errorf("Set() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func TestReadAccess(t *testing.T) { + r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user2")) + var got string + if err := r.Get(context.Background(), &got); err != nil || got != "test" { + t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") + } + if err := r.Set(context.Background(), "update"); err == nil { + t.Errorf("Set() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func TestReadWriteAccess(t *testing.T) { + r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user1")) + var got string + if err := r.Get(context.Background(), &got); err != nil || got != "test" { + t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") + } + if err := r.Set(context.Background(), "update"); err != nil { + t.Errorf("Set() = %v; want = nil", err) + } +} + +func TestQueryAccess(t *testing.T) { + r := aoClient.NewRef("_adminsdk/go/protected") + got := make(map[string]interface{}) + if err := r.OrderByKey().LimitToFirst(2).Get(context.Background(), &got); err == nil { + t.Errorf("OrderByQuery() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func TestGuestAccess(t *testing.T) { + r := guestClient.NewRef(protectedRef(t, "_adminsdk/go/public")) + var got string + if err := r.Get(context.Background(), &got); err != nil || got != "test" { + t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") + } + if err := r.Set(context.Background(), "update"); err == nil { + t.Errorf("Set() = nil; want = error") + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + + got = "" + r = guestClient.NewRef("_adminsdk/go") + if err := r.Get(context.Background(), &got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + + c := r.Child("protected/user2") + if err := c.Get(context.Background(), &got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } + + c = r.Child("admin") + if err := c.Get(context.Background(), &got); err == nil || got != "" { + t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) + } else if err.Error() != permDenied { + t.Errorf("Error = %q; want = %q", err.Error(), permDenied) + } +} + +func TestWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var m map[string]interface{} + if err := ref.Get(ctx, &m); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(testData, m) { + t.Errorf("Get() = %v; want = %v", m, testData) + } + + cancel() + m = nil + if err := ref.Get(ctx, &m); len(m) != 0 || err == nil { + t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) + } +} + +func protectedRef(t *testing.T, p string) string { + r := client.NewRef(p) + if err := r.Set(context.Background(), "test"); err != nil { + t.Fatal(err) + } + return p +} + +type Dinosaur struct { + Appeared int `json:"appeared"` + Height float64 `json:"height"` + Length float64 `json:"length"` + Order string `json:"order"` + Vanished int `json:"vanished"` + Weight int `json:"weight"` + Ratings Ratings `json:"ratings"` +} + +type Ratings struct { + Pos int `json:"pos"` +} + +type User struct { + Name string `json:"name"` + Since int `json:"since"` +} diff --git a/integration/db/query_test.go b/integration/db/query_test.go new file mode 100644 index 00000000..6573d915 --- /dev/null +++ b/integration/db/query_test.go @@ -0,0 +1,266 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package db + +import ( + "testing" + + "firebase.google.com/go/db" + + "reflect" + + "golang.org/x/net/context" +) + +var heightSorted = []string{ + "linhenykus", "pterodactyl", "lambeosaurus", + "triceratops", "stegosaurus", "bruhathkayosaurus", +} + +func TestLimitToFirst(t *testing.T) { + for _, tc := range []int{2, 10} { + results, err := dinos.OrderByChild("height").LimitToFirst(tc).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + wl := min(tc, len(heightSorted)) + want := heightSorted[:wl] + if len(results) != wl { + t.Errorf("LimitToFirst() = %d; want = %d", len(results), wl) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) + } +} + +func TestLimitToLast(t *testing.T) { + for _, tc := range []int{2, 10} { + results, err := dinos.OrderByChild("height").LimitToLast(tc).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + wl := min(tc, len(heightSorted)) + want := heightSorted[len(heightSorted)-wl:] + if len(results) != wl { + t.Errorf("LimitToLast() = %d; want = %d", len(results), wl) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) + } +} + +func TestStartAt(t *testing.T) { + results, err := dinos.OrderByChild("height").StartAt(3.5).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := heightSorted[len(heightSorted)-2:] + if len(results) != len(want) { + t.Errorf("StartAt() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestEndAt(t *testing.T) { + results, err := dinos.OrderByChild("height").EndAt(3.5).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := heightSorted[:4] + if len(results) != len(want) { + t.Errorf("StartAt() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestStartAndEndAt(t *testing.T) { + results, err := dinos.OrderByChild("height").StartAt(2.5).EndAt(5).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] + if len(results) != len(want) { + t.Errorf("StartAt(), EndAt() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestEqualTo(t *testing.T) { + results, err := dinos.OrderByChild("height").EqualTo(0.6).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := heightSorted[:2] + if len(results) != len(want) { + t.Errorf("EqualTo() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestOrderByNestedChild(t *testing.T) { + results, err := dinos.OrderByChild("ratings/pos").StartAt(4).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := []string{"pterodactyl", "stegosaurus", "triceratops"} + if len(results) != len(want) { + t.Errorf("OrderByChild(ratings/pos) = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestOrderByKey(t *testing.T) { + results, err := dinos.OrderByKey().LimitToFirst(2).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := []string{"bruhathkayosaurus", "lambeosaurus"} + if len(results) != len(want) { + t.Errorf("OrderByKey() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + compareValues(t, results) +} + +func TestOrderByValue(t *testing.T) { + scores := ref.Child("scores") + results, err := scores.OrderByValue().LimitToLast(2).GetOrdered(context.Background()) + if err != nil { + t.Fatal(err) + } + + want := []string{"linhenykus", "pterodactyl"} + if len(results) != len(want) { + t.Errorf("OrderByValue() = %d; want = %d", len(results), len(want)) + } + got := getNames(results) + if !reflect.DeepEqual(got, want) { + t.Errorf("LimitToLast() = %v; want = %v", got, want) + } + wantScores := []int{80, 93} + for i, r := range results { + var val int + if err := r.Unmarshal(&val); err != nil { + t.Fatalf("queryNode.Unmarshal() = %v", err) + } + if val != wantScores[i] { + t.Errorf("queryNode.Unmarshal() = %d; want = %d", val, wantScores[i]) + } + } +} + +func TestQueryWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + q := dinos.OrderByKey().LimitToFirst(2) + var m map[string]Dinosaur + if err := q.Get(ctx, &m); err != nil { + t.Fatal(err) + } + + want := []string{"bruhathkayosaurus", "lambeosaurus"} + if len(m) != len(want) { + t.Errorf("OrderByKey() = %d; want = %d", len(m), len(want)) + } + + cancel() + m = nil + if err := q.Get(ctx, &m); len(m) != 0 || err == nil { + t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) + } +} + +func TestUnorderedQuery(t *testing.T) { + var m map[string]Dinosaur + if err := dinos.OrderByChild("height"). + StartAt(2.5). + EndAt(5). + Get(context.Background(), &m); err != nil { + t.Fatal(err) + } + + want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] + if len(m) != len(want) { + t.Errorf("Get() = %d; want = %d", len(m), len(want)) + } + for i, w := range want { + if _, ok := m[w]; !ok { + t.Errorf("[%d] result[%q] not present", i, w) + } + } +} + +func min(i, j int) int { + if i < j { + return i + } + return j +} + +func getNames(results []db.QueryNode) []string { + s := make([]string, len(results)) + for i, v := range results { + s[i] = v.Key() + } + return s +} + +func compareValues(t *testing.T, results []db.QueryNode) { + for _, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + t.Fatalf("queryNode.Unmarshal(%q) = %v", r.Key(), err) + } + if !reflect.DeepEqual(d, parsedTestData[r.Key()]) { + t.Errorf("queryNode.Unmarshal(%q) = %v; want = %v", r.Key(), d, parsedTestData[r.Key()]) + } + } +} diff --git a/integration/firestore/firestore_test.go b/integration/firestore/firestore_test.go index 6e7b4e28..1b861d92 100644 --- a/integration/firestore/firestore_test.go +++ b/integration/firestore/firestore_test.go @@ -15,13 +15,12 @@ package firestore import ( + "context" "log" "reflect" "testing" "firebase.google.com/go/integration/internal" - - "golang.org/x/net/context" ) func TestFirestore(t *testing.T) { @@ -30,7 +29,7 @@ func TestFirestore(t *testing.T) { return } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { t.Fatal(err) } diff --git a/integration/iid/iid_test.go b/integration/iid/iid_test.go index 9be5dce0..2b1b1c9c 100644 --- a/integration/iid/iid_test.go +++ b/integration/iid/iid_test.go @@ -16,6 +16,7 @@ package iid import ( + "context" "flag" "log" "os" @@ -23,8 +24,6 @@ import ( "firebase.google.com/go/iid" "firebase.google.com/go/integration/internal" - - "golang.org/x/net/context" ) var client *iid.Client @@ -37,7 +36,7 @@ func TestMain(m *testing.M) { } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } diff --git a/integration/internal/internal.go b/integration/internal/internal.go index bc52a16a..a5cd7af6 100644 --- a/integration/internal/internal.go +++ b/integration/internal/internal.go @@ -16,16 +16,18 @@ package internal import ( + "context" "encoding/json" "go/build" "io/ioutil" + "net/http" "path/filepath" "strings" - "golang.org/x/net/context" - firebase "firebase.google.com/go" + "firebase.google.com/go/internal" "google.golang.org/api/option" + "google.golang.org/api/transport" ) const certPath = "integration_cert.json" @@ -42,15 +44,8 @@ func Resource(name string) string { // NewTestApp looks for a service account JSON file named integration_cert.json // in the testdata directory. This file is used to initialize the newly created // App instance. -func NewTestApp(ctx context.Context) (*firebase.App, error) { - pid, err := ProjectID() - if err != nil { - return nil, err - } - config := &firebase.Config{ - StorageBucket: pid + ".appspot.com", - } - return firebase.NewApp(ctx, config, option.WithCredentialsFile(Resource(certPath))) +func NewTestApp(ctx context.Context, conf *firebase.Config) (*firebase.App, error) { + return firebase.NewApp(ctx, conf, option.WithCredentialsFile(Resource(certPath))) } // APIKey fetches a Firebase API key for integration tests. @@ -79,3 +74,14 @@ func ProjectID() (string, error) { } return serviceAccount.ProjectID, nil } + +// NewHTTPClient creates an HTTP client for making authorized requests during tests. +func NewHTTPClient(ctx context.Context, opts ...option.ClientOption) (*http.Client, error) { + opts = append( + opts, + option.WithCredentialsFile(Resource(certPath)), + option.WithScopes(internal.FirebaseScopes...), + ) + hc, _, err := transport.NewHTTPClient(ctx, opts...) + return hc, err +} diff --git a/integration/messaging/messaging_test.go b/integration/messaging/messaging_test.go index d7bb0693..4b8ef6d7 100644 --- a/integration/messaging/messaging_test.go +++ b/integration/messaging/messaging_test.go @@ -15,14 +15,13 @@ package messaging import ( + "context" "flag" "log" "os" "regexp" "testing" - "golang.org/x/net/context" - "firebase.google.com/go/integration/internal" "firebase.google.com/go/messaging" ) @@ -45,7 +44,7 @@ func TestMain(m *testing.M) { } ctx := context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } diff --git a/integration/storage/storage_test.go b/integration/storage/storage_test.go index 5efe92d2..b5a205d5 100644 --- a/integration/storage/storage_test.go +++ b/integration/storage/storage_test.go @@ -15,6 +15,7 @@ package storage import ( + "context" "flag" "fmt" "io/ioutil" @@ -22,10 +23,11 @@ import ( "os" "testing" + "firebase.google.com/go" + gcs "cloud.google.com/go/storage" "firebase.google.com/go/integration/internal" "firebase.google.com/go/storage" - "golang.org/x/net/context" ) var ctx context.Context @@ -38,8 +40,15 @@ func TestMain(m *testing.M) { os.Exit(0) } + pid, err := internal.ProjectID() + if err != nil { + log.Fatalln(err) + } + ctx = context.Background() - app, err := internal.NewTestApp(ctx) + app, err := internal.NewTestApp(ctx, &firebase.Config{ + StorageBucket: fmt.Sprintf("%s.appspot.com", pid), + }) if err != nil { log.Fatalln(err) } diff --git a/internal/http_client.go b/internal/http_client.go index bd40c366..984e8a1d 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -16,13 +16,12 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "io" "io/ioutil" "net/http" - - "golang.org/x/net/context" ) // HTTPClient is a convenient API to make HTTP calls. diff --git a/internal/http_client_test.go b/internal/http_client_test.go index bdac7474..14729d17 100644 --- a/internal/http_client_test.go +++ b/internal/http_client_test.go @@ -14,14 +14,13 @@ package internal import ( + "context" "encoding/json" "io/ioutil" "net/http" "net/http/httptest" "reflect" "testing" - - "golang.org/x/net/context" ) var cases = []struct { diff --git a/internal/internal.go b/internal/internal.go index 225edc9e..bc4f41d1 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -21,6 +21,16 @@ import ( "google.golang.org/api/option" ) +// FirebaseScopes is the set of OAuth2 scopes used by the Admin SDK. +var FirebaseScopes = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/datastore", + "https://www.googleapis.com/auth/devstorage.full_control", + "https://www.googleapis.com/auth/firebase", + "https://www.googleapis.com/auth/identitytoolkit", + "https://www.googleapis.com/auth/userinfo.email", +} + // AuthConfig represents the configuration of Firebase Auth service. type AuthConfig struct { Opts []option.ClientOption @@ -35,6 +45,14 @@ type InstanceIDConfig struct { ProjectID string } +// DatabaseConfig represents the configuration of Firebase Database service. +type DatabaseConfig struct { + Opts []option.ClientOption + URL string + Version string + AuthOverride map[string]interface{} +} + // StorageConfig represents the configuration of Google Cloud Storage service. type StorageConfig struct { Opts []option.ClientOption diff --git a/messaging/messaging.go b/messaging/messaging.go index 97b77d64..231e7212 100644 --- a/messaging/messaging.go +++ b/messaging/messaging.go @@ -17,6 +17,7 @@ package messaging import ( + "context" "encoding/json" "errors" "fmt" @@ -25,8 +26,6 @@ import ( "strings" "time" - "golang.org/x/net/context" - "firebase.google.com/go/internal" "google.golang.org/api/transport" ) @@ -42,13 +41,20 @@ var ( topicNamePattern = regexp.MustCompile("^(/topics/)?(private/)?[a-zA-Z0-9-_.~%]+$") fcmErrorCodes = map[string]string{ + // FCM v1 canonical error codes + "NOT_FOUND": "app instance has been unregistered; code: registration-token-not-registered", + "PERMISSION_DENIED": "sender id does not match regisration token; code: mismatched-credential", + "RESOURCE_EXHAUSTED": "messaging service quota exceeded; code: message-rate-exceeded", + "UNAUTHENTICATED": "apns certificate or auth key was invalid; code: invalid-apns-credentials", + + // FCM v1 new error codes + "APNS_AUTH_ERROR": "apns certificate or auth key was invalid; code: invalid-apns-credentials", + "INTERNAL": "back servers encountered an unknown internl error; code: internal-error", "INVALID_ARGUMENT": "request contains an invalid argument; code: invalid-argument", - "UNREGISTERED": "app instance has been unregistered; code: registration-token-not-registered", - "SENDER_ID_MISMATCH": "sender id does not match regisration token; code: authentication-error", + "SENDER_ID_MISMATCH": "sender id does not match regisration token; code: mismatched-credential", "QUOTA_EXCEEDED": "messaging service quota exceeded; code: message-rate-exceeded", - "APNS_AUTH_ERROR": "apns certificate or auth key was invalid; code: authentication-error", "UNAVAILABLE": "backend servers are temporarily unavailable; code: server-unavailable", - "INTERNAL": "back servers encountered an unknown internl error; code: internal-error", + "UNREGISTERED": "app instance has been unregistered; code: registration-token-not-registered", } iidErrorCodes = map[string]string{ diff --git a/messaging/messaging_test.go b/messaging/messaging_test.go index 27808e84..79ef498b 100644 --- a/messaging/messaging_test.go +++ b/messaging/messaging_test.go @@ -25,9 +25,8 @@ import ( "testing" "time" - "google.golang.org/api/option" - "firebase.google.com/go/internal" + "google.golang.org/api/option" ) const testMessageID = "projects/test-project/messages/msg_id" @@ -633,6 +632,10 @@ func TestSendError(t *testing.T) { resp: "{\"error\": {\"status\": \"INVALID_ARGUMENT\", \"message\": \"test error\"}}", want: "http error status: 500; reason: request contains an invalid argument; code: invalid-argument", }, + { + resp: "{\"error\": {\"status\": \"NOT_FOUND\", \"message\": \"test error\"}}", + want: "http error status: 500; reason: app instance has been unregistered; code: registration-token-not-registered", + }, { resp: "not json", want: "http error status: 500; reason: server responded with an unknown error; response: not json", diff --git a/snippets/auth.go b/snippets/auth.go index 31d24837..9fb739ba 100644 --- a/snippets/auth.go +++ b/snippets/auth.go @@ -20,7 +20,6 @@ import ( firebase "firebase.google.com/go" "firebase.google.com/go/auth" - "google.golang.org/api/iterator" ) @@ -94,10 +93,8 @@ func verifyIDToken(app *firebase.App, idToken string) *auth.Token { // https://firebase.google.com/docs/auth/admin/manage-sessions // ================================================================== -func revokeRefreshTokens(app *firebase.App, uid string) { - +func revokeRefreshTokens(ctx context.Context, app *firebase.App, uid string) { // [START revoke_tokens_golang] - ctx := context.Background() client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) @@ -115,8 +112,7 @@ func revokeRefreshTokens(app *firebase.App, uid string) { // [END revoke_tokens_golang] } -func verifyIDTokenAndCheckRevoked(app *firebase.App, idToken string) *auth.Token { - ctx := context.Background() +func verifyIDTokenAndCheckRevoked(ctx context.Context, app *firebase.App, idToken string) *auth.Token { // [START verify_id_token_and_check_revoked_golang] client, err := app.Auth(ctx) if err != nil { @@ -145,7 +141,7 @@ func getUser(ctx context.Context, app *firebase.App) *auth.UserRecord { // [START get_user_golang] // Get an auth client from the firebase.App - client, err := app.Auth(context.Background()) + client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } @@ -193,7 +189,7 @@ func createUser(ctx context.Context, client *auth.Client) *auth.UserRecord { DisplayName("John Doe"). PhotoURL("http://www.example.com/12345678/photo.png"). Disabled(false) - u, err := client.CreateUser(context.Background(), params) + u, err := client.CreateUser(ctx, params) if err != nil { log.Fatalf("error creating user: %v\n", err) } @@ -209,7 +205,7 @@ func createUserWithUID(ctx context.Context, client *auth.Client) *auth.UserRecor UID(uid). Email("user@example.com"). PhoneNumber("+15555550100") - u, err := client.CreateUser(context.Background(), params) + u, err := client.CreateUser(ctx, params) if err != nil { log.Fatalf("error creating user: %v\n", err) } @@ -229,7 +225,7 @@ func updateUser(ctx context.Context, client *auth.Client) { DisplayName("John Doe"). PhotoURL("http://www.example.com/12345678/photo.png"). Disabled(true) - u, err := client.UpdateUser(context.Background(), uid, params) + u, err := client.UpdateUser(ctx, uid, params) if err != nil { log.Fatalf("error updating user: %v\n", err) } @@ -240,7 +236,7 @@ func updateUser(ctx context.Context, client *auth.Client) { func deleteUser(ctx context.Context, client *auth.Client) { uid := "d" // [START delete_user_golang] - err := client.DeleteUser(context.Background(), uid) + err := client.DeleteUser(ctx, uid) if err != nil { log.Fatalf("error deleting user: %v\n", err) } @@ -252,14 +248,14 @@ func customClaimsSet(ctx context.Context, app *firebase.App) { uid := "uid" // [START set_custom_user_claims_golang] // Get an auth client from the firebase.App - client, err := app.Auth(context.Background()) + client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } // Set admin privilege on the user corresponding to uid. claims := map[string]interface{}{"admin": true} - err = client.SetCustomUserClaims(context.Background(), uid, claims) + err = client.SetCustomUserClaims(ctx, uid, claims) if err != nil { log.Fatalf("error setting custom claims %v\n", err) } @@ -351,7 +347,7 @@ func customClaimsIncremental(ctx context.Context, client *auth.Client) { func listUsers(ctx context.Context, client *auth.Client) { // [START list_all_users_golang] // Note, behind the scenes, the Users() iterator will retrive 1000 Users at a time through the API - iter := client.Users(context.Background(), "") + iter := client.Users(ctx, "") for { user, err := iter.Next() if err == iterator.Done { @@ -366,7 +362,7 @@ func listUsers(ctx context.Context, client *auth.Client) { // Iterating by pages 100 users at a time. // Note that using both the Next() function on an iterator and the NextPage() // on a Pager wrapping that same iterator will result in an error. - pager := iterator.NewPager(client.Users(context.Background(), ""), 100, "") + pager := iterator.NewPager(client.Users(ctx, ""), 100, "") for { var users []*auth.ExportedUserRecord nextPageToken, err := pager.NextPage(&users) diff --git a/snippets/db.go b/snippets/db.go new file mode 100644 index 00000000..8e0bea71 --- /dev/null +++ b/snippets/db.go @@ -0,0 +1,528 @@ +// Copyright 2018 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snippets + +// [START authenticate_db_imports] +import ( + "context" + "fmt" + "log" + + "firebase.google.com/go/db" + + "firebase.google.com/go" + "google.golang.org/api/option" +) + +// [END authenticate_db_imports] + +func authenticateWithAdminPrivileges() { + // [START authenticate_with_admin_privileges] + ctx := context.Background() + conf := &firebase.Config{ + DatabaseURL: "https://databaseName.firebaseio.com", + } + // Fetch the service account key JSON file contents + opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") + + // Initialize the app with a service account, granting admin privileges + app, err := firebase.NewApp(ctx, conf, opt) + if err != nil { + log.Fatalln("Error initializing app:", err) + } + + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // As an admin, the app has access to read and write all data, regradless of Security Rules + ref := client.NewRef("restricted_access/secret_document") + var data map[string]interface{} + if err := ref.Get(ctx, &data); err != nil { + log.Fatalln("Error reading from database:", err) + } + fmt.Println(data) + // [END authenticate_with_admin_privileges] +} + +func authenticateWithLimitedPrivileges() { + // [START authenticate_with_limited_privileges] + ctx := context.Background() + // Initialize the app with a custom auth variable, limiting the server's access + ao := map[string]interface{}{"uid": "my-service-worker"} + conf := &firebase.Config{ + DatabaseURL: "https://databaseName.firebaseio.com", + AuthOverride: &ao, + } + + // Fetch the service account key JSON file contents + opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") + + app, err := firebase.NewApp(ctx, conf, opt) + if err != nil { + log.Fatalln("Error initializing app:", err) + } + + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // The app only has access as defined in the Security Rules + ref := client.NewRef("/some_resource") + var data map[string]interface{} + if err := ref.Get(ctx, &data); err != nil { + log.Fatalln("Error reading from database:", err) + } + fmt.Println(data) + // [END authenticate_with_limited_privileges] +} + +func authenticateWithGuestPrivileges() { + // [START authenticate_with_guest_privileges] + ctx := context.Background() + // Initialize the app with a nil auth variable, limiting the server's access + var nilMap map[string]interface{} + conf := &firebase.Config{ + DatabaseURL: "https://databaseName.firebaseio.com", + AuthOverride: &nilMap, + } + + // Fetch the service account key JSON file contents + opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") + + app, err := firebase.NewApp(ctx, conf, opt) + if err != nil { + log.Fatalln("Error initializing app:", err) + } + + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // The app only has access to public data as defined in the Security Rules + ref := client.NewRef("/some_resource") + var data map[string]interface{} + if err := ref.Get(ctx, &data); err != nil { + log.Fatalln("Error reading from database:", err) + } + fmt.Println(data) + // [END authenticate_with_guest_privileges] +} + +func getReference(ctx context.Context, app *firebase.App) { + // [START get_reference] + // Create a database client from App. + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // Get a database reference to our blog. + ref := client.NewRef("server/saving-data/fireblog") + // [END get_reference] + fmt.Println(ref.Path) +} + +// [START user_type] + +// User is a json-serializable type. +type User struct { + DateOfBirth string `json:"date_of_birth,omitempty"` + FullName string `json:"full_name,omitempty"` + Nickname string `json:"nickname,omitempty"` +} + +// [END user_type] + +func setValue(ctx context.Context, ref *db.Ref) { + // [START set_value] + usersRef := ref.Child("users") + err := usersRef.Set(ctx, map[string]*User{ + "alanisawesome": &User{ + DateOfBirth: "June 23, 1912", + FullName: "Alan Turing", + }, + "gracehop": &User{ + DateOfBirth: "December 9, 1906", + FullName: "Grace Hopper", + }, + }) + if err != nil { + log.Fatalln("Error setting value:", err) + } + // [END set_value] +} + +func setChildValue(ctx context.Context, usersRef *db.Ref) { + // [START set_child_value] + if err := usersRef.Child("alanisawesome").Set(ctx, &User{ + DateOfBirth: "June 23, 1912", + FullName: "Alan Turing", + }); err != nil { + log.Fatalln("Error setting value:", err) + } + + if err := usersRef.Child("gracehop").Set(ctx, &User{ + DateOfBirth: "December 9, 1906", + FullName: "Grace Hopper", + }); err != nil { + log.Fatalln("Error setting value:", err) + } + // [END set_child_value] +} + +func updateChild(ctx context.Context, usersRef *db.Ref) { + // [START update_child] + hopperRef := usersRef.Child("gracehop") + if err := hopperRef.Update(ctx, map[string]interface{}{ + "nickname": "Amazing Grace", + }); err != nil { + log.Fatalln("Error updating child:", err) + } + // [END update_child] +} + +func updateChildren(ctx context.Context, usersRef *db.Ref) { + // [START update_children] + if err := usersRef.Update(ctx, map[string]interface{}{ + "alanisawesome/nickname": "Alan The Machine", + "gracehop/nickname": "Amazing Grace", + }); err != nil { + log.Fatalln("Error updating children:", err) + } + // [END update_children] +} + +func overwriteValue(ctx context.Context, usersRef *db.Ref) { + // [START overwrite_value] + if err := usersRef.Update(ctx, map[string]interface{}{ + "alanisawesome": &User{Nickname: "Alan The Machine"}, + "gracehop": &User{Nickname: "Amazing Grace"}, + }); err != nil { + log.Fatalln("Error updating children:", err) + } + // [END overwrite_value] +} + +// [START post_type] + +// Post is a json-serializable type. +type Post struct { + Author string `json:"author,omitempty"` + Title string `json:"title,omitempty"` +} + +// [END post_type] + +func pushValue(ctx context.Context, ref *db.Ref) { + // [START push_value] + postsRef := ref.Child("posts") + + newPostRef, err := postsRef.Push(ctx, nil) + if err != nil { + log.Fatalln("Error pushing child node:", err) + } + + if err := newPostRef.Set(ctx, &Post{ + Author: "gracehop", + Title: "Announcing COBOL, a New Programming Language", + }); err != nil { + log.Fatalln("Error setting value:", err) + } + + // We can also chain the two calls together + if _, err := postsRef.Push(ctx, &Post{ + Author: "alanisawesome", + Title: "The Turing Machine", + }); err != nil { + log.Fatalln("Error pushing child node:", err) + } + // [END push_value] +} + +func pushAndSetValue(ctx context.Context, postsRef *db.Ref) { + // [START push_and_set_value] + if _, err := postsRef.Push(ctx, &Post{ + Author: "gracehop", + Title: "Announcing COBOL, a New Programming Language", + }); err != nil { + log.Fatalln("Error pushing child node:", err) + } + // [END push_and_set_value] +} + +func pushKey(ctx context.Context, postsRef *db.Ref) { + // [START push_key] + // Generate a reference to a new location and add some data using Push() + newPostRef, err := postsRef.Push(ctx, nil) + if err != nil { + log.Fatalln("Error pushing child node:", err) + } + + // Get the unique key generated by Push() + postID := newPostRef.Key + // [END push_key] + fmt.Println(postID) +} + +func transaction(ctx context.Context, client *db.Client) { + // [START transaction] + fn := func(t db.TransactionNode) (interface{}, error) { + var currentValue int + if err := t.Unmarshal(¤tValue); err != nil { + return nil, err + } + return currentValue + 1, nil + } + + ref := client.NewRef("server/saving-data/fireblog/posts/-JRHTHaIs-jNPLXOQivY/upvotes") + if err := ref.Transaction(ctx, fn); err != nil { + log.Fatalln("Transaction failed to commit:", err) + } + // [END transaction] +} + +func readValue(ctx context.Context, app *firebase.App) { + // [START read_value] + // Create a database client from App. + client, err := app.Database(ctx) + if err != nil { + log.Fatalln("Error initializing database client:", err) + } + + // Get a database reference to our posts + ref := client.NewRef("server/saving-data/fireblog/posts") + + // Read the data at the posts reference (this is a blocking operation) + var post Post + if err := ref.Get(ctx, &post); err != nil { + log.Fatalln("Error reading value:", err) + } + // [END read_value] + fmt.Println(ref.Path) +} + +// [START dinosaur_type] + +// Dinosaur is a json-serializable type. +type Dinosaur struct { + Height int `json:"height"` + Width int `json:"width"` +} + +// [END dinosaur_type] + +func orderByChild(ctx context.Context, client *db.Client) { + // [START order_by_child] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("%s was %d meteres tall", r.Key(), d.Height) + } + // [END order_by_child] +} + +func orderByNestedChild(ctx context.Context, client *db.Client) { + // [START order_by_nested_child] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("dimensions/height").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("%s was %d meteres tall", r.Key(), d.Height) + } + // [END order_by_nested_child] +} + +func orderByKey(ctx context.Context, client *db.Client) { + // [START order_by_key] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByKey().GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + snapshot := make([]Dinosaur, len(results)) + for i, r := range results { + var d Dinosaur + if err := r.Unmarshal(&d); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + snapshot[i] = d + } + fmt.Println(snapshot) + // [END order_by_key] +} + +func orderByValue(ctx context.Context, client *db.Client) { + // [START order_by_value] + ref := client.NewRef("scores") + + results, err := ref.OrderByValue().GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var score int + if err := r.Unmarshal(&score); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("The %s dinosaur's score is %d\n", r.Key(), score) + } + // [END order_by_value] +} + +func limitToLast(ctx context.Context, client *db.Client) { + // [START limit_query_1] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("weight").LimitToLast(2).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END limit_query_1] +} + +func limitToFirst(ctx context.Context, client *db.Client) { + // [START limit_query_2] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").LimitToFirst(2).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END limit_query_2] +} + +func limitWithValueOrder(ctx context.Context, client *db.Client) { + // [START limit_query_3] + ref := client.NewRef("scores") + + results, err := ref.OrderByValue().LimitToLast(3).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + var score int + if err := r.Unmarshal(&score); err != nil { + log.Fatalln("Error unmarshaling result:", err) + } + fmt.Printf("The %s dinosaur's score is %d\n", r.Key(), score) + } + // [END limit_query_3] +} + +func startAt(ctx context.Context, client *db.Client) { + // [START range_query_1] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").StartAt(3).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_1] +} + +func endAt(ctx context.Context, client *db.Client) { + // [START range_query_2] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByKey().EndAt("pterodactyl").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_2] +} + +func startAndEndAt(ctx context.Context, client *db.Client) { + // [START range_query_3] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByKey().StartAt("b").EndAt("b\uf8ff").GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_3] +} + +func equalTo(ctx context.Context, client *db.Client) { + // [START range_query_4] + ref := client.NewRef("dinosaurs") + + results, err := ref.OrderByChild("height").EqualTo(25).GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + for _, r := range results { + fmt.Println(r.Key()) + } + // [END range_query_4] +} + +func complexQuery(ctx context.Context, client *db.Client) { + // [START complex_query] + ref := client.NewRef("dinosaurs") + + var favDinoHeight int + if err := ref.Child("stegosaurus").Child("height").Get(ctx, &favDinoHeight); err != nil { + log.Fatalln("Error querying database:", err) + } + + query := ref.OrderByChild("height").EndAt(favDinoHeight).LimitToLast(2) + results, err := query.GetOrdered(ctx) + if err != nil { + log.Fatalln("Error querying database:", err) + } + if len(results) == 2 { + // Data is ordered by increasing height, so we want the first entry. + // Second entry is stegosarus. + fmt.Printf("The dinosaur just shorter than the stegosaurus is %s\n", results[0].Key()) + } else { + fmt.Println("The stegosaurus is the shortest dino") + } + // [END complex_query] +} diff --git a/snippets/messaging.go b/snippets/messaging.go index f558b40c..18f6e462 100644 --- a/snippets/messaging.go +++ b/snippets/messaging.go @@ -15,13 +15,13 @@ package snippets import ( + "context" "fmt" "log" "time" "firebase.google.com/go" "firebase.google.com/go/messaging" - "golang.org/x/net/context" ) func sendToToken(app *firebase.App) { diff --git a/storage/storage.go b/storage/storage.go index 985b6eb7..878e2175 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -16,12 +16,11 @@ package storage import ( + "context" "errors" "cloud.google.com/go/storage" "firebase.google.com/go/internal" - - "golang.org/x/net/context" ) // Client is the interface for the Firebase Storage service. diff --git a/storage/storage_test.go b/storage/storage_test.go index 833aedf3..7a77e60c 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -15,12 +15,12 @@ package storage import ( + "context" "testing" "google.golang.org/api/option" "firebase.google.com/go/internal" - "golang.org/x/net/context" ) var opts = []option.ClientOption{ diff --git a/testdata/dinosaurs.json b/testdata/dinosaurs.json new file mode 100644 index 00000000..9d7afaab --- /dev/null +++ b/testdata/dinosaurs.json @@ -0,0 +1,78 @@ +{ + "dinosaurs": { + "bruhathkayosaurus": { + "appeared": -70000000, + "height": 25, + "length": 44, + "order": "saurischia", + "vanished": -70000000, + "weight": 135000, + "ratings": { + "pos": 1 + } + }, + "lambeosaurus": { + "appeared": -76000000, + "height": 2.1, + "length": 12.5, + "order": "ornithischia", + "vanished": -75000000, + "weight": 5000, + "ratings": { + "pos": 2 + } + }, + "linhenykus": { + "appeared": -85000000, + "height": 0.6, + "length": 1, + "order": "theropoda", + "vanished": -75000000, + "weight": 3, + "ratings": { + "pos": 3 + } + }, + "pterodactyl": { + "appeared": -150000000, + "height": 0.6, + "length": 0.8, + "order": "pterosauria", + "vanished": -148500000, + "weight": 2, + "ratings": { + "pos": 4 + } + }, + "stegosaurus": { + "appeared": -155000000, + "height": 4, + "length": 9, + "order": "ornithischia", + "vanished": -150000000, + "weight": 2500, + "ratings": { + "pos": 5 + } + }, + "triceratops": { + "appeared": -68000000, + "height": 3, + "length": 8, + "order": "ornithischia", + "vanished": -66000000, + "weight": 11000, + "ratings": { + "pos": 6 + } + } + }, + "scores": { + "bruhathkayosaurus": 55, + "lambeosaurus": 21, + "linhenykus": 80, + "pterodactyl": 93, + "stegosaurus": 5, + "triceratops": 22 + } +} diff --git a/testdata/dinosaurs_index.json b/testdata/dinosaurs_index.json new file mode 100644 index 00000000..bf4a2551 --- /dev/null +++ b/testdata/dinosaurs_index.json @@ -0,0 +1,29 @@ +{ + "rules": { + "_adminsdk": { + "go": { + "dinodb": { + "dinosaurs": { + ".indexOn": ["height", "ratings/pos"] + }, + "scores": { + ".indexOn": ".value" + } + }, + "protected": { + "$uid": { + ".read": "auth != null", + ".write": "$uid === auth.uid" + } + }, + "admin": { + ".read": "false", + ".write": "false" + }, + "public": { + ".read": "true" + } + } + } + } +} \ No newline at end of file diff --git a/testdata/firebase_config.json b/testdata/firebase_config.json index e9a3b5bc..772da62d 100644 --- a/testdata/firebase_config.json +++ b/testdata/firebase_config.json @@ -1,4 +1,5 @@ { + "databaseURL": "https://auto-init.database.url", "projectId": "auto-init-project-id", "storageBucket": "auto-init.storage.bucket" }