Skip to content

Commit

Permalink
fix score_cache, fixes #15
Browse files Browse the repository at this point in the history
  • Loading branch information
rsteca committed Nov 23, 2016
1 parent 662fd36 commit de6d4f5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 18 deletions.
24 changes: 8 additions & 16 deletions evolutionary_search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,23 @@ def _individual_to_params(individual, name_values):
return dict((name, values[gene]) for gene, (name, values) in zip(individual, name_values))


def _evalFunction(individual, searchobj, name_values, X, y, scorer, cv, iid, fit_params,
score_cache = {}

def _evalFunction(individual, name_values, X, y, scorer, cv, iid, fit_params,
verbose=0, error_score='raise'):
parameters = _individual_to_params(individual, name_values)
score = 0
n_test = 0
for train, test in cv:
paramkey = str(parameters)
if paramkey in searchobj.score_cache:
searchobj.num_cache_hits += 1
_score = searchobj.score_cache[paramkey]
paramkey = str(individual)
if paramkey in score_cache:
_score = score_cache[paramkey]
else:
_score, _, _ = _fit_and_score(estimator=individual.est, X=X, y=y, scorer=scorer,
train=train, test=test, verbose=verbose,
parameters=parameters, fit_params=fit_params,
error_score=error_score)
searchobj.num_evaluations += 1
searchobj.score_cache[paramkey] = _score
if searchobj.verbose and (searchobj.num_evaluations + searchobj.num_cache_hits) % searchobj.population_size == 0:
print("Scoring evaluations: %d, Cache hits: %d, Total: %d" % (
searchobj.num_evaluations, searchobj.num_cache_hits, searchobj.num_evaluations + searchobj.num_cache_hits))
score_cache[paramkey] = _score
if iid:
score += _score * len(test)
n_test += len(test)
Expand Down Expand Up @@ -272,9 +269,6 @@ def __init__(self, estimator, params, scoring=None, cv=4,
self.gene_crossover_prob = gene_crossover_prob
self.tournament_size = tournament_size
self.gene_type = gene_type
self.score_cache = {}
self.num_cache_hits = 0
self.num_evaluations = 0

def fit(self, X, y=None):
self.best_estimator_ = None
Expand Down Expand Up @@ -316,7 +310,7 @@ def _fit(self, X, y, parameter_dict):
toolbox.register("individual", _initIndividual, creator.Individual, maxints=maxints)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)

toolbox.register("evaluate", _evalFunction, searchobj=self,
toolbox.register("evaluate", _evalFunction,
name_values=name_values, X=X, y=y,
scorer=self.scorer_, cv=cv, iid=self.iid, verbose=self.verbose,
error_score=self.error_score, fit_params=self.fit_params)
Expand Down Expand Up @@ -349,8 +343,6 @@ def _fit(self, X, y, parameter_dict):
if self.verbose:
print("Best individual is: %s\nwith fitness: %s" % (
current_best_params_, current_best_score_))
print("Scoring evaluations: %d, Cache hits: %d, Total: %d" % (
self.num_evaluations, self.num_cache_hits, self.num_evaluations + self.num_cache_hits))

if current_best_score_ > self.best_score_:
self.best_score_ = current_best_score_
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

setup(
name='sklearn-deap',
version='0.1.6',
version='0.1.7',
author='Rodrigo',
author_email='',
description='Use evolutionary algorithms instead of gridsearch in scikit-learn.',
url='https://github.com/rsteca/sklearn-deap',
download_url='https://github.com/rsteca/sklearn-deap/tarball/0.1.6',
download_url='https://github.com/rsteca/sklearn-deap/tarball/0.1.7',
classifiers=[
'Development Status :: 4 - Beta',
'Programming Language :: Python :: 2',
Expand Down

0 comments on commit de6d4f5

Please sign in to comment.