diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0325c936..ac86adc1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -84,13 +84,15 @@ jobs: echo "::set-output name=dir::$(pip cache dir)" - name: Install linter run: | - pip install ruff + pip install ruff==0.0.280 - name: Lint project run: ruff check . type-check: runs-on: ubuntu-latest strategy: fail-fast: false + matrix: + python-version: ["3.8"] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.7.0 @@ -98,28 +100,24 @@ jobs: access_token: ${{ github.token }} if: ${{github.ref != 'refs/head/main'}} - uses: actions/checkout@v2 - - name: Set up Python 3.8 - uses: actions/setup-python@v2 + - name: conda cache + uses: actions/cache@v2 + env: + # Increase this value to reset cache if etc/example-environment.yml has not changed + CACHE_NUMBER: 0 with: - python-version: "3.8" - - name: Setup conda - uses: s-weigand/setup-conda@v1 + path: ~/conda_pkgs_dir + key: + ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ matrix.python-version }}-${{ hashFiles('ci3.8.yml') }} + - name: Setup conda environment + uses: conda-incubator/setup-miniconda@v2 with: - update-conda: true - python-version: "3.8" - conda-channels: anaconda, conda-forge - - name: Install ecCodes - run: | - conda install -y eccodes>=2.21.0 -c conda-forge - conda install -y pyproj -c conda-forge - conda install -y gdal -c conda-forge - - name: Get pip cache dir - id: pip-cache - run: | - python -m pip install --upgrade pip wheel - echo "::set-output name=dir::$(pip cache dir)" - - name: Install weather-tools + python-version: ${{ matrix.python-version }} + channels: conda-forge + environment-file: ci${{ matrix.python-version}}.yml + activate-environment: weather-tools + - name: Install weather-tools[test] run: | - pip install -e .[test] --use-deprecated=legacy-resolver + conda run -n weather-tools pip install -e .[test] --use-deprecated=legacy-resolver - name: Run type checker - run: pytype + run: conda run -n weather-tools pytype diff --git a/Dockerfile b/Dockerfile index 057442f5..30a72f57 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,6 +29,7 @@ ARG weather_tools_git_rev=main RUN git clone https://github.com/google/weather-tools.git /weather WORKDIR /weather RUN git checkout "${weather_tools_git_rev}" +RUN rm -r /weather/weather_*/test_data/ RUN conda env create -f environment.yml --debug # Activate the conda env and update the PATH diff --git a/ci3.8.yml b/ci3.8.yml index d6a1e0bd..803e59c4 100644 --- a/ci3.8.yml +++ b/ci3.8.yml @@ -16,7 +16,7 @@ dependencies: - requests=2.28.1 - netcdf4=1.6.1 - rioxarray=0.13.4 - - xarray-beam=0.3.1 + - xarray-beam=0.6.2 - ecmwf-api-client=1.6.3 - fsspec=2022.11.0 - gcsfs=2022.11.0 @@ -33,6 +33,8 @@ dependencies: - ruff==0.0.260 - google-cloud-sdk=410.0.0 - aria2=1.36.0 + - zarr=2.15.0 - pip: + - cython==0.29.34 - earthengine-api==0.1.329 - .[test] diff --git a/ci3.9.yml b/ci3.9.yml index a43cec16..e9e0671f 100644 --- a/ci3.9.yml +++ b/ci3.9.yml @@ -16,7 +16,7 @@ dependencies: - requests=2.28.1 - netcdf4=1.6.1 - rioxarray=0.13.4 - - xarray-beam=0.3.1 + - xarray-beam=0.6.2 - ecmwf-api-client=1.6.3 - fsspec=2022.11.0 - gcsfs=2022.11.0 @@ -33,6 +33,8 @@ dependencies: - aria2=1.36.0 - xarray==2023.1.0 - ruff==0.0.260 + - zarr=2.15.0 - pip: + - cython==0.29.34 - earthengine-api==0.1.329 - .[test] diff --git a/environment.yml b/environment.yml index eae35f9c..0b043980 100644 --- a/environment.yml +++ b/environment.yml @@ -4,7 +4,7 @@ channels: dependencies: - python=3.8.13 - apache-beam=2.40.0 - - xarray-beam=0.3.1 + - xarray-beam=0.6.2 - xarray=2023.1.0 - fsspec=2022.11.0 - gcsfs=2022.11.0 @@ -25,7 +25,9 @@ dependencies: - google-cloud-sdk=410.0.0 - aria2=1.36.0 - pip=22.3 + - zarr=2.15.0 - pip: + - cython==0.29.34 - earthengine-api==0.1.329 - firebase-admin==6.0.1 - . diff --git a/setup.py b/setup.py index af798889..dedb552e 100644 --- a/setup.py +++ b/setup.py @@ -57,8 +57,9 @@ "earthengine-api>=0.1.263", "pyproj", # requires separate binary installation! "gdal", # requires separate binary installation! - "xarray-beam==0.3.1", + "xarray-beam==0.6.2", "gcsfs==2022.11.0", + "zarr==2.15.0", ] weather_sp_requirements = [ @@ -82,6 +83,7 @@ "memray", "pytest-memray", "h5py", + "pooch", ] all_test_requirements = beam_gcp_requirements + weather_dl_requirements + \ @@ -115,7 +117,7 @@ ], python_requires='>=3.8, <3.10', - install_requires=['apache-beam[gcp]==2.40.0'], + install_requires=['apache-beam[gcp]==2.40.0', 'gcsfs==2022.11.0'], use_scm_version=True, setup_requires=['setuptools_scm'], scripts=['weather_dl/weather-dl', 'weather_mv/weather-mv', 'weather_sp/weather-sp'], diff --git a/weather_dl/README.md b/weather_dl/README.md index 57924111..d5608c67 100644 --- a/weather_dl/README.md +++ b/weather_dl/README.md @@ -57,6 +57,8 @@ _Common options_: that partitions will be processed in sequential order of each config; 'fair' means that partitions from each config will be interspersed evenly. Note: When using 'fair' scheduling, we recommend you set the '--partition-chunks' to a much smaller number. Default: 'in-order'. +* `--log-level`: An integer to configure log level. Default: 2(INFO). +* `--use-local-code`: Supply local code to the Runner. Default: False. > Note: > * In case of BigQuery manifest tool will create the BQ table itself, if not already present. @@ -93,6 +95,17 @@ weather-dl configs/mars_example_config.cfg \ --job_name $JOB_NAME ``` +Using DataflowRunner and using local code for pipeline + +```bash +weather-dl configs/mars_example_config.cfg \ + --runner DataflowRunner \ + --project $PROJECT \ + --temp_location gs://$BUCKET/tmp \ + --job_name $JOB_NAME \ + --use-local-code +``` + Using the DataflowRunner and specifying 3 requests per license ```bash diff --git a/weather_dl/download_pipeline/manifest.py b/weather_dl/download_pipeline/manifest.py index b07d779f..d2a72550 100644 --- a/weather_dl/download_pipeline/manifest.py +++ b/weather_dl/download_pipeline/manifest.py @@ -368,6 +368,9 @@ def set_stage(self, stage: Stage) -> None: if stage == Stage.FETCH: new_status.fetch_start_time = current_utc_time + new_status.fetch_end_time = None + new_status.download_start_time = None + new_status.download_end_time = None elif stage == Stage.RETRIEVE: new_status.retrieve_start_time = current_utc_time elif stage == Stage.DOWNLOAD: diff --git a/weather_dl/download_pipeline/pipeline.py b/weather_dl/download_pipeline/pipeline.py index fa4983e2..71bdb71f 100644 --- a/weather_dl/download_pipeline/pipeline.py +++ b/weather_dl/download_pipeline/pipeline.py @@ -177,6 +177,7 @@ def run(argv: t.List[str], save_main_session: bool = True) -> PipelineArgs: help="Update the manifest for the already downloaded shards and exit. Default: 'false'.") parser.add_argument('--log-level', type=int, default=2, help='An integer to configure log level. Default: 2(INFO)') + parser.add_argument('--use-local-code', action='store_true', default=False, help='Supply local code to the Runner.') known_args, pipeline_args = parser.parse_known_args(argv[1:]) diff --git a/weather_dl/download_pipeline/pipeline_test.py b/weather_dl/download_pipeline/pipeline_test.py index 9e18a4bc..c6370c6f 100644 --- a/weather_dl/download_pipeline/pipeline_test.py +++ b/weather_dl/download_pipeline/pipeline_test.py @@ -58,7 +58,8 @@ schedule='in-order', check_skip_in_dry_run=False, update_manifest=False, - log_level=20), + log_level=20, + use_local_code=False), pipeline_options=PipelineOptions('--save_main_session True'.split()), configs=[Config.from_dict(CONFIG)], client_name='cds', diff --git a/weather_dl/download_pipeline/util.py b/weather_dl/download_pipeline/util.py index 1ee9e24e..3e92b8d8 100644 --- a/weather_dl/download_pipeline/util.py +++ b/weather_dl/download_pipeline/util.py @@ -105,7 +105,7 @@ def to_json_serializable_type(value: t.Any) -> t.Any: elif type(value) == np.ndarray: # Will return a scaler if array is of size 1, else will return a list. return value.tolist() - elif type(value) == datetime.datetime or type(value) == str or type(value) == np.datetime64: + elif isinstance(value, datetime.datetime) or isinstance(value, str) or isinstance(value, np.datetime64): # Assume strings are ISO format timestamps... try: value = datetime.datetime.fromisoformat(value) diff --git a/weather_dl/setup.py b/weather_dl/setup.py index 60c8732f..b96466fd 100644 --- a/weather_dl/setup.py +++ b/weather_dl/setup.py @@ -48,7 +48,7 @@ setup( name='download_pipeline', packages=find_packages(), - version='0.1.19', + version='0.1.20', author='Anthromets', author_email='anthromets-ecmwf@google.com', url='https://weather-tools.readthedocs.io/en/latest/weather_dl/', diff --git a/weather_dl/weather-dl b/weather_dl/weather-dl index 6fc95b7d..5b73411a 100755 --- a/weather_dl/weather-dl +++ b/weather_dl/weather-dl @@ -24,6 +24,8 @@ import tempfile import weather_dl +SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0' + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) @@ -48,10 +50,16 @@ if __name__ == '__main__': from download_pipeline import cli except ImportError as e: raise ImportError('please re-install package in a clean python environment.') from e - - if '-h' in sys.argv or '--help' in sys.argv or len(sys.argv) == 1: - cli() - else: + + args = [] + + if "DataflowRunner" in sys.argv and "--sdk_container_image" not in sys.argv: + args.extend(['--sdk_container_image', + os.getenv('SDK_CONTAINER_IMAGE', SDK_CONTAINER_IMAGE), + '--experiments', + 'use_runner_v2']) + + if "--use-local-code" in sys.argv: with tempfile.TemporaryDirectory() as tmpdir: original_dir = os.getcwd() @@ -72,5 +80,7 @@ if __name__ == '__main__': # cleanup memory to prevent pickling error. tar = None weather_dl = None - - cli(['--extra_package', pkg_archive]) + args.extend(['--extra_package', pkg_archive]) + cli(args) + else: + cli(args) diff --git a/weather_mv/README.md b/weather_mv/README.md index af77766d..e99f5b4c 100644 --- a/weather_mv/README.md +++ b/weather_mv/README.md @@ -49,6 +49,8 @@ _Common options_ * `--num_shards`: Number of shards to use when writing windowed elements to cloud storage. Only used with the `topic` flag. Default: 5 shards. * `-d, --dry-run`: Preview the load into BigQuery. Default: off. +* `--log-level`: An integer to configure log level. Default: 2(INFO). +* `--use-local-code`: Supply local code to the Runner. Default: False. Invoke with `-h` or `--help` to see the full range of options. @@ -59,8 +61,9 @@ usage: weather-mv bigquery [-h] -i URIS [--topic TOPIC] [--window_size WINDOW_SI -o OUTPUT_TABLE [-v variables [variables ...]] [-a area [area ...]] [--import_time IMPORT_TIME] [--infer_schema] [--xarray_open_dataset_kwargs XARRAY_OPEN_DATASET_KWARGS] - [--tif_metadata_for_datetime TIF_METADATA_FOR_DATETIME] [-s] - [--coordinate_chunk_size COORDINATE_CHUNK_SIZE] + [--tif_metadata_for_start_time TIF_METADATA_FOR_START_TIME] + [--tif_metadata_for_end_time TIF_METADATA_FOR_END_TIME] [-s] + [--coordinate_chunk_size COORDINATE_CHUNK_SIZE] ['--skip_creating_polygon'] ``` The `bigquery` subcommand loads weather data into BigQuery. In addition to the common options above, users may specify @@ -78,9 +81,13 @@ _Command options_: * `--xarray_open_dataset_kwargs`: Keyword-args to pass into `xarray.open_dataset()` in the form of a JSON string. * `--coordinate_chunk_size`: The size of the chunk of coordinates used for extracting vector data into BigQuery. Used to tune parallel uploads. -* `--tif_metadata_for_datetime` : Metadata that contains tif file's timestamp. Applicable only for tif files. +* `--tif_metadata_for_start_time` : Metadata that contains tif file's start/initialization time. Applicable only for tif files. +* `--tif_metadata_for_end_time` : Metadata that contains tif file's end/forecast time. Applicable only for tif files (optional). * `-s, --skip-region-validation` : Skip validation of regions for data migration. Default: off. * `--disable_grib_schema_normalization` : To disable grib's schema normalization. Default: off. +* `--skip_creating_polygon` : Not ingest grid points as polygons in BigQuery. Default: Ingest grid points as Polygon in + BigQuery. Note: This feature relies on the assumption that the provided grid has an equal distance between consecutive + points of latitude and longitude. Invoke with `bq -h` or `bigquery --help` to see the full range of options. @@ -117,6 +124,16 @@ weather-mv bq --uris "gs://your-bucket/*.nc" \ --dry-run ``` +Ingest grid points with skip creating polygon in BigQuery: + +```bash +weather-mv bq --uris "gs://your-bucket/*.nc" \ + --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ + --temp_location "gs://$BUCKET/tmp" \ # Needed for batch writes to BigQuery + --direct_num_workers 2 \ + --skip_creating_polygon +``` + Load COG's (.tif) files: ```bash @@ -124,7 +141,8 @@ weather-mv bq --uris "gs://your-bucket/*.tif" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location "gs://$BUCKET/tmp" \ # Needed for batch writes to BigQuery --direct_num_workers 2 \ - --tif_metadata_for_datetime start_time + --tif_metadata_for_start_time start_time \ + --tif_metadata_for_end_time end_time ``` Upload only a subset of variables: @@ -147,6 +165,39 @@ weather-mv bq --uris "gs://your-bucket/*.nc" \ --direct_num_workers 2 ``` +Upload a zarr file: + +```bash +weather-mv bq --uris "gs://your-bucket/*.zarr" \ + --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ + --temp_location "gs://$BUCKET/tmp" \ + --use-local-code \ + --zarr \ + --direct_num_workers 2 +``` + +Upload a specific date range's data from the zarr file: + +```bash +weather-mv bq --uris "gs://your-bucket/*.zarr" \ + --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ + --temp_location "gs://$BUCKET/tmp" \ + --use-local-code \ + --zarr \ + --zarr_kwargs '{"start_date": "2021-07-18", "end_date": "2021-07-19"}' \ + --direct_num_workers 2 +``` + +Upload a specific date range's data from the file: + +```bash +weather-mv bq --uris "gs://your-bucket/*.nc" \ + --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ + --temp_location "gs://$BUCKET/tmp" \ + --use-local-code \ + --xarray_open_dataset_kwargs '{"start_date": "2021-07-18", "end_date": "2021-07-19"}' \ +``` + Control how weather data is opened with XArray: ```bash @@ -169,6 +220,19 @@ weather-mv bq --uris "gs://your-bucket/*.nc" \ --job_name $JOB_NAME ``` +Using DataflowRunner and using local code for pipeline + +```bash +weather-mv bq --uris "gs://your-bucket/*.nc" \ + --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ + --runner DataflowRunner \ + --project $PROJECT \ + --region $REGION \ + --temp_location "gs://$BUCKET/tmp" \ + --job_name $JOB_NAME \ + --use-local-code +``` + For a full list of how to configure the Dataflow pipeline, please review [this table](https://cloud.google.com/dataflow/docs/reference/pipeline-options). diff --git a/weather_mv/loader_pipeline/bq.py b/weather_mv/loader_pipeline/bq.py index 4c86e116..bc71da01 100644 --- a/weather_mv/loader_pipeline/bq.py +++ b/weather_mv/loader_pipeline/bq.py @@ -24,8 +24,10 @@ import geojson import numpy as np import xarray as xr +import xarray_beam as xbeam from apache_beam.io import WriteToBigQuery, BigQueryDisposition from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.transforms import window from google.cloud import bigquery from xarray.core.utils import ensure_us_time_resolution @@ -45,7 +47,9 @@ DATA_URI_COLUMN = 'data_uri' DATA_FIRST_STEP = 'data_first_step' GEO_POINT_COLUMN = 'geo_point' +GEO_POLYGON_COLUMN = 'geo_polygon' LATITUDE_RANGE = (-90, 90) +LONGITUDE_RANGE = (-180, 180) @dataclasses.dataclass @@ -73,8 +77,10 @@ class ToBigQuery(ToDataSink): infer_schema: If true, this sink will attempt to read in an example data file read all its variables, and generate a BigQuery schema. xarray_open_dataset_kwargs: A dictionary of kwargs to pass to xr.open_dataset(). - tif_metadata_for_datetime: If the input is a .tif file, parse the tif metadata at - this location for a timestamp. + tif_metadata_for_start_time: If the input is a .tif file, parse the tif metadata at + this location for a start time / initialization time. + tif_metadata_for_end_time: If the input is a .tif file, parse the tif metadata at + this location for a end/forecast time. skip_region_validation: Turn off validation that checks if all Cloud resources are in the same region. disable_grib_schema_normalization: Turn off grib's schema normalization; Default: normalization enabled. @@ -90,10 +96,14 @@ class ToBigQuery(ToDataSink): import_time: t.Optional[datetime.datetime] infer_schema: bool xarray_open_dataset_kwargs: t.Dict - tif_metadata_for_datetime: t.Optional[str] + tif_metadata_for_start_time: t.Optional[str] + tif_metadata_for_end_time: t.Optional[str] skip_region_validation: bool disable_grib_schema_normalization: bool coordinate_chunk_size: int = 10_000 + skip_creating_polygon: bool = False + lat_grid_resolution: t.Optional[float] = None + lon_grid_resolution: t.Optional[float] = None @classmethod def add_parser_arguments(cls, subparser: argparse.ArgumentParser): @@ -105,6 +115,11 @@ def add_parser_arguments(cls, subparser: argparse.ArgumentParser): 'all data variables as columns.') subparser.add_argument('-a', '--area', metavar='area', type=float, nargs='+', default=list(), help='Target area in [N, W, S, E]. Default: Will include all available area.') + subparser.add_argument('--skip_creating_polygon', action='store_true', + help='Not ingest grid points as polygons in BigQuery. Default: Ingest grid points as ' + 'Polygon in BigQuery. Note: This feature relies on the assumption that the ' + 'provided grid has an equal distance between consecutive points of latitude and ' + 'longitude.') subparser.add_argument('--import_time', type=str, default=datetime.datetime.utcnow().isoformat(), help=("When writing data to BigQuery, record that data import occurred at this " "time (format: YYYY-MM-DD HH:MM:SS.usec+offset). Default: now in UTC.")) @@ -113,8 +128,11 @@ def add_parser_arguments(cls, subparser: argparse.ArgumentParser): 'off') subparser.add_argument('--xarray_open_dataset_kwargs', type=json.loads, default='{}', help='Keyword-args to pass into `xarray.open_dataset()` in the form of a JSON string.') - subparser.add_argument('--tif_metadata_for_datetime', type=str, default=None, - help='Metadata that contains tif file\'s timestamp. ' + subparser.add_argument('--tif_metadata_for_start_time', type=str, default=None, + help='Metadata that contains tif file\'s start/initialization time. ' + 'Applicable only for tif files.') + subparser.add_argument('--tif_metadata_for_end_time', type=str, default=None, + help='Metadata that contains tif file\'s end/forecast time. ' 'Applicable only for tif files.') subparser.add_argument('-s', '--skip-region-validation', action='store_true', default=False, help='Skip validation of regions for data migration. Default: off') @@ -138,10 +156,14 @@ def validate_arguments(cls, known_args: argparse.Namespace, pipeline_args: t.Lis # Check that all arguments are supplied for COG input. _, uri_extension = os.path.splitext(known_args.uris) - if uri_extension == '.tif' and not known_args.tif_metadata_for_datetime: - raise RuntimeError("'--tif_metadata_for_datetime' is required for tif files.") - elif uri_extension != '.tif' and known_args.tif_metadata_for_datetime: - raise RuntimeError("'--tif_metadata_for_datetime' can be specified only for tif files.") + if (uri_extension in ['.tif', '.tiff'] and not known_args.tif_metadata_for_start_time): + raise RuntimeError("'--tif_metadata_for_start_time' is required for tif files.") + elif uri_extension not in ['.tif', '.tiff'] and ( + known_args.tif_metadata_for_start_time + or known_args.tif_metadata_for_end_time + ): + raise RuntimeError("'--tif_metadata_for_start_time' and " + "'--tif_metadata_for_end_time' can be specified only for tif files.") # Check that Cloud resource regions are consistent. if not (known_args.dry_run or known_args.skip_region_validation): @@ -156,8 +178,29 @@ def __post_init__(self): if self.zarr: self.xarray_open_dataset_kwargs = self.zarr_kwargs with open_dataset(self.first_uri, self.xarray_open_dataset_kwargs, - self.disable_grib_schema_normalization, self.tif_metadata_for_datetime, - is_zarr=self.zarr) as open_ds: + self.disable_grib_schema_normalization, self.tif_metadata_for_start_time, + self.tif_metadata_for_end_time, is_zarr=self.zarr) as open_ds: + + if not self.skip_creating_polygon: + logger.warning("Assumes that equal distance between consecutive points of latitude " + "and longitude for the entire grid.") + # Find the grid_resolution. + if open_ds['latitude'].size > 1 and open_ds['longitude'].size > 1: + latitude_length = len(open_ds['latitude']) + longitude_length = len(open_ds['longitude']) + + latitude_range = np.ptp(open_ds["latitude"].values) + longitude_range = np.ptp(open_ds["longitude"].values) + + self.lat_grid_resolution = abs(latitude_range / latitude_length) / 2 + self.lon_grid_resolution = abs(longitude_range / longitude_length) / 2 + + else: + self.skip_creating_polygon = True + logger.warning("Polygon can't be genereated as provided dataset has a only single grid point.") + else: + logger.info("Polygon is not created as '--skip_creating_polygon' flag passed.") + # Define table from user input if self.variables and not self.infer_schema and not open_ds.attrs['is_normalized']: logger.info('Creating schema from input variables.') @@ -188,7 +231,7 @@ def prepare_coordinates(self, uri: str) -> t.Iterator[t.Tuple[str, t.List[t.Dict logger.info(f'Preparing coordinates for: {uri!r}.') with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization, - self.tif_metadata_for_datetime, is_zarr=self.zarr) as ds: + self.tif_metadata_for_start_time, self.tif_metadata_for_end_time, is_zarr=self.zarr) as ds: data_ds: xr.Dataset = _only_target_vars(ds, self.variables) if self.area: n, w, s, e = self.area @@ -207,69 +250,100 @@ def extract_rows(self, uri: str, coordinates: t.List[t.Dict]) -> t.Iterator[t.Di self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization, - self.tif_metadata_for_datetime, is_zarr=self.zarr) as ds: + self.tif_metadata_for_start_time, self.tif_metadata_for_end_time, is_zarr=self.zarr) as ds: data_ds: xr.Dataset = _only_target_vars(ds, self.variables) + yield from self.to_rows(coordinates, data_ds, uri) - first_ts_raw = data_ds.time[0].values if isinstance(data_ds.time.values, - np.ndarray) else data_ds.time.values - first_time_step = to_json_serializable_type(first_ts_raw) - - for it in coordinates: - # Use those index values to select a Dataset containing one row of data. - row_ds = data_ds.loc[it] - - # Create a Name-Value map for data columns. Result looks like: - # {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None} - row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values)) - for n, v in row_ds.data_vars.items()} - - # Serialize coordinates. - it = {k: to_json_serializable_type(v) for k, v in it.items()} - - # Add indexed coordinates. - row.update(it) - # Add un-indexed coordinates. - for c in row_ds.coords: - if c not in it and (not self.variables or c in self.variables): - row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values)) - - # Add import metadata. - row[DATA_IMPORT_TIME_COLUMN] = self.import_time - row[DATA_URI_COLUMN] = uri - row[DATA_FIRST_STEP] = first_time_step - row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], row['longitude']) - - # 'row' ends up looking like: - # {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812, - # 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...} - beam.metrics.Metrics.counter('Success', 'ExtractRows').inc() - yield row + def to_rows(self, coordinates: t.Iterable[t.Dict], ds: xr.Dataset, uri: str) -> t.Iterator[t.Dict]: + first_ts_raw = ( + ds.time[0].values if isinstance(ds.time.values, np.ndarray) + else ds.time.values + ) + first_time_step = to_json_serializable_type(first_ts_raw) + for it in coordinates: + # Use those index values to select a Dataset containing one row of data. + row_ds = ds.loc[it] + + # Create a Name-Value map for data columns. Result looks like: + # {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None} + row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values)) + for n, v in row_ds.data_vars.items()} + + # Serialize coordinates. + it = {k: to_json_serializable_type(v) for k, v in it.items()} + + # Add indexed coordinates. + row.update(it) + # Add un-indexed coordinates. + for c in row_ds.coords: + if c not in it and (not self.variables or c in self.variables): + row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values)) + + # Add import metadata. + row[DATA_IMPORT_TIME_COLUMN] = self.import_time + row[DATA_URI_COLUMN] = uri + row[DATA_FIRST_STEP] = first_time_step + + longitude = ((row['longitude'] + 180) % 360) - 180 + row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude) + row[GEO_POLYGON_COLUMN] = ( + fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution) + if not self.skip_creating_polygon + else None + ) + # 'row' ends up looking like: + # {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812, + # 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...} + beam.metrics.Metrics.counter('Success', 'ExtractRows').inc() + yield row + + def chunks_to_rows(self, _, ds: xr.Dataset) -> t.Iterator[t.Dict]: + uri = ds.attrs.get(DATA_URI_COLUMN, '') + # Re-calculate import time for streaming extractions. + if not self.import_time or self.zarr: + self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) + yield from self.to_rows(get_coordinates(ds, uri), ds, uri) def expand(self, paths): """Extract rows of variables from data paths into a BigQuery table.""" - extracted_rows = ( + if not self.zarr: + extracted_rows = ( paths | 'PrepareCoordinates' >> beam.FlatMap(self.prepare_coordinates) | beam.Reshuffle() | 'ExtractRows' >> beam.FlatMapTuple(self.extract_rows) - ) - - if not self.dry_run: - ( - extracted_rows - | 'WriteToBigQuery' >> WriteToBigQuery( - project=self.table.project, - dataset=self.table.dataset_id, - table=self.table.table_id, - write_disposition=BigQueryDisposition.WRITE_APPEND, - create_disposition=BigQueryDisposition.CREATE_NEVER) ) else: - ( - extracted_rows - | 'Log Extracted Rows' >> beam.Map(logger.debug) + xarray_open_dataset_kwargs = self.xarray_open_dataset_kwargs.copy() + xarray_open_dataset_kwargs.pop('chunks') + start_date = xarray_open_dataset_kwargs.pop('start_date', None) + end_date = xarray_open_dataset_kwargs.pop('end_date', None) + ds, chunks = xbeam.open_zarr(self.first_uri, **xarray_open_dataset_kwargs) + + if start_date is not None and end_date is not None: + ds = ds.sel(time=slice(start_date, end_date)) + + ds.attrs[DATA_URI_COLUMN] = self.first_uri + extracted_rows = ( + paths + | 'OpenChunks' >> xbeam.DatasetToChunks(ds, chunks) + | 'ExtractRows' >> beam.FlatMapTuple(self.chunks_to_rows) + | 'Window' >> beam.WindowInto(window.FixedWindows(60)) + | 'AddTimestamp' >> beam.Map(timestamp_row) ) + if self.dry_run: + return extracted_rows | 'Log Rows' >> beam.Map(logger.info) + return ( + extracted_rows + | 'WriteToBigQuery' >> WriteToBigQuery( + project=self.table.project, + dataset=self.table.dataset_id, + table=self.table.table_id, + write_disposition=BigQueryDisposition.WRITE_APPEND, + create_disposition=BigQueryDisposition.CREATE_NEVER) + ) + def map_dtype_to_sql_type(var_type: np.dtype) -> str: """Maps a np.dtype to a suitable BigQuery column type.""" @@ -305,14 +379,95 @@ def to_table_schema(columns: t.List[t.Tuple[str, str]]) -> t.List[bigquery.Schem fields.append(bigquery.SchemaField(DATA_URI_COLUMN, 'STRING', mode='NULLABLE')) fields.append(bigquery.SchemaField(DATA_FIRST_STEP, 'TIMESTAMP', mode='NULLABLE')) fields.append(bigquery.SchemaField(GEO_POINT_COLUMN, 'GEOGRAPHY', mode='NULLABLE')) + fields.append(bigquery.SchemaField(GEO_POLYGON_COLUMN, 'GEOGRAPHY', mode='NULLABLE')) return fields +def timestamp_row(it: t.Dict) -> window.TimestampedValue: + """Associate an extracted row with the import_time timestamp.""" + timestamp = it[DATA_IMPORT_TIME_COLUMN].timestamp() + return window.TimestampedValue(it, timestamp) + + def fetch_geo_point(lat: float, long: float) -> str: """Calculates a geography point from an input latitude and longitude.""" if lat > LATITUDE_RANGE[1] or lat < LATITUDE_RANGE[0]: raise ValueError(f"Invalid latitude value '{lat}'") - long = ((long + 180) % 360) - 180 + if long > LONGITUDE_RANGE[1] or long < LONGITUDE_RANGE[0]: + raise ValueError(f"Invalid longitude value '{long}'") point = geojson.dumps(geojson.Point((long, lat))) return point + + +def fetch_geo_polygon(latitude: float, longitude: float, lat_grid_resolution: float, lon_grid_resolution: float) -> str: + """Create a Polygon based on latitude, longitude and resolution. + + Example :: + * - . - * + | | + . • . + | | + * - . - * + In order to create the polygon, we require the `*` point as indicated in the above example. + To determine the position of the `*` point, we find the `.` point. + The `get_lat_lon_range` function gives the `.` point and `bound_point` gives the `*` point. + """ + lat_lon_bound = bound_point(latitude, longitude, lat_grid_resolution, lon_grid_resolution) + polygon = geojson.dumps(geojson.Polygon([[ + (lat_lon_bound[0][0], lat_lon_bound[0][1]), # lower_left + (lat_lon_bound[1][0], lat_lon_bound[1][1]), # upper_left + (lat_lon_bound[2][0], lat_lon_bound[2][1]), # upper_right + (lat_lon_bound[3][0], lat_lon_bound[3][1]), # lower_right + (lat_lon_bound[0][0], lat_lon_bound[0][1]), # lower_left + ]])) + return polygon + + +def bound_point(latitude: float, longitude: float, lat_grid_resolution: float, lon_grid_resolution: float) -> t.List: + """Calculate the bound point based on latitude, longitude and grid resolution. + + Example :: + * - . - * + | | + . • . + | | + * - . - * + This function gives the `*` point in the above example. + """ + lat_in_bound = latitude in [90.0, -90.0] + lon_in_bound = longitude in [-180.0, 180.0] + + lat_range = get_lat_lon_range(latitude, "latitude", lat_in_bound, + lat_grid_resolution, lon_grid_resolution) + lon_range = get_lat_lon_range(longitude, "longitude", lon_in_bound, + lat_grid_resolution, lon_grid_resolution) + lower_left = [lon_range[1], lat_range[1]] + upper_left = [lon_range[1], lat_range[0]] + upper_right = [lon_range[0], lat_range[0]] + lower_right = [lon_range[0], lat_range[1]] + return [lower_left, upper_left, upper_right, lower_right] + + +def get_lat_lon_range(value: float, lat_lon: str, is_point_out_of_bound: bool, + lat_grid_resolution: float, lon_grid_resolution: float) -> t.List: + """Calculate the latitude, longitude point range point latitude, longitude and grid resolution. + + Example :: + * - . - * + | | + . • . + | | + * - . - * + This function gives the `.` point in the above example. + """ + if lat_lon == 'latitude': + if is_point_out_of_bound: + return [-90 + lat_grid_resolution, 90 - lat_grid_resolution] + else: + return [value + lat_grid_resolution, value - lat_grid_resolution] + else: + if is_point_out_of_bound: + return [-180 + lon_grid_resolution, 180 - lon_grid_resolution] + else: + return [value + lon_grid_resolution, value - lon_grid_resolution] diff --git a/weather_mv/loader_pipeline/bq_test.py b/weather_mv/loader_pipeline/bq_test.py index e7a6e5b6..fae7ab31 100644 --- a/weather_mv/loader_pipeline/bq_test.py +++ b/weather_mv/loader_pipeline/bq_test.py @@ -15,6 +15,7 @@ import json import logging import os +import tempfile import typing as t import unittest @@ -23,12 +24,15 @@ import pandas as pd import simplejson import xarray as xr +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that, is_not_empty from google.cloud.bigquery import SchemaField from .bq import ( DEFAULT_IMPORT_TIME, dataset_to_table_schema, fetch_geo_point, + fetch_geo_polygon, ToBigQuery, ) from .sinks_test import TestDataBase, _handle_missing_grib_be @@ -74,6 +78,7 @@ def test_schema_generation(self): SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -89,6 +94,7 @@ def test_schema_generation__with_schema_normalization(self): SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -104,6 +110,7 @@ def test_schema_generation__with_target_columns(self): SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -119,6 +126,7 @@ def test_schema_generation__with_target_columns__with_schema_normalization(self) SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -135,6 +143,7 @@ def test_schema_generation__no_targets_specified(self): SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -151,6 +160,7 @@ def test_schema_generation__no_targets_specified__with_schema_normalization(self SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -184,6 +194,7 @@ def test_schema_generation__non_index_coords(self): SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), + SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) @@ -193,14 +204,18 @@ class ExtractRowsTestBase(TestDataBase): def extract(self, data_path, *, variables=None, area=None, open_dataset_kwargs=None, import_time=DEFAULT_IMPORT_TIME, disable_grib_schema_normalization=False, - tif_metadata_for_datetime=None, zarr: bool = False, zarr_kwargs=None) -> t.Iterator[t.Dict]: + tif_metadata_for_start_time=None, tif_metadata_for_end_time=None, zarr: bool = False, zarr_kwargs=None, + skip_creating_polygon: bool = False) -> t.Iterator[t.Dict]: if zarr_kwargs is None: zarr_kwargs = {} - op = ToBigQuery.from_kwargs(first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs, - output_table='foo.bar.baz', variables=variables, area=area, - xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time, infer_schema=False, - tif_metadata_for_datetime=tif_metadata_for_datetime, skip_region_validation=True, - disable_grib_schema_normalization=disable_grib_schema_normalization, coordinate_chunk_size=1000) + op = ToBigQuery.from_kwargs( + first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs, + output_table='foo.bar.baz', variables=variables, area=area, + xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time, infer_schema=False, + tif_metadata_for_start_time=tif_metadata_for_start_time, + tif_metadata_for_end_time=tif_metadata_for_end_time, skip_region_validation=True, + disable_grib_schema_normalization=disable_grib_schema_normalization, coordinate_chunk_size=1000, + skip_creating_polygon=skip_creating_polygon) coords = op.prepare_coordinates(data_path) for uri, chunk in coords: yield from op.extract_rows(uri, chunk) @@ -233,7 +248,7 @@ def setUp(self) -> None: self.test_data_path = f'{self.test_data_folder}/test_data_20180101.nc' def test_extract_rows(self): - actual = next(self.extract(self.test_data_path)) + actual = next(self.extract(self.test_data_path, skip_creating_polygon=True)) expected = { 'd2m': 242.3035430908203, 'data_import_time': '1970-01-01T00:00:00+00:00', @@ -245,6 +260,7 @@ def test_extract_rows(self): 'u10': 3.4776244163513184, 'v10': 0.03294110298156738, 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))), + 'geo_polygon': None } self.assertRowsEqual(actual, expected) @@ -259,11 +275,15 @@ def test_extract_rows__with_subset_variables(self): 'time': '2018-01-02T06:00:00+00:00', 'u10': 3.4776244163513184, 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (-108.098837, 48.900826), (-108.098837, 49.099174), + (-107.901163, 49.099174), (-107.901163, 48.900826), + (-108.098837, 48.900826)])) } self.assertRowsEqual(actual, expected) def test_extract_rows__specific_area(self): - actual = next(self.extract(self.test_data_path, area=[45, -103, 33, -92])) + actual = next(self.extract(self.test_data_path, area=[45, -103, 33, -92], skip_creating_polygon=True)) expected = { 'd2m': 246.19993591308594, 'data_import_time': '1970-01-01T00:00:00+00:00', @@ -275,6 +295,7 @@ def test_extract_rows__specific_area(self): 'u10': 2.73445987701416, 'v10': 0.08277571201324463, 'geo_point': geojson.dumps(geojson.Point((-103.0, 45.0))), + 'geo_polygon': None } self.assertRowsEqual(actual, expected) @@ -291,6 +312,10 @@ def test_extract_rows__specific_area_float_points(self): 'u10': 3.94743275642395, 'v10': -0.19749987125396729, 'geo_point': geojson.dumps(geojson.Point((-103.400002, 45.200001))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (-103.498839, 45.100827), (-103.498839, 45.299174), + (-103.301164, 45.299174), (-103.301164, 45.100827), + (-103.498839, 45.100827)])) } self.assertRowsEqual(actual, expected) @@ -307,7 +332,11 @@ def test_extract_rows__specify_import_time(self): 'time': '2018-01-02T06:00:00+00:00', 'u10': 3.4776244163513184, 'v10': 0.03294110298156738, - 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))) + 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (-108.098837, 48.900826), (-108.098837, 49.099174), + (-107.901163, 49.099174), (-107.901163, 48.900826), + (-108.098837, 48.900826)])) } self.assertRowsEqual(actual, expected) @@ -324,7 +353,8 @@ def test_extract_rows_single_point(self): 'time': '2018-01-02T06:00:00+00:00', 'u10': 3.4776244163513184, 'v10': 0.03294110298156738, - 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))) + 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))), + 'geo_polygon': None } self.assertRowsEqual(actual, expected) @@ -342,12 +372,16 @@ def test_extract_rows_nan(self): 'u10': None, 'v10': 0.03294110298156738, 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (-108.098837, 48.900826), (-108.098837, 49.099174), + (-107.901163, 49.099174), (-107.901163, 48.900826), + (-108.098837, 48.900826)])) } self.assertRowsEqual(actual, expected) - def test_extract_rows__with_valid_lat_long(self): - valid_lat_long = [[-90, -360], [-90, -359], [-45, -180], [-45, -45], [0, 0], [45, 45], [45, 180], [90, 359], - [90, 360]] + def test_extract_rows__with_valid_lat_long_with_point(self): + valid_lat_long = [[-90, 0], [-90, 1], [-45, -180], [-45, -45], [0, 0], [45, 45], [45, -180], [90, -1], + [90, 0]] actual_val = [ '{"type": "Point", "coordinates": [0, -90]}', '{"type": "Point", "coordinates": [1, -90]}', @@ -364,7 +398,31 @@ def test_extract_rows__with_valid_lat_long(self): expected = fetch_geo_point(lat, long) self.assertEqual(actual, expected) - def test_extract_rows__with_invalid_lat(self): + def test_extract_rows__with_valid_lat_long_with_polygon(self): + valid_lat_long = [[-90, 0], [-90, -180], [-45, -180], [-45, 180], [0, 0], [90, 180], [45, -180], [-90, 180], + [90, 1], [0, 180], [1, -180], [90, -180]] + actual_val = [ + '{"type": "Polygon", "coordinates": [[[-1, 89], [-1, -89], [1, -89], [1, 89], [-1, 89]]]}', + '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}', + '{"type": "Polygon", "coordinates": [[[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]]}', + '{"type": "Polygon", "coordinates": [[[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]]}', + '{"type": "Polygon", "coordinates": [[[-1, -1], [-1, 1], [1, 1], [1, -1], [-1, -1]]]}', + '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}', + '{"type": "Polygon", "coordinates": [[[179, 44], [179, 46], [-179, 46], [-179, 44], [179, 44]]]}', + '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}', + '{"type": "Polygon", "coordinates": [[[0, 89], [0, -89], [2, -89], [2, 89], [0, 89]]]}', + '{"type": "Polygon", "coordinates": [[[179, -1], [179, 1], [-179, 1], [-179, -1], [179, -1]]]}', + '{"type": "Polygon", "coordinates": [[[179, 0], [179, 2], [-179, 2], [-179, 0], [179, 0]]]}', + '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}' + ] + lat_grid_resolution = 1 + lon_grid_resolution = 1 + for actual, (lat, long) in zip(actual_val, valid_lat_long): + with self.subTest(): + expected = fetch_geo_polygon(lat, long, lat_grid_resolution, lon_grid_resolution) + self.assertEqual(actual, expected) + + def test_extract_rows__with_invalid_lat_lon(self): invalid_lat_long = [[-100, -2000], [-100, -500], [100, 500], [100, 2000]] for (lat, long) in invalid_lat_long: with self.subTest(): @@ -384,12 +442,16 @@ def test_extract_rows_zarr(self): 'longitude': 0, 'time': '1959-01-01T00:00:00+00:00', 'geo_point': geojson.dumps(geojson.Point((0.0, 90.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (-0.124913, 89.875173), (-0.124913, -89.875173), + (0.124913, -89.875173), (0.124913, 89.875173), + (-0.124913, 89.875173)])) } self.assertRowsEqual(actual, expected) def test_droping_variable_while_opening_zarr(self): input_path = os.path.join(self.test_data_folder, 'test_data.zarr') - actual = next(self.extract(input_path, zarr=True, zarr_kwargs={ 'drop_variables': ['cape'] })) + actual = next(self.extract(input_path, zarr=True, zarr_kwargs={'drop_variables': ['cape']})) expected = { 'd2m': 237.5404052734375, 'data_import_time': '1970-01-01T00:00:00+00:00', @@ -399,6 +461,10 @@ def test_droping_variable_while_opening_zarr(self): 'longitude': 0, 'time': '1959-01-01T00:00:00+00:00', 'geo_point': geojson.dumps(geojson.Point((0.0, 90.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (-0.124913, 89.875173), (-0.124913, -89.875173), + (0.124913, -89.875173), (0.124913, 89.875173), + (-0.124913, 89.875173)])) } self.assertRowsEqual(actual, expected) @@ -407,10 +473,13 @@ class ExtractRowsTifSupportTest(ExtractRowsTestBase): def setUp(self) -> None: super().setUp() - self.test_data_path = f'{self.test_data_folder}/test_data_tif_start_time.tif' + self.test_data_path = f'{self.test_data_folder}/test_data_tif_time.tif' - def test_extract_rows(self): - actual = next(self.extract(self.test_data_path, tif_metadata_for_datetime='start_time')) + def test_extract_rows_with_end_time(self): + actual = next( + self.extract(self.test_data_path, tif_metadata_for_start_time='start_time', + tif_metadata_for_end_time='end_time') + ) expected = { 'dewpoint_temperature_2m': 281.09349060058594, 'temperature_2m': 296.8329772949219, @@ -420,7 +489,33 @@ def test_extract_rows(self): 'latitude': 42.09783344918844, 'longitude': -123.66686981141397, 'time': '2020-07-01T00:00:00+00:00', - 'geo_point': geojson.dumps(geojson.Point((-123.66687, 42.097833))) + 'valid_time': '2020-07-01T00:00:00+00:00', + 'geo_point': geojson.dumps(geojson.Point((-123.66687, 42.097833))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (-123.669853, 42.095605), (-123.669853, 42.100066), + (-123.663885, 42.100066), (-123.663885, 42.095605), + (-123.669853, 42.095605)])) + } + self.assertRowsEqual(actual, expected) + + def test_extract_rows_without_end_time(self): + actual = next( + self.extract(self.test_data_path, tif_metadata_for_start_time='start_time') + ) + expected = { + 'dewpoint_temperature_2m': 281.09349060058594, + 'temperature_2m': 296.8329772949219, + 'data_import_time': '1970-01-01T00:00:00+00:00', + 'data_first_step': '2020-07-01T00:00:00+00:00', + 'data_uri': self.test_data_path, + 'latitude': 42.09783344918844, + 'longitude': -123.66686981141397, + 'time': '2020-07-01T00:00:00+00:00', + 'geo_point': geojson.dumps(geojson.Point((-123.66687, 42.097833))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (-123.669853, 42.095605), (-123.669853, 42.100066), + (-123.663885, 42.100066), (-123.663885, 42.095605), + (-123.669853, 42.095605)])) } self.assertRowsEqual(actual, expected) @@ -447,6 +542,10 @@ def test_extract_rows(self): 'valid_time': '2021-10-18T06:00:00+00:00', 'z': 1.42578125, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (179.950014, 89.950028), (179.950014, -89.950028), + (-179.950014, -89.950028), (-179.950014, 89.950028), + (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @@ -461,6 +560,10 @@ def test_extract_rows__with_vars__excludes_non_index_coords__without_schema_norm 'longitude': -180.0, 'z': 1.42578125, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (179.950014, 89.950028), (179.950014, -89.950028), + (-179.950014, -89.950028), (-179.950014, 89.950028), + (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @@ -477,6 +580,10 @@ def test_extract_rows__with_vars__includes_coordinates_in_vars__without_schema_n 'step': 0, 'z': 1.42578125, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (179.950014, 89.950028), (179.950014, -89.950028), + (-179.950014, -89.950028), (-179.950014, 89.950028), + (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @@ -491,6 +598,10 @@ def test_extract_rows__with_vars__excludes_non_index_coords__with_schema_normali 'longitude': -180.0, 'surface_0_00_instant_z': 1.42578125, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (179.950014, 89.950028), (179.950014, -89.950028), + (-179.950014, -89.950028), (-179.950014, 89.950028), + (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @@ -506,6 +617,10 @@ def test_extract_rows__with_vars__includes_coordinates_in_vars__with_schema_norm 'step': 0, 'surface_0_00_instant_z': 1.42578125, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (179.950014, 89.950028), (179.950014, -89.950028), + (-179.950014, -89.950028), (-179.950014, 89.950028), + (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @@ -559,6 +674,10 @@ def test_multiple_editions__without_schema_normalization(self): 'v200': -3.6647186279296875, 'valid_time': '2021-12-10T20:00:00+00:00', 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (179.950014, 89.950028), (179.950014, -89.950028), + (-179.950014, -89.950028), (-179.950014, 89.950028), + (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @@ -614,7 +733,11 @@ def test_multiple_editions__with_schema_normalization(self): 'surface_0_00_instant_tprate': 0.0, 'surface_0_00_instant_ceil': 179.17018127441406, 'valid_time': '2021-12-10T20:00:00+00:00', - 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))) + 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (179.950014, 89.950028), (179.950014, -89.950028), + (-179.950014, -89.950028), (-179.950014, 89.950028), + (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @@ -634,9 +757,47 @@ def test_multiple_editions__with_vars__includes_coordinates_in_vars__with_schema 'depthBelowLandLayer_0_00_instant_stl1': 251.02520751953125, 'depthBelowLandLayer_7_00_instant_stl2': 253.54124450683594, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), + 'geo_polygon': geojson.dumps(geojson.Polygon([ + (179.950014, 89.950028), (179.950014, -89.950028), + (-179.950014, -89.950028), (-179.950014, 89.950028), + (179.950014, 89.950028)])) + } self.assertRowsEqual(actual, expected) +class ExtractRowsFromZarrTest(ExtractRowsTestBase): + + def setUp(self) -> None: + super().setUp() + self.tmpdir = tempfile.TemporaryDirectory() + + def tearDown(self) -> None: + super().tearDown() + self.tmpdir.cleanup() + + def test_extracts_rows(self): + input_zarr = os.path.join(self.tmpdir.name, 'air_temp.zarr') + + ds = ( + xr.tutorial.open_dataset('air_temperature', cache_dir=self.test_data_folder) + .isel(time=slice(0, 4), lat=slice(0, 4), lon=slice(0, 4)) + .rename(dict(lon='longitude', lat='latitude')) + ) + ds.to_zarr(input_zarr) + + op = ToBigQuery.from_kwargs( + first_uri=input_zarr, zarr_kwargs=dict(chunks=None, consolidated=True), dry_run=True, zarr=True, + output_table='foo.bar.baz', + variables=list(), area=list(), xarray_open_dataset_kwargs=dict(), import_time=None, infer_schema=False, + tif_metadata_for_start_time=None, tif_metadata_for_end_time=None, skip_region_validation=True, + disable_grib_schema_normalization=False, + ) + + with TestPipeline() as p: + result = p | op + assert_that(result, is_not_empty()) + + if __name__ == '__main__': unittest.main() diff --git a/weather_mv/loader_pipeline/ee.py b/weather_mv/loader_pipeline/ee.py index eb473b0d..73bd0d67 100644 --- a/weather_mv/loader_pipeline/ee.py +++ b/weather_mv/loader_pipeline/ee.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import csv import dataclasses import json import logging +import math import os import re import shutil @@ -27,7 +29,6 @@ import apache_beam as beam import ee import numpy as np -import xarray as xr from apache_beam.io.filesystems import FileSystems from apache_beam.io.gcp.gcsio import WRITE_CHUNK_SIZE from apache_beam.options.pipeline_options import PipelineOptions @@ -36,8 +37,8 @@ from google.auth.transport import requests from rasterio.io import MemoryFile -from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin -from .util import make_attrs_ee_compatible, RateLimit, validate_region +from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin, upload +from .util import make_attrs_ee_compatible, RateLimit, validate_region, get_utc_timestamp logger = logging.getLogger(__name__) @@ -51,6 +52,7 @@ 'IMAGE': '.tiff', 'TABLE': '.csv' } +ROWS_PER_WRITE = 10_000 # Number of rows per feature collection write. def is_compute_engine() -> bool: @@ -155,7 +157,12 @@ def setup(self): def check_setup(self): """Ensures that setup has been called.""" if not self._has_setup: - self.setup() + try: + # This throws an exception if ee is not initialized. + ee.data.getAlgorithms() + self._has_setup = True + except ee.EEException: + self.setup() def process(self, *args, **kwargs): """Checks that setup has been called then call the process implementation.""" @@ -436,6 +443,8 @@ def add_to_queue(self, queue: Queue, item: t.Any): def convert_to_asset(self, queue: Queue, uri: str): """Converts source data into EE asset (GeoTiff or CSV) and uploads it to the bucket.""" logger.info(f'Converting {uri!r} to COGs...') + job_start_time = get_utc_timestamp() + with open_dataset(uri, self.open_dataset_kwargs, self.disable_grib_schema_normalization, @@ -455,6 +464,8 @@ def convert_to_asset(self, queue: Queue, uri: str): ('start_time', 'end_time', 'is_normalized')) dtype, crs, transform = (attrs.pop(key) for key in ['dtype', 'crs', 'transform']) attrs.update({'is_normalized': str(is_normalized)}) # EE properties does not support bool. + # Adding job_start_time to properites. + attrs["job_start_time"] = job_start_time # Make attrs EE ingestable. attrs = make_attrs_ee_compatible(attrs) @@ -496,17 +507,40 @@ def convert_to_asset(self, queue: Queue, uri: str): channel_names = [] file_name = f'{asset_name}.csv' - df = xr.Dataset.to_dataframe(ds) - df = df.reset_index() + shape = math.prod(list(ds.dims.values())) + # Names of dimesions, coordinates and data variables. + dims = list(ds.dims) + coords = [c for c in list(ds.coords) if c not in dims] + vars = list(ds.data_vars) + header = dims + coords + vars + + # Data of dimesions, coordinates and data variables. + dims_data = [ds[dim].data for dim in dims] + coords_data = [np.full((shape,), ds[coord].data) for coord in coords] + vars_data = [ds[var].data.flatten() for var in vars] + data = coords_data + vars_data + + dims_shape = [len(ds[dim].data) for dim in dims] - # Copy in-memory dataframe to gcs. + def get_dims_data(index: int) -> t.List[t.Any]: + """Returns dimensions for the given flattened index.""" + return [ + dim[int(index / math.prod(dims_shape[i+1:])) % len(dim)] for (i, dim) in enumerate(dims_data) + ] + + # Copy CSV to gcs. target_path = os.path.join(self.asset_location, file_name) - with tempfile.NamedTemporaryFile() as tmp_df: - df.to_csv(tmp_df.name, index=False) - tmp_df.flush() - tmp_df.seek(0) - with FileSystems().create(target_path) as dst: - shutil.copyfileobj(tmp_df, dst, WRITE_CHUNK_SIZE) + with tempfile.NamedTemporaryFile() as temp: + with open(temp.name, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerows([header]) + # Write rows in batches. + for i in range(0, shape, ROWS_PER_WRITE): + writer.writerows( + [get_dims_data(i) + list(row) for row in zip(*[d[i:i + ROWS_PER_WRITE] for d in data])] + ) + + upload(temp.name, target_path) asset_data = AssetData( name=asset_name, @@ -612,6 +646,8 @@ def start_ingestion(self, asset_request: t.Dict) -> str: """Creates COG-backed asset in earth engine. Returns the asset id.""" self.check_setup() + asset_request['properties']['ingestion_time'] = get_utc_timestamp() + try: if self.ee_asset_type == 'IMAGE': result = ee.data.createAsset(asset_request) diff --git a/weather_mv/loader_pipeline/pipeline.py b/weather_mv/loader_pipeline/pipeline.py index 10136747..ef685473 100644 --- a/weather_mv/loader_pipeline/pipeline.py +++ b/weather_mv/loader_pipeline/pipeline.py @@ -17,6 +17,7 @@ import json import logging import typing as t +import warnings import apache_beam as beam from apache_beam.io.filesystems import FileSystems @@ -27,6 +28,7 @@ from .streaming import GroupMessagesByFixedWindows, ParsePaths logger = logging.getLogger(__name__) +SDK_CONTAINER_IMAGE = 'gcr.io/weather-tools-prod/weather-tools:0.0.0' def configure_logger(verbosity: int) -> None: @@ -54,8 +56,9 @@ def pipeline(known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None known_args.first_uri = next(iter(all_uris)) with beam.Pipeline(argv=pipeline_args) as p: - if known_args.topic or known_args.subscription: - + if known_args.zarr: + paths = p + elif known_args.topic or known_args.subscription: paths = ( p # Windowing is based on this code sample: @@ -113,6 +116,7 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]: help='Preview the weather-mv job. Default: off') base.add_argument('--log-level', type=int, default=2, help='An integer to configure log level. Default: 2(INFO)') + base.add_argument('--use-local-code', action='store_true', default=False, help='Supply local code to the Runner.') subparsers = parser.add_subparsers(help='help for subcommand', dest='subcommand') @@ -138,11 +142,19 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]: # Validate Zarr arguments if known_args.uris.endswith('.zarr'): known_args.zarr = True - known_args.zarr_kwargs['chunks'] = known_args.zarr_kwargs.get('chunks', None) if known_args.zarr_kwargs and not known_args.zarr: raise ValueError('`--zarr_kwargs` argument is only allowed with valid Zarr input URI.') + if known_args.zarr_kwargs: + if not known_args.zarr_kwargs.get('start_date') or not known_args.zarr_kwargs.get('end_date'): + warnings.warn('`--zarr_kwargs` not contains both `start_date` and `end_date`' + 'so whole zarr-dataset will ingested.') + + if known_args.zarr: + known_args.zarr_kwargs['chunks'] = known_args.zarr_kwargs.get('chunks', None) + known_args.zarr_kwargs['consolidated'] = known_args.zarr_kwargs.get('consolidated', True) + # Validate subcommand if known_args.subcommand == 'bigquery' or known_args.subcommand == 'bq': ToBigQuery.validate_arguments(known_args, pipeline_args) diff --git a/weather_mv/loader_pipeline/pipeline_test.py b/weather_mv/loader_pipeline/pipeline_test.py index 09e53b43..3834b537 100644 --- a/weather_mv/loader_pipeline/pipeline_test.py +++ b/weather_mv/loader_pipeline/pipeline_test.py @@ -30,7 +30,7 @@ def setUp(self) -> None: ).split() self.tif_base_cli_args = ( 'weather-mv bq ' - f'-i {self.test_data_folder}/test_data_tif_start_time.tif ' + f'-i {self.test_data_folder}/test_data_tif_time.tif ' '-o myproject.mydataset.mytable ' '--import_time 2022-02-04T22:22:12.125893 ' '-s' @@ -62,10 +62,13 @@ def setUp(self) -> None: 'xarray_open_dataset_kwargs': {}, 'coordinate_chunk_size': 10_000, 'disable_grib_schema_normalization': False, - 'tif_metadata_for_datetime': None, + 'tif_metadata_for_start_time': None, + 'tif_metadata_for_end_time': None, 'zarr': False, 'zarr_kwargs': {}, 'log_level': 2, + 'use_local_code': False, + 'skip_creating_polygon': False, } @@ -81,7 +84,8 @@ def test_log_level_arg(self): def test_tif_metadata_for_datetime_raise_error_for_non_tif_file(self): with self.assertRaisesRegex(RuntimeError, 'can be specified only for tif files.'): - run(self.base_cli_args + '--tif_metadata_for_datetime start_time'.split()) + run(self.base_cli_args + '--tif_metadata_for_start_time start_time ' + '--tif_metadata_for_end_time end_time'.split()) def test_tif_metadata_for_datetime_raise_error_if_flag_is_absent(self): with self.assertRaisesRegex(RuntimeError, 'is required for tif files.'): diff --git a/weather_mv/loader_pipeline/regrid_test.py b/weather_mv/loader_pipeline/regrid_test.py index 5cc5b2a1..87ffaad4 100644 --- a/weather_mv/loader_pipeline/regrid_test.py +++ b/weather_mv/loader_pipeline/regrid_test.py @@ -122,7 +122,7 @@ def test_zarr__coarsen(self): self.Op, first_uri=input_zarr, output_path=output_zarr, - zarr_input_chunks={"time": 5}, + zarr_input_chunks={"time": 25}, zarr=True ) diff --git a/weather_mv/loader_pipeline/sinks.py b/weather_mv/loader_pipeline/sinks.py index 1ab11263..e755205b 100644 --- a/weather_mv/loader_pipeline/sinks.py +++ b/weather_mv/loader_pipeline/sinks.py @@ -138,8 +138,9 @@ def rearrange_time_list(order_list: t.List, time_list: t.List) -> t.List: return datetime.datetime(*time_list) -def _preprocess_tif(ds: xr.Dataset, filename: str, tif_metadata_for_datetime: str, uri: str, - band_names_dict: t.Dict, initialization_time_regex: str, forecast_time_regex: str) -> xr.Dataset: +def _preprocess_tif(ds: xr.Dataset, filename: str, tif_metadata_for_start_time: str, + tif_metadata_for_end_time: str, uri: str, band_names_dict: t.Dict, + initialization_time_regex: str, forecast_time_regex: str) -> xr.Dataset: """Transforms (y, x) coordinates into (lat, long) and adds bands data in data variables. This also retrieves datetime from tif's metadata and stores it into dataset. @@ -162,6 +163,7 @@ def _replace_dataarray_names_with_long_names(ds: xr.Dataset): ds = _replace_dataarray_names_with_long_names(ds) end_time = None + start_time = None if initialization_time_regex and forecast_time_regex: try: start_time = match_datetime(uri, initialization_time_regex) @@ -174,15 +176,40 @@ def _replace_dataarray_names_with_long_names(ds: xr.Dataset): ds.attrs['start_time'] = start_time ds.attrs['end_time'] = end_time - datetime_value_ms = None + init_time = None + forecast_time = None + coords = {} try: - datetime_value_s = (int(end_time.timestamp()) if end_time is not None - else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0) - ds = ds.assign_coords({'time': datetime.datetime.utcfromtimestamp(datetime_value_s)}) - except KeyError: - raise RuntimeError(f"Invalid datetime metadata of tif: {tif_metadata_for_datetime}.") + # if start_time/end_time is in integer milliseconds + init_time = (int(start_time.timestamp()) if start_time is not None + else int(ds.attrs[tif_metadata_for_start_time]) / 1000.0) + coords['time'] = datetime.datetime.utcfromtimestamp(init_time) + + if tif_metadata_for_end_time: + forecast_time = (int(end_time.timestamp()) if end_time is not None + else int(ds.attrs[tif_metadata_for_end_time]) / 1000.0) + coords['valid_time'] = datetime.datetime.utcfromtimestamp(forecast_time) + + ds = ds.assign_coords(coords) + except KeyError as e: + raise RuntimeError(f"Invalid datetime metadata of tif: {e}.") except ValueError: - raise RuntimeError(f"Invalid datetime value in tif's metadata: {datetime_value_ms}.") + try: + # if start_time/end_time is in UTC string format + init_time = (int(start_time.timestamp()) if start_time is not None + else datetime.datetime.strptime(ds.attrs[tif_metadata_for_start_time], + '%Y-%m-%dT%H:%M:%SZ')) + coords['time'] = init_time + + if tif_metadata_for_end_time: + forecast_time = (int(end_time.timestamp()) if end_time is not None + else datetime.datetime.strptime(ds.attrs[tif_metadata_for_end_time], + '%Y-%m-%dT%H:%M:%SZ')) + coords['valid_time'] = forecast_time + + ds = ds.assign_coords(coords) + except ValueError as e: + raise RuntimeError(f"Invalid datetime value in tif's metadata: {e}.") return ds @@ -349,6 +376,11 @@ def __open_dataset_file(filename: str, False) +def upload(src: str, dst: str) -> None: + """Uploads a file to the specified GCS bucket destination.""" + subprocess.run(f'gsutil -m cp {src} {dst}'.split(), check=True, capture_output=True, text=True, input="n/n") + + def copy(src: str, dst: str) -> None: """Copy data via `gcloud alpha storage` or `gsutil`.""" errors: t.List[subprocess.CalledProcessError] = [] @@ -390,16 +422,26 @@ def open_local(uri: str) -> t.Iterator[str]: def open_dataset(uri: str, open_dataset_kwargs: t.Optional[t.Dict] = None, disable_grib_schema_normalization: bool = False, - tif_metadata_for_datetime: t.Optional[str] = None, group_common_hypercubes: t.Optional[bool] = False, + tif_metadata_for_start_time: t.Optional[str] = None, + tif_metadata_for_end_time: t.Optional[str] = None, band_names_dict: t.Optional[t.Dict] = None, initialization_time_regex: t.Optional[str] = None, forecast_time_regex: t.Optional[str] = None, is_zarr: bool = False) -> t.Iterator[xr.Dataset]: """Open the dataset at 'uri' and return a xarray.Dataset.""" try: + local_open_dataset_kwargs = start_date = end_date = None + if open_dataset_kwargs is not None: + local_open_dataset_kwargs = open_dataset_kwargs.copy() + start_date = local_open_dataset_kwargs.pop('start_date', None) + end_date = local_open_dataset_kwargs.pop('end_date', None) + if is_zarr: - ds: xr.Dataset = xr.open_dataset(uri, engine='zarr', **open_dataset_kwargs) + ds: xr.Dataset = _add_is_normalized_attr(xr.open_dataset(uri, engine='zarr', + **local_open_dataset_kwargs), False) + if start_date is not None and end_date is not None: + ds = ds.sel(time=slice(start_date, end_date)) beam.metrics.Metrics.counter('Success', 'ReadNetcdfData').inc() yield ds ds.close() @@ -409,7 +451,7 @@ def open_dataset(uri: str, xr_datasets: xr.Dataset = __open_dataset_file(local_path, uri_extension, disable_grib_schema_normalization, - open_dataset_kwargs, + local_open_dataset_kwargs, group_common_hypercubes) # Extracting dtype, crs and transform from the dataset. try: @@ -427,16 +469,20 @@ def open_dataset(uri: str, logger.info(f'opened dataset size: {total_size_in_bytes}') else: + if start_date is not None and end_date is not None: + xr_dataset = xr_datasets.sel(time=slice(start_date, end_date)) if uri_extension in ['.tif', '.tiff']: - xr_dataset = _preprocess_tif(xr_datasets, - local_path, - tif_metadata_for_datetime, - uri, - band_names_dict, - initialization_time_regex, - forecast_time_regex) + xr_dataset = _preprocess_tif(xr_dataset, + local_path, + tif_metadata_for_start_time, + tif_metadata_for_end_time, + uri, + band_names_dict, + initialization_time_regex, + forecast_time_regex) else: xr_dataset = xr_datasets + # Extracting dtype, crs and transform from the dataset & storing them as attributes. xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform}) logger.info(f'opened dataset size: {xr_dataset.nbytes}') diff --git a/weather_mv/loader_pipeline/sinks_test.py b/weather_mv/loader_pipeline/sinks_test.py index d9fad23a..0d8838da 100644 --- a/weather_mv/loader_pipeline/sinks_test.py +++ b/weather_mv/loader_pipeline/sinks_test.py @@ -84,7 +84,7 @@ def setUp(self) -> None: super().setUp() self.test_data_path = os.path.join(self.test_data_folder, 'test_data_20180101.nc') self.test_grib_path = os.path.join(self.test_data_folder, 'test_data_grib_single_timestep') - self.test_tif_path = os.path.join(self.test_data_folder, 'test_data_tif_start_time.tif') + self.test_tif_path = os.path.join(self.test_data_folder, 'test_data_tif_time.tif') self.test_zarr_path = os.path.join(self.test_data_folder, 'test_data.zarr') def test_opens_grib_files(self): @@ -104,7 +104,8 @@ def test_accepts_xarray_kwargs(self): self.assertDictContainsSubset({'is_normalized': False}, ds2.attrs) def test_opens_tif_files(self): - with open_dataset(self.test_tif_path, tif_metadata_for_datetime='start_time') as ds: + with open_dataset(self.test_tif_path, tif_metadata_for_start_time='start_time', + tif_metadata_for_end_time='end_time') as ds: self.assertIsNotNone(ds) self.assertDictContainsSubset({'is_normalized': False}, ds.attrs) @@ -112,6 +113,7 @@ def test_opens_zarr(self): with open_dataset(self.test_zarr_path, is_zarr=True, open_dataset_kwargs={}) as ds: self.assertIsNotNone(ds) self.assertEqual(list(ds.data_vars), ['cape', 'd2m']) + def test_open_dataset__fits_memory_bounds(self): with write_netcdf() as test_netcdf_path: with limit_memory(max_memory=30): diff --git a/weather_mv/loader_pipeline/streaming.py b/weather_mv/loader_pipeline/streaming.py index 7210b2e7..3a7a8f49 100644 --- a/weather_mv/loader_pipeline/streaming.py +++ b/weather_mv/loader_pipeline/streaming.py @@ -84,7 +84,7 @@ def try_parse_message(cls, message_body: t.Union[str, t.Dict]) -> t.Dict: try: return json.loads(message_body) except (json.JSONDecodeError, TypeError): - if type(message_body) is dict: + if isinstance(message_body, dict): return message_body raise diff --git a/weather_mv/loader_pipeline/util.py b/weather_mv/loader_pipeline/util.py index a31a06a9..079b86de 100644 --- a/weather_mv/loader_pipeline/util.py +++ b/weather_mv/loader_pipeline/util.py @@ -28,7 +28,6 @@ import uuid from functools import partial from urllib.parse import urlparse - import apache_beam as beam import numpy as np import pandas as pd @@ -134,6 +133,9 @@ def _check_for_coords_vars(ds_data_var: str, target_var: str) -> bool: specified by the user.""" return ds_data_var.endswith('_'+target_var) or ds_data_var.startswith(target_var+'_') +def get_utc_timestamp() -> float: + """Returns the current UTC Timestamp.""" + return datetime.datetime.now().timestamp() def _only_target_coordinate_vars(ds: xr.Dataset, data_vars: t.List[str]) -> t.List[str]: """If the user specifies target fields in the dataset, get all the matching coords & data vars.""" diff --git a/weather_mv/loader_pipeline/util_test.py b/weather_mv/loader_pipeline/util_test.py index dae9c873..65d9169e 100644 --- a/weather_mv/loader_pipeline/util_test.py +++ b/weather_mv/loader_pipeline/util_test.py @@ -38,9 +38,10 @@ def test_gets_indexed_coordinates(self): ds = xr.open_dataset(self.test_data_path) self.assertEqual( next(get_coordinates(ds)), - {'latitude': 49.0, - 'longitude':-108.0, - 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)} + { + 'latitude': 49.0, + 'longitude': -108.0, + 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)} ) def test_no_duplicate_coordinates(self): @@ -91,24 +92,28 @@ def test_get_coordinates(self): actual, [ [ - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None) + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None) }, - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T07:00:00+00:00').replace(tzinfo=None) + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T07:00:00+00:00').replace(tzinfo=None) }, - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T08:00:00+00:00').replace(tzinfo=None) + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T08:00:00+00:00').replace(tzinfo=None) }, ], [ - {'longitude': -108.0, - 'latitude': 49.0, - 'time': datetime.fromisoformat('2018-01-02T09:00:00+00:00').replace(tzinfo=None) - } + { + 'longitude': -108.0, + 'latitude': 49.0, + 'time': datetime.fromisoformat('2018-01-02T09:00:00+00:00').replace(tzinfo=None) + } ] ] ) diff --git a/weather_mv/setup.py b/weather_mv/setup.py index 26201eef..4bdb4a0b 100644 --- a/weather_mv/setup.py +++ b/weather_mv/setup.py @@ -45,6 +45,7 @@ "numpy==1.22.4", "pandas==1.5.1", "xarray==2023.1.0", + "xarray-beam==0.6.2", "cfgrib==0.9.10.2", "netcdf4==1.6.1", "geojson==2.5.0", @@ -55,6 +56,8 @@ "earthengine-api>=0.1.263", "pyproj==3.4.0", # requires separate binary installation! "gdal==3.5.1", # requires separate binary installation! + "gcsfs==2022.11.0", + "zarr==2.15.0", ] setup( @@ -62,7 +65,7 @@ packages=find_packages(), author='Anthromets', author_email='anthromets-ecmwf@google.com', - version='0.2.15', + version='0.2.19', url='https://weather-tools.readthedocs.io/en/latest/weather_mv/', description='A tool to load weather data into BigQuery.', install_requires=beam_gcp_requirements + base_requirements, diff --git a/weather_mv/test_data/test_data_tif_start_time.tif b/weather_mv/test_data/test_data_tif_start_time.tif deleted file mode 100644 index 82f32dd7..00000000 Binary files a/weather_mv/test_data/test_data_tif_start_time.tif and /dev/null differ diff --git a/weather_mv/test_data/test_data_tif_time.tif b/weather_mv/test_data/test_data_tif_time.tif new file mode 100644 index 00000000..32b7f63b Binary files /dev/null and b/weather_mv/test_data/test_data_tif_time.tif differ diff --git a/weather_mv/weather-mv b/weather_mv/weather-mv index e8be4fcd..886c3d1f 100755 --- a/weather_mv/weather-mv +++ b/weather_mv/weather-mv @@ -23,6 +23,8 @@ import tempfile import weather_mv +SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0' + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) @@ -47,9 +49,15 @@ if __name__ == '__main__': except ImportError as e: raise ImportError('please re-install package in a clean python environment.') from e - if '-h' in sys.argv or '--help' in sys.argv or len(sys.argv) == 1: - cli() - else: + args = [] + + if "DataflowRunner" in sys.argv and "--sdk_container_image" not in sys.argv: + args.extend(['--sdk_container_image', + os.getenv('SDK_CONTAINER_IMAGE', SDK_CONTAINER_IMAGE), + '--experiments', + 'use_runner_v2']) + + if "--use-local-code" in sys.argv: with tempfile.TemporaryDirectory() as tmpdir: original_dir = os.getcwd() @@ -70,5 +78,7 @@ if __name__ == '__main__': # cleanup memory to prevent pickling error. tar = None weather_mv = None - - cli(['--extra_package', pkg_archive]) + args.extend(['--extra_package', pkg_archive]) + cli(args) + else: + cli(args) diff --git a/weather_sp/README.md b/weather_sp/README.md index 019217a6..93f81907 100644 --- a/weather_sp/README.md +++ b/weather_sp/README.md @@ -27,6 +27,8 @@ _Common options_: using Python formatting, see [Output section](#output) below. * `-f, --force`: Force re-splitting of the pipeline. Turns of skipping of already split data. * `-d, --dry-run`: Test the input file matching and the output file scheme without splitting. +* `--log-level`: An integer to configure log level. Default: 2(INFO). +* `--use-local-code`: Supply local code to the Runner. Default: False. Invoke with `-h` or `--help` to see the full range of options. @@ -59,6 +61,19 @@ weather-sp --input-pattern 'gs://test-tmp/era5/2015/**' \ --job_name $JOB_NAME ``` +Using DataflowRunner and using local code for pipeline + +```bash +weather-sp --input-pattern 'gs://test-tmp/era5/2015/**' \ + --output-dir 'gs://test-tmp/era5/splits' + --formatting '.{typeOfLevel}' \ + --runner DataflowRunner \ + --project $PROJECT \ + --temp_location gs://$BUCKET/tmp \ + --job_name $JOB_NAME \ + --use-local-code +``` + Using ecCodes-powered grib splitting on Dataflow (this is often more robust, especially when splitting multiple dimensions at once): diff --git a/weather_sp/setup.py b/weather_sp/setup.py index d56b351e..e22279cd 100644 --- a/weather_sp/setup.py +++ b/weather_sp/setup.py @@ -44,7 +44,7 @@ packages=find_packages(), author='Anthromets', author_email='anthromets-ecmwf@google.com', - version='0.3.0', + version='0.3.2', url='https://weather-tools.readthedocs.io/en/latest/weather_sp/', description='A tool to split weather data files into per-variable files.', install_requires=beam_gcp_requirements + base_requirements, diff --git a/weather_sp/splitter_pipeline/file_splitters.py b/weather_sp/splitter_pipeline/file_splitters.py index 6456c426..7a4d3c77 100644 --- a/weather_sp/splitter_pipeline/file_splitters.py +++ b/weather_sp/splitter_pipeline/file_splitters.py @@ -16,6 +16,7 @@ import itertools import logging import os +import re import shutil import string import subprocess @@ -158,6 +159,10 @@ class GribSplitterV2(GribSplitter): See https://confluence.ecmwf.int/display/ECC/grib_copy. """ + def replace_non_numeric_bracket(self, match: re.Match) -> str: + value = match.group(1) + return f"[{value}]" if not value.isdigit() else "{" + value + "}" + def split_data(self) -> None: if not self.output_info.split_dims(): raise ValueError('No splitting specified in template.') @@ -172,7 +177,10 @@ def split_data(self) -> None: unformatted_output_path = self.output_info.unformatted_output_path() prefix, _ = os.path.split(next(iter(string.Formatter().parse(unformatted_output_path)))[0]) _, tail = unformatted_output_path.split(prefix) - output_template = tail.replace('{', '[').replace('}', ']') + + # Replace { with [ and } with ] only for non-numeric values inside {} of tail + output_str = re.sub(r'\{(\w+)\}', self.replace_non_numeric_bracket, tail) + output_template = output_str.format(*self.output_info.template_folders) slash = '/' delimiter = 'DELIMITER' diff --git a/weather_sp/splitter_pipeline/pipeline.py b/weather_sp/splitter_pipeline/pipeline.py index c3dcd470..bbcee909 100644 --- a/weather_sp/splitter_pipeline/pipeline.py +++ b/weather_sp/splitter_pipeline/pipeline.py @@ -26,6 +26,7 @@ from .file_splitters import get_splitter logger = logging.getLogger(__name__) +SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0' def configure_logger(verbosity: int) -> None: @@ -88,6 +89,7 @@ def run(argv: t.List[str], save_main_session: bool = True): ) parser.add_argument('-i', '--input-pattern', type=str, required=True, help='Pattern for input weather data.') + parser.add_argument('--use-local-code', action='store_true', default=False, help='Supply local code to the Runner.') output_options = parser.add_mutually_exclusive_group(required=True) output_options.add_argument( '--output-template', type=str, diff --git a/weather_sp/weather-sp b/weather_sp/weather-sp index a22b75b1..4db3dc20 100755 --- a/weather_sp/weather-sp +++ b/weather_sp/weather-sp @@ -23,6 +23,8 @@ import tempfile import weather_sp +SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0' + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) @@ -50,9 +52,15 @@ if __name__ == '__main__': raise ImportError( 'please re-install package in a clean python environment.') from e - if '-h' in sys.argv or '--help' in sys.argv or len(sys.argv) == 1: - cli() - else: + args = [] + + if "DataflowRunner" in sys.argv and "--sdk_container_image" not in sys.argv: + args.extend(['--sdk_container_image', + os.getenv('SDK_CONTAINER_IMAGE', SDK_CONTAINER_IMAGE), + '--experiments', + 'use_runner_v2']) + + if "--use-local-code" in sys.argv: with tempfile.TemporaryDirectory() as tmpdir: original_dir = os.getcwd() @@ -68,11 +76,12 @@ if __name__ == '__main__': pkg_archive = glob.glob(os.path.join(tmpdir, '*.tar.gz'))[0] with tarfile.open(pkg_archive, 'r') as tar: - assert any([f.endswith('.py') for f in - tar.getnames()]), 'extra_package must include python files!' + assert any([f.endswith('.py') for f in tar.getnames()]), 'extra_package must include python files!' # cleanup memory to prevent pickling error. tar = None weather_sp = None - - cli(['--extra_package', pkg_archive]) + args.extend(['--extra_package', pkg_archive]) + cli(args) + else: + cli(args) \ No newline at end of file