Skip to content

Commit

Permalink
feat: add query parameter for allowing datasets
Browse files Browse the repository at this point in the history
add include in search
  • Loading branch information
SimonThordal committed Apr 16, 2024
1 parent 22dbb63 commit 538b0a0
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 2 deletions.
38 changes: 38 additions & 0 deletions tests/unit/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,44 @@ def test_match_exclude_dataset():
assert len(res["results"]) == 0, res


def test_match_include_dataset():
# When querying Putin
query = {"queries": {"vv": EXAMPLE}}
# Using only datasets that do not include Putin
params = {
"algorithm": "name-based",
"include_dataset": ["ae_local_terrorists", "mx_governors"],
}
resp = client.post("/match/default", json=query, params=params)
# We should get a succesful response
assert resp.status_code == 200, resp.text
data = resp.json()
res = data["responses"]["vv"]
# And we should get no matches
assert len(res["results"]) == 0, res
# When using a dataset that includes Putin
params = {
"algorithm": "name-based",
"include_dataset": ["eu_fsf", "ae_local_terrorists"],
}
resp = client.post("/match/default", json=query, params=params)
data = resp.json()
res = data["responses"]["vv"]
# And we should get matches
assert len(res["results"]) > 0, res
# When we exclude the eu_fsf dataset
params = {
"algorithm": "name-based",
"include_dataset": ["eu_fsf", "mx_governors", "ae_local_terrorists"],
"exclude_dataset": "eu_fsf",
}
# We should get no matches
resp = client.post("/match/default", json=query, params=params)
data = resp.json()
res = data["responses"]["vv"]
assert len(res["results"]) == 0, res


def test_filter_topic():
query = {"queries": {"vv": EXAMPLE}}
params = {"algorithm": "name-based", "topics": "crime.cyber"}
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,29 @@ def test_search_filter_exclude_dataset():
assert new_total == 0


def test_search_filter_include_dataset():
res = client.get("/search/default?q=vladimir putin")
assert res.status_code == 200, res
total = res.json()["total"]["value"]
assert total > 0, total
# When we include a dataset that does not contain Putin or is not available
# in the collection we should get no results
res = client.get("/search/default?q=vladimir putin&include_dataset=mx_senators")
assert res.status_code == 200, res
new_total = res.json()["total"]["value"]
assert new_total == 0
# When we include a dataset that contains Putin we should get results
res = client.get("/search/default?q=vladimir putin&include_dataset=eu_fsf")
new_total = res.json()["total"]["value"]
assert new_total > 0
# When using both include and exclude, the exclude should take precedence
res = client.get(
"/search/default?q=vladimir putin&include_dataset=eu_fsf&exclude_dataset=eu_fsf"
)
new_total = res.json()["total"]["value"]
assert new_total == 0


def test_search_filter_changed_since():
ts = datetime.utcnow() + timedelta(days=1)
tx = ts.isoformat(sep="T", timespec="minutes")
Expand Down
4 changes: 4 additions & 0 deletions yente/routers/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ async def match(
title="Lower bound of score for results to be returned at all",
),
algorithm: str = Query(settings.DEFAULT_ALGORITHM, title=ALGO_HELP),
include_dataset: List[str] = Query(
[], title="Only include the given datasets in results"
),
exclude_schema: List[str] = Query(
[], title="Remove the given types of entities from results"
),
Expand Down Expand Up @@ -144,6 +147,7 @@ async def match(
entity,
filters=filters,
fuzzy=fuzzy,
include_dataset=include_dataset,
exclude_schema=exclude_schema,
exclude_dataset=exclude_dataset,
changed_since=changed_since,
Expand Down
4 changes: 4 additions & 0 deletions yente/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ async def search(
schema: str = Query(
settings.BASE_SCHEMA, title="Types of entities that can match the search"
),
include_dataset: List[str] = Query(
[], title="Only include the given datasets in results"
),
exclude_schema: List[str] = Query(
[], title="Remove the given types of entities from results"
),
Expand Down Expand Up @@ -109,6 +112,7 @@ async def search(
filters=filters,
fuzzy=fuzzy,
simple=simple,
include_dataset=include_dataset,
exclude_schema=exclude_schema,
exclude_dataset=exclude_dataset,
changed_since=changed_since,
Expand Down
14 changes: 12 additions & 2 deletions yente/search/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@ def filter_query(
dataset: Optional[Dataset] = None,
schema: Optional[Schema] = None,
filters: FilterDict = {},
include_dataset: List[str] = [],
exclude_schema: List[str] = [],
exclude_dataset: List[str] = [],
changed_since: Optional[str] = None,
) -> Clause:
filterqs: List[Clause] = []
if dataset is not None:
ds = [d for d in dataset.dataset_names if d not in exclude_dataset]
ds = [
d
for d in dataset.dataset_names
if (len(include_dataset) == 0 or d in include_dataset)
and d not in exclude_dataset
]
filterqs.append({"terms": {"datasets": ds}})
if schema is not None:
schemata = schema.matchable_schemata
Expand Down Expand Up @@ -76,7 +82,7 @@ def names_query(entity: EntityProxy, fuzzy: bool = True) -> List[Clause]:
term = {NAME_KEY_FIELD: {"value": key, "boost": 4.0}}
shoulds.append({"term": term})
for token in set(index_name_parts(names)):
term = {NAME_PART_FIELD: {"value": token, 'boost': 1.0}}
term = {NAME_PART_FIELD: {"value": token, "boost": 1.0}}
shoulds.append({"term": term})
for phoneme in set(phonetic_names(names)):
term = {NAME_PHONETIC_FIELD: {"value": phoneme, "boost": 0.8}}
Expand All @@ -89,6 +95,7 @@ def entity_query(
entity: EntityProxy,
filters: FilterDict = {},
fuzzy: bool = True,
include_dataset: List[str] = [],
exclude_schema: List[str] = [],
exclude_dataset: List[str] = [],
changed_since: Optional[str] = None,
Expand All @@ -110,6 +117,7 @@ def entity_query(
filters=filters,
dataset=dataset,
schema=entity.schema,
include_dataset=include_dataset,
exclude_schema=exclude_schema,
exclude_dataset=exclude_dataset,
changed_since=changed_since,
Expand All @@ -123,6 +131,7 @@ def text_query(
filters: FilterDict = {},
fuzzy: bool = False,
simple: bool = False,
include_dataset: List[str] = [],
exclude_schema: List[str] = [],
exclude_dataset: List[str] = [],
changed_since: Optional[str] = None,
Expand Down Expand Up @@ -156,6 +165,7 @@ def text_query(
dataset=dataset,
schema=schema,
filters=filters,
include_dataset=include_dataset,
exclude_schema=exclude_schema,
exclude_dataset=exclude_dataset,
changed_since=changed_since,
Expand Down

0 comments on commit 538b0a0

Please sign in to comment.