Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weather-mv will ingest data into BQ from Zarr much faster. #357

Merged
merged 31 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b095cef
Fixed issues found loading Zarr into BQ.
alxmrs Jun 24, 2023
43f2374
Base weather-tools install requires gcsfs.
alxmrs Jun 24, 2023
d6487b0
Not normalized by default.
alxmrs Jun 24, 2023
be077ff
Parallel Zarr ingestion into BQ.
alxmrs Jun 24, 2023
0226ccb
Fix setup.py syntax error.
alxmrs Jun 24, 2023
3c57acb
Fixing Zarr + Xarray-Beam support.
alxmrs Jun 24, 2023
14601e4
Added happy path unit test for parallel zarr reading in BQ.
alxmrs Jun 25, 2023
9eedf3f
fix flake8 issues.
alxmrs Jun 25, 2023
88790a6
Better whitespace.
alxmrs Jun 25, 2023
64343b9
Adding open_ds kwargs to open zarr.
alxmrs Jun 25, 2023
3051917
Attempting to fix pickling issues.
alxmrs Jun 25, 2023
d955562
Another attempt to fix pickling error, now in transform.
alxmrs Jun 25, 2023
b1870a5
Experiment: is xbeam.open_zarr the issue?
alxmrs Jun 25, 2023
5146f64
adding engine=zarr.
alxmrs Jun 25, 2023
62ed509
open_zarr --> open_dataset w/ engine.
alxmrs Jun 25, 2023
18eb433
delete regrid
alxmrs Jun 26, 2023
868a0cf
Pinned Zarr version.
alxmrs Jun 26, 2023
9b6329d
Hard coded current CL for docker image.
alxmrs Jun 26, 2023
5cb6111
rm unnecessary delete.
alxmrs Jun 27, 2023
e4806db
Only recent years.
alxmrs Jun 27, 2023
239cfde
All data w/ streaming inserts.
alxmrs Jun 27, 2023
d9a9368
Experiment: added windowing.
alxmrs Jun 27, 2023
333d98a
Documented `timestamp_row` fn.
alxmrs Jun 29, 2023
c8c82de
Self-review: Prepared changes for PR.
alxmrs Jun 29, 2023
e3eda36
Small cleanup.
alxmrs Jun 29, 2023
cbd674c
Remove debug isel.
alxmrs Jul 11, 2023
c75ca93
Added types to `to_rows()`.
alxmrs Jul 11, 2023
1e8e4be
Fixed flake8 lint errors.
alxmrs Jul 11, 2023
fa99904
Better types for `to_rows()`.
alxmrs Jul 11, 2023
2cf9a28
Test updated and 'chunks' removed from zarr_kwargs
dabhicusp Jul 13, 2023
f2608e2
Zarr version updated.
dabhicusp Jul 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci3.8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- requests=2.28.1
- netcdf4=1.6.1
- rioxarray=0.13.4
- xarray-beam=0.3.1
- xarray-beam=0.6.2
- ecmwf-api-client=1.6.3
- fsspec=2022.11.0
- gcsfs=2022.11.0
Expand All @@ -33,6 +33,7 @@ dependencies:
- ruff==0.0.260
- google-cloud-sdk=410.0.0
- aria2=1.36.0
- zarr=2.15.0
- pip:
- earthengine-api==0.1.329
- .[test]
3 changes: 2 additions & 1 deletion ci3.9.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- requests=2.28.1
- netcdf4=1.6.1
- rioxarray=0.13.4
- xarray-beam=0.3.1
- xarray-beam=0.6.2
- ecmwf-api-client=1.6.3
- fsspec=2022.11.0
- gcsfs=2022.11.0
Expand All @@ -33,6 +33,7 @@ dependencies:
- aria2=1.36.0
- xarray==2023.1.0
- ruff==0.0.260
- zarr=2.15.0
- pip:
- earthengine-api==0.1.329
- .[test]
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
dependencies:
- python=3.8.13
- apache-beam=2.40.0
- xarray-beam=0.3.1
- xarray-beam=0.6.2
- xarray=2023.1.0
- fsspec=2022.11.0
- gcsfs=2022.11.0
Expand All @@ -25,6 +25,7 @@ dependencies:
- google-cloud-sdk=410.0.0
- aria2=1.36.0
- pip=22.3
- zarr=2.15.0
- pip:
- earthengine-api==0.1.329
- firebase-admin==6.0.1
Expand Down
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@
"earthengine-api>=0.1.263",
"pyproj", # requires separate binary installation!
"gdal", # requires separate binary installation!
"xarray-beam==0.3.1",
"xarray-beam==0.6.2",
"gcsfs==2022.11.0",
"zarr==2.15.0",
]

weather_sp_requirements = [
Expand All @@ -82,6 +83,7 @@
"memray",
"pytest-memray",
"h5py",
"pooch",
]

all_test_requirements = beam_gcp_requirements + weather_dl_requirements + \
Expand Down Expand Up @@ -115,7 +117,7 @@

],
python_requires='>=3.8, <3.10',
install_requires=['apache-beam[gcp]==2.40.0'],
install_requires=['apache-beam[gcp]==2.40.0', 'gcsfs==2022.11.0'],
use_scm_version=True,
setup_requires=['setuptools_scm'],
scripts=['weather_dl/weather-dl', 'weather_mv/weather-mv', 'weather_sp/weather-sp'],
Expand Down
135 changes: 80 additions & 55 deletions weather_mv/loader_pipeline/bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import geojson
import numpy as np
import xarray as xr
import xarray_beam as xbeam
from apache_beam.io import WriteToBigQuery, BigQueryDisposition
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.transforms import window
from google.cloud import bigquery
from xarray.core.utils import ensure_us_time_resolution

Expand Down Expand Up @@ -236,73 +238,90 @@ def extract_rows(self, uri: str, coordinates: t.List[t.Dict]) -> t.Iterator[t.Di
with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization,
self.tif_metadata_for_datetime, is_zarr=self.zarr) as ds:
data_ds: xr.Dataset = _only_target_vars(ds, self.variables)
yield from self.to_rows(coordinates, data_ds, uri)

first_ts_raw = data_ds.time[0].values if isinstance(data_ds.time.values,
np.ndarray) else data_ds.time.values
first_time_step = to_json_serializable_type(first_ts_raw)

for it in coordinates:
# Use those index values to select a Dataset containing one row of data.
row_ds = data_ds.loc[it]

# Create a Name-Value map for data columns. Result looks like:
# {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None}
row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values))
for n, v in row_ds.data_vars.items()}

# Serialize coordinates.
it = {k: to_json_serializable_type(v) for k, v in it.items()}

# Add indexed coordinates.
row.update(it)
# Add un-indexed coordinates.
for c in row_ds.coords:
if c not in it and (not self.variables or c in self.variables):
row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values))

# Add import metadata.
row[DATA_IMPORT_TIME_COLUMN] = self.import_time
row[DATA_URI_COLUMN] = uri
row[DATA_FIRST_STEP] = first_time_step

longitude = ((row['longitude'] + 180) % 360) - 180
row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude)
row[GEO_POLYGON_COLUMN] = (
fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution)
if not self.skip_creating_polygon
else None
)
# 'row' ends up looking like:
# {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812,
# 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...}
beam.metrics.Metrics.counter('Success', 'ExtractRows').inc()
yield row
def to_rows(self, coordinates: t.Iterable[t.Dict], ds: xr.Dataset, uri: str) -> t.Iterator[t.Dict]:
first_ts_raw = (
ds.time[0].values if isinstance(ds.time.values, np.ndarray)
else ds.time.values
)
first_time_step = to_json_serializable_type(first_ts_raw)
for it in coordinates:
# Use those index values to select a Dataset containing one row of data.
row_ds = ds.loc[it]

# Create a Name-Value map for data columns. Result looks like:
# {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None}
row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values))
for n, v in row_ds.data_vars.items()}

# Serialize coordinates.
it = {k: to_json_serializable_type(v) for k, v in it.items()}

# Add indexed coordinates.
row.update(it)
# Add un-indexed coordinates.
for c in row_ds.coords:
if c not in it and (not self.variables or c in self.variables):
row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values))

# Add import metadata.
row[DATA_IMPORT_TIME_COLUMN] = self.import_time
row[DATA_URI_COLUMN] = uri
row[DATA_FIRST_STEP] = first_time_step

longitude = ((row['longitude'] + 180) % 360) - 180
row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude)
row[GEO_POLYGON_COLUMN] = (
fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution)
if not self.skip_creating_polygon
else None
)
# 'row' ends up looking like:
# {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812,
# 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...}
beam.metrics.Metrics.counter('Success', 'ExtractRows').inc()
yield row

def chunks_to_rows(self, _, ds: xr.Dataset) -> t.Iterator[t.Dict]:
uri = ds.attrs.get(DATA_URI_COLUMN, '')
# Re-calculate import time for streaming extractions.
if not self.import_time or self.zarr:
self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
yield from self.to_rows(get_coordinates(ds, uri), ds, uri)

def expand(self, paths):
"""Extract rows of variables from data paths into a BigQuery table."""
extracted_rows = (
if not self.zarr:
extracted_rows = (
paths
| 'PrepareCoordinates' >> beam.FlatMap(self.prepare_coordinates)
| beam.Reshuffle()
| 'ExtractRows' >> beam.FlatMapTuple(self.extract_rows)
)

if not self.dry_run:
(
extracted_rows
| 'WriteToBigQuery' >> WriteToBigQuery(
project=self.table.project,
dataset=self.table.dataset_id,
table=self.table.table_id,
write_disposition=BigQueryDisposition.WRITE_APPEND,
create_disposition=BigQueryDisposition.CREATE_NEVER)
)
else:
(
extracted_rows
| 'Log Extracted Rows' >> beam.Map(logger.debug)
ds, chunks = xbeam.open_zarr(self.first_uri, **self.xarray_open_dataset_kwargs)
ds.attrs[DATA_URI_COLUMN] = self.first_uri
extracted_rows = (
paths
| 'OpenChunks' >> xbeam.DatasetToChunks(ds, chunks)
| 'ExtractRows' >> beam.FlatMapTuple(self.chunks_to_rows)
| 'Window' >> beam.WindowInto(window.FixedWindows(60))
Copy link
Collaborator

@mahrsee1997 mahrsee1997 Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If incorporating these steps (window) enables real-time data ingestion into BigQuery (in batch jobs), then we should relocate these lines here.
Ref: #291

| 'AddTimestamp' >> beam.Map(timestamp_row)
)

if self.dry_run:
return extracted_rows | 'Log Rows' >> beam.Map(logger.info)
return (
extracted_rows
| 'WriteToBigQuery' >> WriteToBigQuery(
project=self.table.project,
dataset=self.table.dataset_id,
table=self.table.table_id,
write_disposition=BigQueryDisposition.WRITE_APPEND,
create_disposition=BigQueryDisposition.CREATE_NEVER)
)


def map_dtype_to_sql_type(var_type: np.dtype) -> str:
"""Maps a np.dtype to a suitable BigQuery column type."""
Expand Down Expand Up @@ -343,6 +362,12 @@ def to_table_schema(columns: t.List[t.Tuple[str, str]]) -> t.List[bigquery.Schem
return fields


def timestamp_row(it: t.Dict) -> window.TimestampedValue:
"""Associate an extracted row with the import_time timestamp."""
timestamp = it[DATA_IMPORT_TIME_COLUMN].timestamp()
return window.TimestampedValue(it, timestamp)


def fetch_geo_point(lat: float, long: float) -> str:
"""Calculates a geography point from an input latitude and longitude."""
if lat > LATITUDE_RANGE[1] or lat < LATITUDE_RANGE[0]:
Expand Down
48 changes: 41 additions & 7 deletions weather_mv/loader_pipeline/bq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
import tempfile
import typing as t
import unittest

Expand All @@ -23,6 +24,8 @@
import pandas as pd
import simplejson
import xarray as xr
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that, is_not_empty
from google.cloud.bigquery import SchemaField

from .bq import (
Expand Down Expand Up @@ -205,13 +208,13 @@ def extract(self, data_path, *, variables=None, area=None, open_dataset_kwargs=N
skip_creating_polygon: bool = False) -> t.Iterator[t.Dict]:
if zarr_kwargs is None:
zarr_kwargs = {}
op = ToBigQuery.from_kwargs(first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs,
output_table='foo.bar.baz', variables=variables, area=area,
xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time,
infer_schema=False, tif_metadata_for_datetime=tif_metadata_for_datetime,
skip_region_validation=True,
disable_grib_schema_normalization=disable_grib_schema_normalization,
coordinate_chunk_size=1000, skip_creating_polygon=skip_creating_polygon)
op = ToBigQuery.from_kwargs(
first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs,
output_table='foo.bar.baz', variables=variables, area=area,
xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time, infer_schema=False,
tif_metadata_for_datetime=tif_metadata_for_datetime, skip_region_validation=True,
disable_grib_schema_normalization=disable_grib_schema_normalization, coordinate_chunk_size=1000,
skip_creating_polygon=skip_creating_polygon)
coords = op.prepare_coordinates(data_path)
for uri, chunk in coords:
yield from op.extract_rows(uri, chunk)
Expand Down Expand Up @@ -737,5 +740,36 @@ def test_multiple_editions__with_vars__includes_coordinates_in_vars__with_schema
self.assertRowsEqual(actual, expected)


class ExtractRowsFromZarrTest(ExtractRowsTestBase):

def setUp(self) -> None:
super().setUp()
self.tmpdir = tempfile.TemporaryDirectory()

def tearDown(self) -> None:
super().tearDown()
self.tmpdir.cleanup()

def test_extracts_rows(self):
input_zarr = os.path.join(self.tmpdir.name, 'air_temp.zarr')

ds = (
xr.tutorial.open_dataset('air_temperature', cache_dir=self.test_data_folder)
.isel(time=slice(0, 4), lat=slice(0, 4), lon=slice(0, 4))
.rename(dict(lon='longitude', lat='latitude'))
)
ds.to_zarr(input_zarr)

op = ToBigQuery.from_kwargs(
first_uri=input_zarr, zarr_kwargs=dict(), dry_run=True, zarr=True, output_table='foo.bar.baz',
variables=list(), area=list(), xarray_open_dataset_kwargs=dict(), import_time=None, infer_schema=False,
tif_metadata_for_datetime=None, skip_region_validation=True, disable_grib_schema_normalization=False,
)

with TestPipeline() as p:
result = p | op
assert_that(result, is_not_empty())


if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions weather_mv/loader_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .streaming import GroupMessagesByFixedWindows, ParsePaths

logger = logging.getLogger(__name__)
SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0'
SDK_CONTAINER_IMAGE = 'gcr.io/weather-tools-prod/weather-tools:0.0.0'


def configure_logger(verbosity: int) -> None:
Expand Down Expand Up @@ -55,8 +55,9 @@ def pipeline(known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None
known_args.first_uri = next(iter(all_uris))

with beam.Pipeline(argv=pipeline_args) as p:
if known_args.topic or known_args.subscription:

if known_args.zarr:
paths = p
elif known_args.topic or known_args.subscription:
paths = (
p
# Windowing is based on this code sample:
Expand Down Expand Up @@ -140,7 +141,6 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]:
# Validate Zarr arguments
if known_args.uris.endswith('.zarr'):
known_args.zarr = True
known_args.zarr_kwargs['chunks'] = known_args.zarr_kwargs.get('chunks', None)

if known_args.zarr_kwargs and not known_args.zarr:
raise ValueError('`--zarr_kwargs` argument is only allowed with valid Zarr input URI.')
Expand Down
2 changes: 1 addition & 1 deletion weather_mv/loader_pipeline/regrid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_zarr__coarsen(self):
self.Op,
first_uri=input_zarr,
output_path=output_zarr,
zarr_input_chunks={"time": 5},
zarr_input_chunks={"time": 25},
zarr=True
)

Expand Down
4 changes: 2 additions & 2 deletions weather_mv/loader_pipeline/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _replace_dataarray_names_with_long_names(ds: xr.Dataset):
datetime_value_ms = None
try:
datetime_value_s = (int(end_time.timestamp()) if end_time is not None
else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0)
else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0)
ds = ds.assign_coords({'time': datetime.datetime.utcfromtimestamp(datetime_value_s)})
except KeyError:
raise RuntimeError(f"Invalid datetime metadata of tif: {tif_metadata_for_datetime}.")
Expand Down Expand Up @@ -375,7 +375,7 @@ def open_dataset(uri: str,
"""Open the dataset at 'uri' and return a xarray.Dataset."""
try:
if is_zarr:
ds: xr.Dataset = xr.open_dataset(uri, engine='zarr', **open_dataset_kwargs)
ds: xr.Dataset = _add_is_normalized_attr(xr.open_dataset(uri, engine='zarr', **open_dataset_kwargs), False)
beam.metrics.Metrics.counter('Success', 'ReadNetcdfData').inc()
yield ds
ds.close()
Expand Down
1 change: 1 addition & 0 deletions weather_mv/loader_pipeline/sinks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test_opens_zarr(self):
with open_dataset(self.test_zarr_path, is_zarr=True, open_dataset_kwargs={}) as ds:
self.assertIsNotNone(ds)
self.assertEqual(list(ds.data_vars), ['cape', 'd2m'])

def test_open_dataset__fits_memory_bounds(self):
with write_netcdf() as test_netcdf_path:
with limit_memory(max_memory=30):
Expand Down
Loading
Loading