diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 0749373..113e4df 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -57,5 +57,9 @@ jobs: env: TEST_DATABASE_URL: "mysql://username:password@localhost:3306/testsuite" run: "scripts/test" + - name: "Run tests with SQLite" + env: + TEST_DATABASE_URL: "sqlite:///testsuite" + run: "scripts/test" - name: "Enforce coverage" run: "scripts/coverage" diff --git a/README.md b/README.md index 633a1d7..940b63e 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,6 @@ $ pip install orm[sqlite] ``` Driver support is provided using one of [asyncpg][asyncpg], [aiomysql][aiomysql], or [aiosqlite][aiosqlite]. -Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL and [pymysql][pymysql] for MySQL. --- @@ -52,23 +51,22 @@ Note that if you are using any synchronous SQLAlchemy functions such as `engine. ```python import databases import orm -import sqlalchemy database = databases.Database("sqlite:///db.sqlite") -metadata = sqlalchemy.MetaData() +models = orm.ModelRegistry(database=database) class Note(orm.Model): - __tablename__ = "notes" - __database__ = database - __metadata__ = metadata - id = orm.Integer(primary_key=True) - text = orm.String(max_length=100) - completed = orm.Boolean(default=False) + tablename = "notes" + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "text": orm.String(max_length=100), + "completed": orm.Boolean(default=False), + } -# Create the database and tables -engine = sqlalchemy.create_engine(str(database.url)) -metadata.create_all(engine) +# Create the tables +models.create_all() await Note.objects.create(text="Buy the groceries.", completed=False) @@ -78,9 +76,6 @@ print(note) ``` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ -[alembic]: https://alembic.sqlalchemy.org/en/latest/ -[psycopg2]: https://www.psycopg.org/ -[pymysql]: https://github.com/PyMySQL/PyMySQL [asyncpg]: https://github.com/MagicStack/asyncpg [aiomysql]: https://github.com/aio-libs/aiomysql [aiosqlite]: https://github.com/jreese/aiosqlite diff --git a/docs/declaring_models.md b/docs/declaring_models.md index 414ec1c..ea6cfa0 100644 --- a/docs/declaring_models.md +++ b/docs/declaring_models.md @@ -1,42 +1,39 @@ ## Declaring models You can define models by inheriting from `orm.Model` and -defining model fields as attributes in the class. +defining model fields in the `fields` attribute. For each defined model you need to set two special variables: -* `__database__` for database connection. -* `__metadata__` for `SQLAlchemy` functions and migrations. +* `registry` an instance of `orm.ModelRegistry` +* `fields` a `dict` of `orm` fields -You can also specify the table name in database by setting `__tablename__` attribute. +You can also specify the table name in database by setting `tablename` attribute. ```python import databases import orm -import sqlalchemy database = databases.Database("sqlite:///db.sqlite") -metadata = sqlalchemy.MetaData() +models = orm.ModelRegistry(database=database) class Note(orm.Model): - __tablename__ = "notes" - __database__ = database - __metadata__ = metadata - - id = orm.Integer(primary_key=True) - text = orm.String(max_length=100) - completed = orm.Boolean(default=False) + tablename = "notes" + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "text": orm.String(max_length=100), + "completed": orm.Boolean(default=False), + } ``` ORM can create or drop database and tables from models using SQLAlchemy. -For using these functions or `Alembic` migrations, you still have to -install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL and [pymysql][pymysql] for MySQL. - -Afer installing a synchronous DB driver, you can create tables for the models using: +You can use the following methods: ```python -engine = sqlalchemy.create_engine(str(database.url)) -metadata.create_all(engine) +models.create_all() + +models.drop_all() ``` ## Data types @@ -72,6 +69,4 @@ See `TypeSystem` for [type-specific validation keyword arguments][typesystem-fie * `orm.UUID()` * `orm.JSON()` -[psycopg2]: https://www.psycopg.org/ -[pymysql]: https://github.com/PyMySQL/PyMySQL [typesystem-fields]: https://www.encode.io/typesystem/fields/ diff --git a/docs/index.md b/docs/index.md index 633a1d7..b19cbac 100644 --- a/docs/index.md +++ b/docs/index.md @@ -41,7 +41,6 @@ $ pip install orm[sqlite] ``` Driver support is provided using one of [asyncpg][asyncpg], [aiomysql][aiomysql], or [aiosqlite][aiosqlite]. -Note that if you are using any synchronous SQLAlchemy functions such as `engine.create_all()` or [alembic][alembic] migrations then you still have to install a synchronous DB driver: [psycopg2][psycopg2] for PostgreSQL and [pymysql][pymysql] for MySQL. --- @@ -52,23 +51,22 @@ Note that if you are using any synchronous SQLAlchemy functions such as `engine. ```python import databases import orm -import sqlalchemy database = databases.Database("sqlite:///db.sqlite") -metadata = sqlalchemy.MetaData() +models = orm.ModelRegistry(database=database) class Note(orm.Model): - __tablename__ = "notes" - __database__ = database - __metadata__ = metadata - id = orm.Integer(primary_key=True) - text = orm.String(max_length=100) - completed = orm.Boolean(default=False) + tablename = "notes" + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "text": orm.String(max_length=100), + "completed": orm.Boolean(default=False), + } # Create the database and tables -engine = sqlalchemy.create_engine(str(database.url)) -metadata.create_all(engine) +models.create_all() await Note.objects.create(text="Buy the groceries.", completed=False) @@ -78,9 +76,6 @@ print(note) ``` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ -[alembic]: https://alembic.sqlalchemy.org/en/latest/ -[psycopg2]: https://www.psycopg.org/ -[pymysql]: https://github.com/PyMySQL/PyMySQL [asyncpg]: https://github.com/MagicStack/asyncpg [aiomysql]: https://github.com/aio-libs/aiomysql [aiosqlite]: https://github.com/jreese/aiosqlite diff --git a/docs/making_queries.md b/docs/making_queries.md index cfb13a1..367a557 100644 --- a/docs/making_queries.md +++ b/docs/making_queries.md @@ -7,20 +7,19 @@ Let's say you have the following model defined: ```python import databases import orm -import sqlalchemy database = databases.Database("sqlite:///db.sqlite") -metadata = sqlalchemy.MetaData() +models = orm.ModelRegistry(database=database) class Note(orm.Model): - __tablename__ = "notes" - __database__ = database - __metadata__ = metadata - - id = orm.Integer(primary_key=True) - text = orm.String(max_length=100) - completed = orm.Boolean(default=False) + tablename = "notes" + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "text": orm.String(max_length=100), + "completed": orm.Boolean(default=False), + } ``` You can use the following queryset methods: diff --git a/docs/relationships.md b/docs/relationships.md index 19b579c..0561de4 100644 --- a/docs/relationships.md +++ b/docs/relationships.md @@ -7,30 +7,29 @@ Let's say you have the following models defined: ```python import databases import orm -import sqlalchemy database = databases.Database("sqlite:///db.sqlite") -metadata = sqlalchemy.MetaData() +models = orm.ModelRegistry(database=database) class Album(orm.Model): - __tablename__ = "album" - __metadata__ = metadata - __database__ = database - - id = orm.Integer(primary_key=True) - name = orm.String(max_length=100) + tablename = "albums" + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "name": orm.String(max_length=100), + } class Track(orm.Model): - __tablename__ = "track" - __metadata__ = metadata - __database__ = database - - id = orm.Integer(primary_key=True) - album = orm.ForeignKey(Album) - title = orm.String(max_length=100) - position = orm.Integer() + tablename = "tracks" + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "album": orm.ForeignKey(Album), + "title": orm.String(max_length=100), + "position": orm.Integer(), + } ``` You can create some `Album` and `Track` instances: diff --git a/orm/__init__.py b/orm/__init__.py index 1ff29ed..ae58d78 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -15,9 +15,9 @@ Text, Time, ) -from orm.models import Model +from orm.models import Model, ModelRegistry -__version__ = "0.1.9" +__version__ = "0.2.0" __all__ = [ "NoMatch", "MultipleMatches", @@ -36,4 +36,5 @@ "UUID", "ForeignKey", "Model", + "ModelRegistry", ] diff --git a/orm/fields.py b/orm/fields.py index 16bd83b..d8a301c 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -15,26 +15,29 @@ def __init__( **kwargs: typing.Any, ) -> None: if primary_key: - kwargs["allow_null"] = True - super().__init__(**kwargs) # type: ignore + kwargs["read_only"] = True + self.allow_null = kwargs.get("allow_null", False) self.primary_key = primary_key self.index = index self.unique = unique + self.validator = self.get_validator(**kwargs) def get_column(self, name: str) -> sqlalchemy.Column: column_type = self.get_column_type() - allow_null = getattr(self, "allow_null", False) constraints = self.get_constraints() return sqlalchemy.Column( name, column_type, *constraints, primary_key=self.primary_key, - nullable=allow_null and not self.primary_key, + nullable=self.allow_null and not self.primary_key, index=self.index, unique=self.unique, ) + def get_validator(self, **kwargs) -> typesystem.Field: + raise NotImplementedError() # pragma: no cover + def get_column_type(self) -> sqlalchemy.types.TypeEngine: raise NotImplementedError() # pragma: no cover @@ -45,92 +48,146 @@ def expand_relationship(self, value): return value -class String(ModelField, typesystem.String): +class String(ModelField): def __init__(self, **kwargs): assert "max_length" in kwargs, "max_length is required" super().__init__(**kwargs) + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.String(**kwargs) + def get_column_type(self): - return sqlalchemy.String(length=self.max_length) + return sqlalchemy.String(length=self.validator.max_length) + +class Text(ModelField): + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.Text(**kwargs) -class Text(ModelField, typesystem.Text): def get_column_type(self): return sqlalchemy.Text() -class Integer(ModelField, typesystem.Integer): +class Integer(ModelField): + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.Integer(**kwargs) + def get_column_type(self): return sqlalchemy.Integer() -class BigInteger(ModelField, typesystem.Integer): +class Float(ModelField): + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.Float(**kwargs) + def get_column_type(self): - return sqlalchemy.BigInteger() + return sqlalchemy.Float() -class Float(ModelField, typesystem.Float): +class BigInteger(ModelField): + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.Integer(**kwargs) + def get_column_type(self): - return sqlalchemy.Float() + return sqlalchemy.BigInteger() + +class Boolean(ModelField): + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.Boolean(**kwargs) -class Boolean(ModelField, typesystem.Boolean): def get_column_type(self): return sqlalchemy.Boolean() -class DateTime(ModelField, typesystem.DateTime): +class DateTime(ModelField): + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.DateTime(**kwargs) + def get_column_type(self): return sqlalchemy.DateTime() -class Date(ModelField, typesystem.Date): +class Date(ModelField): + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.Date(**kwargs) + def get_column_type(self): return sqlalchemy.Date() -class Time(ModelField, typesystem.Time): +class Time(ModelField): + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.Time(**kwargs) + def get_column_type(self): return sqlalchemy.Time() -class JSON(ModelField, typesystem.Any): +class JSON(ModelField): + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.Any(**kwargs) + def get_column_type(self): return sqlalchemy.JSON() -class ForeignKey(ModelField, typesystem.Field): +class ForeignKey(ModelField): + class ForeignKeyValidator(typesystem.Field): + def validate(self, value): + return value.pk + def __init__(self, to, allow_null: bool = False): super().__init__(allow_null=allow_null) self.to = to - def validate(self, value, strict=False): - return value.pk + @property + def target(self): + if not hasattr(self, "_target"): + if isinstance(self.to, str): + self._target = self.registry.models[self.to] + else: + self._target = self.to + return self._target - def get_constraints(self): - fk_string = self.to.__tablename__ + "." + self.to.__pkname__ - return [sqlalchemy.schema.ForeignKey(fk_string)] + def get_validator(self, **kwargs) -> typesystem.Field: + return self.ForeignKeyValidator() - def get_column_type(self): - to_column = self.to.fields[self.to.__pkname__] - return to_column.get_column_type() + def get_column(self, name: str) -> sqlalchemy.Column: + target = self.target + to_field = target.fields[target.pkname] + + column_type = to_field.get_column_type() + constraints = [ + sqlalchemy.schema.ForeignKey(f"{target.tablename}.{target.pkname}") + ] + return sqlalchemy.Column( + name, + column_type, + *constraints, + nullable=self.allow_null, + ) def expand_relationship(self, value): - if isinstance(value, self.to): + target = self.target + if isinstance(value, target): return value - return self.to({self.to.__pkname__: value}) + return target(pk=value) -class Enum(ModelField, typesystem.Any): +class Enum(ModelField): def __init__(self, enum, **kwargs): super().__init__(**kwargs) self.enum = enum + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.Any(**kwargs) + def get_column_type(self): return sqlalchemy.Enum(self.enum) -class Decimal(ModelField, typesystem.Decimal): +class Decimal(ModelField): def __init__(self, max_digits: int, decimal_places: int, **kwargs): assert max_digits, "max_digits is required" assert decimal_places, "decimal_places is required" @@ -138,10 +195,16 @@ def __init__(self, max_digits: int, decimal_places: int, **kwargs): self.decimal_places = decimal_places super().__init__(**kwargs) + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.Decimal(**kwargs) + def get_column_type(self): return sqlalchemy.Numeric(precision=self.max_digits, scale=self.decimal_places) -class UUID(ModelField, typesystem.UUID): +class UUID(ModelField): + def get_validator(self, **kwargs) -> typesystem.Field: + return typesystem.UUID(**kwargs) + def get_column_type(self): return GUID() diff --git a/orm/models.py b/orm/models.py index 2b15db9..590d975 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,10 +1,13 @@ import typing +import anyio +import databases import sqlalchemy import typesystem -from typesystem.schemas import SchemaMetaclass +from sqlalchemy.ext.asyncio import create_async_engine from orm.exceptions import MultipleMatches, NoMatch +from orm.fields import String, Text FILTER_OPERATORS = { "exact": "__eq__", @@ -19,31 +22,79 @@ } -class ModelMetaclass(SchemaMetaclass): - def __new__( - cls: type, name: str, bases: typing.Sequence[type], attrs: dict - ) -> type: - new_model = super(ModelMetaclass, cls).__new__( # type: ignore - cls, name, bases, attrs - ) +class ModelRegistry: + def __init__(self, database: databases.Database) -> None: + self.database = database + self.models = {} + self.metadata = sqlalchemy.MetaData() - if attrs.get("__abstract__"): - return new_model + def create_all(self): + url = self._get_database_url() + anyio.run(self._create_all, url) - tablename = attrs["__tablename__"] - metadata = attrs["__metadata__"] - pkname = None + def drop_all(self): + url = self._get_database_url() + anyio.run(self._drop_all, url) - columns = [] - for name, field in new_model.fields.items(): + async def _create_all(self, url: str): + engine = create_async_engine(url) + + for model_cls in self.models.values(): + model_cls.build_table() + + async with self.database: + async with engine.begin() as conn: + await conn.run_sync(self.metadata.create_all) + + await engine.dispose() + + async def _drop_all(self, url: str): + engine = create_async_engine(url) + + for model_cls in self.models.values(): + model_cls.build_table() + + async with self.database: + async with engine.begin() as conn: + await conn.run_sync(self.metadata.drop_all) + + await engine.dispose() + + def _get_database_url(self) -> str: + url = self.database.url + if not url.driver: + if url.dialect == "postgresql": + url = url.replace(driver="asyncpg") + elif url.dialect == "mysql": + url = url.replace(driver="aiomysql") + elif url.dialect == "sqlite": + url = url.replace(driver="aiosqlite") + return str(url) + + +class ModelMeta(type): + def __new__(cls, name, bases, attrs): + model_class = super().__new__(cls, name, bases, attrs) + + if "registry" in attrs: + model_class.database = attrs["registry"].database + attrs["registry"].models[name] = model_class + + if "tablename" not in attrs: + setattr(model_class, "tablename", name.lower()) + + for name, field in attrs.get("fields", {}).items(): + setattr(field, "registry", attrs.get("registry")) if field.primary_key: - pkname = name - columns.append(field.get_column(name)) + model_class.pkname = name - new_model.__table__ = sqlalchemy.Table(tablename, metadata, *columns) - new_model.__pkname__ = pkname + return model_class - return new_model + @property + def table(cls): + if not hasattr(cls, "_table"): + cls._table = cls.build_table() + return cls._table class QuerySet: @@ -70,11 +121,20 @@ def __get__(self, instance, owner): @property def database(self): - return self.model_cls.__database__ + return self.model_cls.registry.database @property def table(self): - return self.model_cls.__table__ + return self.model_cls.table + + @property + def schema(self): + fields = {key: field.validator for key, field in self.model_cls.fields.items()} + return typesystem.Schema(fields=fields) + + @property + def pkname(self): + return self.model_cls.pkname def build_select_expression(self): tables = [self.table] @@ -84,9 +144,10 @@ def build_select_expression(self): model_cls = self.model_cls select_from = self.table for part in item.split("__"): - model_cls = model_cls.fields[part].to - select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) - tables.append(model_cls.__table__) + model_cls = model_cls.fields[part].target + table = model_cls.table + select_from = sqlalchemy.sql.join(select_from, table) + tables.append(table) expr = sqlalchemy.sql.select(tables) expr = expr.select_from(select_from) @@ -122,7 +183,7 @@ def _filter_query(self, _exclude: bool = False, **kwargs): select_related = list(self._select_related) if kwargs.get("pk"): - pk_name = self.model_cls.__pkname__ + pk_name = self.model_cls.pkname kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): @@ -150,9 +211,9 @@ def _filter_query(self, _exclude: bool = False, **kwargs): # Walk the relationships to the actual model class # against which the comparison is being made. for part in related_parts: - model_cls = model_cls.fields[part].to + model_cls = model_cls.fields[part].target - column = model_cls.__table__.columns[field_name] + column = model_cls.table.columns[field_name] else: op = "exact" @@ -195,15 +256,40 @@ def _filter_query(self, _exclude: bool = False, **kwargs): order_by=self._order_by, ) - def select_related(self, related): - if not isinstance(related, (list, tuple)): - related = [related] + def search(self, term: typing.Any): + if not term: + return self + + filter_clauses = list(self.filter_clauses) + value = f"%{term}%" + + # has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in term) + # if has_escaped_character: + # # enable escape modifier + # for char in self.ESCAPE_CHARACTERS: + # term = term.replace(char, f'\\{char}') + # term = f"%{value}%" + # + # clause.modifiers['escape'] = '\\' if has_escaped_character else None + + search_fields = [ + name + for name, field in self.model_cls.fields.items() + if isinstance(field, (String, Text)) + ] + search_clauses = [ + self.table.columns[name].ilike(value) for name in search_fields + ] + + if len(search_clauses) > 1: + filter_clauses.append(sqlalchemy.sql.or_(*search_clauses)) + else: + filter_clauses.extend(search_clauses) - related = list(self._select_related) + related return self.__class__( model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=related, + filter_clauses=filter_clauses, + select_related=self._select_related, limit_count=self.limit_count, offset=self.query_offset, order_by=self._order_by, @@ -219,6 +305,20 @@ def order_by(self, *order_by): order_by=order_by, ) + def select_related(self, related): + if not isinstance(related, (list, tuple)): + related = [related] + + related = list(self._select_related) + related + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=related, + limit_count=self.limit_count, + offset=self.query_offset, + order_by=self._order_by, + ) + async def exists(self) -> bool: expr = self.build_select_expression() expr = sqlalchemy.exists(expr).select() @@ -284,24 +384,22 @@ async def first(self, **kwargs): async def create(self, **kwargs): # Validate the keyword arguments. fields = self.model_cls.fields - required = [key for key, value in fields.items() if not value.has_default()] - validator = typesystem.Object( - properties=fields, required=required, additional_properties=False + validator = typesystem.Schema( + fields={key: value.validator for key, value in fields.items()} ) kwargs = validator.validate(kwargs) - # Remove primary key when None to prevent not null constraint in postgresql. - pkname = self.model_cls.__pkname__ - pk = self.model_cls.fields[pkname] - if kwargs[pkname] is None and pk.allow_null: - del kwargs[pkname] + # TODO: Better to implement after UUID, probably need another database + # for key, value in fields.items(): + # if value.validator.read_only and value.validator.has_default(): + # kwargs[key] = value.validator.get_default_value() # Build the insert expression. expr = self.table.insert() expr = expr.values(**kwargs) # Execute the insert, and return a new model instance. - instance = self.model_cls(kwargs) + instance = self.model_cls(**kwargs) instance.pk = await self.database.execute(expr) return instance @@ -320,37 +418,55 @@ def _prepare_order_by(self, order_by: str): return order_col.desc() if reverse else order_col -class Model(typesystem.Schema, metaclass=ModelMetaclass): - __abstract__ = True - +class Model(metaclass=ModelMeta): objects = QuerySet() - def __init__(self, *args, **kwargs): + def __init__(self, **kwargs): if "pk" in kwargs: - kwargs[self.__pkname__] = kwargs.pop("pk") - super().__init__(*args, **kwargs) + kwargs[self.pkname] = kwargs.pop("pk") + for key, value in kwargs.items(): + if key not in self.fields: + raise ValueError( + f"Invalid keyword {key} for class {self.__class__.__name__}" + ) + setattr(self, key, value) @property def pk(self): - return getattr(self, self.__pkname__) + return getattr(self, self.pkname) @pk.setter def pk(self, value): - setattr(self, self.__pkname__, value) + setattr(self, self.pkname, value) + + @classmethod + def build_table(cls): + tablename = cls.tablename + metadata = cls.registry.metadata + columns = [] + for name, field in cls.fields.items(): + columns.append(field.get_column(name)) + return sqlalchemy.Table(tablename, metadata, *columns, extend_existing=True) + + @property + def table(self): + return self.__class__.table async def update(self, **kwargs): # Validate the keyword arguments. - fields = {key: field for key, field in self.fields.items() if key in kwargs} - validator = typesystem.Object(properties=fields) + fields = { + key: field.validator for key, field in self.fields.items() if key in kwargs + } + validator = typesystem.Schema(fields=fields) kwargs = validator.validate(kwargs) # Build the update expression. - pk_column = getattr(self.__table__.c, self.__pkname__) - expr = self.__table__.update() + pk_column = getattr(self.table.c, self.pkname) + expr = self.table.update() expr = expr.values(**kwargs).where(pk_column == self.pk) # Perform the update. - await self.__database__.execute(expr) + await self.database.execute(expr) # Update the model instance. for key, value in kwargs.items(): @@ -358,19 +474,19 @@ async def update(self, **kwargs): async def delete(self): # Build the delete expression. - pk_column = getattr(self.__table__.c, self.__pkname__) - expr = self.__table__.delete().where(pk_column == self.pk) + pk_column = getattr(self.table.c, self.pkname) + expr = self.table.delete().where(pk_column == self.pk) # Perform the delete. - await self.__database__.execute(expr) + await self.database.execute(expr) async def load(self): # Build the select expression. - pk_column = getattr(self.__table__.c, self.__pkname__) - expr = self.__table__.select().where(pk_column == self.pk) + pk_column = getattr(self.table.c, self.pkname) + expr = self.table.select().where(pk_column == self.pk) # Perform the fetch. - row = await self.__database__.fetch_one(expr) + row = await self.database.fetch_one(expr) # Update the instance. for key, value in dict(row._mapping).items(): @@ -387,18 +503,18 @@ def from_row(cls, row, select_related=[]): for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) - model_cls = cls.fields[first_part].to + model_cls = cls.fields[first_part].target item[first_part] = model_cls.from_row(row, select_related=[remainder]) else: - model_cls = cls.fields[related].to + model_cls = cls.fields[related].target item[related] = model_cls.from_row(row) # Pull out the regular column values. - for column in cls.__table__.columns: + for column in cls.table.columns: if column.name not in item: item[column.name] = row[column] - return cls(item) + return cls(**item) def __setattr__(self, key, value): if key in self.fields: @@ -406,3 +522,11 @@ def __setattr__(self, key, value): # fully-fledged relationship instance, with just the pk loaded. value = self.fields[key].expand_relationship(value) super().__setattr__(key, value) + + def __eq__(self, other): + if self.__class__ != other.__class__: + return False + for key in self.fields.keys(): + if getattr(self, key, None) != getattr(other, key, None): + return False + return True diff --git a/requirements.txt b/requirements.txt index 48f32cf..2dbfb80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,4 @@ -databases[postgresql, mysql] -psycopg2-binary -pymysql +databases[postgresql, mysql, sqlite] typesystem # Packaging @@ -8,7 +6,6 @@ twine wheel # Testing -anyio>=3.0.0,<4 autoflake black codecov diff --git a/setup.cfg b/setup.cfg index cbd8464..e03bee4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,7 @@ filterwarnings= # Turn warnings that aren't filtered into exceptions error ignore::DeprecationWarning + ignore::sqlalchemy.exc.SAWarning [coverage:run] source_pkgs = orm, tests diff --git a/setup.py b/setup.py index 5e55658..629a8d1 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ PACKAGE = "orm" URL = "https://github.com/encode/orm" + def get_version(package): """ Return package version as listed in `__version__` in `init.py`. @@ -50,12 +51,12 @@ def get_packages(package): packages=get_packages(PACKAGE), package_data={PACKAGE: ["py.typed"]}, data_files=[("", ["LICENSE.md"])], - install_requires=["databases>=0.2.1", "typesystem"], + install_requires=["anyio>=3.0.0,<4", "databases>=0.5.0", "typesystem>=0.3.0"], extras_require={ "postgresql": ["asyncpg"], "mysql": ["aiomysql"], "sqlite": ["aiosqlite"], - "postgresql+aiopg": ["aiopg"] + "postgresql+aiopg": ["aiopg"], }, classifiers=[ "Development Status :: 3 - Alpha", diff --git a/tests/test_columns.py b/tests/test_columns.py index 52aeb07..7faeae1 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -5,15 +5,14 @@ import databases import pytest -import sqlalchemy import orm from tests.settings import DATABASE_URL pytestmark = pytest.mark.anyio -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL) +models = orm.ModelRegistry(database=database) def time(): @@ -26,63 +25,61 @@ class StatusEnum(Enum): class Example(orm.Model): - __tablename__ = "example" - __metadata__ = metadata - __database__ = database - - id = orm.Integer(primary_key=True) - uuid = orm.UUID(allow_null=True) - huge_number = orm.BigInteger(default=9223372036854775807) - created = orm.DateTime(default=datetime.datetime.now) - created_day = orm.Date(default=datetime.date.today) - created_time = orm.Time(default=time) - description = orm.Text(allow_blank=True) - value = orm.Float(allow_null=True) - price = orm.Decimal(max_digits=5, decimal_places=2, allow_null=True) - data = orm.JSON(default={}) - status = orm.Enum(StatusEnum, default=StatusEnum.DRAFT) + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "uuid": orm.UUID(allow_null=True), + "created": orm.DateTime(default=datetime.datetime.now), + "created_day": orm.Date(default=datetime.date.today), + "created_time": orm.Time(default=time), + "data": orm.JSON(default={}), + "description": orm.Text(allow_blank=True), + "huge_number": orm.BigInteger(default=0), + "price": orm.Decimal(max_digits=5, decimal_places=2, allow_null=True), + "status": orm.Enum(StatusEnum, default=StatusEnum.DRAFT), + "value": orm.Float(allow_null=True), + } @pytest.fixture(autouse=True, scope="module") def create_test_database(): - database_url = databases.DatabaseURL(DATABASE_URL) - if database_url.scheme == "mysql": - url = str(database_url.replace(driver="pymysql")) - else: - url = str(database_url) - - engine = sqlalchemy.create_engine(url) - metadata.create_all(engine) + models.create_all() yield - metadata.drop_all(engine) + models.drop_all() + + +@pytest.fixture(autouse=True) +async def rollback_transactions(): + with database.force_rollback(): + async with database: + yield async def test_model_crud(): - async with database: - await Example.objects.create() - - example = await Example.objects.get() - assert example.huge_number == 9223372036854775807 - assert example.created.year == datetime.datetime.now().year - assert example.created_day == datetime.date.today() - assert example.description == "" - assert example.value is None - assert example.price is None - assert example.data == {} - assert example.status == StatusEnum.DRAFT - assert example.uuid is None - - await example.update( - data={"foo": 123}, - value=123.456, - status=StatusEnum.RELEASED, - price=decimal.Decimal("999.99"), - uuid=uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b"), - ) - - example = await Example.objects.get() - assert example.value == 123.456 - assert example.data == {"foo": 123} - assert example.status == StatusEnum.RELEASED - assert example.price == decimal.Decimal("999.99") - assert example.uuid == uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b") + await Example.objects.create() + + example = await Example.objects.get() + assert example.created.year == datetime.datetime.now().year + assert example.created_day == datetime.date.today() + assert example.data == {} + assert example.description == "" + assert example.huge_number == 0 + assert example.price is None + assert example.status == StatusEnum.DRAFT + assert example.uuid is None + assert example.value is None + + await example.update( + data={"foo": 123}, + value=123.456, + status=StatusEnum.RELEASED, + price=decimal.Decimal("999.99"), + uuid=uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b"), + ) + + example = await Example.objects.get() + assert example.value == 123.456 + assert example.data == {"foo": 123} + assert example.status == StatusEnum.RELEASED + assert example.price == decimal.Decimal("999.99") + assert example.uuid == uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b") diff --git a/tests/test_foreignkey.py b/tests/test_foreignkey.py index 83a8a26..4815c33 100644 --- a/tests/test_foreignkey.py +++ b/tests/test_foreignkey.py @@ -1,178 +1,168 @@ import databases import pytest -import sqlalchemy import orm from tests.settings import DATABASE_URL pytestmark = pytest.mark.anyio -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL) +models = orm.ModelRegistry(database=database) class Album(orm.Model): - __tablename__ = "album" - __metadata__ = metadata - __database__ = database - - id = orm.Integer(primary_key=True) - name = orm.String(max_length=100) + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "name": orm.String(max_length=100), + } class Track(orm.Model): - __tablename__ = "track" - __metadata__ = metadata - __database__ = database - - id = orm.Integer(primary_key=True) - album = orm.ForeignKey(Album) - title = orm.String(max_length=100) - position = orm.Integer() + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "album": orm.ForeignKey("Album"), + "title": orm.String(max_length=100), + "position": orm.Integer(), + } class Organisation(orm.Model): - __tablename__ = "org" - __metadata__ = metadata - __database__ = database - - id = orm.Integer(primary_key=True) - ident = orm.String(max_length=100) + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "ident": orm.String(max_length=100), + } class Team(orm.Model): - __tablename__ = "team" - __metadata__ = metadata - __database__ = database - - id = orm.Integer(primary_key=True) - org = orm.ForeignKey(Organisation) - name = orm.String(max_length=100) + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "org": orm.ForeignKey(Organisation), + "name": orm.String(max_length=100), + } class Member(orm.Model): - __tablename__ = "member" - __metadata__ = metadata - __database__ = database - - id = orm.Integer(primary_key=True) - team = orm.ForeignKey(Team) - email = orm.String(max_length=100) + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "team": orm.ForeignKey(Team), + "email": orm.String(max_length=100), + } @pytest.fixture(autouse=True, scope="module") def create_test_database(): - database_url = databases.DatabaseURL(DATABASE_URL) - if database_url.scheme == "mysql": - url = str(database_url.replace(driver="pymysql")) - else: - url = str(database_url) - - engine = sqlalchemy.create_engine(url) - metadata.create_all(engine) + models.create_all() yield - metadata.drop_all(engine) + models.drop_all() + + +@pytest.fixture(autouse=True) +async def rollback_connections(): + with database.force_rollback(): + async with database: + yield async def test_model_crud(): - async with database: - album = await Album.objects.create(name="Malibu") - await Track.objects.create(album=album, title="The Bird", position=1) - await Track.objects.create( - album=album, title="Heart don't stand a chance", position=2 - ) - await Track.objects.create(album=album, title="The Waters", position=3) - - track = await Track.objects.get(title="The Bird") - assert track.album.pk == album.pk - assert not hasattr(track.album, "name") - await track.album.load() - assert track.album.name == "Malibu" + album = await Album.objects.create(name="Malibu") + await Track.objects.create(album=album, title="The Bird", position=1) + await Track.objects.create( + album=album, title="Heart don't stand a chance", position=2 + ) + await Track.objects.create(album=album, title="The Waters", position=3) + + track = await Track.objects.get(title="The Bird") + assert track.album.pk == album.pk + assert not hasattr(track.album, "name") + await track.album.load() + assert track.album.name == "Malibu" async def test_select_related(): - async with database: - album = await Album.objects.create(name="Malibu") - await Track.objects.create(album=album, title="The Bird", position=1) - await Track.objects.create( - album=album, title="Heart don't stand a chance", position=2 - ) - await Track.objects.create(album=album, title="The Waters", position=3) - - fantasies = await Album.objects.create(name="Fantasies") - await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) - await Track.objects.create(album=fantasies, title="Sick Muse", position=2) - await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) - - track = await Track.objects.select_related("album").get(title="The Bird") - assert track.album.name == "Malibu" + album = await Album.objects.create(name="Malibu") + await Track.objects.create(album=album, title="The Bird", position=1) + await Track.objects.create( + album=album, title="Heart don't stand a chance", position=2 + ) + await Track.objects.create(album=album, title="The Waters", position=3) - tracks = await Track.objects.select_related("album").all() - assert len(tracks) == 6 + fantasies = await Album.objects.create(name="Fantasies") + await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + + track = await Track.objects.select_related("album").get(title="The Bird") + assert track.album.name == "Malibu" + + tracks = await Track.objects.select_related("album").all() + assert len(tracks) == 6 async def test_fk_filter(): - async with database: - malibu = await Album.objects.create(name="Malibu") - await Track.objects.create(album=malibu, title="The Bird", position=1) - await Track.objects.create( - album=malibu, title="Heart don't stand a chance", position=2 - ) - await Track.objects.create(album=malibu, title="The Waters", position=3) - - fantasies = await Album.objects.create(name="Fantasies") - await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) - await Track.objects.create(album=fantasies, title="Sick Muse", position=2) - await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) - - tracks = ( - await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() - ) - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Fantasies" - - tracks = ( - await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() - ) - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Fantasies" - - tracks = await Track.objects.filter(album__name__icontains="fan").all() - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Fantasies" - - tracks = await Track.objects.filter(album=malibu).select_related("album").all() - assert len(tracks) == 3 - for track in tracks: - assert track.album.name == "Malibu" + malibu = await Album.objects.create(name="Malibu") + await Track.objects.create(album=malibu, title="The Bird", position=1) + await Track.objects.create( + album=malibu, title="Heart don't stand a chance", position=2 + ) + await Track.objects.create(album=malibu, title="The Waters", position=3) + + fantasies = await Album.objects.create(name="Fantasies") + await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + + tracks = ( + await Track.objects.select_related("album") + .filter(album__name="Fantasies") + .all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = ( + await Track.objects.select_related("album") + .filter(album__name__icontains="fan") + .all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.filter(album__name__icontains="fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.filter(album=malibu).select_related("album").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu" async def test_multiple_fk(): - async with database: - acme = await Organisation.objects.create(ident="ACME Ltd") - red_team = await Team.objects.create(org=acme, name="Red Team") - blue_team = await Team.objects.create(org=acme, name="Blue Team") - await Member.objects.create(team=red_team, email="a@example.org") - await Member.objects.create(team=red_team, email="b@example.org") - await Member.objects.create(team=blue_team, email="c@example.org") - await Member.objects.create(team=blue_team, email="d@example.org") - - other = await Organisation.objects.create(ident="Other ltd") - team = await Team.objects.create(org=other, name="Green Team") - await Member.objects.create(team=team, email="e@example.org") - - members = ( - await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() - ) - assert len(members) == 4 - for member in members: - assert member.team.org.ident == "ACME Ltd" + acme = await Organisation.objects.create(ident="ACME Ltd") + red_team = await Team.objects.create(org=acme, name="Red Team") + blue_team = await Team.objects.create(org=acme, name="Blue Team") + await Member.objects.create(team=red_team, email="a@example.org") + await Member.objects.create(team=red_team, email="b@example.org") + await Member.objects.create(team=blue_team, email="c@example.org") + await Member.objects.create(team=blue_team, email="d@example.org") + + other = await Organisation.objects.create(ident="Other ltd") + team = await Team.objects.create(org=other, name="Green Team") + await Member.objects.create(team=team, email="e@example.org") + + members = ( + await Member.objects.select_related("team__org") + .filter(team__org__ident="ACME Ltd") + .all() + ) + assert len(members) == 4 + for member in members: + assert member.team.org.ident == "ACME Ltd" diff --git a/tests/test_models.py b/tests/test_models.py index ade2678..a8a72eb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,251 +1,256 @@ import databases import pytest -import sqlalchemy +import typesystem import orm from tests.settings import DATABASE_URL pytestmark = pytest.mark.anyio -database = databases.Database(DATABASE_URL, force_rollback=True) -metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL) +models = orm.ModelRegistry(database=database) class User(orm.Model): - __tablename__ = "users" - __metadata__ = metadata - __database__ = database - - id = orm.Integer(primary_key=True) - name = orm.String(max_length=100) + tablename = "users" + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "name": orm.String(max_length=100), + "language": orm.String(max_length=100, allow_null=True), + } class Product(orm.Model): - __tablename__ = "product" - __metadata__ = metadata - __database__ = database - - id = orm.Integer(primary_key=True) - name = orm.String(max_length=100) - rating = orm.Integer(minimum=1, maximum=5) - in_stock = orm.Boolean(default=False) + tablename = "products" + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "name": orm.String(max_length=100), + "rating": orm.Integer(minimum=1, maximum=5), + "in_stock": orm.Boolean(default=False), + } -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture(autouse=True, scope="function") def create_test_database(): - database_url = databases.DatabaseURL(DATABASE_URL) - if database_url.scheme == "mysql": - url = str(database_url.replace(driver="pymysql")) - else: - url = str(database_url) - - engine = sqlalchemy.create_engine(url) - metadata.create_all(engine) + models.create_all() yield - metadata.drop_all(engine) + models.drop_all() + + +@pytest.fixture(autouse=True) +async def rollback_connections(): + with database.force_rollback(): + async with database: + yield def test_model_class(): - assert list(User.fields.keys()) == ["id", "name"] + assert list(User.fields.keys()) == ["id", "name", "language"] assert isinstance(User.fields["id"], orm.Integer) assert User.fields["id"].primary_key is True assert isinstance(User.fields["name"], orm.String) - assert User.fields["name"].max_length == 100 - assert isinstance(User.__table__, sqlalchemy.Table) + assert User.fields["name"].validator.max_length == 100 + + with pytest.raises(ValueError): + User(invalid="123") + + assert User(id=1) != Product(id=1) + assert User(id=1) != User(id=2) + assert User(id=1) == User(id=1) + + assert isinstance(User.objects.schema.fields["id"], typesystem.Integer) + assert isinstance(User.objects.schema.fields["name"], typesystem.String) def test_model_pk(): user = User(pk=1) assert user.pk == 1 assert user.id == 1 + assert User.objects.pkname == "id" async def test_model_crud(): - async with database: - users = await User.objects.all() - assert users == [] + users = await User.objects.all() + assert users == [] - user = await User.objects.create(name="Tom") - users = await User.objects.all() - assert user.name == "Tom" - assert user.pk is not None - assert users == [user] + user = await User.objects.create(name="Tom") + users = await User.objects.all() + assert user.name == "Tom" + assert user.pk is not None + assert users == [user] - lookup = await User.objects.get() - assert lookup == user + lookup = await User.objects.get() + assert lookup == user - await user.update(name="Jane") - users = await User.objects.all() - assert user.name == "Jane" - assert user.pk is not None - assert users == [user] + await user.update(name="Jane") + users = await User.objects.all() + assert user.name == "Jane" + assert user.pk is not None + assert users == [user] - await user.delete() - users = await User.objects.all() - assert users == [] + await user.delete() + users = await User.objects.all() + assert users == [] async def test_model_get(): - async with database: - with pytest.raises(orm.NoMatch): - await User.objects.get() + with pytest.raises(orm.NoMatch): + await User.objects.get() - user = await User.objects.create(name="Tom") - lookup = await User.objects.get() - assert lookup == user + user = await User.objects.create(name="Tom") + lookup = await User.objects.get() + assert lookup == user - user = await User.objects.create(name="Jane") - with pytest.raises(orm.MultipleMatches): - await User.objects.get() + user = await User.objects.create(name="Jane") + with pytest.raises(orm.MultipleMatches): + await User.objects.get() - same_user = await User.objects.get(pk=user.id) - assert same_user.id == user.id - assert same_user.pk == user.pk + same_user = await User.objects.get(pk=user.id) + assert same_user.id == user.id + assert same_user.pk == user.pk async def test_model_filter(): - async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") - await User.objects.create(name="Lucy") - - user = await User.objects.get(name="Lucy") - assert user.name == "Lucy" - - with pytest.raises(orm.NoMatch): - await User.objects.get(name="Jim") + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") - await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) - await Product.objects.create(name="Dress", rating=4) - await Product.objects.create(name="Coat", rating=3, in_stock=True) + user = await User.objects.get(name="Lucy") + assert user.name == "Lucy" - product = await Product.objects.get(name__iexact="t-shirt", rating=5) - assert product.pk is not None - assert product.name == "T-Shirt" - assert product.rating == 5 + with pytest.raises(orm.NoMatch): + await User.objects.get(name="Jim") - products = await Product.objects.all(rating__gte=2, in_stock=True) - assert len(products) == 2 + await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) + await Product.objects.create(name="Dress", rating=4) + await Product.objects.create(name="Coat", rating=3, in_stock=True) - products = await Product.objects.exclude(rating__gte=4, in_stock=True).all() - assert len(products) == 2 + product = await Product.objects.get(name__iexact="t-shirt", rating=5) + assert product.pk is not None + assert product.name == "T-Shirt" + assert product.rating == 5 - products = await Product.objects.exclude(in_stock=True).all() - assert len(products) == 1 + products = await Product.objects.all(rating__gte=2, in_stock=True) + assert len(products) == 2 - products = await Product.objects.all(name__icontains="T") - assert len(products) == 2 + products = await Product.objects.all(name__icontains="T") + assert len(products) == 2 - products = await Product.objects.exclude(name__icontains="T").all() - assert len(products) == 1 + # Test escaping % character from icontains, contains, and iexact + await Product.objects.create(name="100%-Cotton", rating=3) + await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) + await Product.objects.create(name="Cotton-100%", rating=3) + products = Product.objects.filter(name__iexact="100%-cotton") + assert await products.count() == 1 - # Test escaping % character from icontains, contains, and iexact - await Product.objects.create(name="100%-Cotton", rating=3) - await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) - await Product.objects.create(name="Cotton-100%", rating=3) - products = Product.objects.filter(name__iexact="100%-cotton") - assert await products.count() == 1 + products = Product.objects.filter(name__contains="%") + assert await products.count() == 3 - products = Product.objects.exclude(name__iexact="100%-cotton") - assert await products.count() == 5 + products = Product.objects.filter(name__icontains="%") + assert await products.count() == 3 - products = Product.objects.filter(name__contains="%") - assert await products.count() == 3 + products = Product.objects.exclude(name__iexact="100%-cotton") + assert await products.count() == 5 - products = Product.objects.exclude(name__contains="%") - assert await products.count() == 3 + products = Product.objects.exclude(name__contains="%") + assert await products.count() == 3 - products = Product.objects.filter(name__icontains="%") - assert await products.count() == 3 - - products = Product.objects.exclude(name__icontains="%") - assert await products.count() == 3 + products = Product.objects.exclude(name__icontains="%") + assert await products.count() == 3 async def test_model_order_by(): - async with database: - await User.objects.create(name="Bob") - await User.objects.create(name="Allen") - await User.objects.create(name="Bob") + await User.objects.create(name="Bob") + await User.objects.create(name="Allen") + await User.objects.create(name="Bob") - users = await User.objects.order_by("name").all() - assert users[0].name == "Allen" - assert users[1].name == "Bob" + users = await User.objects.order_by("name").all() + assert users[0].name == "Allen" + assert users[1].name == "Bob" - users = await User.objects.order_by("-name").all() - assert users[1].name == "Bob" - assert users[2].name == "Allen" + users = await User.objects.order_by("-name").all() + assert users[1].name == "Bob" + assert users[2].name == "Allen" - users = await User.objects.order_by("name", "-id").all() - assert users[0].name == "Allen" - assert users[1].name == "Bob" - assert users[0].id < users[1].id + users = await User.objects.order_by("name", "-id").all() + assert users[0].name == "Allen" + assert users[0].id == 2 + assert users[1].name == "Bob" + assert users[1].id == 3 - users = await User.objects.filter(name="Bob").order_by("-id").all() - assert users[0].name == "Bob" - assert users[1].name == "Bob" - assert users[0].id > users[1].id + users = await User.objects.filter(name="Bob").order_by("-id").all() + assert users[0].name == "Bob" + assert users[0].id == 3 + assert users[1].name == "Bob" + assert users[1].id == 1 - users = await User.objects.order_by("id").limit(1).all() - assert users[0].name == "Bob" + users = await User.objects.order_by("id").limit(1).all() + assert users[0].name == "Bob" + assert users[0].id == 1 - users = await User.objects.order_by("id").limit(1).offset(1).all() - assert users[0].name == "Allen" + users = await User.objects.order_by("id").limit(1).offset(1).all() + assert users[0].name == "Allen" + assert users[0].id == 2 async def test_model_exists(): - async with database: - await User.objects.create(name="Tom") - assert await User.objects.filter(name="Tom").exists() is True - assert await User.objects.filter(name="Jane").exists() is False + await User.objects.create(name="Tom") + assert await User.objects.filter(name="Tom").exists() is True + assert await User.objects.filter(name="Jane").exists() is False async def test_model_count(): - async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") - await User.objects.create(name="Lucy") + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") - assert await User.objects.count() == 3 - assert await User.objects.filter(name__icontains="T").count() == 1 + assert await User.objects.count() == 3 + assert await User.objects.filter(name__icontains="T").count() == 1 async def test_model_limit(): - async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") - await User.objects.create(name="Lucy") + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") - assert len(await User.objects.limit(2).all()) == 2 + assert len(await User.objects.limit(2).all()) == 2 async def test_model_limit_with_filter(): - async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Tom") - await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") - assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 + assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 async def test_offset(): - async with database: - await User.objects.create(name="Tom") - await User.objects.create(name="Jane") + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") - users = await User.objects.offset(1).limit(1).all() - assert users[0].name == "Jane" + users = await User.objects.offset(1).limit(1).all() + assert users[0].name == "Jane" async def test_model_first(): - async with database: - tom = await User.objects.create(name="Tom") - jane = await User.objects.create(name="Jane") + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") + + assert await User.objects.first() == tom + assert await User.objects.first(name="Jane") == jane + assert await User.objects.filter(name="Jane").first() == jane + assert await User.objects.filter(name="Lucy").first() is None + + +async def test_model_search(): + tom = await User.objects.create(name="Tom", language="English") + tshirt = await Product.objects.create(name="T-Shirt", rating=5) - assert await User.objects.first() == tom - assert await User.objects.first(name="Jane") == jane - assert await User.objects.filter(name="Jane").first() == jane - assert await User.objects.filter(name="Lucy").first() is None + assert await User.objects.search(term="").first() == tom + assert await User.objects.search(term="tom").first() == tom + assert await Product.objects.search(term="shirt").first() == tshirt async def test_model_get_or_create():