Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Make rtdb ref.push() only create local node when no value is passed #739

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
88 changes: 88 additions & 0 deletions firebase_admin/_db_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2023 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Internal utilities for Firebase Realtime Database module"""

import time
import random
import math

_PUSH_CHARS = '-0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz'

def time_now():
return int(time.time()*1000)

def _generate_next_push_id():
"""Creates a unique push id generator.

Creates 20-character string identifiers with the following properties:
1. They're based on timestamps so that they sort after any existing ids.

2. They contain 96-bits of random data after the timestamp so that IDs won't
collide with other clients' IDs.

3. They sort lexicographically*(so the timestamp is converted to characters
that will sort properly).

4. They're monotonically increasing. Even if you generate more than one in
the same timestamp, the latter ones will sort after the former ones. We do
this by using the previous random bits but "incrementing" them by 1 (only
in the case of a timestamp collision).
"""

# Timestamp of last push, used to prevent local collisions if you push twice
# in one ms.
last_push_time = 0

# We generate 96-bits of randomness which get turned into 12 characters and
# appended to the timestamp to prevent collisions with other clients. We
# store the last characters we generated because in the event of a collision,
# we'll use those same characters except "incremented" by one.
last_rand_chars_indexes = []

def next_push_id(now):
nonlocal last_push_time
nonlocal last_rand_chars_indexes
is_duplicate_time = now == last_push_time
last_push_time = now

push_id = ''
for _ in range(8):
push_id = _PUSH_CHARS[now % 64] + push_id
now = math.floor(now / 64)

if not is_duplicate_time:
last_rand_chars_indexes = []
for _ in range(12):
last_rand_chars_indexes.append(random.randrange(64))
else:
for index in range(11, -1, -1):
if last_rand_chars_indexes[index] == 63:
last_rand_chars_indexes[index] = 0
else:
break
if index != 0:
last_rand_chars_indexes[index] += 1
elif index == 0 and last_rand_chars_indexes[index] != 0:
last_rand_chars_indexes[index] += 1

for index in range(12):
push_id += _PUSH_CHARS[last_rand_chars_indexes[index]]

if len(push_id) != 20:
raise ValueError("push_id length should be 20")
return push_id
return next_push_id

get_next_push_id = _generate_next_push_id()
22 changes: 12 additions & 10 deletions firebase_admin/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from firebase_admin import _http_client
from firebase_admin import _sseclient
from firebase_admin import _utils

from firebase_admin import _db_utils

_DB_ATTRIBUTE = '_database'
_INVALID_PATH_CHARACTERS = '[].?#$'
Expand Down Expand Up @@ -301,27 +301,29 @@ def set_if_unchanged(self, expected_etag, value):

raise error

def push(self, value=''):
def push(self, value=None):
"""Creates a new child node.

The optional value argument can be used to provide an initial value for the child node. If
no value is provided, child node will have empty string as the default value.

The optional value argument can be used to provide an initial value for the child node.
If you provide a value, a child node is created and the value written to that location.
If you don't provide a value, the child node is created but nothing is written to the
database and the child remains empty (but you can use the Reference elsewhere).
Args:
value: JSON-serializable initial value for the child node (optional).

Returns:
Reference: A Reference representing the newly created child node.

Raises:
ValueError: If the value is None.
TypeError: If the value is not JSON-serializable.
FirebaseError: If an error occurs while communicating with the remote database server.
"""
if value is None:
raise ValueError('Value must not be None.')
output = self._client.body('post', self._add_suffix(), json=value)
push_id = output.get('name')
now = _db_utils.time_now()
push_id = _db_utils.get_next_push_id(now)
push_ref = self.child(push_id)

if value is not None:
push_ref.set(value)
return self.child(push_id)

def update(self, value):
Expand Down
21 changes: 20 additions & 1 deletion integration/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_push(self, testref):
python = testref.parent
ref = python.child('users').push()
assert ref.path == '/_adminsdk/python/users/' + ref.key
assert ref.get() == ''
assert ref.get() is None

def test_push_with_value(self, testref):
python = testref.parent
Expand All @@ -158,6 +158,25 @@ def test_push_with_value(self, testref):
assert ref.path == '/_adminsdk/python/users/' + ref.key
assert ref.get() == value

def test_push_to_local_ref(self, testref):
python = testref.parent
ref1 = python.child('games').push()
assert ref1.get() is None
ref2 = ref1.push("card")
assert ref2.parent.key == ref1.key
assert ref1.get() == {ref2.key: 'card'}
assert ref2.get() == 'card'

def test_push_set_local_ref(self, testref):
python = testref.parent
ref1 = python.child('games').push().child('card')
ref2 = ref1.push()
assert ref2.get() is None
ref3 = ref1.push('heart')
ref2.set('spade')
assert ref2.get() == 'spade'
assert ref1.parent.get() == {'card': {ref2.key: 'spade', ref3.key: 'heart'}}

def test_set_primitive_value(self, testref):
python = testref.parent
ref = python.child('users').push()
Expand Down
50 changes: 33 additions & 17 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import time

from unittest import mock
import pytest

import firebase_admin
Expand Down Expand Up @@ -145,7 +146,7 @@ def get(cls, ref):

@classmethod
def push(cls, ref):
ref.push()
ref.push({'foo': 'bar'})

@classmethod
def set(cls, ref):
Expand Down Expand Up @@ -179,6 +180,8 @@ class TestReference:
500: exceptions.InternalError,
}

duplicate_timestamp = time.time()

@classmethod
def setup_class(cls):
firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : cls.test_url})
Expand Down Expand Up @@ -392,33 +395,46 @@ def test_set_invalid_update(self, update):
@pytest.mark.parametrize('data', valid_values)
def test_push(self, data):
ref = db.reference('/test')
recorder = self.instrument(ref, json.dumps({'name' : 'testkey'}))
recorder = self.instrument(ref, json.dumps({}))
child = ref.push(data)
assert isinstance(child, db.Reference)
assert child.key == 'testkey'
assert len(child.key) == 20
assert len(recorder) == 1
assert recorder[0].method == 'POST'
assert recorder[0].url == 'https://test.firebaseio.com/test.json'
assert recorder[0].method == 'PUT'
assert recorder[0].url == f'https://test.firebaseio.com/test/{child.key}.json?print=silent'
assert json.loads(recorder[0].body.decode()) == data
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'
assert recorder[0].headers['User-Agent'] == db._USER_AGENT

def test_push_default(self):
ref = db.reference('/test')
recorder = self.instrument(ref, json.dumps({'name' : 'testkey'}))
assert ref.push().key == 'testkey'
assert len(recorder) == 1
assert recorder[0].method == 'POST'
assert recorder[0].url == 'https://test.firebaseio.com/test.json'
assert json.loads(recorder[0].body.decode()) == ''
assert recorder[0].headers['Authorization'] == 'Bearer mock-token'
assert recorder[0].headers['User-Agent'] == db._USER_AGENT
recorder = self.instrument(ref, json.dumps({}))
child = ref.push()
assert isinstance(child, db.Reference)
assert len(child.key) == 20
assert len(recorder) == 0

def test_push_none_value(self):
@pytest.mark.parametrize('data', valid_values)
@mock.patch('time.time', mock.MagicMock(return_value=duplicate_timestamp))
def test_push_duplicate_timestamp(self, data):
ref = db.reference('/test')
self.instrument(ref, '')
with pytest.raises(ValueError):
ref.push(None)
recorder = self.instrument(ref, json.dumps({}))
child = []
child.append(ref.push(data))
child.append(ref.push(data))
key1 = child[0].key
key2 = child[1].key
# First 8 digits are the encoded timestamp
assert key1[:8] == key2[:8]
assert key2 > key1
assert len(recorder) == 2
for index, record in enumerate(recorder):
assert record.method == 'PUT'
assert record.url == \
f'https://test.firebaseio.com/test/{child[index].key}.json?print=silent'
assert json.loads(record.body.decode()) == data
assert record.headers['Authorization'] == 'Bearer mock-token'
assert record.headers['User-Agent'] == db._USER_AGENT

def test_delete(self):
ref = db.reference('/test')
Expand Down
Loading