diff --git a/stheno/random.py b/stheno/random.py index 71a0226..9f171be 100644 --- a/stheno/random.py +++ b/stheno/random.py @@ -298,7 +298,7 @@ def kl(self, other: "Normal"): scalar: KL divergence. """ return ( - B.iqf_diag(other.var, other.mean - self.mean)[0] + B.iqf_diag(other.var, other.mean - self.mean)[..., 0] + B.ratio(self.var, other.var) + B.logdet(other.var) - B.logdet(self.var)