Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 15, 2024
1 parent d9fc7de commit 3434e0a
Show file tree
Hide file tree
Showing 16 changed files with 212 additions and 66 deletions.
10 changes: 7 additions & 3 deletions concatenator/attribute_handling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Functions for converting "coordinates" in netCDF variable attributes
between paths that reference a group hierarchy and flattened paths.
between paths that reference a group hierarchy and flattened paths.
"""

import json
Expand All @@ -12,9 +12,13 @@
import concatenator

# Values needed for history_json attribute
HISTORY_JSON_SCHEMA = "https://harmony.earthdata.nasa.gov/schemas/history/0.1.0/history-v0.1.0.json"
HISTORY_JSON_SCHEMA = (
"https://harmony.earthdata.nasa.gov/schemas/history/0.1.0/history-v0.1.0.json"
)
PROGRAM = "stitchee"
PROGRAM_REF = "https://cmr.earthdata.nasa.gov:443/search/concepts/S2940253910-LARC_CLOUD"
PROGRAM_REF = (
"https://cmr.earthdata.nasa.gov:443/search/concepts/S2940253910-LARC_CLOUD"
)
VERSION = importlib_metadata.distribution("stitchee").version


Expand Down
32 changes: 23 additions & 9 deletions concatenator/dataset_and_group_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def walk(
# Flatten the paths of variables referenced in the 'coordinates' attribute
flatten_coordinate_attribute_paths(new_dataset, var, var_group_name)

if (len(var.dimensions) == 1) and _string_dimension_name_pattern.fullmatch(
var.dimensions[0]
):
if (
len(var.dimensions) == 1
) and _string_dimension_name_pattern.fullmatch(var.dimensions[0]):
list_of_character_string_vars.append(var_group_name)

# Delete variables
Expand Down Expand Up @@ -145,7 +145,9 @@ def flatten_grouped_dataset(
if ensure_all_dims_are_coords and (
new_dim_name not in list(nc_dataset.variables.keys())
):
nc_dataset.createVariable(dim.name, datatype=np.int32, dimensions=(dim.name,))
nc_dataset.createVariable(
dim.name, datatype=np.int32, dimensions=(dim.name,)
)
temporary_coordinate_variables.append(dim.name)

list_of_character_string_vars: list[str] = []
Expand Down Expand Up @@ -185,7 +187,9 @@ def regroup_flattened_dataset(
group_lst = []
# need logic if there is data in the top level not in a group
for var_name, _ in dataset.variables.items():
group_lst.append("/".join(str(var_name).split(concatenator.group_delim)[:-1]))
group_lst.append(
"/".join(str(var_name).split(concatenator.group_delim)[:-1])
)
group_lst = ["/" if group == "" else group for group in group_lst]
groups = set(group_lst)
for group in groups:
Expand Down Expand Up @@ -226,9 +230,13 @@ def regroup_flattened_dataset(
new_var_dims = tuple(
str(d).rsplit(concatenator.group_delim, 1)[-1] for d in var.dims
)
dim_sizes = [_get_dimension_size(base_dataset, dim) for dim in new_var_dims]
dim_sizes = [
_get_dimension_size(base_dataset, dim) for dim in new_var_dims
]

chunk_sizes = _calculate_chunks(dim_sizes, default_low_dim_chunksize=4000)
chunk_sizes = _calculate_chunks(
dim_sizes, default_low_dim_chunksize=4000
)

# Do the variable creation
if var.dtype == "O":
Expand All @@ -237,7 +245,11 @@ def regroup_flattened_dataset(
vartype = str(var.dtype)

compression: str | None = "zlib"
if vartype.startswith("<U") and len(var.shape) == 1 and var.shape[0] < 10:
if (
vartype.startswith("<U")
and len(var.shape) == 1
and var.shape[0] < 10
):
compression = None

var_group.createVariable(
Expand Down Expand Up @@ -269,7 +281,9 @@ def regroup_flattened_dataset(

def _get_nested_group(dataset: nc.Dataset, group_path: str) -> nc.Group:
nested_group = dataset
for group in group_path.strip(concatenator.group_delim).split(concatenator.group_delim)[:-1]:
for group in group_path.strip(concatenator.group_delim).split(
concatenator.group_delim
)[:-1]:
nested_group = nested_group.groups[group]
return nested_group

Expand Down
35 changes: 25 additions & 10 deletions concatenator/dimension_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def remove_duplicate_dims(nc_dataset: nc.Dataset) -> nc.Dataset:

for var_name, var in nc_dataset.variables.items():
dim_list = list(var.dimensions)
if len(set(dim_list)) != len(dim_list): # get true if var.dimensions has a duplicate
dup_vars[var_name] = var # populate dictionary with variables with vars with dup dims
if len(set(dim_list)) != len(
dim_list
): # get true if var.dimensions has a duplicate
dup_vars[var_name] = (
var # populate dictionary with variables with vars with dup dims
)

for dup_var_name, dup_var in dup_vars.items():
dim_list = list(
Expand All @@ -35,7 +39,9 @@ def remove_duplicate_dims(nc_dataset: nc.Dataset) -> nc.Dataset:

# Dimension(s) that are duplicated are retrieved.
# Note: this is not yet tested for more than one duplicated dimension.
dim_dup = [item for item, count in collections.Counter(dim_list).items() if count > 1][0]
dim_dup = [
item for item, count in collections.Counter(dim_list).items() if count > 1
][0]
dim_dup_length = dup_var.shape[
dup_var.dimensions.index(dim_dup)
] # length of the duplicated dimension
Expand All @@ -53,7 +59,9 @@ def remove_duplicate_dims(nc_dataset: nc.Dataset) -> nc.Dataset:

# Attributes for the original variable are retrieved.
attrs_contents = get_attributes_minus_fillvalue_and_renamed_coords(
original_var_name=dup_var_name, new_var_name=dim_dup_new, original_dataset=nc_dataset
original_var_name=dup_var_name,
new_var_name=dim_dup_new,
original_dataset=nc_dataset,
)
# for attrname in dup_var.ncattrs():
# if attrname != '_FillValue':
Expand All @@ -67,22 +75,24 @@ def remove_duplicate_dims(nc_dataset: nc.Dataset) -> nc.Dataset:

# Only create a new *Dimension* if it doesn't already exist.
if dim_dup_new not in nc_dataset.dimensions.keys():

# New dimension is created by copying from the duplicated dimension.
nc_dataset.createDimension(dim_dup_new, dim_dup_length)

# Only create a new dimension *Variable* if it existed originally in the NetCDF structure.
if dim_dup in nc_dataset.variables.keys():

# New variable object is created for the renamed, previously duplicated dimension.
new_dup_var[dim_dup_new] = nc_dataset.createVariable(
dim_dup_new,
nc_dataset.variables[dim_dup].dtype,
(dim_dup_new,),
fill_value=fill_value,
)
dim_var_attr_contents = get_attributes_minus_fillvalue_and_renamed_coords(
original_var_name=dim_dup, new_var_name=dim_dup_new, original_dataset=nc_dataset
dim_var_attr_contents = (
get_attributes_minus_fillvalue_and_renamed_coords(
original_var_name=dim_dup,
new_var_name=dim_dup_new,
original_dataset=nc_dataset,
)
)
for attr_name, contents in dim_var_attr_contents.items():
new_dup_var[dim_dup_new].setncattr(attr_name, contents)
Expand All @@ -94,7 +104,10 @@ def remove_duplicate_dims(nc_dataset: nc.Dataset) -> nc.Dataset:

# Replace original *Variable* with new variable with no duplicated dimensions.
new_dup_var[dup_var_name] = nc_dataset.createVariable(
dup_var_name, str(dup_var[:].dtype), tuple(new_dim_list), fill_value=fill_value
dup_var_name,
str(dup_var[:].dtype),
tuple(new_dim_list),
fill_value=fill_value,
)
for attr_name, contents in attrs_contents.items():
new_dup_var[dup_var_name].setncattr(attr_name, contents)
Expand All @@ -111,7 +124,9 @@ def get_attributes_minus_fillvalue_and_renamed_coords(

for ncattr in original_dataset.variables[original_var_name].ncattrs():
if ncattr != "_FillValue":
contents: str = original_dataset.variables[original_var_name].getncattr(ncattr)
contents: str = original_dataset.variables[original_var_name].getncattr(
ncattr
)
if ncattr == "coordinates":
contents.replace(original_var_name, new_var_name)
attrs_contents[ncattr] = contents
Expand Down
8 changes: 6 additions & 2 deletions concatenator/file_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def make_temp_dir_with_input_file_copies(
) -> tuple[list[str], str]:
"""Creates temporary directory and copies input files."""
new_data_dir = Path(
add_label_to_path(str(output_path.parent / "temp_copy"), label=str(uuid.uuid4()))
add_label_to_path(
str(output_path.parent / "temp_copy"), label=str(uuid.uuid4())
)
).resolve()
os.makedirs(new_data_dir, exist_ok=True)
logger.info("Created temporary directory: %s", str(new_data_dir))
Expand Down Expand Up @@ -52,7 +54,9 @@ def validate_output_path(filepath: str, overwrite: bool = False) -> str:
f"Run again with `overwrite` option to overwrite existing file."
)
if path.is_dir(): # the specified path is an existing directory
raise TypeError("Output path cannot be a directory. Please specify a new filepath.")
raise TypeError(
"Output path cannot be a directory. Please specify a new filepath."
)
return str(path)


Expand Down
16 changes: 13 additions & 3 deletions concatenator/harmony/download_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@


def multi_core_download(
urls: list, destination_dir: str, access_token: str, cfg: dict, process_count: int | None = None
urls: list,
destination_dir: str,
access_token: str,
cfg: dict,
process_count: int | None = None,
) -> list[Path]:
"""
A method which automagically scales downloads to the number of CPU
Expand Down Expand Up @@ -74,7 +78,11 @@ def multi_core_download(


def _download_worker(
url_queue: queue.Queue, path_list: list, destination_dir: str, access_token: str, cfg: dict
url_queue: queue.Queue,
path_list: list,
destination_dir: str,
access_token: str,
cfg: dict,
) -> None:
"""
A method to be executed in a separate process which processes the url_queue
Expand Down Expand Up @@ -104,7 +112,9 @@ def _download_worker(
break

path = Path(
download(url, destination_dir, logger=logger, access_token=access_token, cfg=cfg)
download(
url, destination_dir, logger=logger, access_token=access_token, cfg=cfg
)
)
filename_match = re.match(r".*\/(.+\..+)", urlparse(url).path)

Expand Down
22 changes: 17 additions & 5 deletions concatenator/harmony/service_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def invoke(self):
# Message-only support is being depreciated in Harmony, so we should expect to
# only see requests with catalogs when invoked with a newer Harmony instance
# https://github.com/nasa/harmony-service-lib-py/blob/21bcfbda17caf626fb14d2ac4f8673be9726b549/harmony/adapter.py#L71
raise RuntimeError("Invoking Batchee without a STAC catalog is not supported")
raise RuntimeError(
"Invoking Batchee without a STAC catalog is not supported"
)

return self.message, self.process_catalog(self.catalog)

Expand Down Expand Up @@ -104,7 +106,9 @@ def process_catalog(self, catalog: pystac.Catalog) -> pystac.Catalog:
history_json: list[dict] = []
for file_count, file in enumerate(input_files):
file_size = sizeof_fmt(file.stat().st_size)
self.logger.info(f"File {file_count} is size <{file_size}>. Path={file}")
self.logger.info(
f"File {file_count} is size <{file_size}>. Path={file}"
)

with nc.Dataset(file, "r") as dataset:
history_json.extend(retrieve_history(dataset))
Expand Down Expand Up @@ -134,14 +138,22 @@ def process_catalog(self, catalog: pystac.Catalog) -> pystac.Catalog:
# -- Output to STAC catalog --
result.clear_items()
properties = dict(
start_datetime=datetimes["start_datetime"], end_datetime=datetimes["end_datetime"]
start_datetime=datetimes["start_datetime"],
end_datetime=datetimes["end_datetime"],
)

item = Item(
str(uuid4()), bbox_to_geometry(bounding_box), bounding_box, None, properties
str(uuid4()),
bbox_to_geometry(bounding_box),
bounding_box,
None,
properties,
)
asset = Asset(
staged_url, title=filename, media_type="application/x-netcdf4", roles=["data"]
staged_url,
title=filename,
media_type="application/x-netcdf4",
roles=["data"],
)
item.add_asset("data", asset)
result.add_item(item)
Expand Down
6 changes: 5 additions & 1 deletion concatenator/harmony/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Misc utility functions"""

from datetime import datetime

from pystac import Asset, Item
Expand Down Expand Up @@ -78,7 +79,10 @@ def _get_output_date_range(input_items: list[Item]) -> dict[str, str]:
start_datetime = min(start_datetime, new_start_datetime)
end_datetime = max(end_datetime, new_end_datetime)

return {"start_datetime": start_datetime.isoformat(), "end_datetime": end_datetime.isoformat()}
return {
"start_datetime": start_datetime.isoformat(),
"end_datetime": end_datetime.isoformat(),
}


def _get_item_date_range(item: Item) -> tuple[datetime, datetime]:
Expand Down
5 changes: 4 additions & 1 deletion concatenator/run_stitchee.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def parse_args(args: list) -> argparse.Namespace:
default="__",
)
parser.add_argument(
"-O", "--overwrite", action="store_true", help="Overwrite output file if it already exists."
"-O",
"--overwrite",
action="store_true",
help="Overwrite output file if it already exists.",
)
parser.add_argument(
"-v",
Expand Down
19 changes: 14 additions & 5 deletions concatenator/stitchee.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def stitchee(
concatenator.group_delim = group_delimiter

intermediate_flat_filepaths: list[str] = []
benchmark_log = {"flattening": 0.0, "concatenating": 0.0, "reconstructing_groups": 0.0}
benchmark_log = {
"flattening": 0.0,
"concatenating": 0.0,
"reconstructing_groups": 0.0,
}

# Proceed to concatenate only files that are workable (can be opened and are not empty).
input_files, num_input_files = validate_workable_files(files_to_concat, logger)
Expand All @@ -103,7 +107,9 @@ def stitchee(
# Exit cleanly with the file copied if one workable netCDF file found.
if num_input_files == 1:
shutil.copyfile(input_files[0], output_file)
logger.info("One workable netCDF file. Copied to output path without modification.")
logger.info(
"One workable netCDF file. Copied to output path without modification."
)
return output_file

if concat_dim and (concat_method == "xarray-combine"):
Expand All @@ -123,14 +129,15 @@ def stitchee(
# Instead of "with nc.Dataset() as" inside the loop, we use a context manager stack.
# This way all files are cleanly closed outside the loop.
with ExitStack() as context_stack:

logger.info("Flattening all input files...")
xrdataset_list = []
concat_dim_order = []
for i, filepath in enumerate(input_files):
# The group structure is flattened.
start_time = time.time()
logger.info(" ..file %03d/%03d <%s>..", i + 1, num_input_files, filepath)
logger.info(
" ..file %03d/%03d <%s>..", i + 1, num_input_files, filepath
)

ncfile = context_stack.enter_context(nc.Dataset(filepath, "r+"))

Expand Down Expand Up @@ -171,7 +178,9 @@ def stitchee(
# Reorder the xarray datasets according to the concat dim values.
xrdataset_list = [
dataset
for _, dataset in sorted(zip(concat_dim_order, xrdataset_list), key=lambda x: x[0])
for _, dataset in sorted(
zip(concat_dim_order, xrdataset_list), key=lambda x: x[0]
)
]

# Flattened files are concatenated together (Using XARRAY).
Expand Down
Loading

0 comments on commit 3434e0a

Please sign in to comment.