Skip to content

Commit

Permalink
Merge branch 'main' into mv-group-common-hypercubes
Browse files Browse the repository at this point in the history
  • Loading branch information
j9sh264 committed Jul 4, 2023
2 parents 8780a79 + ed62877 commit 437fc58
Show file tree
Hide file tree
Showing 21 changed files with 123 additions and 73 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ jobs:
uses: conda-incubator/setup-miniconda@v2
with:
python-version: ${{ matrix.python-version }}
channels: conda-forge
environment-file: ci${{ matrix.python-version}}.yml
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
activate-environment: weather-tools
- name: Check MetView's installation
shell: bash -l {0}
Expand Down
11 changes: 7 additions & 4 deletions ci3.8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ dependencies:
- eccodes=2.27.0
- requests=2.28.1
- netcdf4=1.6.1
- rioxarray=0.12.2
- rioxarray=0.13.4
- xarray-beam=0.3.1
- ecmwf-api-client=1.6.3
- fsspec=2022.10.0
- fsspec=2022.11.0
- gcsfs=2022.11.0
- gdal=3.5.1
- pyproj=3.4.0
- geojson=2.5.0=py_0
Expand All @@ -28,8 +29,10 @@ dependencies:
- pandas=1.5.1
- pip=22.3
- pygrib=2.1.4
- xarray==2023.1.0
- ruff==0.0.260
- google-cloud-sdk=410.0.0
- aria2=1.36.0
- pip:
- earthengine-api==0.1.329
- xarray==2023.1.0
- ruff==0.0.260
- .[test]
11 changes: 7 additions & 4 deletions ci3.9.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ dependencies:
- eccodes=2.27.0
- requests=2.28.1
- netcdf4=1.6.1
- rioxarray=0.12.2
- rioxarray=0.13.4
- xarray-beam=0.3.1
- ecmwf-api-client=1.6.3
- fsspec=2022.10.0
- fsspec=2022.11.0
- gcsfs=2022.11.0
- gdal=3.5.1
- pyproj=3.4.0
- geojson=2.5.0=py_0
Expand All @@ -28,8 +29,10 @@ dependencies:
- pandas=1.5.1
- pip=22.3
- pygrib=2.1.4
- google-cloud-sdk=410.0.0
- aria2=1.36.0
- xarray==2023.1.0
- ruff==0.0.260
- pip:
- earthengine-api==0.1.329
- xarray==2023.1.0
- ruff==0.0.260
- .[test]
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies:
- xarray=2023.1.0
- fsspec=2022.11.0
- gcsfs=2022.11.0
- rioxarray=0.12.2
- rioxarray=0.13.4
- gdal=3.5.1
- pyproj=3.4.0
- cdsapi=0.5.1
Expand Down
3 changes: 2 additions & 1 deletion weather_dl/download_pipeline/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Fetcher(beam.DoFn):
client_name: str
manifest: Manifest = NoOpManifest(Location('noop://in-memory'))
store: t.Optional[Store] = None
log_level: t.Optional[int] = logging.INFO

def __post_init__(self):
if self.store is None:
Expand All @@ -66,7 +67,7 @@ def fetch_data(self, config: Config, *, worker_name: str = 'default') -> None:
if skip_partition(config, self.store, self.manifest):
return

client = CLIENTS[self.client_name](config)
client = CLIENTS[self.client_name](config, self.log_level)
target = prepare_target_name(config)

with tempfile.NamedTemporaryFile() as temp:
Expand Down
12 changes: 9 additions & 3 deletions weather_dl/download_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ def subsection_and_request(it: Config) -> t.Tuple[str, int]:
(
partitions
| 'GroupBy Request Limits' >> beam.GroupBy(subsection_and_request)
| 'Fetch Data' >> beam.ParDo(Fetcher(args.client_name, args.manifest, args.store))
| 'Fetch Data' >> beam.ParDo(Fetcher(args.client_name,
args.manifest,
args.store,
args.known_args.log_level))
)


Expand Down Expand Up @@ -172,10 +175,12 @@ def run(argv: t.List[str], save_main_session: bool = True) -> PipelineArgs:
help="To enable file skipping logic in dry-run mode. Default: 'false'.")
parser.add_argument('-u', '--update-manifest', action='store_true', default=False,
help="Update the manifest for the already downloaded shards and exit. Default: 'false'.")
parser.add_argument('--log-level', type=int, default=2,
help='An integer to configure log level. Default: 2(INFO)')

known_args, pipeline_args = parser.parse_known_args(argv[1:])

configure_logger(3) # 0 = error, 1 = warn, 2 = info, 3 = debug
configure_logger(known_args.log_level) # 0 = error, 1 = warn, 2 = info, 3 = debug

configs = []
for cfg in known_args.config:
Expand Down Expand Up @@ -223,7 +228,8 @@ def run(argv: t.List[str], save_main_session: bool = True) -> PipelineArgs:
manifest = LocalManifest(Location(local_dir))

num_requesters_per_key = known_args.num_requests_per_key
client = CLIENTS[client_name](configs[0])
known_args.log_level = 40 - known_args.log_level * 10
client = CLIENTS[client_name](configs[0], known_args.log_level)
if num_requesters_per_key == -1:
num_requesters_per_key = client.num_requests_per_key(config.dataset)

Expand Down
3 changes: 2 additions & 1 deletion weather_dl/download_pipeline/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
partition_chunks=None,
schedule='in-order',
check_skip_in_dry_run=False,
update_manifest=False),
update_manifest=False,
log_level=20),
pipeline_options=PipelineOptions('--save_main_session True'.split()),
configs=[Config.from_dict(CONFIG)],
client_name='cds',
Expand Down
5 changes: 5 additions & 0 deletions weather_dl/download_pipeline/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def download_with_aria2(url: str, path: str) -> None:

def generate_hdate(date: str, subtract_year: str) -> str:
"""Generate a historical date by subtracting a specified number of years from the given date.
If input date is leap day (Feb 29), return Feb 28 even if target hdate is also a leap year.
This is expected in ECMWF API.
Args:
date (str): The input date in the format 'YYYY-MM-DD'.
Expand All @@ -213,6 +215,9 @@ def generate_hdate(date: str, subtract_year: str) -> str:
"""
try:
input_date = datetime.datetime.strptime(date, "%Y-%m-%d")
# Check for leap day
if input_date.month == 2 and input_date.day == 29:
input_date = input_date - datetime.timedelta(days=1)
subtract_year = int(subtract_year)
except (ValueError, TypeError):
logger.error("Invalid input.")
Expand Down
10 changes: 10 additions & 0 deletions weather_dl/download_pipeline/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,13 @@ def test_valid_hdate(self):
substract_year = '4'
expected_result = '2016-01-02'
self.assertEqual(generate_hdate(date, substract_year), expected_result)

# Also test for leap day correctness
date = '2020-02-29'
substract_year = '3'
expected_result = '2017-02-28'
self.assertEqual(generate_hdate(date, substract_year), expected_result)

substract_year = '4'
expected_result = '2016-02-28'
self.assertEqual(generate_hdate(date, substract_year), expected_result)
4 changes: 3 additions & 1 deletion weather_mv/loader_pipeline/bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

DEFAULT_IMPORT_TIME = datetime.datetime.utcfromtimestamp(0).replace(tzinfo=datetime.timezone.utc).isoformat()
DATA_IMPORT_TIME_COLUMN = 'data_import_time'
Expand Down Expand Up @@ -224,6 +223,9 @@ def extract_rows(self, uri: str, coordinates: t.List[t.Dict]) -> t.Iterator[t.Di
row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values))
for n, v in row_ds.data_vars.items()}

# Serialize coordinates.
it = {k: to_json_serializable_type(v) for k, v in it.items()}

# Add indexed coordinates.
row.update(it)
# Add un-indexed coordinates.
Expand Down
1 change: 0 additions & 1 deletion weather_mv/loader_pipeline/ee.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from .util import make_attrs_ee_compatible, RateLimit, validate_region

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

COMPUTE_ENGINE_STR = 'Metadata-Flavor: Google'
# For EE ingestion retry logic.
Expand Down
4 changes: 3 additions & 1 deletion weather_mv/loader_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]:
'Default: `{"chunks": null, "consolidated": true}`.')
base.add_argument('-d', '--dry-run', action='store_true', default=False,
help='Preview the weather-mv job. Default: off')
base.add_argument('--log-level', type=int, default=2,
help='An integer to configure log level. Default: 2(INFO)')

subparsers = parser.add_subparsers(help='help for subcommand', dest='subcommand')

Expand All @@ -131,7 +133,7 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]:

known_args, pipeline_args = parser.parse_known_args(argv[1:])

configure_logger(2) # 0 = error, 1 = warn, 2 = info, 3 = debug
configure_logger(known_args.log_level) # 0 = error, 1 = warn, 2 = info, 3 = debug

# Validate Zarr arguments
if known_args.uris.endswith('.zarr'):
Expand Down
5 changes: 5 additions & 0 deletions weather_mv/loader_pipeline/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def setUp(self) -> None:
'tif_metadata_for_datetime': None,
'zarr': False,
'zarr_kwargs': {},
'log_level': 2,
}


Expand All @@ -74,6 +75,10 @@ def test_dry_runs_are_allowed(self):
known_args, _ = run(self.base_cli_args + '--dry-run'.split())
self.assertEqual(known_args.dry_run, True)

def test_log_level_arg(self):
known_args, _ = run(self.base_cli_args + '--log-level 3'.split())
self.assertEqual(known_args.log_level, 3)

def test_tif_metadata_for_datetime_raise_error_for_non_tif_file(self):
with self.assertRaisesRegex(RuntimeError, 'can be specified only for tif files.'):
run(self.base_cli_args + '--tif_metadata_for_datetime start_time'.split())
Expand Down
1 change: 0 additions & 1 deletion weather_mv/loader_pipeline/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from .sinks import ToDataSink, open_local, copy

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

try:
import metview as mv
Expand Down
51 changes: 21 additions & 30 deletions weather_mv/loader_pipeline/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import cfgrib
import numpy as np
import rasterio
import rioxarray
import xarray as xr
from apache_beam.io.filesystem import CompressionTypes, FileSystem, CompressedFile, DEFAULT_READ_BUFFER_SIZE
from pyproj import Transformer
Expand All @@ -40,7 +41,6 @@
DEFAULT_TIME_ORDER_LIST = ['%Y', '%m', '%d', '%H', '%M', '%S']

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class KwargsFactoryMixin:
Expand Down Expand Up @@ -145,14 +145,9 @@ def _preprocess_tif(ds: xr.Dataset, filename: str, tif_metadata_for_datetime: st
This also retrieves datetime from tif's metadata and stores it into dataset.
"""

def _get_band_data(i):
if not band_names_dict:
band = ds.band_data[i]
band.name = ds.band_data.attrs['long_name'][i]
else:
band = ds.band_data
band.name = band_names_dict.get(band.name)
return band
def _replace_dataarray_names_with_long_names(ds: xr.Dataset):
rename_dict = {var_name: ds[var_name].attrs.get('long_name', var_name) for var_name in ds.variables}
return ds.rename(rename_dict)

y, x = np.meshgrid(ds['y'], ds['x'])
transformer = Transformer.from_crs(ds.spatial_ref.crs_wkt, TIF_TRANSFORM_CRS_TO, always_xy=True)
Expand All @@ -162,14 +157,9 @@ def _get_band_data(i):
ds['x'] = lon[:, 0]
ds = ds.rename({'y': 'latitude', 'x': 'longitude'})

band_length = len(ds.band)
ds = ds.squeeze().drop_vars('band').drop_vars('spatial_ref')

band_data_list = [_get_band_data(i) for i in range(band_length)]
ds = ds.squeeze().drop_vars('spatial_ref')

ds_is_normalized_attr = ds.attrs['is_normalized']
ds = xr.merge(band_data_list)
ds.attrs['is_normalized'] = ds_is_normalized_attr
ds = _replace_dataarray_names_with_long_names(ds)

end_time = None
if initialization_time_regex and forecast_time_regex:
Expand All @@ -184,17 +174,15 @@ def _get_band_data(i):
ds.attrs['start_time'] = start_time
ds.attrs['end_time'] = end_time

# TODO(#159): Explore ways to capture required metadata using xarray.
with rasterio.open(filename) as f:
datetime_value_ms = None
try:
datetime_value_s = (int(end_time.timestamp()) if end_time is not None
else int(f.tags()[tif_metadata_for_datetime]) / 1000.0)
ds = ds.assign_coords({'time': datetime.datetime.utcfromtimestamp(datetime_value_s)})
except KeyError:
raise RuntimeError(f"Invalid datetime metadata of tif: {tif_metadata_for_datetime}.")
except ValueError:
raise RuntimeError(f"Invalid datetime value in tif's metadata: {datetime_value_ms}.")
datetime_value_ms = None
try:
datetime_value_s = (int(end_time.timestamp()) if end_time is not None
else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0)
ds = ds.assign_coords({'time': datetime.datetime.utcfromtimestamp(datetime_value_s)})
except KeyError:
raise RuntimeError(f"Invalid datetime metadata of tif: {tif_metadata_for_datetime}.")
except ValueError:
raise RuntimeError(f"Invalid datetime value in tif's metadata: {datetime_value_ms}.")

return ds

Expand Down Expand Up @@ -329,7 +317,7 @@ def __open_dataset_file(filename: str,

# If URI extension is .tif, try opening file by specifying engine="rasterio".
if uri_extension in ['.tif', '.tiff']:
return _add_is_normalized_attr(xr.open_dataset(filename, engine='rasterio'), False)
return _add_is_normalized_attr(rioxarray.open_rasterio(filename, band_as_variable=True), False)

# If no open kwargs are available and URI extension is other than tif, make educated guesses about the dataset.
try:
Expand Down Expand Up @@ -424,8 +412,11 @@ def open_dataset(uri: str,
open_dataset_kwargs,
group_common_hypercubes)
# Extracting dtype, crs and transform from the dataset.
with rasterio.open(local_path, 'r') as f:
dtype, crs, transform = (f.profile.get(key) for key in ['dtype', 'crs', 'transform'])
try:
with rasterio.open(local_path, 'r') as f:
dtype, crs, transform = (f.profile.get(key) for key in ['dtype', 'crs', 'transform'])
except rasterio.errors.RasterioIOError:
logger.warning('Cannot parse projection and data type information for Dataset %r.', uri)

if group_common_hypercubes:
total_size_in_bytes = 0
Expand Down
2 changes: 0 additions & 2 deletions weather_mv/loader_pipeline/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
import apache_beam as beam
from apache_beam.transforms.window import FixedWindows

logging.getLogger().setLevel(logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class GroupMessagesByFixedWindows(beam.PTransform):
Expand Down
12 changes: 5 additions & 7 deletions weather_mv/loader_pipeline/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from .sinks import DEFAULT_COORD_KEYS

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

CANARY_BUCKET_NAME = 'anthromet_canary_bucket'
CANARY_RECORD = {'foo': 'bar'}
Expand Down Expand Up @@ -118,6 +117,8 @@ def to_json_serializable_type(value: t.Any) -> t.Any:

# We assume here that naive timestamps are in UTC timezone.
return value.replace(tzinfo=datetime.timezone.utc).isoformat()
elif isinstance(value, datetime.timedelta):
return value.total_seconds()
elif isinstance(value, np.timedelta64):
# Return time delta in seconds.
return float(value / np.timedelta64(1, 's'))
Expand Down Expand Up @@ -203,14 +204,11 @@ def get_coordinates(ds: xr.Dataset, uri: str = '') -> t.Iterator[t.Dict]:
"""Generates normalized coordinate dictionaries that can be used to index Datasets with `.loc[]`."""
# Creates flattened iterator of all coordinate positions in the Dataset.
#
# Coordinates have been pre-processed to remove NaNs and to format datetime objects
# to ISO format strings.
#
# Example: (-108.0, 49.0, '2018-01-02T22:00:00+00:00')
# Example: (-108.0, 49.0, datetime.datetime('2018-01-02T22:00:00+00:00'))
coords = itertools.product(
*(
(
to_json_serializable_type(v)
v
for v in ensure_us_time_resolution(ds[c].variable.values).tolist()
)
for c in ds.coords.indexes
Expand All @@ -219,7 +217,7 @@ def get_coordinates(ds: xr.Dataset, uri: str = '') -> t.Iterator[t.Dict]:
# Give dictionary keys to a coordinate index.
#
# Example:
# {'longitude': -108.0, 'latitude': 49.0, 'time': '2018-01-02T23:00:00+00:00'}
# {'longitude': -108.0, 'latitude': 49.0, 'time': datetime.datetime('2018-01-02T23:00:00+00:00')}
idx = 0
total_coords = math.prod(ds.coords.dims.values())
for idx, it in enumerate(coords):
Expand Down
Loading

0 comments on commit 437fc58

Please sign in to comment.