Skip to content

Commit

Permalink
feat: much faster file name resolution using ListFiles
Browse files Browse the repository at this point in the history
  • Loading branch information
sg-s committed Aug 30, 2024
1 parent 5ddf8a0 commit 07256ac
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 76 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
"tabulate",
"filetype",
"httpx",
"deeporigin-data-sdk==0.1.0a7",
"deeporigin-data-sdk==0.1.0a8",
"humanize",
"packaging",
]
Expand Down
125 changes: 50 additions & 75 deletions src/data_hub/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,14 @@ def download_database(
df.to_csv(os.path.join(destination, database_hid + ".csv"))


def _replace_with_mapper(item, mapper: dict):
"""utility function to replace items in a nested list with values generated by a mapper (a dict)."""

if isinstance(item, list):
return [_replace_with_mapper(sub_item, mapper) for sub_item in item]
return mapper.get(item, item)


@beartype
def get_dataframe(
database_id: str,
Expand All @@ -817,25 +825,29 @@ def get_dataframe(
return_type: Whether to return a `pandas.Dataframe` or a dictionary.
"""

# TODO: list_database_rows and describe_row can be called in parallel

# figure out the rows
rows = _api.list_database_rows(
database_row_id=database_id,
client=client,
)

# filter out template rows
rows = [
row for row in rows if not (hasattr(row, "is_template") and row.is_template)
]

# figure out the column names and ID of the database
db_row = _api.describe_row(
row_id=database_id,
client=client,
)
assert (
db_row.type == "database"
), f"Expected database_id: {database_id} to resolve to a database, but instead, it resolved to a {db_row.type}"

# filter out template rows
rows = [
row for row in rows if not (hasattr(row, "is_template") and row.is_template)
]

if db_row.type != "database":
raise DeepOriginException(
f"Expected database_id: {database_id} to resolve to a database, but instead, it resolved to a {db_row.type}"
)

columns = db_row.cols
database_id = db_row.id
Expand All @@ -846,8 +858,8 @@ def get_dataframe(
data["Validation Status"] = []

# keep track of all files and references in this database
file_ids = []
reference_ids = []
file_ids = []

# remove notebook columns because they are not
# shown in the UI as columns
Expand All @@ -862,12 +874,33 @@ def get_dataframe(
data[column["id"]] = []

for row in rows:
# warning: add_row_to_data mutates file_ids
add_row_to_data(
data=data,
row=row,
columns=columns,
use_file_names=use_file_names,
file_ids=file_ids,
)

# make a dict that maps from file IDs to file names
if use_file_names:
file_id_mapper = dict()

# determine file name for every file ID in the dataframe
files = _api.list_files(
filters=[dict(fileIds=file_ids)],
client=client,
)
for file in files:
file_id_mapper[file.file.id] = file.file.name

for column in columns:
if column["type"] == "file":
inputs = data[column["id"]]

data[column["id"]] = [
_replace_with_mapper(item, file_id_mapper) for item in inputs
]

if return_type == "dataframe":
# make the dataframe
Expand All @@ -877,7 +910,6 @@ def get_dataframe(
from deeporigin.data_hub.dataframe import DataFrame

df = DataFrame(data)
df.attrs["file_ids"] = list(set(file_ids))
df.attrs["reference_ids"] = list(set(reference_ids))
df.attrs["id"] = database_id
df.attrs["metadata"] = dict(db_row)
Expand All @@ -904,65 +936,15 @@ def get_dataframe(
return renamed_data


@beartype
def _parse_column_value(
def add_row_to_data(
*,
column: dict,
fields: Optional[list],
data: dict,
row,
columns: list,
file_ids: list,
reference_ids: list,
use_file_names: bool,
reference_format: IDFormat,
):
"""Internal function parse column values
Warning: Internal function
Do not use this function.
"""

if fields is None:
return None

field = [field for field in fields if field.column_id == column["id"]]

if len(field) == 0:
return None

if not hasattr(field[0], "value"):
return None
value = [field[0].value]

# special treatment for some column types
if column["type"] == "select" and len(value) == 1 and value[0] is not None:
value = value[0].selected_options
elif column["type"] == "file" and len(value) == 1 and value[0] is not None:
value = value[0].file_ids

file_ids.extend(value)

if use_file_names:
try:
value = [_api.describe_file(file_id=file_id).name for file_id in value]
except DeepOriginException:
# something went wrong, ignore
pass
elif column["type"] == "reference" and len(value) == 1 and value[0] is not None:
value = value[0].row_ids
reference_ids.extend(value)

if reference_format == "human-id":
value = convert_id_format(ids=value)
value = [thing.hid for thing in value]

if len(value) == 0:
value = None

return value


def add_row_to_data(*, data: dict, row, columns: list, use_file_names: bool = True):
"""utility function to combine data from a row into a dataframe"""
row_data = _row_to_dict(row, use_file_names=use_file_names)
row_data = _row_to_dict(row, file_ids=file_ids)
data["ID"].append(row_data["ID"])
data["Validation Status"].append(row_data["Validation Status"])

Expand All @@ -977,7 +959,7 @@ def add_row_to_data(*, data: dict, row, columns: list, use_file_names: bool = Tr
data[col_id].append(None)


def _row_to_dict(row, *, use_file_names: bool = True):
def _row_to_dict(row, *, file_ids: list):
"""utility function to convert a row to a dictionary"""
fields = row.fields

Expand All @@ -999,14 +981,7 @@ def _row_to_dict(row, *, use_file_names: bool = True):

elif field.type == "file":
value = field.value.file_ids
if use_file_names and value is not None:
try:
value = [
_api.describe_file(file_id=file_id).name for file_id in value
]
except DeepOriginException:
# something went wrong, ignore
pass
file_ids.extend(value)
else:
value = field.value
values[field.column_id] = value
Expand Down

0 comments on commit 07256ac

Please sign in to comment.