Skip to content

Commit

Permalink
Plum 2 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Mar 6, 2023
1 parent f78e5c1 commit 3801007
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 41 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"numpy>=1.16",
"fdm",
"algebra>=1",
"plum-dispatch>=1.5.3",
"plum-dispatch>=2",
"backends>=1.4.11",
"backends-matrix>=1.2.11",
"mlkernels>=0.3.6",
Expand Down
11 changes: 4 additions & 7 deletions stheno/model/gp.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from types import FunctionType
from typing import Union

from fdm import central_fdm
from lab import B
from mlkernels import (
OneKernel,
ZeroMean,
OneMean,
)
from plum import Union
from mlkernels import OneKernel, OneMean, ZeroMean
from plum import isinstance

from .fdd import FDD
from .. import PromisedGP, PromisedMeasure, _dispatch
from ..random import RandomProcess
from .fdd import FDD

__all__ = ["assert_same_measure", "intersection_measure_group", "cross", "GP"]

Expand Down
22 changes: 8 additions & 14 deletions stheno/model/measure.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
from types import FunctionType
from typing import Union

from lab import B
from matrix import Constant
from mlkernels import (
num_elements,
ZeroKernel,
TensorProductKernel,
)
from plum import Union
from mlkernels import TensorProductKernel, ZeroKernel, num_elements

from .. import PromisedMeasure, _dispatch
from ..lazy import LazyMatrix, LazyVector
from ..mo import AmbiguousDimensionalityKernel as ADK
from ..mo import MultiOutputKernel as MOK
from ..mo import MultiOutputMean as MOM
from .fdd import FDD
from .gp import GP, assert_same_measure
from .observations import (
AbstractObservations,
Observations,
AbstractPseudoObservations,
Observations,
combine,
)
from .. import _dispatch, PromisedMeasure
from ..lazy import LazyVector, LazyMatrix
from ..mo import (
MultiOutputKernel as MOK,
MultiOutputMean as MOM,
AmbiguousDimensionalityKernel as ADK,
)

__all__ = ["Measure"]

Expand Down
3 changes: 2 additions & 1 deletion stheno/model/observations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Union

from lab import B
from matrix import AbstractMatrix, Diagonal
from mlkernels import PosteriorKernel, PosteriorMean, SubspaceKernel, num_elements
from plum import Union

from .. import _dispatch
from .fdd import FDD
Expand Down
5 changes: 3 additions & 2 deletions stheno/random.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from types import FunctionType
from typing import Union

from algebra.util import identical
from lab import B
from matrix import AbstractMatrix, Zero, Diagonal
from plum import convert, Union
from matrix import AbstractMatrix, Diagonal, Zero
from plum import convert
from wbml.util import indented_kv

from . import _dispatch
Expand Down
19 changes: 6 additions & 13 deletions tests/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,21 @@
import tensorflow as tf
from lab import B
from matrix import Diagonal
from mlkernels import (
Kernel,
pairwise,
elwise,
Linear,
EQ,
Delta,
Exp,
)
from mlkernels import EQ, Delta, Exp, Kernel, Linear, elwise, pairwise
from plum import isinstance
from stheno.model import (
Measure,
FDD,
GP,
Measure,
Obs,
PseudoObs,
PseudoObsFITC,
PseudoObsDTC,
PseudoObsFITC,
cross,
FDD,
)

from .util import assert_equal_normals, assert_equal_measures
from ..util import approx
from .util import assert_equal_measures, assert_equal_normals


def test_measure_groups():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from algebra.util import identical
from matrix import Dense, Zero
from plum import NotFoundLookupError
from plum import NotFoundLookupError, isinstance
from scipy.stats import multivariate_normal
from stheno.random import Normal, RandomVector

Expand Down
4 changes: 2 additions & 2 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from time import time
from typing import Union

import numpy as np
from lab import B
from matrix import AbstractMatrix
from plum import dispatch, Union

from plum import dispatch
from stheno import Normal

__all__ = ["benchmark", "approx"]
Expand Down

0 comments on commit 3801007

Please sign in to comment.