Skip to content

Commit

Permalink
fix #1
Browse files Browse the repository at this point in the history
try to fix a bug in the lisi code
  • Loading branch information
slowkow committed Mar 3, 2020
1 parent f2499d8 commit d89ec95
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions harmonypy/lisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from typing import Iterable


def compute_lisi(
X: np.array, metadata: pd.DataFrame, label_colnames, perplexity = 30
X: np.array,
metadata: pd.DataFrame,
label_colnames: Iterable[str],
perplexity: float=30
):
"""Compute the Local Inverse Simpson Index (LISI) for each column in metadata.
Expand Down Expand Up @@ -61,7 +65,14 @@ def compute_lisi(
return lisi_df


def compute_simpson(distances, indices, labels, n_categories, perplexity, tol = 1e-5):
def compute_simpson(
distances: np.ndarray,
indices: np.ndarray,
labels: pd.Categorical,
n_categories: int,
perplexity: float,
tol: float=1e-5
):
n = distances.shape[1]
P = np.zeros(distances.shape[0])
simpson = np.zeros(n)
Expand Down Expand Up @@ -115,8 +126,8 @@ def compute_simpson(distances, indices, labels, n_categories, perplexity, tol =
# Simpson's index
for label_category in labels.categories:
ix = indices[:,i]
q = np.squeeze(np.argwhere(labels[ix] == label_category))
if len(q):
q = labels[ix] == label_category
if np.any(q):
P_sum = np.sum(P[q])
simpson[i] += P_sum * P_sum
return simpson
Expand Down

0 comments on commit d89ec95

Please sign in to comment.