Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dask to run tasks #1714

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
63b032a
Use dask to run preprocessing tasks (and disable other tasks for now)
bouweandela Sep 2, 2022
4792193
Improve use of dask client as suggested by @zklaus
bouweandela Sep 13, 2022
11b12f1
Restore sequential and parallel task run
bouweandela Sep 13, 2022
6faa8f6
Add missing compute argument to save function for multimodel statisti…
bouweandela Sep 23, 2022
49e3589
Add distributed as a dependency
bouweandela Sep 23, 2022
54edb73
Restore previous API to fix tests
bouweandela Sep 23, 2022
021b3e0
Add xarray as a dependency
bouweandela Sep 23, 2022
53aa585
Merge branch 'main' into dask-distributed
bouweandela Sep 23, 2022
d5cf4f6
Merge branch 'main' of github.com:ESMValGroup/ESMValCore into dask-di…
bouweandela Oct 14, 2022
0f5bda8
Support configuring dask
bouweandela Oct 14, 2022
52af816
Add a suffix to output_directory if it exists instead of stopping
bouweandela Oct 16, 2022
823c731
Fix tests
bouweandela Oct 18, 2022
6f5a6bf
single call to compute
fnattino Nov 7, 2022
b455bbb
Only start cluster if necessary and fix filename-future mapping
bouweandela Nov 10, 2022
05d69f7
Merge branch 'main' of github.com:ESMValGroup/ESMValCore into dask-di…
bouweandela Nov 21, 2022
be991f8
Use iris (https://github.com/SciTools/iris/pull/5031) for saving
bouweandela Nov 21, 2022
93b6da1
Merge branch 'main' of github.com:ESMValGroup/ESMValCore into dask-di…
bouweandela Nov 22, 2022
37bb757
Point to iris branch
bouweandela Dec 5, 2022
2b60264
Merge branch 'main' into dask-distributed
bouweandela Dec 5, 2022
66bdf2e
Merge branch 'main' of github.com:ESMValGroup/ESMValCore into dask-di…
bouweandela Mar 30, 2023
065b8a4
Work in progress
bouweandela Apr 4, 2023
f47d6da
Update branch name
bouweandela Apr 4, 2023
6e25901
Remove type hint
bouweandela Apr 4, 2023
e176745
Get iris from main branch
bouweandela Apr 21, 2023
d6787fa
Merge branch 'main' into dask-distributed
bouweandela Apr 21, 2023
222c7e5
Add default scheduler
bouweandela Apr 28, 2023
b760abb
Use release candidate
bouweandela May 11, 2023
6b35638
Try to install iris from PyPI
bouweandela May 11, 2023
cf79d20
Merge branch 'main' of github.com:ESMValGroup/ESMValCore into dask-di…
bouweandela May 12, 2023
4348953
Update dependencies
bouweandela May 19, 2023
6bf6328
Remove pip install of ESMValTool_sample_data
bouweandela May 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ dependencies:
- compilers
# 1.8.18/py39, they seem weary to build manylinux wheels
# and pypi ver built with older gdal
- dask
- distributed
- fiona
- esmpy!=8.1.0 # see github.com/ESMValGroup/ESMValCore/issues/1208
- geopy
Expand All @@ -22,3 +24,4 @@ dependencies:
- python>=3.8
- python-stratify
- scipy>=1.6
- xarray
23 changes: 19 additions & 4 deletions esmvalcore/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def process_recipe(recipe_file, config_user):

from multiprocessing import cpu_count
n_processes = config_user['max_parallel_tasks'] or cpu_count()
config_user['max_parallel_tasks'] = n_processes
logger.info("Running tasks using at most %s processes", n_processes)

logger.info(
Expand Down Expand Up @@ -401,10 +402,24 @@ def run(self,
cfg = read_config_user_file(config_file, recipe.stem, kwargs)

# Create run dir
if os.path.exists(cfg['run_dir']):
print("ERROR: run_dir {} already exists, aborting to "
"prevent data loss".format(cfg['output_dir']))
os.makedirs(cfg['run_dir'])
out_dir = Path(cfg['output_dir'])
if out_dir.exists():
# Add an extra suffix to avoid path collision with another process.
suffix = 1
new_out_dir = out_dir
while new_out_dir.exists():
new_out_dir = Path(f"{out_dir}-{suffix}")
suffix += 1
if suffix > 1000:
print("ERROR: output_dir {} already exists, aborting to "
"prevent data loss".format(cfg['output_dir']))
break
# Update configuration with the new path.
cfg['output_dir'] = str(new_out_dir)
for dirname in ('run_dir', 'preproc_dir', 'work_dir', 'plot_dir'):
cfg[dirname] = str(out_dir / Path(cfg[dirname]).name)
os.makedirs(cfg['output_dir'])
os.mkdir(cfg['run_dir'])

# configure logging
log_files = configure_logging(output_dir=cfg['run_dir'],
Expand Down
8 changes: 6 additions & 2 deletions esmvalcore/_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ def _get_default_settings(variable, config_user, derive=False):
}

# Configure saving cubes to file
settings['save'] = {'compress': config_user['compress_netcdf']}
settings['save'] = {'compress': config_user['compress_netcdf'],
'compute': config_user['max_parallel_tasks'] != -1}
if variable['short_name'] != variable['original_short_name']:
settings['save']['alias'] = variable['short_name']

Expand Down Expand Up @@ -712,6 +713,9 @@ def _get_downstream_settings(step, order, products):
if key in remaining_steps:
if all(p.settings.get(key, object()) == value for p in products):
settings[key] = value
save = dict(some_product.settings.get('save', {}))
save.pop('filename', None)
settings['save'] = save
return settings


Expand Down Expand Up @@ -1918,7 +1922,7 @@ def run(self):
if not self._cfg['offline']:
esgf.download(self._download_files, self._cfg['download_dir'])

self.tasks.run(max_parallel_tasks=self._cfg['max_parallel_tasks'])
self.tasks.run(self._cfg)
self.write_html_summary()

def get_output(self) -> dict:
Expand Down
51 changes: 47 additions & 4 deletions esmvalcore/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import abc
import contextlib
import datetime
import importlib
import logging
import numbers
import os
Expand All @@ -18,6 +19,8 @@
from shutil import which
from typing import Dict, Type

import dask
import dask.distributed
import psutil
import yaml

Expand Down Expand Up @@ -711,19 +714,59 @@ def get_independent(self) -> 'TaskSet':
independent_tasks.add(task)
return independent_tasks

def run(self, max_parallel_tasks: int = None) -> None:
def run(self, cfg) -> None:
"""Run tasks.

Parameters
----------
max_parallel_tasks : int
Number of processes to run. If `1`, run the tasks sequentially.
cfg : dict
Config-user dict.
"""
if max_parallel_tasks == 1:
max_parallel_tasks = cfg['max_parallel_tasks']
if max_parallel_tasks == -1:
self._run_dask(cfg)
elif max_parallel_tasks == 1:
self._run_sequential()
else:
self._run_parallel(max_parallel_tasks)

def _run_dask(self, cfg) -> None:
"""Run tasks using dask."""
# Configure dask
client_args = cfg.get('dask', {}).get('client', {})
cluster_args = cfg.get('dask', {}).get('cluster', {})
cluster_type = cluster_args.pop(
'type',
'dask.distributed.LocalCluster',
)
cluster_scale = cluster_args.pop('scale', 1)

# STart cluster
cluster_module_name, cluster_cls_name = cluster_type.rsplit('.', 1)
cluster_module = importlib.import_module(cluster_module_name)
cluster_cls = getattr(cluster_module, cluster_cls_name)
cluster = cluster_cls(**cluster_args)
cluster.scale(cluster_scale)

# Connect client and run computation
with dask.distributed.Client(cluster, **client_args) as client:
bouweandela marked this conversation as resolved.
Show resolved Hide resolved
logger.info(f"Dask dashboard: {client.dashboard_link}")
for task in sorted(self.flatten(), key=lambda t: t.priority):
if hasattr(task, 'delayeds'):
logger.info(f"Scheduling task {task.name}")
task.run()
logger.info(f"Computing task {task.name}")
futures = client.compute(
list(task.delayeds.values()),
priority=-task.priority,
)
future_map = dict(zip(futures, task.delayeds.keys()))
else:
logger.info(f"Skipping task {task.name}")
for future in dask.distributed.as_completed(futures):
filename = future_map[future]
logger.info(f"Wrote {filename}")

def _run_sequential(self) -> None:
"""Run tasks sequentially."""
n_tasks = len(self.flatten())
Expand Down
1 change: 1 addition & 0 deletions esmvalcore/experimental/config/_config_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def deprecate(func, variable, version: str = None):
_validators = {
# From user config
'log_level': validate_string,
'dask': validate_dict,
'exit_on_warning': validate_bool,
'output_dir': validate_path,
'download_dir': validate_path,
Expand Down
17 changes: 13 additions & 4 deletions esmvalcore/preprocessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def __init__(self, attributes, settings, ancestors=None):

self._cubes = None
self._prepared = False
self.delayed = None

def _input_files_for_log(self):
"""Do not log input files twice in output log."""
Expand Down Expand Up @@ -453,10 +454,15 @@ def cubes(self, value):

def save(self):
"""Save cubes to disk."""
self.files = preprocess(self._cubes, 'save',
input_files=self._input_files,
**self.settings['save'])
self.files = preprocess(self.files, 'cleanup',
result = save(
self._cubes,
**self.settings['save'],
)
if not self.settings['save'].get('compute', True):
self.delayed = result
self.files = [self.settings['save']['filename']]
self.files = preprocess(self.files,
'cleanup',
input_files=self._input_files,
**self.settings.get('cleanup', {}))

Expand Down Expand Up @@ -545,6 +551,7 @@ def __init__(
self.order = list(order)
self.debug = debug
self.write_ncl_interface = write_ncl_interface
self.delayeds = {}

def _initialize_product_provenance(self):
"""Initialize product provenance."""
Expand Down Expand Up @@ -589,6 +596,7 @@ def _initialize_products(self, products):

def _run(self, _):
"""Run the preprocessor."""
self.delayeds.clear()
self._initialize_product_provenance()

steps = {
Expand All @@ -613,6 +621,7 @@ def _run(self, _):

for product in self.products:
product.close()
self.delayeds[product.filename] = product.delayed
metadata_files = write_metadata(self.products,
self.write_ncl_interface)
return metadata_files
Expand Down
25 changes: 23 additions & 2 deletions esmvalcore/preprocessor/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import iris.aux_factory
import iris.exceptions
import numpy as np
import xarray
import yaml
from cf_units import suppress_errors

Expand Down Expand Up @@ -250,6 +251,7 @@ def save(cubes,
filename,
optimize_access='',
compress=False,
compute=True,
alias='',
**kwargs):
"""Save iris cubes to file.
Expand All @@ -274,6 +276,10 @@ def save(cubes,
compress: bool, optional
Use NetCDF internal compression.

compute: bool, optional
If true save immediately, otherwise return a dask.delayed.Delayed
object that can be used for saving the data later.

alias: str, optional
Var name to use when saving instead of the one in the cube.

Expand All @@ -289,6 +295,8 @@ def save(cubes,
"""
if not cubes:
raise ValueError(f"Cannot save empty cubes '{cubes}'")
if len(cubes) > 1:
raise ValueError(f"`save` expects as single cube, got '{cubes}")

# Rename some arguments
kwargs['target'] = filename
Expand Down Expand Up @@ -331,9 +339,22 @@ def save(cubes,
logger.debug('Changing var_name from %s to %s', cube.var_name,
alias)
cube.var_name = alias
iris.save(cubes, **kwargs)

return filename
cube = cubes[0]
if compute is True:
iris.save(cube, **kwargs)
return filename

data_array = xarray.DataArray.from_iris(cube)
kwargs.pop('target')
kwargs['_FillValue'] = kwargs.pop('fill_value')
encoding = {cube.var_name: kwargs}
delayed = data_array.to_netcdf(
filename,
encoding=encoding,
compute=False,
)
return delayed


def _get_debug_filename(filename, step):
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
'cartopy',
# see https://github.com/SciTools/cf-units/issues/218
'cf-units>=3.0.0,<3.1.0,!=3.0.1.post0', # ESMValCore/issues/1655
'dask[array]',
'dask[array,distributed]',
'esgf-pyclient>=0.3.1',
'esmpy!=8.1.0', # see github.com/ESMValGroup/ESMValCore/issues/1208
'fiona',
Expand All @@ -55,6 +55,7 @@
'scitools-iris>=3.2.1',
'shapely[vectorized]',
'stratify',
'xarray',
'yamale',
],
# Test dependencies
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def _get_default_settings_for_chl(fix_dir, save_filename, preprocessor):
},
'save': {
'compress': False,
'compute': True,
'filename': save_filename,
}
}
Expand Down Expand Up @@ -684,6 +685,7 @@ def test_default_fx_preprocessor(tmp_path, patched_datafinder, config_user):
},
'save': {
'compress': False,
'compute': True,
'filename': product.filename,
}
}
Expand Down Expand Up @@ -3601,14 +3603,16 @@ def test_recipe_run(tmp_path, patched_datafinder, config_user, mocker):

recipe = get_recipe(tmp_path, content, config_user)

os.makedirs(config_user['output_dir'])
recipe.tasks.run = mocker.Mock()
recipe.write_filled_recipe = mocker.Mock()
recipe.run()

esmvalcore._recipe.esgf.download.assert_called_once_with(
set(), config_user['download_dir'])
recipe.tasks.run.assert_called_once_with(
max_parallel_tasks=config_user['max_parallel_tasks'])
cfg = dict(config_user)
cfg['write_ncl_interface'] = False
recipe.tasks.run.assert_called_once_with(cfg)
recipe.write_filled_recipe.assert_called_once()


Expand Down
5 changes: 4 additions & 1 deletion tests/integration/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def test_run_tasks(monkeypatch, tmp_path, max_parallel_tasks, example_tasks,
"""Check that tasks are run correctly."""
monkeypatch.setattr(esmvalcore._task, 'Pool',
multiprocessing.get_context(mpmethod).Pool)
example_tasks.run(max_parallel_tasks=max_parallel_tasks)
cfg = {
'max_parallel_tasks': max_parallel_tasks,
}
example_tasks.run(cfg)

for task in example_tasks:
print(task.name, task.output_files)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/main/test_esmvaltool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_run(mocker, tmp_path, cmd_offline, cfg_offline):
'config_file': tmp_path / '.esmvaltool' / 'config-user.yml',
'log_level': 'info',
'offline': cfg_offline,
'output_dir': str(output_dir),
'preproc_dir': str(output_dir / 'preproc_dir'),
'run_dir': str(output_dir / 'run_dir'),
}
Expand Down