diff --git a/ci3.8.yml b/ci3.8.yml index d6a1e0bd..0ce25520 100644 --- a/ci3.8.yml +++ b/ci3.8.yml @@ -16,7 +16,7 @@ dependencies: - requests=2.28.1 - netcdf4=1.6.1 - rioxarray=0.13.4 - - xarray-beam=0.3.1 + - xarray-beam=0.6.2 - ecmwf-api-client=1.6.3 - fsspec=2022.11.0 - gcsfs=2022.11.0 @@ -33,6 +33,7 @@ dependencies: - ruff==0.0.260 - google-cloud-sdk=410.0.0 - aria2=1.36.0 + - zarr=2.15.0 - pip: - earthengine-api==0.1.329 - .[test] diff --git a/ci3.9.yml b/ci3.9.yml index a43cec16..33471870 100644 --- a/ci3.9.yml +++ b/ci3.9.yml @@ -16,7 +16,7 @@ dependencies: - requests=2.28.1 - netcdf4=1.6.1 - rioxarray=0.13.4 - - xarray-beam=0.3.1 + - xarray-beam=0.6.2 - ecmwf-api-client=1.6.3 - fsspec=2022.11.0 - gcsfs=2022.11.0 @@ -33,6 +33,7 @@ dependencies: - aria2=1.36.0 - xarray==2023.1.0 - ruff==0.0.260 + - zarr=2.15.0 - pip: - earthengine-api==0.1.329 - .[test] diff --git a/environment.yml b/environment.yml index eae35f9c..79000c6b 100644 --- a/environment.yml +++ b/environment.yml @@ -4,7 +4,7 @@ channels: dependencies: - python=3.8.13 - apache-beam=2.40.0 - - xarray-beam=0.3.1 + - xarray-beam=0.6.2 - xarray=2023.1.0 - fsspec=2022.11.0 - gcsfs=2022.11.0 @@ -25,6 +25,7 @@ dependencies: - google-cloud-sdk=410.0.0 - aria2=1.36.0 - pip=22.3 + - zarr=2.15.0 - pip: - earthengine-api==0.1.329 - firebase-admin==6.0.1 diff --git a/setup.py b/setup.py index af798889..dedb552e 100644 --- a/setup.py +++ b/setup.py @@ -57,8 +57,9 @@ "earthengine-api>=0.1.263", "pyproj", # requires separate binary installation! "gdal", # requires separate binary installation! - "xarray-beam==0.3.1", + "xarray-beam==0.6.2", "gcsfs==2022.11.0", + "zarr==2.15.0", ] weather_sp_requirements = [ @@ -82,6 +83,7 @@ "memray", "pytest-memray", "h5py", + "pooch", ] all_test_requirements = beam_gcp_requirements + weather_dl_requirements + \ @@ -115,7 +117,7 @@ ], python_requires='>=3.8, <3.10', - install_requires=['apache-beam[gcp]==2.40.0'], + install_requires=['apache-beam[gcp]==2.40.0', 'gcsfs==2022.11.0'], use_scm_version=True, setup_requires=['setuptools_scm'], scripts=['weather_dl/weather-dl', 'weather_mv/weather-mv', 'weather_sp/weather-sp'], diff --git a/weather_mv/loader_pipeline/bq.py b/weather_mv/loader_pipeline/bq.py index 58120940..5f466419 100644 --- a/weather_mv/loader_pipeline/bq.py +++ b/weather_mv/loader_pipeline/bq.py @@ -24,8 +24,10 @@ import geojson import numpy as np import xarray as xr +import xarray_beam as xbeam from apache_beam.io import WriteToBigQuery, BigQueryDisposition from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.transforms import window from google.cloud import bigquery from xarray.core.utils import ensure_us_time_resolution @@ -236,73 +238,90 @@ def extract_rows(self, uri: str, coordinates: t.List[t.Dict]) -> t.Iterator[t.Di with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization, self.tif_metadata_for_datetime, is_zarr=self.zarr) as ds: data_ds: xr.Dataset = _only_target_vars(ds, self.variables) + yield from self.to_rows(coordinates, data_ds, uri) - first_ts_raw = data_ds.time[0].values if isinstance(data_ds.time.values, - np.ndarray) else data_ds.time.values - first_time_step = to_json_serializable_type(first_ts_raw) - - for it in coordinates: - # Use those index values to select a Dataset containing one row of data. - row_ds = data_ds.loc[it] - - # Create a Name-Value map for data columns. Result looks like: - # {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None} - row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values)) - for n, v in row_ds.data_vars.items()} - - # Serialize coordinates. - it = {k: to_json_serializable_type(v) for k, v in it.items()} - - # Add indexed coordinates. - row.update(it) - # Add un-indexed coordinates. - for c in row_ds.coords: - if c not in it and (not self.variables or c in self.variables): - row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values)) - - # Add import metadata. - row[DATA_IMPORT_TIME_COLUMN] = self.import_time - row[DATA_URI_COLUMN] = uri - row[DATA_FIRST_STEP] = first_time_step - - longitude = ((row['longitude'] + 180) % 360) - 180 - row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude) - row[GEO_POLYGON_COLUMN] = ( - fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution) - if not self.skip_creating_polygon - else None - ) - # 'row' ends up looking like: - # {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812, - # 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...} - beam.metrics.Metrics.counter('Success', 'ExtractRows').inc() - yield row + def to_rows(self, coordinates: t.Iterable[t.Dict], ds: xr.Dataset, uri: str) -> t.Iterator[t.Dict]: + first_ts_raw = ( + ds.time[0].values if isinstance(ds.time.values, np.ndarray) + else ds.time.values + ) + first_time_step = to_json_serializable_type(first_ts_raw) + for it in coordinates: + # Use those index values to select a Dataset containing one row of data. + row_ds = ds.loc[it] + + # Create a Name-Value map for data columns. Result looks like: + # {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None} + row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values)) + for n, v in row_ds.data_vars.items()} + + # Serialize coordinates. + it = {k: to_json_serializable_type(v) for k, v in it.items()} + + # Add indexed coordinates. + row.update(it) + # Add un-indexed coordinates. + for c in row_ds.coords: + if c not in it and (not self.variables or c in self.variables): + row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values)) + + # Add import metadata. + row[DATA_IMPORT_TIME_COLUMN] = self.import_time + row[DATA_URI_COLUMN] = uri + row[DATA_FIRST_STEP] = first_time_step + + longitude = ((row['longitude'] + 180) % 360) - 180 + row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude) + row[GEO_POLYGON_COLUMN] = ( + fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution) + if not self.skip_creating_polygon + else None + ) + # 'row' ends up looking like: + # {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812, + # 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...} + beam.metrics.Metrics.counter('Success', 'ExtractRows').inc() + yield row + + def chunks_to_rows(self, _, ds: xr.Dataset) -> t.Iterator[t.Dict]: + uri = ds.attrs.get(DATA_URI_COLUMN, '') + # Re-calculate import time for streaming extractions. + if not self.import_time or self.zarr: + self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) + yield from self.to_rows(get_coordinates(ds, uri), ds, uri) def expand(self, paths): """Extract rows of variables from data paths into a BigQuery table.""" - extracted_rows = ( + if not self.zarr: + extracted_rows = ( paths | 'PrepareCoordinates' >> beam.FlatMap(self.prepare_coordinates) | beam.Reshuffle() | 'ExtractRows' >> beam.FlatMapTuple(self.extract_rows) - ) - - if not self.dry_run: - ( - extracted_rows - | 'WriteToBigQuery' >> WriteToBigQuery( - project=self.table.project, - dataset=self.table.dataset_id, - table=self.table.table_id, - write_disposition=BigQueryDisposition.WRITE_APPEND, - create_disposition=BigQueryDisposition.CREATE_NEVER) ) else: - ( - extracted_rows - | 'Log Extracted Rows' >> beam.Map(logger.debug) + ds, chunks = xbeam.open_zarr(self.first_uri, **self.xarray_open_dataset_kwargs) + ds.attrs[DATA_URI_COLUMN] = self.first_uri + extracted_rows = ( + paths + | 'OpenChunks' >> xbeam.DatasetToChunks(ds, chunks) + | 'ExtractRows' >> beam.FlatMapTuple(self.chunks_to_rows) + | 'Window' >> beam.WindowInto(window.FixedWindows(60)) + | 'AddTimestamp' >> beam.Map(timestamp_row) ) + if self.dry_run: + return extracted_rows | 'Log Rows' >> beam.Map(logger.info) + return ( + extracted_rows + | 'WriteToBigQuery' >> WriteToBigQuery( + project=self.table.project, + dataset=self.table.dataset_id, + table=self.table.table_id, + write_disposition=BigQueryDisposition.WRITE_APPEND, + create_disposition=BigQueryDisposition.CREATE_NEVER) + ) + def map_dtype_to_sql_type(var_type: np.dtype) -> str: """Maps a np.dtype to a suitable BigQuery column type.""" @@ -343,6 +362,12 @@ def to_table_schema(columns: t.List[t.Tuple[str, str]]) -> t.List[bigquery.Schem return fields +def timestamp_row(it: t.Dict) -> window.TimestampedValue: + """Associate an extracted row with the import_time timestamp.""" + timestamp = it[DATA_IMPORT_TIME_COLUMN].timestamp() + return window.TimestampedValue(it, timestamp) + + def fetch_geo_point(lat: float, long: float) -> str: """Calculates a geography point from an input latitude and longitude.""" if lat > LATITUDE_RANGE[1] or lat < LATITUDE_RANGE[0]: diff --git a/weather_mv/loader_pipeline/bq_test.py b/weather_mv/loader_pipeline/bq_test.py index ed96cd9d..224a7a93 100644 --- a/weather_mv/loader_pipeline/bq_test.py +++ b/weather_mv/loader_pipeline/bq_test.py @@ -15,6 +15,7 @@ import json import logging import os +import tempfile import typing as t import unittest @@ -23,6 +24,8 @@ import pandas as pd import simplejson import xarray as xr +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that, is_not_empty from google.cloud.bigquery import SchemaField from .bq import ( @@ -205,13 +208,13 @@ def extract(self, data_path, *, variables=None, area=None, open_dataset_kwargs=N skip_creating_polygon: bool = False) -> t.Iterator[t.Dict]: if zarr_kwargs is None: zarr_kwargs = {} - op = ToBigQuery.from_kwargs(first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs, - output_table='foo.bar.baz', variables=variables, area=area, - xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time, - infer_schema=False, tif_metadata_for_datetime=tif_metadata_for_datetime, - skip_region_validation=True, - disable_grib_schema_normalization=disable_grib_schema_normalization, - coordinate_chunk_size=1000, skip_creating_polygon=skip_creating_polygon) + op = ToBigQuery.from_kwargs( + first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs, + output_table='foo.bar.baz', variables=variables, area=area, + xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time, infer_schema=False, + tif_metadata_for_datetime=tif_metadata_for_datetime, skip_region_validation=True, + disable_grib_schema_normalization=disable_grib_schema_normalization, coordinate_chunk_size=1000, + skip_creating_polygon=skip_creating_polygon) coords = op.prepare_coordinates(data_path) for uri, chunk in coords: yield from op.extract_rows(uri, chunk) @@ -737,5 +740,36 @@ def test_multiple_editions__with_vars__includes_coordinates_in_vars__with_schema self.assertRowsEqual(actual, expected) +class ExtractRowsFromZarrTest(ExtractRowsTestBase): + + def setUp(self) -> None: + super().setUp() + self.tmpdir = tempfile.TemporaryDirectory() + + def tearDown(self) -> None: + super().tearDown() + self.tmpdir.cleanup() + + def test_extracts_rows(self): + input_zarr = os.path.join(self.tmpdir.name, 'air_temp.zarr') + + ds = ( + xr.tutorial.open_dataset('air_temperature', cache_dir=self.test_data_folder) + .isel(time=slice(0, 4), lat=slice(0, 4), lon=slice(0, 4)) + .rename(dict(lon='longitude', lat='latitude')) + ) + ds.to_zarr(input_zarr) + + op = ToBigQuery.from_kwargs( + first_uri=input_zarr, zarr_kwargs=dict(), dry_run=True, zarr=True, output_table='foo.bar.baz', + variables=list(), area=list(), xarray_open_dataset_kwargs=dict(), import_time=None, infer_schema=False, + tif_metadata_for_datetime=None, skip_region_validation=True, disable_grib_schema_normalization=False, + ) + + with TestPipeline() as p: + result = p | op + assert_that(result, is_not_empty()) + + if __name__ == '__main__': unittest.main() diff --git a/weather_mv/loader_pipeline/pipeline.py b/weather_mv/loader_pipeline/pipeline.py index c12bd5f5..f6d40c41 100644 --- a/weather_mv/loader_pipeline/pipeline.py +++ b/weather_mv/loader_pipeline/pipeline.py @@ -27,7 +27,7 @@ from .streaming import GroupMessagesByFixedWindows, ParsePaths logger = logging.getLogger(__name__) -SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0' +SDK_CONTAINER_IMAGE = 'gcr.io/weather-tools-prod/weather-tools:0.0.0' def configure_logger(verbosity: int) -> None: @@ -55,8 +55,9 @@ def pipeline(known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None known_args.first_uri = next(iter(all_uris)) with beam.Pipeline(argv=pipeline_args) as p: - if known_args.topic or known_args.subscription: - + if known_args.zarr: + paths = p + elif known_args.topic or known_args.subscription: paths = ( p # Windowing is based on this code sample: @@ -140,7 +141,6 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]: # Validate Zarr arguments if known_args.uris.endswith('.zarr'): known_args.zarr = True - known_args.zarr_kwargs['chunks'] = known_args.zarr_kwargs.get('chunks', None) if known_args.zarr_kwargs and not known_args.zarr: raise ValueError('`--zarr_kwargs` argument is only allowed with valid Zarr input URI.') diff --git a/weather_mv/loader_pipeline/regrid_test.py b/weather_mv/loader_pipeline/regrid_test.py index 5cc5b2a1..87ffaad4 100644 --- a/weather_mv/loader_pipeline/regrid_test.py +++ b/weather_mv/loader_pipeline/regrid_test.py @@ -122,7 +122,7 @@ def test_zarr__coarsen(self): self.Op, first_uri=input_zarr, output_path=output_zarr, - zarr_input_chunks={"time": 5}, + zarr_input_chunks={"time": 25}, zarr=True ) diff --git a/weather_mv/loader_pipeline/sinks.py b/weather_mv/loader_pipeline/sinks.py index 0f8cd561..f412f9bd 100644 --- a/weather_mv/loader_pipeline/sinks.py +++ b/weather_mv/loader_pipeline/sinks.py @@ -177,7 +177,7 @@ def _replace_dataarray_names_with_long_names(ds: xr.Dataset): datetime_value_ms = None try: datetime_value_s = (int(end_time.timestamp()) if end_time is not None - else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0) + else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0) ds = ds.assign_coords({'time': datetime.datetime.utcfromtimestamp(datetime_value_s)}) except KeyError: raise RuntimeError(f"Invalid datetime metadata of tif: {tif_metadata_for_datetime}.") @@ -375,7 +375,7 @@ def open_dataset(uri: str, """Open the dataset at 'uri' and return a xarray.Dataset.""" try: if is_zarr: - ds: xr.Dataset = xr.open_dataset(uri, engine='zarr', **open_dataset_kwargs) + ds: xr.Dataset = _add_is_normalized_attr(xr.open_dataset(uri, engine='zarr', **open_dataset_kwargs), False) beam.metrics.Metrics.counter('Success', 'ReadNetcdfData').inc() yield ds ds.close() diff --git a/weather_mv/loader_pipeline/sinks_test.py b/weather_mv/loader_pipeline/sinks_test.py index f7ca1641..a759c586 100644 --- a/weather_mv/loader_pipeline/sinks_test.py +++ b/weather_mv/loader_pipeline/sinks_test.py @@ -112,6 +112,7 @@ def test_opens_zarr(self): with open_dataset(self.test_zarr_path, is_zarr=True, open_dataset_kwargs={}) as ds: self.assertIsNotNone(ds) self.assertEqual(list(ds.data_vars), ['cape', 'd2m']) + def test_open_dataset__fits_memory_bounds(self): with write_netcdf() as test_netcdf_path: with limit_memory(max_memory=30): diff --git a/weather_mv/loader_pipeline/util_test.py b/weather_mv/loader_pipeline/util_test.py index dae9c873..65d9169e 100644 --- a/weather_mv/loader_pipeline/util_test.py +++ b/weather_mv/loader_pipeline/util_test.py @@ -38,9 +38,10 @@ def test_gets_indexed_coordinates(self): ds = xr.open_dataset(self.test_data_path) self.assertEqual( next(get_coordinates(ds)), - {'latitude': 49.0, - 'longitude':-108.0, - 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)} + { + 'latitude': 49.0, + 'longitude': -108.0, + 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)} ) def test_no_duplicate_coordinates(self): @@ -91,24 +92,28 @@ def test_get_coordinates(self): actual, [ [ - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None) + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None) }, - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T07:00:00+00:00').replace(tzinfo=None) + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T07:00:00+00:00').replace(tzinfo=None) }, - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T08:00:00+00:00').replace(tzinfo=None) + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T08:00:00+00:00').replace(tzinfo=None) }, ], [ - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T09:00:00+00:00').replace(tzinfo=None) - } + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T09:00:00+00:00').replace(tzinfo=None) + } ] ] ) diff --git a/weather_mv/setup.py b/weather_mv/setup.py index bfe09713..f200a822 100644 --- a/weather_mv/setup.py +++ b/weather_mv/setup.py @@ -45,6 +45,7 @@ "numpy==1.22.4", "pandas==1.5.1", "xarray==2023.1.0", + "xarray-beam==0.6.2", "cfgrib==0.9.10.2", "netcdf4==1.6.1", "geojson==2.5.0", @@ -55,6 +56,8 @@ "earthengine-api>=0.1.263", "pyproj==3.4.0", # requires separate binary installation! "gdal==3.5.1", # requires separate binary installation! + "gcsfs==2022.11.0", + "zarr==2.15.0", ] setup(