Skip to content

Commit

Permalink
Add pdp plot to symbolic explanation output
Browse files Browse the repository at this point in the history
  • Loading branch information
Sarah Krebs committed Jan 23, 2024
1 parent b3723aa commit 933a137
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 26 deletions.
16 changes: 12 additions & 4 deletions deepcave/plugins/hyperparameter/pdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def get_output_layout(register):
return dcc.Graph(register("graph", "figure"), style={"height": config.FIGURE_HEIGHT})

@staticmethod
def load_outputs(run, inputs, outputs):
def get_pdp_figure(run, inputs, outputs, show_confidence, show_ice, title=None):
# Parse inputs
hp1_name = inputs["hyperparameter_name_1"]
hp1_idx = run.configspace.get_idx_by_hyperparameter_name(hp1_name)
Expand All @@ -250,9 +250,6 @@ def load_outputs(run, inputs, outputs):
hp2_idx = run.configspace.get_idx_by_hyperparameter_name(hp2_name)
hp2 = run.configspace.get_hyperparameter(hp2_name)

show_confidence = inputs["show_confidence"]
show_ice = inputs["show_ice"]

objective = run.get_objective(inputs["objective_id"])
objective_name = objective.name

Expand Down Expand Up @@ -323,6 +320,7 @@ def load_outputs(run, inputs, outputs):
"yaxis": {
"title": objective_name,
},
"title": title
}
)
else:
Expand All @@ -349,10 +347,20 @@ def load_outputs(run, inputs, outputs):
xaxis=dict(tickvals=x_tickvals, ticktext=x_ticktext, title=hp1_name),
yaxis=dict(tickvals=y_tickvals, ticktext=y_ticktext, title=hp2_name),
margin=config.FIGURE_MARGIN,
title=title
)
)

figure = go.Figure(data=traces, layout=layout)
save_image(figure, "pdp.pdf")

return figure

@staticmethod
def load_outputs(run, inputs, outputs):
show_confidence = inputs["show_confidence"]
show_ice = inputs["show_ice"]

figure = PartialDependencies.get_pdp_figure(run, inputs, outputs, show_confidence, show_ice)

return figure
78 changes: 56 additions & 22 deletions deepcave/plugins/hyperparameter/symbolic_explanations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import signal

import dash_bootstrap_components as dbc
import numpy as np
import plotly.graph_objs as go
Expand All @@ -11,6 +9,7 @@
from deepcave import config
from deepcave.evaluators.epm.random_forest_surrogate import RandomForestSurrogate
from deepcave.plugins.static import StaticPlugin
from deepcave.plugins.hyperparameter.pdp import PartialDependencies
from deepcave.runs import Status
from deepcave.utils.layout import get_checklist_options, get_select_options, help_button
from deepcave.utils.styled_plotty import get_color, get_hyperparameter_ticks, save_image
Expand Down Expand Up @@ -295,9 +294,21 @@ def process(run, inputs):
num_samples=num_samples,
)

x = pdp.x_pdp[:, idxs].tolist()
x = pdp.x_pdp
y = pdp.y_pdp.tolist()

# Save PDP information for PDP plot as comparison
y_pdp = y
pdp_variances = pdp.y_variances.tolist()

# We have to cut the ICE curves because it's too much data
x_ice = pdp._ice.x_ice.tolist()
y_ice = pdp._ice.y_ice.tolist()

if len(x_ice) > MAX_SHOWN_SAMPLES:
x_ice = x_ice[:MAX_SHOWN_SAMPLES]
y_ice = y_ice[:MAX_SHOWN_SAMPLES]

else:
cs = surrogate_model.config_space
random_samples = np.asarray(
Expand All @@ -308,8 +319,9 @@ def process(run, inputs):
)
]
)
x = random_samples.tolist()
x = random_samples
y = surrogate_model.predict(random_samples)[0]
x_ice, y_ice, pdp_variances, y_pdp = [], [], [], []

symb_params = dict(
population_size=population_size,
Expand All @@ -323,7 +335,7 @@ def process(run, inputs):

# run SR on samples
symb_model = SymbolicRegressor(**symb_params)
symb_model.fit(x, y)
symb_model.fit(x[:, idxs], y)

try:
conv_expr = (
Expand All @@ -342,47 +354,59 @@ def process(run, inputs):
"the parsimony hyperparameter."
)

y_symbolic = symb_model.predict(x).tolist()
y_symbolic = symb_model.predict(x[:, idxs]).tolist()

return {"x": x, "y": y_symbolic, "expr": conv_expr}
return {
"x": x.tolist(),
"y": y_pdp,
"y_symbolic": y_symbolic,
"expr": conv_expr,
"variances": pdp_variances,
"x_ice": x_ice,
"y_ice": y_ice,
}

@staticmethod
def get_output_layout(register):
return dcc.Graph(register("graph", "figure"), style={"height": config.FIGURE_HEIGHT})
return [dcc.Graph(register("symb_graph", "figure"), style={"height": config.FIGURE_HEIGHT}),
dcc.Graph(register("pdp_graph", "figure"), style={"height": config.FIGURE_HEIGHT})]

@staticmethod
def load_outputs(run, inputs, outputs):
# Parse inputs
hp1_name = inputs["hyperparameter_name_1"]
hp1_idx = run.configspace.get_idx_by_hyperparameter_name(hp1_name)
hp1 = run.configspace.get_hyperparameter(hp1_name)

hp2_name = inputs["hyperparameter_name_2"]
hp2_idx = None
hp2 = None
if hp2_name is not None and hp2_name != "":
hp2_idx = run.configspace.get_idx_by_hyperparameter_name(hp2_name)
hp2 = run.configspace.get_hyperparameter(hp2_name)

objective = run.get_objective(inputs["objective_id"])
objective_name = objective.name

# Parse outputs
x = np.asarray(outputs["x"])
y = np.asarray(outputs["y"])
y_symbolic = np.asarray(outputs["y_symbolic"])
expr = outputs["expr"]

traces = []
traces1 = []
if hp2 is None: # 1D
traces += [
traces1 += [
go.Scatter(
x=x[:, 0],
y=y,
x=x[:, hp1_idx],
y=y_symbolic,
line=dict(color=get_color(0, 1)),
hoverinfo="skip",
showlegend=False,
)
]

tickvals, ticktext = get_hyperparameter_ticks(hp1)
layout = go.Layout(
layout1 = go.Layout(
{
"xaxis": {
"tickvals": tickvals,
Expand All @@ -396,12 +420,12 @@ def load_outputs(run, inputs, outputs):
}
)
else:
z = y
traces += [
z = y_symbolic
traces1 += [
go.Contour(
z=z,
x=x[:, 0],
y=x[:, 1],
x=x[:, hp1_idx],
y=x[:, hp2_idx],
colorbar=dict(
title=objective_name,
),
Expand All @@ -412,7 +436,7 @@ def load_outputs(run, inputs, outputs):
x_tickvals, x_ticktext = get_hyperparameter_ticks(hp1)
y_tickvals, y_ticktext = get_hyperparameter_ticks(hp2)

layout = go.Layout(
layout1 = go.Layout(
dict(
xaxis=dict(tickvals=x_tickvals, ticktext=x_ticktext, title=hp1_name),
yaxis=dict(tickvals=y_tickvals, ticktext=y_ticktext, title=hp2_name),
Expand All @@ -421,7 +445,17 @@ def load_outputs(run, inputs, outputs):
)
)

figure = go.Figure(data=traces, layout=layout)
save_image(figure, "symbolic_explanation.pdf")
figure1 = go.Figure(data=traces1, layout=layout1)
save_image(figure1, "symbolic_explanation.pdf")

return figure
if len(outputs["y_ice"]) > 0:
figure2 = PartialDependencies.get_pdp_figure(run, inputs, outputs,
show_confidence=False,
show_ice=True,
title="Partial Dependency Plot leveraged for training of "
"Symbolic Explanation:"
)

return [figure1, figure2]
else:
return [figure1, []]

0 comments on commit 933a137

Please sign in to comment.