Skip to content

Commit

Permalink
ENH: Add type signatures to the jenksy function to speed it up. Also,…
Browse files Browse the repository at this point in the history
… add a warning if numba is not installed. (pysal#118)
  • Loading branch information
cheginit committed Dec 10, 2021
1 parent 134f2dc commit 77da329
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions mapclassify/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@
FMT = "{:.2f}"

try:
from numba import jit
from numba import njit
HAS_NUMBA = True
except ImportError:

def jit(func):
HAS_NUMBA = False
def njit(func):
return func


Expand Down Expand Up @@ -507,8 +508,8 @@ def natural_breaks(values, k=5, init=10):
return (sids, class_ids, fit, cuts)


@jit
def _fisher_jenks_means(values, classes=5, sort=True):
@njit("f8[:](f8[:], u2)", cache=True)
def _fisher_jenks_means(values, classes=5):
"""
Jenks Optimal (Natural Breaks) algorithm implemented in Python.
Expand All @@ -523,8 +524,6 @@ def _fisher_jenks_means(values, classes=5, sort=True):
assuring heterogeneity among classes.
"""
if sort:
values.sort()
n_data = len(values)
mat1 = np.zeros((n_data + 1, classes + 1), dtype=np.int32)
mat2 = np.zeros((n_data + 1, classes + 1), dtype=np.float32)
Expand Down Expand Up @@ -562,7 +561,7 @@ def _fisher_jenks_means(values, classes=5, sort=True):
id = int(pivot - 2)
kclass[countNum - 1] = values[id]
k = int(pivot - 1)
return kclass
return np.delete(kclass, 0)


class MapClassifier(object):
Expand Down Expand Up @@ -1761,8 +1760,8 @@ class FisherJenks(MapClassifier):
----------
y : array
(n,1), values to classify
k : int
number of classes required
k : int, optional
number of classes, defatuls to 5
Attributes
----------
Expand Down Expand Up @@ -1790,6 +1789,9 @@ class FisherJenks(MapClassifier):
"""

def __init__(self, y, k=K):
if not HAS_NUMBA:
Warn("Numba not installed. Using slow pure python version.",
UserWarning)

nu = len(np.unique(y))
if nu < k:
Expand All @@ -1799,8 +1801,8 @@ def __init__(self, y, k=K):
self.name = "FisherJenks"

def _set_bins(self):
x = self.y.copy()
self.bins = np.array(_fisher_jenks_means(x, classes=self.k)[1:])
x = np.sort(self.y).astype("f8")
self.bins = _fisher_jenks_means(x, classes=self.k)


class FisherJenksSampled(MapClassifier):
Expand Down

0 comments on commit 77da329

Please sign in to comment.