From dfe15c729ad12924bbd92bcab488928147d73c3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 16 May 2024 23:00:39 +0100 Subject: [PATCH] [BUG] fix `Mixture` distribution for more than two components (#337) The `Mixture` distribution after API upgrade broke for 3 or more components. This has been fixed, and is now covered by a test case. --- skpro/distributions/mixture.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/skpro/distributions/mixture.py b/skpro/distributions/mixture.py index 640e9b0c2..641af14df 100644 --- a/skpro/distributions/mixture.py +++ b/skpro/distributions/mixture.py @@ -225,7 +225,8 @@ def sample(self, n_samples=None): dists = [d[1] for d in self._distributions] raw_samples = [d.sample(N).values for d in dists] masked_samples = [ind * raw for ind, raw in zip(indicators, raw_samples)] - sample = np.add(*masked_samples) + masked_samples = np.array(masked_samples) + sample = masked_samples.sum(axis=0) if n_samples is None: spl_index = self.index @@ -273,12 +274,18 @@ def get_test_params(cls, parameter_set="default"): params2 = {"distributions": dists2, "weights": [0.3, 0.7]} # scalar case - normal1 = Normal(mu=0, sigma=1) - normal2 = Normal(mu=3, sigma=2) - dists = [("normal1", normal1), ("normal2", normal2)] - dists2 = [normal1, normal2] + normal3 = Normal(mu=0, sigma=1) + normal4 = Normal(mu=3, sigma=2) + dists = [("normal3", normal3), ("normal4", normal4)] + dists2 = [normal3, normal4] params3 = {"distributions": dists2} params4 = {"distributions": dists, "weights": [0.3, 0.7]} - return [params1, params2, params3, params4] + # more than 2 distributions + normal5 = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=2, columns=columns) + normal6 = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=0.5, columns=columns) + dists3 = [normal1, normal2, normal5, normal6] + params5 = {"distributions": dists3} + + return [params1, params2, params3, params4, params5]