Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a flag to allow for level-wise ingestion of data. #304

Merged
merged 24 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
63d9b94
Initial commit - COG naming as {filename}_level_{level_no}
j9sh264 Dec 7, 2022
e0b5161
changed the naming convention to {asset_name}_level_{level}_{height}
j9sh264 Dec 16, 2022
71e94d6
Code cleanup, removed the repeated code blocks.
j9sh264 Mar 1, 2023
0a9af55
Addressed internal team review comments.
j9sh264 Mar 17, 2023
fda5c96
Merge 'main' into mv-group-common-hypercubes
j9sh264 Mar 20, 2023
8466b99
flake8 changes
j9sh264 Mar 20, 2023
87e9965
New changes to maintain backward compatibility
j9sh264 May 17, 2023
b2942f7
BigQuery changes
j9sh264 May 17, 2023
acf5395
Merge branch 'main' into mv-group-common-hypercubes
j9sh264 May 18, 2023
8c8da57
Linting changes
j9sh264 May 18, 2023
a8ab4e8
bigquery minor yield change
j9sh264 May 18, 2023
ebdb0e3
Added test cases
j9sh264 May 22, 2023
8ad45a7
Addressed comments
j9sh264 May 22, 2023
3b5763d
Filename correction
j9sh264 May 22, 2023
ef7a348
Merge branch 'main' into mv-group-common-hypercubes
j9sh264 Jun 5, 2023
18e738d
Removed zarr argument check from bq.py
j9sh264 Jun 5, 2023
0043409
Addressed Rahul's comments
j9sh264 Jun 6, 2023
1a63361
Resolved comments
j9sh264 Jun 13, 2023
8780a79
Version bump change
j9sh264 Jun 14, 2023
437fc58
Merge branch 'main' into mv-group-common-hypercubes
j9sh264 Jul 4, 2023
943cae0
Merge branch 'main' into mv-group-common-hypercubes
j9sh264 Oct 17, 2023
6abe689
flake8 changes
j9sh264 Oct 17, 2023
a3e54d7
Minor change to solve sinks test cases.
j9sh264 Nov 23, 2023
464a20d
Updated position of group_common_hypercubes function arg.
j9sh264 Nov 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions weather_mv/loader_pipeline/bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def validate_arguments(cls, known_args: argparse.Namespace, pipeline_args: t.Lis
if known_args.area:
assert len(known_args.area) == 4, 'Must specify exactly 4 lat/long values for area: N, W, S, E boundaries.'

# Add a check for group_common_hypercubes.
if pipeline_options_dict.get('group_common_hypercubes'):
raise RuntimeError('--group_common_hypercubes can be specified only for earth engine ingestions.')

# Check that all arguments are supplied for COG input.
_, uri_extension = os.path.splitext(known_args.uris)
if uri_extension == '.tif' and not known_args.tif_metadata_for_datetime:
Expand Down
152 changes: 81 additions & 71 deletions weather_mv/loader_pipeline/ee.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class ToEarthEngine(ToDataSink):
ee_qps: int
ee_latency: float
ee_max_concurrent: int
group_common_hypercubes: bool
band_names_mapping: str
initialization_time_regex: str
forecast_time_regex: str
Expand Down Expand Up @@ -268,6 +269,8 @@ def add_parser_arguments(cls, subparser: argparse.ArgumentParser):
help='The expected latency per requests, in seconds. Default: 0.5')
subparser.add_argument('--ee_max_concurrent', type=int, default=10,
help='Maximum concurrent api requests to EE allowed for your project. Default: 10')
subparser.add_argument('--group_common_hypercubes', action='store_true', default=False,
j9sh264 marked this conversation as resolved.
Show resolved Hide resolved
help='To group common hypercubes into image collections when loading grib data.')
subparser.add_argument('--band_names_mapping', type=str, default=None,
help='A JSON file which contains the band names for the TIFF file.')
subparser.add_argument('--initialization_time_regex', type=str, default=None,
Expand Down Expand Up @@ -417,6 +420,7 @@ class ConvertToAsset(beam.DoFn, beam.PTransform, KwargsFactoryMixin):
ee_asset_type: str = 'IMAGE'
open_dataset_kwargs: t.Optional[t.Dict] = None
disable_grib_schema_normalization: bool = False
group_common_hypercubes: t.Optional[bool] = False
band_names_dict: t.Optional[t.Dict] = None
initialization_time_regex: t.Optional[str] = None
forecast_time_regex: t.Optional[str] = None
Expand All @@ -436,80 +440,86 @@ def convert_to_asset(self, queue: Queue, uri: str):
with open_dataset(uri,
self.open_dataset_kwargs,
self.disable_grib_schema_normalization,
group_common_hypercubes=self.group_common_hypercubes,
band_names_dict=self.band_names_dict,
initialization_time_regex=self.initialization_time_regex,
forecast_time_regex=self.forecast_time_regex) as ds:

attrs = ds.attrs
data = list(ds.values())
asset_name = get_ee_safe_name(uri)
channel_names = [da.name for da in data]
start_time, end_time, is_normalized = (attrs.get(key) for key in
('start_time', 'end_time', 'is_normalized'))
dtype, crs, transform = (attrs.pop(key) for key in ['dtype', 'crs', 'transform'])
attrs.update({'is_normalized': str(is_normalized)}) # EE properties does not support bool.
# Make attrs EE ingestable.
attrs = make_attrs_ee_compatible(attrs)

# For tiff ingestions.
if self.ee_asset_type == 'IMAGE':
file_name = f'{asset_name}.tiff'

with MemoryFile() as memfile:
with memfile.open(driver='COG',
dtype=dtype,
width=data[0].data.shape[1],
height=data[0].data.shape[0],
count=len(data),
nodata=np.nan,
crs=crs,
transform=transform,
compress='lzw') as f:
for i, da in enumerate(data):
f.write(da, i+1)
# Making the channel name EE-safe before adding it as a band name.
f.set_band_description(i+1, get_ee_safe_name(channel_names[i]))
f.update_tags(i+1, band_name=channel_names[i])
f.update_tags(i+1, **da.attrs)
# Write attributes as tags in tiff.
f.update_tags(**attrs)

# Copy in-memory tiff to gcs.
forecast_time_regex=self.forecast_time_regex) as ds_list:
if not isinstance(ds_list, list):
ds_list = [ds_list]

for ds in ds_list:
attrs = ds.attrs
data = list(ds.values())
asset_name = get_ee_safe_name(uri)
channel_names = [da.name for da in data]
start_time, end_time, is_normalized = (attrs.get(key) for key in
('start_time', 'end_time', 'is_normalized'))
dtype, crs, transform = (attrs.pop(key) for key in ['dtype', 'crs', 'transform'])
attrs.update({'is_normalized': str(is_normalized)}) # EE properties does not support bool.
# Make attrs EE ingestable.
attrs = make_attrs_ee_compatible(attrs)

if self.group_common_hypercubes:
level, height = (attrs.pop(key) for key in ['level', 'height'])
j9sh264 marked this conversation as resolved.
Show resolved Hide resolved
safe_level_name = get_ee_safe_name(level)
asset_name = f'{asset_name}_{safe_level_name}'

# For tiff ingestions.
if self.ee_asset_type == 'IMAGE':
file_name = f'{asset_name}.tiff'

with MemoryFile() as memfile:
with memfile.open(driver='COG',
dtype=dtype,
width=data[0].data.shape[1],
height=data[0].data.shape[0],
count=len(data),
nodata=np.nan,
crs=crs,
transform=transform,
compress='lzw') as f:
for i, da in enumerate(data):
f.write(da, i+1)
# Making the channel name EE-safe before adding it as a band name.
f.set_band_description(i+1, get_ee_safe_name(channel_names[i]))
f.update_tags(i+1, band_name=channel_names[i])
f.update_tags(i+1, **da.attrs)

# Write attributes as tags in tiff.
f.update_tags(**attrs)

# Copy in-memory tiff to gcs.
target_path = os.path.join(self.asset_location, file_name)
with FileSystems().create(target_path) as dst:
shutil.copyfileobj(memfile, dst, WRITE_CHUNK_SIZE)
# For feature collection ingestions.
elif self.ee_asset_type == 'TABLE':
channel_names = []
file_name = f'{asset_name}.csv'

df = xr.Dataset.to_dataframe(ds)
df = df.reset_index()

# Copy in-memory dataframe to gcs.
target_path = os.path.join(self.asset_location, file_name)
with FileSystems().create(target_path) as dst:
shutil.copyfileobj(memfile, dst, WRITE_CHUNK_SIZE)
# For feature collection ingestions.
elif self.ee_asset_type == 'TABLE':
channel_names = []
file_name = f'{asset_name}.csv'

df = xr.Dataset.to_dataframe(ds)
df = df.reset_index()
# NULL and NaN create data-type mismatch issue in ee therefore replacing all of them.
# fillna fills in NaNs, NULLs, and NaTs but we have to exclude NaTs.
non_nat = df.select_dtypes(exclude=['datetime', 'timedelta', 'datetimetz'])
df[non_nat.columns] = non_nat.fillna(-9999)

# Copy in-memory dataframe to gcs.
target_path = os.path.join(self.asset_location, file_name)
with tempfile.NamedTemporaryFile() as tmp_df:
df.to_csv(tmp_df.name, index=False)
tmp_df.flush()
tmp_df.seek(0)
with FileSystems().create(target_path) as dst:
shutil.copyfileobj(tmp_df, dst, WRITE_CHUNK_SIZE)

asset_data = AssetData(
name=asset_name,
target_path=target_path,
channel_names=channel_names,
start_time=start_time,
end_time=end_time,
properties=attrs
)

self.add_to_queue(queue, asset_data)
self.add_to_queue(queue, None) # Indicates end of the subprocess.
with tempfile.NamedTemporaryFile() as tmp_df:
df.to_csv(tmp_df.name, index=False)
tmp_df.flush()
tmp_df.seek(0)
with FileSystems().create(target_path) as dst:
shutil.copyfileobj(tmp_df, dst, WRITE_CHUNK_SIZE)

asset_data = AssetData(
name=asset_name,
target_path=target_path,
channel_names=channel_names,
start_time=start_time,
end_time=end_time,
properties=attrs
)

self.add_to_queue(queue, asset_data)
self.add_to_queue(queue, None) # Indicates end of the subprocess.

def process(self, uri: str) -> t.Iterator[AssetData]:
"""Opens grib files and yields AssetData.
Expand Down
95 changes: 69 additions & 26 deletions weather_mv/loader_pipeline/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,12 @@ def _is_3d_da(da):
return len(da.shape) == 3


def __normalize_grib_dataset(filename: str) -> xr.Dataset:
def __normalize_grib_dataset(filename: str,
group_common_hypercubes: t.Optional[bool] = False) -> t.Union[xr.Dataset,
t.List[xr.Dataset]]:
"""Reads a list of datasets and merge them into a single dataset."""
_data_array_list = []
_level_data_dict = {}

list_ds = cfgrib.open_datasets(filename)
ds_attrs = list_ds[0].attrs
dv_units_dict = {}
Expand Down Expand Up @@ -250,6 +253,13 @@ def __normalize_grib_dataset(filename: str) -> xr.Dataset:
attrs['start_time'] = start_time
attrs['end_time'] = end_time

if group_common_hypercubes:
attrs['level'] = level # Adding the level in the metadata, will remove in further steps.
attrs['is_normalized'] = True # Adding the 'is_normalized' attribute in the metadata.

if level not in _level_data_dict:
_level_data_dict[level] = []

no_of_levels = da.shape[0] if _is_3d_da(da) else 1

# Deal with the randomness that is 3d data interspersed with 2d.
Expand All @@ -271,7 +281,8 @@ def __normalize_grib_dataset(filename: str) -> xr.Dataset:
logger.debug('Found channel %s', channel_name)

# Add the height as a metadata field, that seems useful.
copied_da.attrs['height'] = height
copied_da.attrs['height'] = height_string

# Add the units of each band as a metadata field.
dv_units_dict['unit_'+channel_name] = None
if 'units' in attrs:
Expand All @@ -281,26 +292,44 @@ def __normalize_grib_dataset(filename: str) -> xr.Dataset:
if _is_3d_da(da):
copied_da = copied_da.sel({level: height})
copied_da = copied_da.drop_vars(level)
_data_array_list.append(copied_da)

# Stick the forecast hour, start_time, end_time, data variables units
# in the ds attrs as well, that's useful.
ds_attrs['forecast_hour'] = _data_array_list[0].attrs['forecast_hour']
ds_attrs['start_time'] = _data_array_list[0].attrs['start_time']
ds_attrs['end_time'] = _data_array_list[0].attrs['end_time']
ds_attrs.update(**dv_units_dict)
_level_data_dict[level].append(copied_da)

_data_array_list = []
for level, ds in _level_data_dict.items():
j9sh264 marked this conversation as resolved.
Show resolved Hide resolved
if len(ds) == 1:
dataset = ds[0].to_dataset(promote_attrs=True)
j9sh264 marked this conversation as resolved.
Show resolved Hide resolved
else:
dataset = xr.merge(ds)
_data_array_list.append(dataset)
j9sh264 marked this conversation as resolved.
Show resolved Hide resolved
j9sh264 marked this conversation as resolved.
Show resolved Hide resolved

if not group_common_hypercubes:
# Stick the forecast hour, start_time, end_time, data variables units
# in the ds attrs as well, that's useful.
ds_attrs['forecast_hour'] = _data_array_list[0].attrs['forecast_hour']
ds_attrs['start_time'] = _data_array_list[0].attrs['start_time']
ds_attrs['end_time'] = _data_array_list[0].attrs['end_time']
ds_attrs.update(**dv_units_dict)

merged_dataset = xr.merge(_data_array_list)
merged_dataset.attrs.clear()
merged_dataset.attrs.update(ds_attrs)
return merged_dataset

merged_dataset = xr.merge(_data_array_list)
merged_dataset.attrs.clear()
merged_dataset.attrs.update(ds_attrs)
return merged_dataset
return _data_array_list


def __open_dataset_file(filename: str,
uri_extension: str,
disable_grib_schema_normalization: bool,
open_dataset_kwargs: t.Optional[t.Dict] = None) -> xr.Dataset:
open_dataset_kwargs: t.Optional[t.Dict] = None,
group_common_hypercubes: t.Optional[bool] = False) -> t.Union[xr.Dataset, t.List[xr.Dataset]]:
"""Opens the dataset at 'uri' and returns a xarray.Dataset."""
# add a flag to group common hypercubes
if group_common_hypercubes:
return __normalize_grib_dataset(filename, group_common_hypercubes)

# add a flag to use cfgrib
j9sh264 marked this conversation as resolved.
Show resolved Hide resolved
if open_dataset_kwargs:
return _add_is_normalized_attr(xr.open_dataset(filename, **open_dataset_kwargs), False)

Expand Down Expand Up @@ -380,6 +409,7 @@ def open_dataset(uri: str,
open_dataset_kwargs: t.Optional[t.Dict] = None,
disable_grib_schema_normalization: bool = False,
tif_metadata_for_datetime: t.Optional[str] = None,
group_common_hypercubes: t.Optional[bool] = False,
band_names_dict: t.Optional[t.Dict] = None,
initialization_time_regex: t.Optional[str] = None,
forecast_time_regex: t.Optional[str] = None,
Expand All @@ -394,29 +424,42 @@ def open_dataset(uri: str,
return
with open_local(uri) as local_path:
_, uri_extension = os.path.splitext(uri)
xr_dataset: xr.Dataset = __open_dataset_file(local_path,
uri_extension,
disable_grib_schema_normalization,
open_dataset_kwargs)
if uri_extension in ['.tif', '.tiff']:
xr_dataset = _preprocess_tif(xr_dataset,
xr_datasets: xr.Dataset = __open_dataset_file(local_path,
uri_extension,
disable_grib_schema_normalization,
open_dataset_kwargs,
group_common_hypercubes)
# Extracting dtype, crs and transform from the dataset.
with rasterio.open(local_path, 'r') as f:
dtype, crs, transform = (f.profile.get(key) for key in ['dtype', 'crs', 'transform'])

if group_common_hypercubes:
total_size_in_bytes = 0

for xr_dataset in xr_datasets:
xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform})
total_size_in_bytes += xr_dataset.nbytes

logger.info(f'opened dataset size: {total_size_in_bytes}')
elif uri_extension in ['.tif', '.tiff']:
xr_dataset = _preprocess_tif(xr_datasets,
local_path,
tif_metadata_for_datetime,
uri,
band_names_dict,
initialization_time_regex,
forecast_time_regex)
else:
xr_dataset = xr_datasets

# Extracting dtype, crs and transform from the dataset & storing them as attributes.
with rasterio.open(local_path, 'r') as f:
dtype, crs, transform = (f.profile.get(key) for key in ['dtype', 'crs', 'transform'])
xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform})
xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform})
j9sh264 marked this conversation as resolved.
Show resolved Hide resolved

logger.info(f'opened dataset size: {xr_dataset.nbytes}')

beam.metrics.Metrics.counter('Success', 'ReadNetcdfData').inc()
yield xr_dataset
yield xr_datasets if group_common_hypercubes else xr_dataset
xr_dataset.close()
j9sh264 marked this conversation as resolved.
Show resolved Hide resolved

except Exception as e:
beam.metrics.Metrics.counter('Failure', 'ReadNetcdfData').inc()
logger.error(f'Unable to open file {uri!r}: {e}')
Expand Down
6 changes: 6 additions & 0 deletions weather_mv/loader_pipeline/sinks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def setUp(self) -> None:
self.test_grib_path = os.path.join(self.test_data_folder, 'test_data_grib_single_timestep')
self.test_tif_path = os.path.join(self.test_data_folder, 'test_data_tif_start_time.tif')
self.test_zarr_path = os.path.join(self.test_data_folder, 'test_data.zarr')
self.test_grib_multi_level_path = os.path.join(self.test_data_folder, 'test_data_grib_multi_levels.grib2')

def test_opens_grib_files(self):
with open_dataset(self.test_grib_path) as ds1:
Expand Down Expand Up @@ -118,6 +119,11 @@ def test_open_dataset__fits_memory_bounds(self):
with open_dataset(test_netcdf_path) as _:
pass

def test_group_common_hypercubes(self):
with open_dataset(self.test_grib_multi_level_path,
group_common_hypercubes=True) as ds:
self.assertEqual(isinstance(ds, list), True)


class DatetimeTest(unittest.TestCase):

Expand Down
j9sh264 marked this conversation as resolved.
Show resolved Hide resolved
Binary file not shown.