Skip to content

Commit

Permalink
[BUG] fix Mixture distribution for more than two components (#337)
Browse files Browse the repository at this point in the history
The `Mixture` distribution after API upgrade broke for 3 or more
components.

This has been fixed, and is now covered by a test case.
  • Loading branch information
fkiraly authored May 16, 2024
1 parent 18822a6 commit dfe15c7
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions skpro/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def sample(self, n_samples=None):
dists = [d[1] for d in self._distributions]
raw_samples = [d.sample(N).values for d in dists]
masked_samples = [ind * raw for ind, raw in zip(indicators, raw_samples)]
sample = np.add(*masked_samples)
masked_samples = np.array(masked_samples)
sample = masked_samples.sum(axis=0)

if n_samples is None:
spl_index = self.index
Expand Down Expand Up @@ -273,12 +274,18 @@ def get_test_params(cls, parameter_set="default"):
params2 = {"distributions": dists2, "weights": [0.3, 0.7]}

# scalar case
normal1 = Normal(mu=0, sigma=1)
normal2 = Normal(mu=3, sigma=2)
dists = [("normal1", normal1), ("normal2", normal2)]
dists2 = [normal1, normal2]
normal3 = Normal(mu=0, sigma=1)
normal4 = Normal(mu=3, sigma=2)
dists = [("normal3", normal3), ("normal4", normal4)]
dists2 = [normal3, normal4]

params3 = {"distributions": dists2}
params4 = {"distributions": dists, "weights": [0.3, 0.7]}

return [params1, params2, params3, params4]
# more than 2 distributions
normal5 = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=2, columns=columns)
normal6 = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=0.5, columns=columns)
dists3 = [normal1, normal2, normal5, normal6]
params5 = {"distributions": dists3}

return [params1, params2, params3, params4, params5]

0 comments on commit dfe15c7

Please sign in to comment.