Skip to content

Commit

Permalink
Fixes from review
Browse files Browse the repository at this point in the history
  • Loading branch information
Sarah Krebs committed Feb 16, 2024
1 parent 04b660a commit da987e8
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 53 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Version 1.1.4
# Version 1.2

## Plugins
- Add symbolic explanations plugin (#46).
Expand Down
1 change: 0 additions & 1 deletion deepcave/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import multiprocessing
import subprocess
from pathlib import Path
Expand Down
2 changes: 1 addition & 1 deletion deepcave/evaluators/epm/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _impute_inactive(self, X: np.ndarray) -> np.ndarray:

def _check_dimensions(self, X: np.ndarray, Y: Optional[np.ndarray] = None) -> None:
"""
Checks if the dimensions of X and Y are correct wrt features.
Checks if the dimensions of X and Y are correct with respect to features.
Parameters
----------
Expand Down
2 changes: 1 addition & 1 deletion deepcave/evaluators/fanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def calculate(
seed: int = 0,
) -> None:
"""
Get the data wrt budget and trains the forest on the encoded data.
Get the data with respect to budget and trains the forest on the encoded data.
Note
----
Expand Down
192 changes: 170 additions & 22 deletions deepcave/plugins/hyperparameter/symbolic_explanations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
# noqa: D400
"""
# SymbolicExplanations
This module provides utilities for generating Symbolic Explanations.
Provided utilities include getting input and output layout,
processing the data and loading the outputs.
## Classes
- SymbolicExplanations: Leverage Symbolic Explanations to obtain a formula an plot it.
## Constants
GRID_POINTS_PER_AXIS : int
SAMPLES_PER_HP : int
MAX_SAMPLES : int
MAX_SHOWN_SAMPLES : int
"""

from typing import Any, Callable, Dict, List, Union

import dash_bootstrap_components as dbc
import numpy as np
import plotly.graph_objs as go
Expand All @@ -8,8 +29,8 @@

from deepcave.config import Config
from deepcave.evaluators.epm.random_forest_surrogate import RandomForestSurrogate
from deepcave.plugins.static import StaticPlugin
from deepcave.plugins.hyperparameter.pdp import PartialDependencies
from deepcave.plugins.static import StaticPlugin
from deepcave.runs import Status
from deepcave.utils.layout import get_checklist_options, get_select_options, help_button
from deepcave.utils.styled_plotty import get_color, get_hyperparameter_ticks, save_image
Expand All @@ -22,14 +43,35 @@


class SymbolicExplanations(StaticPlugin):
"""
Generate Symbolic Explanations.
Provided utilities include getting input and output layout,
processing the data and loading the outputs.
"""

id = "symbolic_explanations"
name = "Symbolic Explanations"
icon = "fas fa-subscript"
help = "docs/plugins/symbolic_explanations.rst"
activate_run_selection = True

@staticmethod
def get_input_layout(register):
def get_input_layout(register: Callable) -> List[Union[dbc.Row, html.Details]]:
"""
Get the layout for the input block.
Parameters
----------
register : Callable
Method to register (user) variables.
The register_input function is located in the Plugin superclass.
Returns
-------
List[Union[dbc.Row, html.Details]
The layout for the input block.
"""
return [
dbc.Row(
[
Expand Down Expand Up @@ -97,7 +139,9 @@ def get_input_layout(register):
[
dbc.Label("Parsimony coefficient"),
help_button(
"Penalizes the complexity of the resulting formulas."
"Penalizes the complexity of the resulting formulas. The "
"higher the value, the higher the penalty on the "
"complexity will be, resulting in simpler formulas."
),
dcc.Slider(
id=register("parsimony", "value", type=int),
Expand Down Expand Up @@ -180,7 +224,19 @@ def get_input_layout(register):
),
]

def load_inputs(self):
def load_inputs(self) -> Dict[str, Dict[str, Any]]:
"""
Load the content for the defined inputs in 'get_input_layout' and 'get_filter_layout'.
This method is necessary to pre-load contents for the inputs.
If the plugin is called for the first time, or there are no results in the cache,
the plugin gets its content from this method.
Returns
-------
Dict[str, Dict[str, Any]]
Content to be filled.
"""
return {
"parsimony": {"value": "-4"},
"generations": {"value": "10"},
Expand All @@ -192,7 +248,30 @@ def load_inputs(self):
},
}

def load_dependency_inputs(self, run, previous_inputs, inputs):
def load_dependency_inputs(self, run, previous_inputs, inputs) -> Dict[str, Any]: # type: ignore # noqa: E501
"""
Work like 'load_inputs' but called after inputs have changed.
Note
----
Only the changes have to be returned. The returned dictionary
will be merged with the inputs.
Parameters
----------
run
The selected run.
inputs
Current content of the inputs.
previous_inputs
Previous content of the inputs.
Not used in this specific function.
Returns
-------
Dict[str, Any]
Dictionary with the changes.
"""
objective_names = run.get_objective_names()
objective_ids = run.get_objective_ids()
objective_options = get_select_options(objective_names, objective_ids)
Expand Down Expand Up @@ -235,8 +314,37 @@ def load_dependency_inputs(self, run, previous_inputs, inputs):
}

@staticmethod
def process(run, inputs):
# Surrogate
def process(run, inputs) -> Dict[str, Any]: # type: ignore
"""
Return raw data based on a run and the input data.
Warning
-------
The returned data must be JSON serializable.
Note
----
The passed inputs are cleaned and therefore differ
compared to 'load_inputs' or 'load_dependency_inputs'.
Please see '_clean_inputs' for more information.
Parameters
----------
run
The run to process.
inputs
The input data.
Returns
-------
Dict[str, Any]
A serialized dictionary.
Raises
------
RuntimeError
If the objective is None.
"""
hp_names = run.configspace.get_hyperparameter_names()
objective = run.get_objective(inputs["objective_id"])
budget = run.get_budget(inputs["budget_id"])
Expand Down Expand Up @@ -276,7 +384,7 @@ def process(run, inputs):
idxs += [idx2]

num_samples = SAMPLES_PER_HP * len(X)
# We limit the samples to max 10k
# The samples are limited to max 10k
if num_samples > MAX_SAMPLES:
num_samples = MAX_SAMPLES

Expand All @@ -296,18 +404,19 @@ def process(run, inputs):
x_ice = pdp._ice.x_ice.tolist()
y_ice = pdp._ice.y_ice.tolist()

# We have to cut the ICE curves because it's too much data
# The ICE curves have to be cut because it's too much data
if len(x_ice) > MAX_SHOWN_SAMPLES:
x_ice = x_ice[:MAX_SHOWN_SAMPLES]
y_ice = y_ice[:MAX_SHOWN_SAMPLES]

if len(selected_hyperparameters) < len(hp_names):
# If number of hyperparameters to explain is smaller than number of hyperparameters optimizes,
# use PDP to train the symbolic explanation
# If number of hyperparameters to explain is smaller than number of hyperparameters
# optimizes, use PDP to train the symbolic explanation
x_symbolic = x_pdp
y_train = y_pdp
else:
# Else, use random samples evaluated with the surrogate model to train the symbolic explanation
# Else, use random samples evaluated with the surrogate model to train the symbolic
# explanation
cs = surrogate_model.config_space
random_samples = np.asarray(
[
Expand Down Expand Up @@ -365,13 +474,54 @@ def process(run, inputs):
}

@staticmethod
def get_output_layout(register):
return [dcc.Graph(register("symb_graph", "figure"), style={"height": Config.FIGURE_HEIGHT}),
dcc.Graph(register("pdp_graph", "figure"), style={"height": Config.FIGURE_HEIGHT})]
def get_output_layout(register: Callable) -> List[dcc.Graph]:
"""
Get the layout for the output block.
Parameters
----------
register : Callable
Method to register outputs.
The register_input function is located in the Plugin superclass.
Returns
-------
List[dcc.Graph]
Layout for the output block.
"""
return [
dcc.Graph(register("symb_graph", "figure"), style={"height": Config.FIGURE_HEIGHT}),
dcc.Graph(register("pdp_graph", "figure"), style={"height": Config.FIGURE_HEIGHT}),
]

@staticmethod
def load_outputs(run, inputs, outputs):
# Parse inputs
def load_outputs(run, inputs, outputs) -> List[go.Figure]: # type: ignore
"""
Read the raw data and prepare it for the layout.
Note
----
The passed inputs are cleaned and therefore differ
compared to 'load_inputs' or 'load_dependency_inputs'.
Please see '_clean_inputs' for more information.
Parameters
----------
run
The selected run.
inputs
Input and filter values from the user.
outputs
Raw output from the run.
Returns
-------
List[go.Figure]
The figure of the Symbolic Explanation and the Partial Dependency Plot (PDP) leveraged
for training in the case that the number of hyperparameters to be explained is smaller
than the number of hyperparameters that was optimized, else, a Partial Dependency Plot
(PDP) for comparison.
"""
hp1_name = inputs["hyperparameter_name_1"]
hp1_idx = run.configspace.get_idx_by_hyperparameter_name(hp1_name)
hp1 = run.configspace.get_hyperparameter(hp1_name)
Expand Down Expand Up @@ -454,10 +604,8 @@ def load_outputs(run, inputs, outputs):
else:
pdp_title = "Partial Dependency for comparison:"

figure2 = PartialDependencies.get_pdp_figure(run, inputs, outputs,
show_confidence=False,
show_ice=False,
title=pdp_title
)
figure2 = PartialDependencies.get_pdp_figure(
run, inputs, outputs, show_confidence=False, show_ice=False, title=pdp_title
)

return [figure1, figure2]
6 changes: 3 additions & 3 deletions deepcave/utils/logging.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ formatters:
handlers:
console:
class: logging.StreamHandler
level: DEBUG
level: INFO
formatter: simple
stream: ext://sys.stdout
loggers:
src.plugins:
level: DEBUG
level: INFO
handlers: [ console ]
propagate: no
matplotlib:
Expand All @@ -22,6 +22,6 @@ loggers:
handlers: [ console ]
propagate: no
root:
level: DEBUG
level: INFO
handlers: [console]
disable_existing_loggers: true
Loading

0 comments on commit da987e8

Please sign in to comment.