From 861dd5feb9b7b552c71f7428cd72654d7e7f8f1f Mon Sep 17 00:00:00 2001 From: aniketinfocusp <122869307+aniketinfocusp@users.noreply.github.com> Date: Thu, 8 Jun 2023 13:25:38 +0530 Subject: [PATCH 1/5] Using `rioxarray` to open `.tif` files (#342) * upgraded rioxarray * updated opendataset for tifs * remvoe rasterio for meta data * updated dataset rename function. * updated function name * updated ci * moved xarray and ruff to conda * updated dependencies in ci * minor fixes * weather_mv bump version --- .github/workflows/ci.yml | 2 +- ci3.8.yml | 11 +++++--- ci3.9.yml | 11 +++++--- environment.yml | 2 +- weather_mv/loader_pipeline/sinks.py | 43 +++++++++++------------------ weather_mv/setup.py | 4 +-- 6 files changed, 34 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 834fbbe0..0325c936 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,8 +53,8 @@ jobs: uses: conda-incubator/setup-miniconda@v2 with: python-version: ${{ matrix.python-version }} + channels: conda-forge environment-file: ci${{ matrix.python-version}}.yml - use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! activate-environment: weather-tools - name: Check MetView's installation shell: bash -l {0} diff --git a/ci3.8.yml b/ci3.8.yml index 6722655e..d6a1e0bd 100644 --- a/ci3.8.yml +++ b/ci3.8.yml @@ -15,10 +15,11 @@ dependencies: - eccodes=2.27.0 - requests=2.28.1 - netcdf4=1.6.1 - - rioxarray=0.12.2 + - rioxarray=0.13.4 - xarray-beam=0.3.1 - ecmwf-api-client=1.6.3 - - fsspec=2022.10.0 + - fsspec=2022.11.0 + - gcsfs=2022.11.0 - gdal=3.5.1 - pyproj=3.4.0 - geojson=2.5.0=py_0 @@ -28,8 +29,10 @@ dependencies: - pandas=1.5.1 - pip=22.3 - pygrib=2.1.4 + - xarray==2023.1.0 + - ruff==0.0.260 + - google-cloud-sdk=410.0.0 + - aria2=1.36.0 - pip: - earthengine-api==0.1.329 - - xarray==2023.1.0 - - ruff==0.0.260 - .[test] diff --git a/ci3.9.yml b/ci3.9.yml index 88c4e084..a43cec16 100644 --- a/ci3.9.yml +++ b/ci3.9.yml @@ -15,10 +15,11 @@ dependencies: - eccodes=2.27.0 - requests=2.28.1 - netcdf4=1.6.1 - - rioxarray=0.12.2 + - rioxarray=0.13.4 - xarray-beam=0.3.1 - ecmwf-api-client=1.6.3 - - fsspec=2022.10.0 + - fsspec=2022.11.0 + - gcsfs=2022.11.0 - gdal=3.5.1 - pyproj=3.4.0 - geojson=2.5.0=py_0 @@ -28,8 +29,10 @@ dependencies: - pandas=1.5.1 - pip=22.3 - pygrib=2.1.4 + - google-cloud-sdk=410.0.0 + - aria2=1.36.0 + - xarray==2023.1.0 + - ruff==0.0.260 - pip: - earthengine-api==0.1.329 - - xarray==2023.1.0 - - ruff==0.0.260 - .[test] diff --git a/environment.yml b/environment.yml index 820270ef..eae35f9c 100644 --- a/environment.yml +++ b/environment.yml @@ -8,7 +8,7 @@ dependencies: - xarray=2023.1.0 - fsspec=2022.11.0 - gcsfs=2022.11.0 - - rioxarray=0.12.2 + - rioxarray=0.13.4 - gdal=3.5.1 - pyproj=3.4.0 - cdsapi=0.5.1 diff --git a/weather_mv/loader_pipeline/sinks.py b/weather_mv/loader_pipeline/sinks.py index 4002eeea..89aaed04 100644 --- a/weather_mv/loader_pipeline/sinks.py +++ b/weather_mv/loader_pipeline/sinks.py @@ -30,6 +30,7 @@ import cfgrib import numpy as np import rasterio +import rioxarray import xarray as xr from apache_beam.io.filesystem import CompressionTypes, FileSystem, CompressedFile, DEFAULT_READ_BUFFER_SIZE from pyproj import Transformer @@ -145,14 +146,9 @@ def _preprocess_tif(ds: xr.Dataset, filename: str, tif_metadata_for_datetime: st This also retrieves datetime from tif's metadata and stores it into dataset. """ - def _get_band_data(i): - if not band_names_dict: - band = ds.band_data[i] - band.name = ds.band_data.attrs['long_name'][i] - else: - band = ds.band_data - band.name = band_names_dict.get(band.name) - return band + def _replace_dataarray_names_with_long_names(ds: xr.Dataset): + rename_dict = {var_name: ds[var_name].attrs.get('long_name', var_name) for var_name in ds.variables} + return ds.rename(rename_dict) y, x = np.meshgrid(ds['y'], ds['x']) transformer = Transformer.from_crs(ds.spatial_ref.crs_wkt, TIF_TRANSFORM_CRS_TO, always_xy=True) @@ -162,14 +158,9 @@ def _get_band_data(i): ds['x'] = lon[:, 0] ds = ds.rename({'y': 'latitude', 'x': 'longitude'}) - band_length = len(ds.band) - ds = ds.squeeze().drop_vars('band').drop_vars('spatial_ref') - - band_data_list = [_get_band_data(i) for i in range(band_length)] + ds = ds.squeeze().drop_vars('spatial_ref') - ds_is_normalized_attr = ds.attrs['is_normalized'] - ds = xr.merge(band_data_list) - ds.attrs['is_normalized'] = ds_is_normalized_attr + ds = _replace_dataarray_names_with_long_names(ds) end_time = None if initialization_time_regex and forecast_time_regex: @@ -184,17 +175,15 @@ def _get_band_data(i): ds.attrs['start_time'] = start_time ds.attrs['end_time'] = end_time - # TODO(#159): Explore ways to capture required metadata using xarray. - with rasterio.open(filename) as f: - datetime_value_ms = None - try: - datetime_value_s = (int(end_time.timestamp()) if end_time is not None - else int(f.tags()[tif_metadata_for_datetime]) / 1000.0) - ds = ds.assign_coords({'time': datetime.datetime.utcfromtimestamp(datetime_value_s)}) - except KeyError: - raise RuntimeError(f"Invalid datetime metadata of tif: {tif_metadata_for_datetime}.") - except ValueError: - raise RuntimeError(f"Invalid datetime value in tif's metadata: {datetime_value_ms}.") + datetime_value_ms = None + try: + datetime_value_s = (int(end_time.timestamp()) if end_time is not None + else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0) + ds = ds.assign_coords({'time': datetime.datetime.utcfromtimestamp(datetime_value_s)}) + except KeyError: + raise RuntimeError(f"Invalid datetime metadata of tif: {tif_metadata_for_datetime}.") + except ValueError: + raise RuntimeError(f"Invalid datetime value in tif's metadata: {datetime_value_ms}.") return ds @@ -306,7 +295,7 @@ def __open_dataset_file(filename: str, # If URI extension is .tif, try opening file by specifying engine="rasterio". if uri_extension in ['.tif', '.tiff']: - return _add_is_normalized_attr(xr.open_dataset(filename, engine='rasterio'), False) + return _add_is_normalized_attr(rioxarray.open_rasterio(filename, band_as_variable=True), False) # If no open kwargs are available and URI extension is other than tif, make educated guesses about the dataset. try: diff --git a/weather_mv/setup.py b/weather_mv/setup.py index a3f43641..26201eef 100644 --- a/weather_mv/setup.py +++ b/weather_mv/setup.py @@ -49,7 +49,7 @@ "netcdf4==1.6.1", "geojson==2.5.0", "simplejson==3.17.6", - "rioxarray==0.12.2", + "rioxarray==0.13.4", "metview==1.13.1", "rasterio==1.3.1", "earthengine-api>=0.1.263", @@ -62,7 +62,7 @@ packages=find_packages(), author='Anthromets', author_email='anthromets-ecmwf@google.com', - version='0.2.14', + version='0.2.15', url='https://weather-tools.readthedocs.io/en/latest/weather_mv/', description='A tool to load weather data into BigQuery.', install_requires=beam_gcp_requirements + base_requirements, From ec9a71422f82d34d7e2a7e5497217e87d0b6afa4 Mon Sep 17 00:00:00 2001 From: Darshan Prajapati <93967637+DarshanSP19@users.noreply.github.com> Date: Mon, 12 Jun 2023 13:22:53 +0530 Subject: [PATCH 2/5] Configure Log Level From CLI (#332) * Configure Log Level From CLI * Update For Other Tools * Lint Correction --------- Co-authored-by: Darshan --- weather_dl/download_pipeline/fetcher.py | 3 ++- weather_dl/download_pipeline/pipeline.py | 12 +++++++++--- weather_dl/download_pipeline/pipeline_test.py | 3 ++- weather_mv/loader_pipeline/bq.py | 1 - weather_mv/loader_pipeline/ee.py | 1 - weather_mv/loader_pipeline/pipeline.py | 4 +++- weather_mv/loader_pipeline/pipeline_test.py | 5 +++++ weather_mv/loader_pipeline/regrid.py | 1 - weather_mv/loader_pipeline/sinks.py | 1 - weather_mv/loader_pipeline/streaming.py | 2 -- weather_mv/loader_pipeline/util.py | 1 - weather_sp/splitter_pipeline/file_splitters.py | 12 ++++++++---- weather_sp/splitter_pipeline/pipeline.py | 14 ++++++++++---- 13 files changed, 39 insertions(+), 21 deletions(-) diff --git a/weather_dl/download_pipeline/fetcher.py b/weather_dl/download_pipeline/fetcher.py index 8be524b0..8830b232 100644 --- a/weather_dl/download_pipeline/fetcher.py +++ b/weather_dl/download_pipeline/fetcher.py @@ -48,6 +48,7 @@ class Fetcher(beam.DoFn): client_name: str manifest: Manifest = NoOpManifest(Location('noop://in-memory')) store: t.Optional[Store] = None + log_level: t.Optional[int] = logging.INFO def __post_init__(self): if self.store is None: @@ -66,7 +67,7 @@ def fetch_data(self, config: Config, *, worker_name: str = 'default') -> None: if skip_partition(config, self.store, self.manifest): return - client = CLIENTS[self.client_name](config) + client = CLIENTS[self.client_name](config, self.log_level) target = prepare_target_name(config) with tempfile.NamedTemporaryFile() as temp: diff --git a/weather_dl/download_pipeline/pipeline.py b/weather_dl/download_pipeline/pipeline.py index 30aec86d..fa4983e2 100644 --- a/weather_dl/download_pipeline/pipeline.py +++ b/weather_dl/download_pipeline/pipeline.py @@ -125,7 +125,10 @@ def subsection_and_request(it: Config) -> t.Tuple[str, int]: ( partitions | 'GroupBy Request Limits' >> beam.GroupBy(subsection_and_request) - | 'Fetch Data' >> beam.ParDo(Fetcher(args.client_name, args.manifest, args.store)) + | 'Fetch Data' >> beam.ParDo(Fetcher(args.client_name, + args.manifest, + args.store, + args.known_args.log_level)) ) @@ -172,10 +175,12 @@ def run(argv: t.List[str], save_main_session: bool = True) -> PipelineArgs: help="To enable file skipping logic in dry-run mode. Default: 'false'.") parser.add_argument('-u', '--update-manifest', action='store_true', default=False, help="Update the manifest for the already downloaded shards and exit. Default: 'false'.") + parser.add_argument('--log-level', type=int, default=2, + help='An integer to configure log level. Default: 2(INFO)') known_args, pipeline_args = parser.parse_known_args(argv[1:]) - configure_logger(3) # 0 = error, 1 = warn, 2 = info, 3 = debug + configure_logger(known_args.log_level) # 0 = error, 1 = warn, 2 = info, 3 = debug configs = [] for cfg in known_args.config: @@ -223,7 +228,8 @@ def run(argv: t.List[str], save_main_session: bool = True) -> PipelineArgs: manifest = LocalManifest(Location(local_dir)) num_requesters_per_key = known_args.num_requests_per_key - client = CLIENTS[client_name](configs[0]) + known_args.log_level = 40 - known_args.log_level * 10 + client = CLIENTS[client_name](configs[0], known_args.log_level) if num_requesters_per_key == -1: num_requesters_per_key = client.num_requests_per_key(config.dataset) diff --git a/weather_dl/download_pipeline/pipeline_test.py b/weather_dl/download_pipeline/pipeline_test.py index 40b7ecea..9e18a4bc 100644 --- a/weather_dl/download_pipeline/pipeline_test.py +++ b/weather_dl/download_pipeline/pipeline_test.py @@ -57,7 +57,8 @@ partition_chunks=None, schedule='in-order', check_skip_in_dry_run=False, - update_manifest=False), + update_manifest=False, + log_level=20), pipeline_options=PipelineOptions('--save_main_session True'.split()), configs=[Config.from_dict(CONFIG)], client_name='cds', diff --git a/weather_mv/loader_pipeline/bq.py b/weather_mv/loader_pipeline/bq.py index 10f5aeb6..c62a705b 100644 --- a/weather_mv/loader_pipeline/bq.py +++ b/weather_mv/loader_pipeline/bq.py @@ -39,7 +39,6 @@ ) logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) DEFAULT_IMPORT_TIME = datetime.datetime.utcfromtimestamp(0).replace(tzinfo=datetime.timezone.utc).isoformat() DATA_IMPORT_TIME_COLUMN = 'data_import_time' diff --git a/weather_mv/loader_pipeline/ee.py b/weather_mv/loader_pipeline/ee.py index b5a211d4..d12c320c 100644 --- a/weather_mv/loader_pipeline/ee.py +++ b/weather_mv/loader_pipeline/ee.py @@ -40,7 +40,6 @@ from .util import make_attrs_ee_compatible, RateLimit, validate_region logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) COMPUTE_ENGINE_STR = 'Metadata-Flavor: Google' # For EE ingestion retry logic. diff --git a/weather_mv/loader_pipeline/pipeline.py b/weather_mv/loader_pipeline/pipeline.py index e8d02ccd..10136747 100644 --- a/weather_mv/loader_pipeline/pipeline.py +++ b/weather_mv/loader_pipeline/pipeline.py @@ -111,6 +111,8 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]: 'Default: `{"chunks": null, "consolidated": true}`.') base.add_argument('-d', '--dry-run', action='store_true', default=False, help='Preview the weather-mv job. Default: off') + base.add_argument('--log-level', type=int, default=2, + help='An integer to configure log level. Default: 2(INFO)') subparsers = parser.add_subparsers(help='help for subcommand', dest='subcommand') @@ -131,7 +133,7 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]: known_args, pipeline_args = parser.parse_known_args(argv[1:]) - configure_logger(2) # 0 = error, 1 = warn, 2 = info, 3 = debug + configure_logger(known_args.log_level) # 0 = error, 1 = warn, 2 = info, 3 = debug # Validate Zarr arguments if known_args.uris.endswith('.zarr'): diff --git a/weather_mv/loader_pipeline/pipeline_test.py b/weather_mv/loader_pipeline/pipeline_test.py index c15f6f12..09e53b43 100644 --- a/weather_mv/loader_pipeline/pipeline_test.py +++ b/weather_mv/loader_pipeline/pipeline_test.py @@ -65,6 +65,7 @@ def setUp(self) -> None: 'tif_metadata_for_datetime': None, 'zarr': False, 'zarr_kwargs': {}, + 'log_level': 2, } @@ -74,6 +75,10 @@ def test_dry_runs_are_allowed(self): known_args, _ = run(self.base_cli_args + '--dry-run'.split()) self.assertEqual(known_args.dry_run, True) + def test_log_level_arg(self): + known_args, _ = run(self.base_cli_args + '--log-level 3'.split()) + self.assertEqual(known_args.log_level, 3) + def test_tif_metadata_for_datetime_raise_error_for_non_tif_file(self): with self.assertRaisesRegex(RuntimeError, 'can be specified only for tif files.'): run(self.base_cli_args + '--tif_metadata_for_datetime start_time'.split()) diff --git a/weather_mv/loader_pipeline/regrid.py b/weather_mv/loader_pipeline/regrid.py index d0f290d9..38d016c2 100644 --- a/weather_mv/loader_pipeline/regrid.py +++ b/weather_mv/loader_pipeline/regrid.py @@ -31,7 +31,6 @@ from .sinks import ToDataSink, open_local, copy logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) try: import metview as mv diff --git a/weather_mv/loader_pipeline/sinks.py b/weather_mv/loader_pipeline/sinks.py index 89aaed04..ca6b0569 100644 --- a/weather_mv/loader_pipeline/sinks.py +++ b/weather_mv/loader_pipeline/sinks.py @@ -41,7 +41,6 @@ DEFAULT_TIME_ORDER_LIST = ['%Y', '%m', '%d', '%H', '%M', '%S'] logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) class KwargsFactoryMixin: diff --git a/weather_mv/loader_pipeline/streaming.py b/weather_mv/loader_pipeline/streaming.py index eaf407c1..7210b2e7 100644 --- a/weather_mv/loader_pipeline/streaming.py +++ b/weather_mv/loader_pipeline/streaming.py @@ -28,9 +28,7 @@ import apache_beam as beam from apache_beam.transforms.window import FixedWindows -logging.getLogger().setLevel(logging.INFO) logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) class GroupMessagesByFixedWindows(beam.PTransform): diff --git a/weather_mv/loader_pipeline/util.py b/weather_mv/loader_pipeline/util.py index a9e1454b..ed79c39b 100644 --- a/weather_mv/loader_pipeline/util.py +++ b/weather_mv/loader_pipeline/util.py @@ -41,7 +41,6 @@ from .sinks import DEFAULT_COORD_KEYS logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) CANARY_BUCKET_NAME = 'anthromet_canary_bucket' CANARY_RECORD = {'foo': 'bar'} diff --git a/weather_sp/splitter_pipeline/file_splitters.py b/weather_sp/splitter_pipeline/file_splitters.py index b2fc5f1d..6456c426 100644 --- a/weather_sp/splitter_pipeline/file_splitters.py +++ b/weather_sp/splitter_pipeline/file_splitters.py @@ -292,10 +292,14 @@ def _get_keys(self) -> t.Dict[str, str]: return {name: name for name in self.output_info.split_dims()} -def get_splitter(file_path: str, output_info: OutFileInfo, dry_run: bool, force_split: bool = False) -> FileSplitter: +def get_splitter(file_path: str, + output_info: OutFileInfo, + dry_run: bool, + force_split: bool = False, + logging_level: int = logging.INFO) -> FileSplitter: if dry_run: logger.info('Using splitter: DrySplitter') - return DrySplitter(file_path, output_info) + return DrySplitter(file_path, output_info, logging_level=logging_level) with FileSystems.open(file_path) as f: header = f.read(4) @@ -309,10 +313,10 @@ def get_splitter(file_path: str, output_info: OutFileInfo, dry_run: bool, force_ cmd = shutil.which('grib_copy') if cmd: logger.info('Using splitter: GribSplitterV2') - return GribSplitterV2(file_path, output_info, force_split) + return GribSplitterV2(file_path, output_info, force_split, logging_level) else: logger.info('Using splitter: GribSplitter') - return GribSplitter(file_path, output_info, force_split) + return GribSplitter(file_path, output_info, force_split, logging_level) # See the NetCDF Spec docs: # https://docs.unidata.ucar.edu/netcdf-c/current/faq.html#How-can-I-tell-which-format-a-netCDF-file-uses diff --git a/weather_sp/splitter_pipeline/pipeline.py b/weather_sp/splitter_pipeline/pipeline.py index da14ce3e..c3dcd470 100644 --- a/weather_sp/splitter_pipeline/pipeline.py +++ b/weather_sp/splitter_pipeline/pipeline.py @@ -41,7 +41,8 @@ def split_file(input_file: str, output_dir: t.Optional[str], formatting: str, dry_run: bool, - force_split: bool = False): + force_split: bool = False, + logging_level: int = logging.INFO): output_base_name = get_output_base_name(input_path=input_file, input_base=input_base_dir, output_template=output_template, @@ -50,10 +51,12 @@ def split_file(input_file: str, logger.info('Splitting file: %s. Output base name: %s', input_file, output_base_name) metrics.Metrics.counter('pipeline', 'splitting file').inc() + level = 40 - logging_level * 10 splitter = get_splitter(input_file, output_base_name, dry_run, - force_split) + force_split, + level) splitter.split_data() @@ -113,9 +116,11 @@ def run(argv: t.List[str], save_main_session: bool = True): help='Test the input file matching and the output file scheme without splitting.') parser.add_argument('-f', '--force', action='store_true', default=False, help='Force re-splitting of the pipeline. Turns of skipping of already split data.') + parser.add_argument('--log-level', type=int, default=2, + help='An integer to configure log level. Default: 2(INFO)') known_args, pipeline_args = parser.parse_known_args(argv[1:]) - configure_logger(2) # 0 = error, 1 = warn, 2 = info, 3 = debug + configure_logger(known_args.log_level) # 0 = error, 1 = warn, 2 = info, 3 = debug pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = save_main_session @@ -153,5 +158,6 @@ def run(argv: t.List[str], save_main_session: bool = True): output_dir, formatting, dry_run, - known_args.force) + known_args.force, + known_args.log_level) ) From e92997206aff52149624b1947c857f2a7bc8169a Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Thu, 22 Jun 2023 07:44:34 -0400 Subject: [PATCH 3/5] Log warning instead at failure to parse projection info with rasterio. (#348) We use rasterio on open datasets to gather projection information and the datatype of data opened with `open_dataset()`. In cases where rasterio (or GDAL) cannot parse the URI, let's log a warning instead of crashing the pipeline. When the CRS info is needed (in the `weather-mv ee` pipeline), users can pass this in via a custom COG conversion transform. --- weather_mv/loader_pipeline/sinks.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/weather_mv/loader_pipeline/sinks.py b/weather_mv/loader_pipeline/sinks.py index ca6b0569..0f8cd561 100644 --- a/weather_mv/loader_pipeline/sinks.py +++ b/weather_mv/loader_pipeline/sinks.py @@ -396,9 +396,12 @@ def open_dataset(uri: str, forecast_time_regex) # Extracting dtype, crs and transform from the dataset & storing them as attributes. - with rasterio.open(local_path, 'r') as f: - dtype, crs, transform = (f.profile.get(key) for key in ['dtype', 'crs', 'transform']) - xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform}) + try: + with rasterio.open(local_path, 'r') as f: + dtype, crs, transform = (f.profile.get(key) for key in ['dtype', 'crs', 'transform']) + xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform}) + except rasterio.errors.RasterioIOError: + logger.warning('Cannot parse projection and data type information for Dataset %r.', uri) logger.info(f'opened dataset size: {xr_dataset.nbytes}') From e49b1693827bb1865c6ce811f92ca03da04ba38e Mon Sep 17 00:00:00 2001 From: aniketinfocusp <122869307+aniketinfocusp@users.noreply.github.com> Date: Fri, 23 Jun 2023 18:14:02 +0530 Subject: [PATCH 4/5] Added timedelta to json serialzation. (#349) * Added timedelta to json serialzation. * Added testcase for timedelta in json serializer * Fixed lint issues --- weather_mv/loader_pipeline/util.py | 2 ++ weather_mv/loader_pipeline/util_test.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/weather_mv/loader_pipeline/util.py b/weather_mv/loader_pipeline/util.py index ed79c39b..8c8c6a54 100644 --- a/weather_mv/loader_pipeline/util.py +++ b/weather_mv/loader_pipeline/util.py @@ -117,6 +117,8 @@ def to_json_serializable_type(value: t.Any) -> t.Any: # We assume here that naive timestamps are in UTC timezone. return value.replace(tzinfo=datetime.timezone.utc).isoformat() + elif isinstance(value, datetime.timedelta): + return value.total_seconds() elif isinstance(value, np.timedelta64): # Return time delta in seconds. return float(value / np.timedelta64(1, 's')) diff --git a/weather_mv/loader_pipeline/util_test.py b/weather_mv/loader_pipeline/util_test.py index ddaff050..e3b1b4ed 100644 --- a/weather_mv/loader_pipeline/util_test.py +++ b/weather_mv/loader_pipeline/util_test.py @@ -14,7 +14,7 @@ import itertools import unittest from collections import Counter -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta import xarray import xarray as xr @@ -240,3 +240,6 @@ def test_to_json_serializable_type_datetime(self): self.assertEqual(self._convert(np.datetime64(1, 'Y')), '1971-01-01T00:00:00+00:00') self.assertEqual(self._convert(np.datetime64(30, 'Y')), input_date) self.assertEqual(self._convert(np.timedelta64(1, 'm')), float(60)) + self.assertEqual(self._convert(timedelta(seconds=1)), float(1)) + self.assertEqual(self._convert(timedelta(minutes=1)), float(60)) + self.assertEqual(self._convert(timedelta(days=1)), float(86400)) From edd742d591cb428aa41f6c4d713c969e39f8dbbe Mon Sep 17 00:00:00 2001 From: Stephan Rasp Date: Mon, 26 Jun 2023 17:46:42 +0200 Subject: [PATCH 5/5] Small fix to return correct hdate for leap days (#352) * Small fix to return correct hdate for leap days * Formatting --- weather_dl/download_pipeline/util.py | 5 +++++ weather_dl/download_pipeline/util_test.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/weather_dl/download_pipeline/util.py b/weather_dl/download_pipeline/util.py index 1654c958..1ee9e24e 100644 --- a/weather_dl/download_pipeline/util.py +++ b/weather_dl/download_pipeline/util.py @@ -203,6 +203,8 @@ def download_with_aria2(url: str, path: str) -> None: def generate_hdate(date: str, subtract_year: str) -> str: """Generate a historical date by subtracting a specified number of years from the given date. + If input date is leap day (Feb 29), return Feb 28 even if target hdate is also a leap year. + This is expected in ECMWF API. Args: date (str): The input date in the format 'YYYY-MM-DD'. @@ -213,6 +215,9 @@ def generate_hdate(date: str, subtract_year: str) -> str: """ try: input_date = datetime.datetime.strptime(date, "%Y-%m-%d") + # Check for leap day + if input_date.month == 2 and input_date.day == 29: + input_date = input_date - datetime.timedelta(days=1) subtract_year = int(subtract_year) except (ValueError, TypeError): logger.error("Invalid input.") diff --git a/weather_dl/download_pipeline/util_test.py b/weather_dl/download_pipeline/util_test.py index 6bc37487..fdaa2393 100644 --- a/weather_dl/download_pipeline/util_test.py +++ b/weather_dl/download_pipeline/util_test.py @@ -89,3 +89,13 @@ def test_valid_hdate(self): substract_year = '4' expected_result = '2016-01-02' self.assertEqual(generate_hdate(date, substract_year), expected_result) + + # Also test for leap day correctness + date = '2020-02-29' + substract_year = '3' + expected_result = '2017-02-28' + self.assertEqual(generate_hdate(date, substract_year), expected_result) + + substract_year = '4' + expected_result = '2016-02-28' + self.assertEqual(generate_hdate(date, substract_year), expected_result)