Skip to content

Commit

Permalink
🐛 FIX distinct queries (#5654)
Browse files Browse the repository at this point in the history
This commit principally fixes issues with distinct queries.
Since refactoring for sqlalchemy 1.4,
the query has been initialised with a starting "dummy" projection from which to join.

```python
query = session.query(starting_table.id)
```

The returned projection was then removed.
This is problematic, though, when requesting distinct result rows,
because the dummy projection is also used to calculate uniqueness.

This has now been changed to use the `select_from` method:
https://docs.sqlalchemy.org/en/14/orm/query.html#sqlalchemy.orm.Query.select_from,
such that now we can initialise without any projection.

```python
query = session.query().select_from(starting_table)
```

The backend QueryBuilder code is also refactored,
principally to make the `SqlaQueryBuilder._build` logic more understandable.

Cherry-pick: 9fa2d88
  • Loading branch information
chrisjsewell authored and sphuber committed Sep 23, 2022
1 parent 8b55a2a commit 7a4532e
Show file tree
Hide file tree
Showing 8 changed files with 614 additions and 531 deletions.
124 changes: 63 additions & 61 deletions aiida/storage/psql_dos/orm/querybuilder/joiner.py

Large diffs are not rendered by default.

946 changes: 498 additions & 448 deletions aiida/storage/psql_dos/orm/querybuilder/main.py

Large diffs are not rendered by default.

34 changes: 18 additions & 16 deletions aiida/storage/sqlite_zip/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
"""
from functools import singledispatch
import json
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union

from sqlalchemy import JSON, case, func
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import ColumnElement

from aiida.common.lang import type_check
Expand All @@ -30,6 +31,7 @@
QueryableAttribute,
SqlaQueryBuilder,
String,
get_column,
)

from . import models
Expand Down Expand Up @@ -188,12 +190,18 @@ def Log(self):
def table_groups_nodes(self):
return models.DbGroupNodes.__table__ # type: ignore[attr-defined] # pylint: disable=no-member

def get_projectable_attribute(
self, alias, column_name: str, attrpath: List[str], cast: Optional[str] = None
) -> ColumnElement:
"""Return an attribute store in a JSON field of the give column"""
# pylint: disable=unused-argument
entity = self.get_column(column_name, alias)[attrpath]
@staticmethod
def _get_projectable_entity(
alias: AliasedClass,
column_name: str,
attrpath: List[str],
cast: Optional[str] = None,
) -> Union[ColumnElement, InstrumentedAttribute]:

if not (attrpath or column_name in ('attributes', 'extras')):
return get_column(column_name, alias)

entity = get_column(column_name, alias)[attrpath]
if cast is None:
pass
elif cast == 'f':
Expand All @@ -212,15 +220,16 @@ def get_projectable_attribute(
raise ValueError(f'Unknown casting key {cast}')
return entity

@staticmethod
def get_filter_expr_from_jsonb( # pylint: disable=too-many-return-statements,too-many-branches
self, operator: str, value, attr_key: List[str], column=None, column_name=None, alias=None
operator: str, value, attr_key: List[str], column=None, column_name=None, alias=None
):
"""Return a filter expression.
See: https://www.sqlite.org/json1.html
"""
if column is None:
column = self.get_column(column_name, alias)
column = get_column(column_name, alias)

query_str = f'{alias or ""}.{column_name or ""}.{attr_key} {operator} {value}'

Expand Down Expand Up @@ -325,13 +334,6 @@ def _cast_json_type(comparator: JSON.Comparator, value: Any) -> Tuple[ColumnElem

@staticmethod
def get_filter_expr_from_column(operator: str, value: Any, column) -> BinaryExpression:
"""A method that returns an valid SQLAlchemy expression.
:param operator: The operator provided by the user ('==', '>', ...)
:param value: The value to compare with, e.g. (5.0, 'foo', ['a','b'])
:param column: an instance of sqlalchemy.orm.attributes.InstrumentedAttribute or
"""
# Label is used because it is what is returned for the
# 'state' column by the hybrid_column construct
if not isinstance(column, (Cast, InstrumentedAttribute, QueryableAttribute, Label, ColumnClause)):
Expand Down
1 change: 1 addition & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ py:obj typing.AbstractContextManager
py:class callable
py:class function
py:class traceback
py:class NoneType
py:class AbstractContextManager
py:class BinaryIO
py:class IO
Expand Down
34 changes: 31 additions & 3 deletions tests/orm/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def test_simple_query_2(self):
assert qb.count() == 1

# Test the hashing:
query1 = qb._impl._update_query(qb.as_dict()) # pylint: disable=protected-access
query1 = qb._impl.get_query(qb.as_dict()) # pylint: disable=protected-access
qb.add_filter('n2', {'label': 'nonexistentlabel'})
assert qb.count() == 0

Expand All @@ -381,8 +381,8 @@ def test_simple_query_2(self):
with pytest.raises(MultipleObjectsError):
orm.QueryBuilder().append(orm.Node).one()

query2 = qb._impl._update_query(qb.as_dict()) # pylint: disable=protected-access
query3 = qb._impl._update_query(qb.as_dict()) # pylint: disable=protected-access
query2 = qb._impl.get_query(qb.as_dict()) # pylint: disable=protected-access
query3 = qb._impl.get_query(qb.as_dict()) # pylint: disable=protected-access

assert id(query1) != id(query2)
assert id(query2) == id(query3)
Expand Down Expand Up @@ -1199,6 +1199,34 @@ def test_joins_group_node(self):
for curr_id in [n1.pk, n2.pk, n3.pk, n4.pk]:
assert curr_id in id_res

def test_joins_group_node_distinct(self):
"""Test that when protecting only the group for a join on nodes, only unique groups are returned.
Regression test for #5535
"""
group = orm.Group(label='mygroup').store()
node_a = orm.Data().store()
node_b = orm.Data().store()
group.add_nodes([node_a, node_b])

# First join the group on the data
query = orm.QueryBuilder()
query.append(orm.Group, project='id', tag='group')
query.append(orm.Data, with_group='group')
query.distinct()

assert query.count() == 1
assert query.all(flat=True) == [group.pk]

# Then reverse and join the data on the group
query = orm.QueryBuilder()
query.append(orm.Data, tag='node')
query.append(orm.Group, with_node='node', project='uuid')
query.distinct()

assert query.all(flat=True) == [group.uuid]
assert query.count() == 1


class QueryBuilderPath:

Expand Down
2 changes: 1 addition & 1 deletion tests/orm/test_querybuilder/test_as_sql.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
'SELECT db_dbnode_1.id, db_dbnode_1.uuid \nFROM db_dbnode AS db_dbnode_1 \nWHERE CAST(db_dbnode_1.node_type AS VARCHAR) LIKE %(param_1)s AND CASE WHEN (jsonb_typeof((db_dbnode_1.extras #> %(extras_1)s)) = %(jsonb_typeof_1)s) THEN (db_dbnode_1.extras #>> %(extras_1)s) = %(param_2)s ELSE %(param_3)s END' % {'param_1': '%', 'extras_1': ('tag4',), 'jsonb_typeof_1': 'string', 'param_2': 'appl_pecoal', 'param_3': False}
'SELECT db_dbnode_1.uuid \nFROM db_dbnode AS db_dbnode_1 \nWHERE CAST(db_dbnode_1.node_type AS VARCHAR) LIKE %(param_1)s AND CASE WHEN (jsonb_typeof((db_dbnode_1.extras #> %(extras_1)s)) = %(jsonb_typeof_1)s) THEN (db_dbnode_1.extras #>> %(extras_1)s) = %(param_2)s ELSE %(param_3)s END' % {'param_1': '%', 'extras_1': ('tag4',), 'jsonb_typeof_1': 'string', 'param_2': 'appl_pecoal', 'param_3': False}
2 changes: 1 addition & 1 deletion tests/orm/test_querybuilder/test_as_sql_inline.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT db_dbnode_1.id, db_dbnode_1.uuid
SELECT db_dbnode_1.uuid
FROM db_dbnode AS db_dbnode_1
WHERE CAST(db_dbnode_1.node_type AS VARCHAR) LIKE '%%' AND CASE WHEN (jsonb_typeof((db_dbnode_1.extras #> '{tag4}')) = 'string') THEN (db_dbnode_1.extras #>> '{tag4}') = 'appl_pecoal' ELSE false END
2 changes: 1 addition & 1 deletion tests/orm/test_querybuilder/test_as_sql_literal_quote.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT db_dbnode_1.id, db_dbnode_1.uuid
SELECT db_dbnode_1.uuid
FROM db_dbnode AS db_dbnode_1
WHERE CAST(db_dbnode_1.node_type AS VARCHAR) LIKE 'data.core.structure.%%' AND CAST((db_dbnode_1.extras #> '{elements}') AS JSONB) @> '["Si"]'

0 comments on commit 7a4532e

Please sign in to comment.