From 3801007feb90b6f38959cf904268d73a11c69ddd Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Mon, 6 Mar 2023 18:33:51 +0100 Subject: [PATCH] Plum 2 compatibility --- setup.py | 2 +- stheno/model/gp.py | 11 ++++------- stheno/model/measure.py | 22 ++++++++-------------- stheno/model/observations.py | 3 ++- stheno/random.py | 5 +++-- tests/model/test_model.py | 19 ++++++------------- tests/test_random.py | 2 +- tests/util.py | 4 ++-- 8 files changed, 27 insertions(+), 41 deletions(-) diff --git a/setup.py b/setup.py index 94c5e17..6fb3a51 100755 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ "numpy>=1.16", "fdm", "algebra>=1", - "plum-dispatch>=1.5.3", + "plum-dispatch>=2", "backends>=1.4.11", "backends-matrix>=1.2.11", "mlkernels>=0.3.6", diff --git a/stheno/model/gp.py b/stheno/model/gp.py index 04a2f4a..a7490a6 100644 --- a/stheno/model/gp.py +++ b/stheno/model/gp.py @@ -1,17 +1,14 @@ from types import FunctionType +from typing import Union from fdm import central_fdm from lab import B -from mlkernels import ( - OneKernel, - ZeroMean, - OneMean, -) -from plum import Union +from mlkernels import OneKernel, OneMean, ZeroMean +from plum import isinstance -from .fdd import FDD from .. import PromisedGP, PromisedMeasure, _dispatch from ..random import RandomProcess +from .fdd import FDD __all__ = ["assert_same_measure", "intersection_measure_group", "cross", "GP"] diff --git a/stheno/model/measure.py b/stheno/model/measure.py index 43bfb39..c0bf995 100644 --- a/stheno/model/measure.py +++ b/stheno/model/measure.py @@ -1,29 +1,23 @@ from types import FunctionType +from typing import Union from lab import B from matrix import Constant -from mlkernels import ( - num_elements, - ZeroKernel, - TensorProductKernel, -) -from plum import Union +from mlkernels import TensorProductKernel, ZeroKernel, num_elements +from .. import PromisedMeasure, _dispatch +from ..lazy import LazyMatrix, LazyVector +from ..mo import AmbiguousDimensionalityKernel as ADK +from ..mo import MultiOutputKernel as MOK +from ..mo import MultiOutputMean as MOM from .fdd import FDD from .gp import GP, assert_same_measure from .observations import ( AbstractObservations, - Observations, AbstractPseudoObservations, + Observations, combine, ) -from .. import _dispatch, PromisedMeasure -from ..lazy import LazyVector, LazyMatrix -from ..mo import ( - MultiOutputKernel as MOK, - MultiOutputMean as MOM, - AmbiguousDimensionalityKernel as ADK, -) __all__ = ["Measure"] diff --git a/stheno/model/observations.py b/stheno/model/observations.py index 5e9bd1a..ebd9656 100644 --- a/stheno/model/observations.py +++ b/stheno/model/observations.py @@ -1,7 +1,8 @@ +from typing import Union + from lab import B from matrix import AbstractMatrix, Diagonal from mlkernels import PosteriorKernel, PosteriorMean, SubspaceKernel, num_elements -from plum import Union from .. import _dispatch from .fdd import FDD diff --git a/stheno/random.py b/stheno/random.py index b1ecf5d..c6530e5 100644 --- a/stheno/random.py +++ b/stheno/random.py @@ -1,9 +1,10 @@ from types import FunctionType +from typing import Union from algebra.util import identical from lab import B -from matrix import AbstractMatrix, Zero, Diagonal -from plum import convert, Union +from matrix import AbstractMatrix, Diagonal, Zero +from plum import convert from wbml.util import indented_kv from . import _dispatch diff --git a/tests/model/test_model.py b/tests/model/test_model.py index d2dfc5a..762200f 100644 --- a/tests/model/test_model.py +++ b/tests/model/test_model.py @@ -3,28 +3,21 @@ import tensorflow as tf from lab import B from matrix import Diagonal -from mlkernels import ( - Kernel, - pairwise, - elwise, - Linear, - EQ, - Delta, - Exp, -) +from mlkernels import EQ, Delta, Exp, Kernel, Linear, elwise, pairwise +from plum import isinstance from stheno.model import ( - Measure, + FDD, GP, + Measure, Obs, PseudoObs, - PseudoObsFITC, PseudoObsDTC, + PseudoObsFITC, cross, - FDD, ) -from .util import assert_equal_normals, assert_equal_measures from ..util import approx +from .util import assert_equal_measures, assert_equal_normals def test_measure_groups(): diff --git a/tests/test_random.py b/tests/test_random.py index 1fb1c13..c7dc002 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -3,7 +3,7 @@ import pytest from algebra.util import identical from matrix import Dense, Zero -from plum import NotFoundLookupError +from plum import NotFoundLookupError, isinstance from scipy.stats import multivariate_normal from stheno.random import Normal, RandomVector diff --git a/tests/util.py b/tests/util.py index daee6c4..4c60c1f 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,10 +1,10 @@ from time import time +from typing import Union import numpy as np from lab import B from matrix import AbstractMatrix -from plum import dispatch, Union - +from plum import dispatch from stheno import Normal __all__ = ["benchmark", "approx"]