Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom seasons spanning calendar years #423

Merged
merged 17 commits into from
Nov 20, 2024

Conversation

tomvothecoder
Copy link
Collaborator

@tomvothecoder tomvothecoder commented Mar 3, 2023

Description

TODO:

  • Support custom seasons that span calendar years (_shift_spanning_months())
    • Requires detecting order of the months in a season. Currently, order does not matter.
    • For example, for custom_season = ["Nov", "Dec", "Jan", "Feb", "Mar"]:
      • ["Nov", "Dec"] are from the previous year since they are listed before "Jan"
      • ["Jan", "Feb", "Mar"] are from the current year
      • Therefore, ["Nov", "Dec"] need to be shifted a year forward for correct
        grouping.
  • Detect and drop incomplete seasons (_drop_incomplete_seasons())
    • Right now xCDAT only detects incomplete "DJF" seasons with _drop_incomplete_djf()
    • Replace boolean config drop_incomplete_djf with drop_incomplete_season
    • A possible solution for detecting incomplete seasons is to check if a season has all of the required months. If the count of months for that season does not match the expected count, then drop that season.
  • Remove requirement for all 12 months to be included in a custom season
  • Refactor logic for shifting months to use Xarray instead of Pandas
  • The current logic maps the custom season to its middle month, represented as cftime time coordinates. Does it make sense to also keep the custom seasons with the time coordinates, similar to what Xarray does?
      <xarray.DataArray 'season' (season: 2)> Size: 32B
      array(['DJFM', 'AMJ'], dtype='<U4')
      Coordinates:
        * season   (season) <U4 32B 'DJFM' 'AMJ'

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • My changes generate no new warnings
  • Any dependent changes have been merged and published in downstream modules

If applicable:

  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass with my changes (locally and CI/CD build)
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have noted that this is a breaking change for a major release (fix or feature that would cause existing functionality to not work as expected)

Additional Context

@tomvothecoder tomvothecoder added the type: enhancement New enhancement request label Mar 3, 2023
@tomvothecoder tomvothecoder self-assigned this Mar 3, 2023
@tomvothecoder tomvothecoder force-pushed the feature/416-custom-season-span branch from 96c5eca to fa087b7 Compare March 6, 2023 20:31
@tomvothecoder
Copy link
Collaborator Author

Example result of _drop_incomplete_seasons():

# Before dropping
# -----------------
# 2000-1, 2000-2, and 2001-12 months in incomplete "DJF" seasons" so they are dropped
ds.time
<xarray.DataArray 'time' (time: 15)>
array([cftime.DatetimeGregorian(2000, 1, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 2, 15, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 3, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 4, 16, 0, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 5, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 6, 16, 0, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 7, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 8, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 9, 16, 0, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 10, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 11, 16, 0, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 12, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2001, 1, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2001, 2, 15, 0, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2001, 12, 16, 12, 0, 0, 0, has_year_zero=False)],
      dtype=object)
Coordinates:
  * time     (time) object 2000-01-16 12:00:00 ... 2001-12-16 12:00:00
Attributes:
    axis:           T
    long_name:      time
    standard_name:  time
    bounds:         time_bnds

# After dropping
# -----------------
ds_new.time
<xarray.DataArray 'time' (time: 12)>
array([cftime.DatetimeGregorian(2000, 3, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 4, 16, 0, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 5, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 6, 16, 0, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 7, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 8, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 9, 16, 0, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 10, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 11, 16, 0, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2000, 12, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2001, 1, 16, 12, 0, 0, 0, has_year_zero=False),
       cftime.DatetimeGregorian(2001, 2, 15, 0, 0, 0, 0, has_year_zero=False)],
      dtype=object)
Coordinates:
  * time     (time) object 2000-03-16 12:00:00 ... 2001-02-15 00:00:00
Attributes:
    axis:           T
    long_name:      time
    standard_name:  time
    bounds:         time_bnds

@tomvothecoder
Copy link
Collaborator Author

Hey @lee1043, this PR seemed to be mostly done when I stopped working on it last year. I just had to fix a few things and update the tests.

Would you like to check out this branch to test it out on real data? Also a code review would be appreciated.

@tomvothecoder tomvothecoder requested a review from lee1043 April 4, 2024 21:50
@tomvothecoder tomvothecoder marked this pull request as ready for review April 4, 2024 21:50
Copy link

codecov bot commented Apr 4, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 100.00%. Comparing base (1cfd369) to head (f2648ff).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff            @@
##              main      #423   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files           15        15           
  Lines         1555      1609   +54     
=========================================
+ Hits          1555      1609   +54     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@lee1043
Copy link
Collaborator

lee1043 commented Apr 4, 2024

@tomvothecoder sure, I will test it out and review. Thank you for the update!

@lee1043
Copy link
Collaborator

lee1043 commented Apr 4, 2024

@tomvothecoder Can this be considered for v0.7.0?

Copy link
Collaborator Author

@tomvothecoder tomvothecoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My PR self-review

Comment on lines 1012 to 991
warnings.warn(
"The `season_config` argument 'drop_incomplete_djf' is being "
"deprecated. Please use 'drop_incomplete_seasons' instead.",
DeprecationWarning,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Need to a specify a specific version that we will deprecate drop_incomplete_djf. Probably v0.8.0 or v0.9.0.

xcdat/temporal.py Outdated Show resolved Hide resolved
Comment on lines -1025 to -1003
if len(input_months) != len(predefined_months):
raise ValueError(
"Exactly 12 months were not passed in the list of custom seasons."
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed requirements for all 12 months to be included in a custom season

@tomvothecoder
Copy link
Collaborator Author

tomvothecoder commented Apr 4, 2024

@tomvothecoder Can this be considered for v0.7.0?

This PR still needs thorough review before I'm confident in merging it. I'll probably tag Steve at some point.

The release after v0.7.0 is more realistic and reasonable. We can always initiate a new release for this feature whenever it is merged.

@lee1043
Copy link
Collaborator

lee1043 commented Apr 4, 2024

@tomvothecoder Can this be considered for v0.7.0?

This PR still needs thorough review before I'm confident in merging it. I'll probably tag Steve at some point.

The release after v0.7.0 is more realistic and reasonable. We can always initiate a new release for this feature whenever it is merged.

@tomvothecoder no problem. Thank you for consideration.

@lee1043
Copy link
Collaborator

lee1043 commented Apr 4, 2024

@tomvothecoder it looks like when custom season go beyond calendar year (Nov, Dec, Jan) there is error as follows.

import os
import xcdat as xc

input_data = os.path.join(
    "/p/css03/esgf_publish/CMIP6/CMIP/AWI/AWI-CM-1-1-MR/historical/r1i1p1f1/Amon/psl/gn/v20181218/",
    "psl_Amon_AWI-CM-1-1-MR_historical_r1i1p1f1_gn_201301-201312.nc")

ds = xc.open_mfdataset(input_data)

# Example of custom seasons in a three month format:
custom_seasons = [
    ['Dec', 'Jan'],
]

season_config = {'custom_seasons': custom_seasons, 'dec_mode': 'DJF', 'drop_incomplete_djf': True}

ds.temporal.group_average("psl", "season", season_config=season_config)
CPU times: user 448 ms, sys: 32.8 ms, total: 481 ms
Wall time: 471 ms
xarray.Dataset
Dimensions:
lat: 192bnds: 2lon: 384time: 4
Coordinates:
lat
(lat)
float64
-89.28 -88.36 ... 88.36 89.28
lon
(lon)
float64
0.0 0.9375 1.875 ... 358.1 359.1
time
(time)
object
2013-02-01 00:00:00 ... 2013-11-...
Data variables:
lat_bnds
(lat, bnds)
float64
dask.array<chunksize=(192, 2), meta=np.ndarray>
lon_bnds
(lon, bnds)
float64
dask.array<chunksize=(384, 2), meta=np.ndarray>
psl
(time, lat, lon)
float64
dask.array<chunksize=(1, 192, 384), meta=np.ndarray>
Indexes: (3)
Attributes: (44)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/.conda/envs/xcdat_dev_20240401/lib/python3.11/site-packages/xarray/core/dataarray.py:862, in DataArray._getitem_coord(self, key)
    861 try:
--> 862     var = self._coords[key]
    863 except KeyError:

KeyError: 'time.year'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
<timed exec> in ?()

~/.conda/envs/xcdat_dev_20240401/lib/python3.11/site-packages/xcdat/temporal.py in ?(self, data_var, freq, weighted, keep_weights, season_config)
    400         }
    401         """
    402         self._set_data_var_attrs(data_var)
    403 
--> 404         return self._averager(
    405             data_var,
    406             "group_average",
    407             freq,

~/.conda/envs/xcdat_dev_20240401/lib/python3.11/site-packages/xcdat/temporal.py in ?(self, data_var, mode, freq, weighted, keep_weights, reference_period, season_config)
    870 
    871         if self._mode == "average":
    872             dv_avg = self._average(ds, data_var)
    873         elif self._mode in ["group_average", "climatology", "departures"]:
--> 874             dv_avg = self._group_average(ds, data_var)
    875 
    876         # The original time dimension is dropped from the dataset because
    877         # it becomes obsolete after the data variable is averaged. When the

~/.conda/envs/xcdat_dev_20240401/lib/python3.11/site-packages/xcdat/temporal.py in ?(self, ds, data_var)
   1295         dv = _get_data_var(ds, data_var)
   1296 
   1297         # Label the time coordinates for grouping weights and the data variable
   1298         # values.
-> 1299         self._labeled_time = self._label_time_coords(dv[self.dim])
   1300 
   1301         if self._weighted:
   1302             time_bounds = ds.bounds.get_bounds("T", var_key=data_var)

~/.conda/envs/xcdat_dev_20240401/lib/python3.11/site-packages/xcdat/temporal.py in ?(self, time_coords)
   1470         >>>       dtype='datetime64[ns]')
   1471         >>> Coordinates:
   1472         >>> * time     (time) datetime64[ns] 2000-01-01T00:00:00 ... 2000-04-01T00:00:00
   1473         """
-> 1474         df_dt_components: pd.DataFrame = self._get_df_dt_components(
   1475             time_coords, drop_obsolete_cols=True
   1476         )
   1477         dt_objects = self._convert_df_to_dt(df_dt_components)

~/.conda/envs/xcdat_dev_20240401/lib/python3.11/site-packages/xcdat/temporal.py in ?(self, time_coords, drop_obsolete_cols)
   1534 
   1535         # Use the TIME_GROUPS dictionary to determine which components
   1536         # are needed to form the labeled time coordinates.
   1537         for component in TIME_GROUPS[self._mode][self._freq]:
-> 1538             df[component] = time_coords[f"{self.dim}.{component}"].values
   1539 
   1540         # The season frequency requires additional datetime components for
   1541         # processing, which are later removed before time coordinates are

~/.conda/envs/xcdat_dev_20240401/lib/python3.11/site-packages/xarray/core/dataarray.py in ?(self, key)
    869     def __getitem__(self, key: Any) -> Self:
    870         if isinstance(key, str):
--> 871             return self._getitem_coord(key)
    872         else:
    873             # xarray-style array indexing
    874             return self.isel(indexers=self._item_key_to_dict(key))

~/.conda/envs/xcdat_dev_20240401/lib/python3.11/site-packages/xarray/core/dataarray.py in ?(self, key)
    861         try:
    862             var = self._coords[key]
    863         except KeyError:
    864             dim_sizes = dict(zip(self.dims, self.shape))
--> 865             _, key, var = _get_virtual_variable(self._coords, key, dim_sizes)
    866 
    867         return self._replace_maybe_drop_dims(var, name=key)

~/.conda/envs/xcdat_dev_20240401/lib/python3.11/site-packages/xarray/core/dataset.py in ?(variables, key, dim_sizes)
    212     if _contains_datetime_like_objects(ref_var):
    213         ref_var = DataArray(ref_var)
    214         data = getattr(ref_var.dt, var_name).data
    215     else:
--> 216         data = getattr(ref_var, var_name).data
    217     virtual_var = Variable(ref_var.dims, data)
    218 
    219     return ref_name, var_name, virtual_var

AttributeError: 'IndexVariable' object has no attribute 'year'
CPU times: user 240 ms, sys: 13.2 ms, total: 253 ms
Wall time: 249 ms
xarray.Dataset
Dimensions:
lat: 192bnds: 2lon: 384time: 1
Coordinates:
lat
(lat)
float64
-89.28 -88.36 ... 88.36 89.28
lon
(lon)
float64
0.0 0.9375 1.875 ... 358.1 359.1
time
(time)
object
2013-02-01 00:00:00
Data variables:
lat_bnds
(lat, bnds)
float64
dask.array<chunksize=(192, 2), meta=np.ndarray>
lon_bnds
(lon, bnds)
float64
dask.array<chunksize=(384, 2), meta=np.ndarray>
psl
(time, lat, lon)
float64
dask.array<chunksize=(1, 192, 384), meta=np.ndarray>
Indexes: (3)
Attributes: (44)
CPU times: user 106 ms, sys: 6.96 ms, total: 113 ms
Wall time: 110 ms
xarray.Dataset
Dimensions:
lat: 192bnds: 2lon: 384time: 1
Coordinates:
lat
(lat)
float64
-89.28 -88.36 ... 88.36 89.28
lon
(lon)
float64
0.0 0.9375 1.875 ... 358.1 359.1
time
(time)
object
2013-02-01 00:00:00
Data variables:
lat_bnds
(lat, bnds)
float64
dask.array<chunksize=(192, 2), meta=np.ndarray>
lon_bnds
(lon, bnds)
float64
dask.array<chunksize=(384, 2), meta=np.ndarray>
psl
(time, lat, lon)
float64
dask.array<chunksize=(1, 192, 384), meta=np.ndarray>
Indexes: (3)
Attributes: (44)

@tomvothecoder
Copy link
Collaborator Author

@lee1043 Thanks for trying to this out and providing an example script! I'll debug the stack trace.

@tomvothecoder tomvothecoder force-pushed the feature/416-custom-season-span branch from 8f9af92 to 8d156c2 Compare November 12, 2024 21:02
@tomvothecoder
Copy link
Collaborator Author

In commit 0b6852f (#423), I refactored the logic for shifting months to use Xarray/NumPy instead of Pandas. We'll gradually shift away from the Pandas back-end for storing and manipulating Datetime components (used to group time coordinates) towards Xarray/Numpy in #217 since xarray >=2024.09.0 now supports grouping by multiple variables.

I will do a final walk through at the next xCDAT meeting (11/20) before merging.

@lee1043
Copy link
Collaborator

lee1043 commented Nov 20, 2024

I conducted extra testing and confirmed that the current PR is working without any noticeable issue.

import xcdat
import matplotlib.pyplot as plt

filepath = "http://esgf.nci.org.au/thredds/dodsC/master/CMIP6/CMIP/CSIRO/ACCESS-ESM1-5/historical/r10i1p1f1/Amon/tas/gn/v20200605/tas_Amon_ACCESS-ESM1-5_historical_r10i1p1f1_gn_185001-201412.nc"
ds = xcdat.open_dataset(filepath)

# Climatology for default seasons

season_climo = ds.temporal.climatology(
    "tas",
    freq="season",
    weighted=True,
    season_config={"dec_mode": "DJF", "drop_incomplete_djf": True},
)

# Climatology for custom seasons

custom_seasons = [
    ["Jan", "Feb", "Mar"],  # "JanFebMar"
    ["Apr", "May", "Jun"],  # "AprMayJun"
    ["Jul", "Aug", "Sep"],  # "JunJulAug"
    ["Oct", "Nov", "Dec"],  # "OctNovDec"
]

c_season_climo = ds.temporal.climatology(
    "tas",
    freq="season",
    weighted=True,
    season_config={"custom_seasons": custom_seasons},
)

fig, ax = plt.subplots(2, 2, figsize=(12, 5))

c_season_climo.isel(time=0)["tas"].plot(ax=ax[0, 0])  # First row, first column
c_season_climo.isel(time=1)["tas"].plot(ax=ax[0, 1])  # First row, second column
c_season_climo.isel(time=2)["tas"].plot(ax=ax[1, 0])  # Second row, first column
c_season_climo.isel(time=3)["tas"].plot(ax=ax[1, 1])  # Second row, second column

plt.tight_layout()
plt.show()

test2_output1

# Climatology for custom seasons

custom_seasons_2 = [
    ["Oct", "Nov", "Dec", "Jan", "Feb"],  # "OctNovDec"
]

c_season_climo_2 = ds.temporal.climatology(
    "tas",
    freq="season",
    weighted=True,
    season_config={"custom_seasons": custom_seasons_2},
)

c_season_climo_2["tas"].plot()

test2_output2

- Methods include `_subset_coords_for_custom_seasons()` and `_shift_custom_season_years()`
@tomvothecoder
Copy link
Collaborator Author

Thank for testing @lee1043. I will now merge this PR!

@tomvothecoder
Copy link
Collaborator Author

Hey @arfriedman, @DamienIrving, and @oliviermarti, I know this PR is a long time coming (over a year and a half). If you're still interested, you can try out this custom seasons feature by checking out the latest main branch and installing xcdat into your env. I have not decided on when to release the next version of xCDAT with this feature yet.

git clone https://github.com/xCDAT/xcdat.git
cd xcdat
conda activate <YOUR-ENV>
make install # or python -m pip install . 

@tomvothecoder tomvothecoder merged commit 27396e5 into main Nov 20, 2024
10 checks passed
@arfriedman
Copy link

Thank you @tomvothecoder! I'm very excited about this feature.

@oliviermarti
Copy link

oliviermarti commented Nov 21, 2024

Tom,

Thank you for this work :-)
Up to now, it seems to produce correct results. I'm gonna try to imagine some further tests.

I still have a concern : I have a monthly variable with time values centered at the middle of the month, and correct bounds. When I compute

dx.temporal.group_average("nmonth", freq="season", weighted=True,
                                season_config={'drop_incomplete_djf':False, 
                                               'dec_mode': 'DJF',})

The result is on a time axis with values at the beginning of the season, not the middle. That not very nice for plots, for instance when I plot monthly and seasonal means on the same plot.

And no bounds are produced. That means that for example that if I start from daily values, then compute compute monthly mean. I can not compute seasonnal means.

Olivier

@pochedls
Copy link
Collaborator

  • Nice work Tom!
  • Thank you for the review and feedback @oliviermarti! I think that your main points are captured in this issue.
  • In the mean time, I think you could a) start with daily, b) compute the monthly mean, c) call add_missing_bounds, d) compute seasonal means (but add_missing_bounds wouldn't subsequently operate correctly on a seasonal time series, so it is still important to fix the root of this issue).

@tomvothecoder tomvothecoder deleted the feature/416-custom-season-span branch November 21, 2024 18:50
@tomvothecoder
Copy link
Collaborator Author

Thanks @oliviermarti! And thank you @pochedls for pointing to #565 and a partial solution. #565 should be relatively easy to address. I will add this as a higher priority item to tackle in the next few months.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
priority: soon Should be addressed soon. project: seats-fy24 type: enhancement New enhancement request
Projects
Status: Done
6 participants