Skip to content

Commit

Permalink
types defined in test too
Browse files Browse the repository at this point in the history
  • Loading branch information
HeikoSchuett committed Jul 2, 2024
1 parent 45998f6 commit d73856c
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/test_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_calc_one_similarity(self):
descriptor='idx', i_des=1, j_des=2)
sim_c, w_c = calc_one_similarity_c(
d1, d2,
np.array([0, 1]), np.array([2, 3, 4]),
np.array([0, 1], dtype=np.int64), np.array([2, 3, 4], dtype=np.int64),
method=method)
self.assertAlmostEqual(
w, w_c, None,
Expand All @@ -93,15 +93,15 @@ def test_integer_input_one(self):
ds1 = Dataset(np.asarray([[0], [2]]).T) # one pattern, two channels
ds2 = Dataset(np.asarray([[0], [2]]).T) # one pattern, two channels
dissim, _ = calc_one_similarity_c(
ds1, ds2, np.array([0]), np.array([1]))
ds1, ds2, np.array([0], dtype=np.int64), np.array([1], dtype=np.int64))
assert_almost_equal(dissim, 2) # standard-squared euclidean


class TestCalc(unittest.TestCase):

def setUp(self):
self.rng = np.random.default_rng(0)
self.dat = self.rng.random((300, 100))
self.dat = self.rng.random((300, 100), dtype=np.float64)
self.data = rsatoolbox.data.Dataset(
self.dat,
obs_descriptors={'obs': np.repeat(np.arange(50), 6),
Expand All @@ -114,8 +114,8 @@ def test_basic(self):
# directly call c version
a = calc(
self.dat,
self.data.obs_descriptors['obs'].astype(int),
self.data.obs_descriptors['rep'].astype(int),
self.data.obs_descriptors['obs'].astype(np.int64),
self.data.obs_descriptors['rep'].astype(np.int64),
50, i + 1)
self_sim = a[:50]
rdm = a[50:]
Expand Down

0 comments on commit d73856c

Please sign in to comment.