Skip to content

Commit

Permalink
Bugfix reproducibility & ensemble member order with dask (#347)
Browse files Browse the repository at this point in the history
* Bugfix: fix random placement of ensemble members in numpy array due to dask multi-threading (#337)

* Bugfix: make STEPS (blending) nowcast reproducable when the seed argument is given (#346)

* Bugfix: make STEPS (blending) nowcast reproducable, independent of number of workers (#346)

* Formatting with black

---------

Co-authored-by: ned <daniele.nerini@meteoswiss.ch>
  • Loading branch information
mpvginde and dnerini authored Mar 5, 2024
1 parent 3167a11 commit 098927d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 27 deletions.
5 changes: 4 additions & 1 deletion pysteps/blending/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def forecast(
)

# 2. Initialize the noise method
np.random.seed(seed)
pp, generate_noise, noise_std_coeffs = _init_noise(
precip,
precip_thr,
Expand All @@ -526,6 +527,7 @@ def forecast(
noise_stddev_adj,
measure_time,
num_workers,
seed,
)

# 3. Perform the cascade decomposition for the input precip fields and
Expand Down Expand Up @@ -1662,6 +1664,7 @@ def _init_noise(
noise_stddev_adj,
measure_time,
num_workers,
seed,
):
"""Initialize the noise method."""
if noise_method is None:
Expand Down Expand Up @@ -1690,6 +1693,7 @@ def _init_noise(
20,
conditional=True,
num_workers=num_workers,
seed=seed,
)

if measure_time:
Expand Down Expand Up @@ -1944,7 +1948,6 @@ def _init_random_generators(
if noise_method is not None:
randgen_prec = []
randgen_motion = []
np.random.seed(seed)
for j in range(n_ens_members):
rs = np.random.RandomState(seed)
randgen_prec.append(rs)
Expand Down
44 changes: 21 additions & 23 deletions pysteps/noise/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,39 +96,37 @@ def compute_noise_stddev_adjs(

if dask_imported and num_workers > 1:
res = []
else:
N_stds = []

N_stds = [None] * num_iter
randstates = []
seed = None

for k in range(num_iter):
randstates.append(np.random.RandomState(seed=seed))
seed = np.random.randint(0, high=1e9)

for k in range(num_iter):

def worker():
# generate Gaussian white noise field, filter it using the chosen
# method, multiply it with the standard deviation of the observed
# field and apply the precipitation mask
N = noise_generator(noise_filter, randstate=randstates[k], seed=seed)
N = N / np.std(N) * sigma + mu
N[~MASK] = R_thr_2
def worker(k):
# generate Gaussian white noise field, filter it using the chosen
# method, multiply it with the standard deviation of the observed
# field and apply the precipitation mask
N = noise_generator(noise_filter, randstate=randstates[k])
N = N / np.std(N) * sigma + mu
N[~MASK] = R_thr_2

# subtract the mean and decompose the masked noise field into a
# cascade
N -= mu
decomp_N = decomp_method(N, F, mask=MASK_)
# subtract the mean and decompose the masked noise field into a
# cascade
N -= mu
decomp_N = decomp_method(N, F, mask=MASK_)

return decomp_N["stds"]

if dask_imported and num_workers > 1:
res.append(dask.delayed(worker)())
else:
N_stds.append(worker())
N_stds[k] = decomp_N["stds"]

if dask_imported and num_workers > 1:
N_stds = dask.compute(*res, num_workers=num_workers)
for k in range(num_iter):
res.append(dask.delayed(worker)(k))
dask.compute(*res, num_workers=num_workers)

else:
for k in range(num_iter):
worker(k)

# for each cascade level, compare the standard deviations between the
# observed field and the masked noise field, which gives the correction
Expand Down
7 changes: 4 additions & 3 deletions pysteps/nowcasts/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def f(precip, i):
precip[i, ~np.isfinite(precip[i, :])] = np.nanmin(precip[i, :])

if noise_method is not None:
np.random.seed(seed)
# get methods for perturbations
init_noise, generate_noise = noise.get_method(noise_method)

Expand All @@ -466,6 +467,7 @@ def f(precip, i):
20,
conditional=True,
num_workers=num_workers,
seed=seed,
)

if measure_time:
Expand Down Expand Up @@ -543,7 +545,6 @@ def f(precip, i):
if noise_method is not None:
randgen_prec = []
randgen_motion = []
np.random.seed(seed)
for _ in range(n_ens_members):
rs = np.random.RandomState(seed)
randgen_prec.append(rs)
Expand Down Expand Up @@ -706,7 +707,7 @@ def _check_inputs(precip, velocity, timesteps, ar_order):


def _update(state, params):
precip_forecast_out = []
precip_forecast_out = [None] * params["n_ens_members"]

if params["noise_method"] is None or params["mask_method"] == "sprog":
for i in range(params["n_cascade_levels"]):
Expand Down Expand Up @@ -828,7 +829,7 @@ def worker(j):

precip_forecast[params["domain_mask"]] = np.nan

precip_forecast_out.append(precip_forecast)
precip_forecast_out[j] = precip_forecast

if (
DASK_IMPORTED
Expand Down

0 comments on commit 098927d

Please sign in to comment.