Skip to content

Commit

Permalink
Merge pull request #66 from jakirkham/dist_more_metrics
Browse files Browse the repository at this point in the history
Support more metrics in cdist and pdist
  • Loading branch information
jakirkham authored Sep 30, 2017
2 parents 33cb8b5 + 80534f2 commit 5d26db9
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
28 changes: 26 additions & 2 deletions dask_distance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ def cdist(XA, XB, metric="euclidean", **kwargs):
"hamming": hamming,
"jaccard": jaccard,
"kulsinski": kulsinski,
"mahalanobis": mahalanobis,
"minkowski": minkowski,
"rogerstanimoto": rogerstanimoto,
"russellrao": russellrao,
"sokalmichener": sokalmichener,
"sokalsneath": sokalsneath,
"seuclidean": seuclidean,
"sqeuclidean": sqeuclidean,
"wminkowski": wminkowski,
"yule": yule,
}

Expand Down Expand Up @@ -93,8 +96,22 @@ def cdist(XA, XB, metric="euclidean", **kwargs):

metric = func_mappings[metric]

if metric == minkowski:
kwargs["p"] = kwargs.get("p", 2)
if metric == mahalanobis:
if "VI" not in kwargs:
kwargs["VI"] = (
dask.array.linalg.inv(
dask.array.cov(dask.array.vstack([XA, XB]).T)
).T
)
elif metric == minkowski:
kwargs.setdefault("p", 2)
elif metric == seuclidean:
if "V" not in kwargs:
kwargs["V"] = (
dask.array.var(dask.array.vstack([XA, XB]), axis=0, ddof=1)
)
elif metric == wminkowski:
kwargs.setdefault("p", 2)

result = metric(XA, XB, **kwargs)

Expand Down Expand Up @@ -124,6 +141,13 @@ def pdist(X, metric="euclidean", **kwargs):
other tradeoffs.
"""

if metric == "mahalanobis":
if "VI" not in kwargs:
kwargs["VI"] = dask.array.linalg.inv(dask.array.cov(X.T)).T
elif metric == "seuclidean":
if "V" not in kwargs:
kwargs["V"] = dask.array.var(X, axis=0, ddof=1)

result = cdist(X, X, metric, **kwargs)

result = dask.array.triu(result, 1)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_dask_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,15 @@ def test_1d_dist(funcname, kw, seed, size, chunks):
("correlation", {}),
("cosine", {}),
("euclidean", {}),
("mahalanobis", {"VI": None}),
("mahalanobis", {}),
("minkowski", {}),
("minkowski", {"p": 3}),
("seuclidean", {"V": None}),
("seuclidean", {}),
("sqeuclidean", {}),
("wminkowski", {}),
("wminkowski", {"p": 1.6}),
(lambda u, v: (abs(u - v) ** 3).sum() ** (1.0 / 3.0), {}),
]
)
Expand All @@ -133,6 +139,19 @@ def test_2d_cdist(metric, kw, seed, u_shape, u_chunks, v_shape, v_chunks):
d_u = da.from_array(a_u, chunks=u_chunks)
d_v = da.from_array(a_v, chunks=v_chunks)

if metric == "mahalanobis":
if "VI" not in kw:
kw["VI"] = 2 * np.random.random(2 * u_shape[-1:]) - 1
elif kw["VI"] is None:
kw.pop("VI")
elif metric == "seuclidean":
if "V" not in kw:
kw["V"] = 2 * np.random.random(u_shape[-1:]) - 1
elif kw["V"] is None:
kw.pop("V")
elif metric == "wminkowski":
kw["w"] = np.random.random(u_shape[-1:])

a_r = spdist.cdist(a_u, a_v, metric, **kw)
d_r = dask_distance.cdist(d_u, d_v, metric, **kw)

Expand All @@ -148,9 +167,15 @@ def test_2d_cdist(metric, kw, seed, u_shape, u_chunks, v_shape, v_chunks):
("correlation", {}),
("cosine", {}),
("euclidean", {}),
("mahalanobis", {"VI": None}),
("mahalanobis", {}),
("minkowski", {}),
("minkowski", {"p": 3}),
("seuclidean", {"V": None}),
("seuclidean", {}),
("sqeuclidean", {}),
("wminkowski", {}),
("wminkowski", {"p": 1.6}),
(lambda u, v: (abs(u - v) ** 3).sum() ** (1.0 / 3.0), {}),
]
)
Expand All @@ -172,6 +197,19 @@ def test_2d_pdist(metric, kw, seed, u_shape, u_chunks):
a_u = 2 * np.random.random(u_shape) - 1
d_u = da.from_array(a_u, chunks=u_chunks)

if metric == "mahalanobis":
if "VI" not in kw:
kw["VI"] = 2 * np.random.random(2 * u_shape[-1:]) - 1
elif kw["VI"] is None:
kw.pop("VI")
elif metric == "seuclidean":
if "V" not in kw:
kw["V"] = 2 * np.random.random(u_shape[-1:]) - 1
elif kw["V"] is None:
kw.pop("V")
elif metric == "wminkowski":
kw["w"] = np.random.random(u_shape[-1:])

a_r = spdist.pdist(a_u, metric, **kw)
d_r = dask_distance.pdist(d_u, metric, **kw)

Expand Down

0 comments on commit 5d26db9

Please sign in to comment.