diff --git a/edgedb/protocol/codecs/base.pyx b/edgedb/protocol/codecs/base.pyx index a40f6e57..8b3d4349 100644 --- a/edgedb/protocol/codecs/base.pyx +++ b/edgedb/protocol/codecs/base.pyx @@ -19,6 +19,8 @@ import codecs +from collections.abc import Mapping as MappingABC + cdef uint64_t RECORD_ENCODER_CHECKED = 1 << 0 cdef uint64_t RECORD_ENCODER_INVALID = 1 << 1 @@ -225,6 +227,86 @@ cdef class BaseNamedRecordCodec(BaseRecordCodec): (codec).dump(level + 1).strip())) return '\n'.join(buf) + cdef encode(self, WriteBuffer buf, object obj): + cdef: + WriteBuffer elem_data + Py_ssize_t objlen + Py_ssize_t i + BaseCodec sub_codec + Py_ssize_t is_dict + Py_ssize_t is_namedtuple + + self._check_encoder() + + # We check in this order (dict, _is_array_iterable, + # MappingABC) so that in the common case of dict or tuple, we + # never do an ABC check. + if cpython.PyDict_Check(obj): + is_dict = True + elif _is_array_iterable(obj): + is_dict = False + elif isinstance(obj, MappingABC): + is_dict = True + else: + raise TypeError( + 'a sized iterable container or mapping ' + 'expected (got type {!r})'.format( + type(obj).__name__)) + is_namedtuple = not is_dict and hasattr(obj, '_fields') + + objlen = len(obj) + if objlen == 0: + buf.write_bytes(EMPTY_RECORD_DATA) + return + + if objlen > _MAXINT32: + raise ValueError('too many elements for a tuple') + + if objlen != len(self.fields_codecs): + raise ValueError( + f'expected {len(self.fields_codecs)} elements in the tuple, ' + f'got {objlen}') + + elem_data = WriteBuffer.new() + for i in range(objlen): + if is_dict: + name = datatypes.record_desc_pointer_name(self.descriptor, i) + try: + item = obj[name] + except KeyError: + raise ValueError( + f"named tuple dict is missing '{name}' key", + ) from None + elif is_namedtuple: + name = datatypes.record_desc_pointer_name(self.descriptor, i) + try: + item = getattr(obj, name) + except AttributeError: + raise ValueError( + f"named tuple is missing '{name}' attribute", + ) from None + else: + item = obj[i] + + elem_data.write_int32(0) # reserved bytes + if item is None: + elem_data.write_int32(-1) + else: + sub_codec = (self.fields_codecs[i]) + try: + sub_codec.encode(elem_data, item) + except (TypeError, ValueError) as e: + value_repr = repr(item) + if len(value_repr) > 40: + value_repr = value_repr[:40] + '...' + raise errors.InvalidArgumentError( + 'invalid input for query argument' + ' ${n}: {v} ({msg})'.format( + n=i, v=value_repr, msg=e)) from e + + buf.write_int32(4 + elem_data.len()) # buffer length + buf.write_int32(objlen) + buf.write_buffer(elem_data) @cython.final cdef class EdegDBCodecContext(pgproto.CodecContext): diff --git a/tests/test_namedtuples.py b/tests/test_namedtuples.py new file mode 100644 index 00000000..04a74ddc --- /dev/null +++ b/tests/test_namedtuples.py @@ -0,0 +1,52 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from collections import namedtuple, UserDict + +import edgedb +from edgedb import _testbase as tb + + +class TestNamedTupleTypes(tb.SyncQueryTestCase): + + async def test_namedtuple_01(self): + NT1 = namedtuple('NT2', ['x', 'y']) + NT2 = namedtuple('NT2', ['y', 'x']) + + ctors = [dict, UserDict, NT1, NT2] + for ctor in ctors: + val = ctor(x=10, y='y') + res = self.client.query_single(''' + select >$0 + ''', val) + + self.assertEqual(res, (10, 'y')) + + async def test_namedtuple_02(self): + NT1 = namedtuple('NT2', ['x', 'z']) + + with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'is missing'): + self.client.query_single(''' + select >$0 + ''', dict(x=20, z='test')) + + with self.assertRaisesRegex(edgedb.InvalidArgumentError, 'is missing'): + self.client.query_single(''' + select >$0 + ''', NT1(x=20, z='test'))