diff --git a/tests/test_randomizedsearch.py b/tests/test_randomizedsearch.py index b182bd7..409941b 100644 --- a/tests/test_randomizedsearch.py +++ b/tests/test_randomizedsearch.py @@ -156,7 +156,7 @@ def test_local_dir(self): parameter_grid, early_stopping=scheduler, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) tune_search.fit(x, y) self.assertTrue(len(os.listdir("./test-result")) != 0) @@ -290,7 +290,7 @@ def test_warm_start_detection(self): parameter_grid, n_jobs=1, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) self.assertEqual(tune_search.early_stop_type, EarlyStopping.NO_EARLY_STOP) @@ -301,7 +301,7 @@ def test_warm_start_detection(self): parameter_grid, n_jobs=1, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) self.assertEqual(tune_search2.early_stop_type, EarlyStopping.NO_EARLY_STOP) @@ -312,7 +312,7 @@ def test_warm_start_detection(self): parameter_grid, n_jobs=1, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) self.assertEqual(tune_search3.early_stop_type, EarlyStopping.NO_EARLY_STOP) @@ -323,7 +323,7 @@ def test_warm_start_detection(self): early_stopping=True, n_jobs=1, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) self.assertEqual(tune_search4.early_stop_type, EarlyStopping.WARM_START_ITER) @@ -334,7 +334,7 @@ def test_warm_start_detection(self): early_stopping=True, n_jobs=1, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) self.assertEqual(tune_search5.early_stop_type, EarlyStopping.WARM_START_ENSEMBLE) @@ -349,7 +349,7 @@ def test_warm_start_error(self): n_jobs=1, early_stopping=False, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) self.assertFalse(tune_search._can_early_stop()) with self.assertRaises(ValueError): tune_search = TuneSearchCV( @@ -358,7 +358,7 @@ def test_warm_start_error(self): n_jobs=1, early_stopping=True, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) from sklearn.linear_model import LogisticRegression clf = LogisticRegression() @@ -370,7 +370,7 @@ def test_warm_start_error(self): early_stopping=True, n_jobs=1, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) from sklearn.ensemble import RandomForestClassifier clf = RandomForestClassifier() @@ -382,7 +382,7 @@ def test_warm_start_error(self): early_stopping=True, n_jobs=1, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) def test_warn_reduce_maxiters(self): parameter_grid = {"alpha": Real(1e-4, 1e-1, prior="log-uniform")} @@ -390,13 +390,16 @@ def test_warn_reduce_maxiters(self): clf = RandomForestClassifier(max_depth=2, random_state=0) with self.assertWarnsRegex(UserWarning, "max_iters is set"): TuneSearchCV( - clf, parameter_grid, max_iters=10, local_dir="./test-result") + clf, + parameter_grid, + max_iters=10, + local_dir=os.path.abspath("./test-result")) with self.assertWarnsRegex(UserWarning, "max_iters is set"): TuneSearchCV( SGDClassifier(), parameter_grid, max_iters=10, - local_dir="./test-result") + local_dir=os.path.abspath("./test-result")) def test_warn_early_stop(self): X, y = make_classification( @@ -893,9 +896,9 @@ def testHyperoptPointsToEvaluate(self): from ray.tune.search.hyperopt import HyperOptSearch # Skip test if category conversion is not available if not hasattr(HyperOptSearch, "_convert_categories_to_indices"): - self.skipTest(f"The current version of Ray does not support the " - f"`points_to_evaluate` argument for search method " - f"`hyperopt`. Skipping test.") + self.skipTest("The current version of Ray does not support the " + "`points_to_evaluate` argument for search method " + "`hyperopt`. Skipping test.") return self._test_points_to_evaluate("hyperopt") diff --git a/tune_sklearn/_trainable.py b/tune_sklearn/_trainable.py index 964679e..c3010c8 100644 --- a/tune_sklearn/_trainable.py +++ b/tune_sklearn/_trainable.py @@ -319,8 +319,8 @@ def _train(self): return_train_score=self.return_train_score, error_score=self.error_score) except ValueError as e: - if ("It is very likely that your" - "model is misconfigured") not in str(e): + if ("It is very likely that your model is misconfigured" not in + str(e)): raise e fit_failed = True @@ -367,9 +367,9 @@ def _train(self): return ret - def save_checkpoint(self, checkpoint_dir): + def save_checkpoint(self, checkpoint_dir: str): # forward-compatbility - return self._save(checkpoint_dir) + self._save(checkpoint_dir) def _save(self, checkpoint_dir): """Creates a checkpoint in ``checkpoint_dir``, creating a pickle file. @@ -387,21 +387,21 @@ def _save(self, checkpoint_dir): cpickle.dump(self.estimator_list, f) except Exception: warnings.warn("Unable to save estimator.", category=RuntimeWarning) - return path - def load_checkpoint(self, checkpoint): + def load_checkpoint(self, checkpoint_dir: str): # forward-compatbility - return self._restore(checkpoint) + self._restore(checkpoint_dir) - def _restore(self, checkpoint): + def _restore(self, checkpoint_dir): """Loads a checkpoint created from `save`. Args: checkpoint (str): file path to pickled checkpoint file. """ + path = os.path.join(checkpoint_dir, "checkpoint") try: - with open(checkpoint, "rb") as f: + with open(path, "rb") as f: self.estimator_list = cpickle.load(f) except Exception: warnings.warn("No estimator restored", category=RuntimeWarning) diff --git a/tune_sklearn/tune_basesearch.py b/tune_sklearn/tune_basesearch.py index 8c8abbf..a048e64 100644 --- a/tune_sklearn/tune_basesearch.py +++ b/tune_sklearn/tune_basesearch.py @@ -327,7 +327,7 @@ def __init__(self, verbose=0, error_score="raise", return_train_score=False, - local_dir="~/ray_results", + local_dir=None, name=None, max_iters=1, use_gpu=False, @@ -773,32 +773,31 @@ def _format_results(self, n_splits, out): trials = [ trial for trial in out.trials if trial.status == Trial.TERMINATED ] - trial_dirs = [trial.logdir for trial in trials] - # The result dtaframes are indexed by their trial logdir - trial_dfs = out.fetch_trial_dataframes() + trial_dfs = out.trial_dataframes + trial_ids = list(trial_dfs) # Try to find a template df to use for trials that did not return # any results. These trials should copy the structure and fill it # with NaNs so that the later reshape actions work. template_df = None - fix_trial_dirs = [] # Holds trial dirs with no results - for trial_dir in trial_dirs: - if trial_dir in trial_dfs and template_df is None: - template_df = trial_dfs[trial_dir] - elif trial_dir not in trial_dfs: - fix_trial_dirs.append(trial_dir) + fix_trial_ids = [] # Holds trial_ids with no results + for trial_id, trial_df in trial_dfs.items(): + if template_df is None and not trial_df.empty: + template_df = trial_df + elif trial_df.empty: + fix_trial_ids.append(trial_id) # Create NaN dataframes for trials without results - if fix_trial_dirs: + if fix_trial_ids: if template_df is None: # No trial returned any results return {} - for trial_dir in fix_trial_dirs: + for trial_id in fix_trial_ids: trial_df = pd.DataFrame().reindex_like(template_df) - trial_dfs[trial_dir] = trial_df + trial_dfs[trial_id] = trial_df # Keep right order - dfs = [trial_dfs[trial_dir] for trial_dir in trial_dirs] + dfs = [trial_dfs[trial_id] for trial_id in trial_ids] finished = [df.iloc[[-1]] for df in dfs] test_scores = {} train_scores = {} diff --git a/tune_sklearn/tune_gridsearch.py b/tune_sklearn/tune_gridsearch.py index 8add6e6..7149903 100644 --- a/tune_sklearn/tune_gridsearch.py +++ b/tune_sklearn/tune_gridsearch.py @@ -157,7 +157,7 @@ def __init__(self, verbose=0, error_score="raise", return_train_score=False, - local_dir="~/ray_results", + local_dir=None, name=None, max_iters=1, use_gpu=False, diff --git a/tune_sklearn/tune_search.py b/tune_sklearn/tune_search.py index e860621..fcdfee6 100644 --- a/tune_sklearn/tune_search.py +++ b/tune_sklearn/tune_search.py @@ -316,7 +316,7 @@ def __init__(self, random_state=None, error_score=np.nan, return_train_score=False, - local_dir="~/ray_results", + local_dir=None, name=None, max_iters=1, search_optimization="random",