Skip to content

Commit

Permalink
Add BoTorch_Modular to config, giving custom GP and acq func ability
Browse files Browse the repository at this point in the history
Now users can customize their GP (kernel, MLL, GP class) as well as the acq func in the config without writing any code. Giving users a lot more control and power from any language with little setup
  • Loading branch information
madeline-scyphers committed Jul 22, 2024
1 parent 0f8df06 commit 70876f5
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 3 deletions.
35 changes: 35 additions & 0 deletions boa/config/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@

import ax.early_stopping.strategies as early_stopping_strats
import ax.global_stopping.strategies as global_stopping_strats
import botorch.acquisition
import botorch.models
import gpytorch.kernels
import gpytorch.mlls
from ax.modelbridge.generation_node import GenerationStep
from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.service.utils.instantiation import TParameterRepresentation
from ax.service.utils.scheduler_options import SchedulerOptions

Expand Down Expand Up @@ -49,6 +54,36 @@ def _gen_strat_converter(gs: Optional[dict] = None) -> dict:
gs["steps"][i] = step
steps.append(step)
continue
if "model_kwargs" in step:
if "botorch_acqf_class" in step["model_kwargs"] and not isinstance(
step["model_kwargs"]["botorch_acqf_class"], botorch.acquisition.AcquisitionFunction
):
step["model_kwargs"]["botorch_acqf_class"] = getattr(
botorch.acquisition, step["model_kwargs"]["botorch_acqf_class"]
)

if "surrogate" in step["model_kwargs"]:
if "mll_class" in step["model_kwargs"]["surrogate"] and not isinstance(
step["model_kwargs"]["surrogate"]["mll_class"], gpytorch.mlls.MarginalLogLikelihood
):
step["model_kwargs"]["surrogate"]["mll_class"] = getattr(
gpytorch.mlls, step["model_kwargs"]["surrogate"]["mll_class"]
)
if "botorch_model_class" in step["model_kwargs"]["surrogate"] and not isinstance(
step["model_kwargs"]["surrogate"]["botorch_model_class"], botorch.models.model.Model
):
step["model_kwargs"]["surrogate"]["botorch_model_class"] = getattr(
botorch.models, step["model_kwargs"]["surrogate"]["botorch_model_class"]
)
if "covar_module_class" in step["model_kwargs"]["surrogate"] and not isinstance(
step["model_kwargs"]["surrogate"]["covar_module_class"], gpytorch.kernels.Kernel
):
step["model_kwargs"]["surrogate"]["covar_module_class"] = getattr(
gpytorch.kernels, step["model_kwargs"]["surrogate"]["covar_module_class"]
)

step["model_kwargs"]["surrogate"] = Surrogate(**step["model_kwargs"]["surrogate"])

try:
step["model"] = Models[step["model"]]
except KeyError:
Expand Down
42 changes: 41 additions & 1 deletion boa/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
from ax.storage.json_store.registry import CORE_DECODER_REGISTRY, CORE_ENCODER_REGISTRY
from __future__ import annotations

from typing import Type

import botorch.acquisition
import gpytorch.kernels
from ax.storage.botorch_modular_registry import (
ACQUISITION_FUNCTION_REGISTRY,
CLASS_TO_REGISTRY,
CLASS_TO_REVERSE_REGISTRY,
)
from ax.storage.json_store.registry import (
CORE_CLASS_DECODER_REGISTRY,
CORE_CLASS_ENCODER_REGISTRY,
CORE_DECODER_REGISTRY,
CORE_ENCODER_REGISTRY,
botorch_modular_to_dict,
class_from_json,
)


def config_to_dict(inst):
Expand All @@ -15,3 +33,25 @@ def _add_common_encodes_and_decodes():
CORE_ENCODER_REGISTRY[BOAConfig] = config_to_dict
# CORE_DECODER_REGISTRY[BOAConfig.__name__] = BOAConfig
CORE_DECODER_REGISTRY[MetricType.__name__] = MetricType

CORE_CLASS_DECODER_REGISTRY["Type[Kernel]"] = class_from_json
CORE_CLASS_ENCODER_REGISTRY[gpytorch.kernels.Kernel] = botorch_modular_to_dict

KERNEL_REGISTRY = {getattr(gpytorch.kernels, kernel): kernel for kernel in gpytorch.kernels.__all__}

REVERSE_KERNEL_REGISTRY: dict[str, Type[gpytorch.kernels.Kernel]] = {v: k for k, v in KERNEL_REGISTRY.items()}

CLASS_TO_REGISTRY[gpytorch.kernels.Kernel] = KERNEL_REGISTRY
CLASS_TO_REVERSE_REGISTRY[gpytorch.kernels.Kernel] = REVERSE_KERNEL_REGISTRY

for acq_func_name in botorch.acquisition.__all__:
acq_func = getattr(botorch.acquisition, acq_func_name)
if acq_func not in ACQUISITION_FUNCTION_REGISTRY:
ACQUISITION_FUNCTION_REGISTRY[acq_func] = acq_func_name

REVERSE_ACQUISITION_FUNCTION_REGISTRY: dict[str, Type[botorch.acquisition.AcquisitionFunction]] = {
v: k for k, v in ACQUISITION_FUNCTION_REGISTRY.items()
}

CLASS_TO_REGISTRY[botorch.acquisition.AcquisitionFunction] = ACQUISITION_FUNCTION_REGISTRY
CLASS_TO_REVERSE_REGISTRY[botorch.acquisition.AcquisitionFunction] = REVERSE_ACQUISITION_FUNCTION_REGISTRY
33 changes: 33 additions & 0 deletions tests/1unit_tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import botorch.acquisition
import botorch.models
import gpytorch.kernels
import gpytorch.mlls
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import Models

Expand Down Expand Up @@ -41,3 +45,32 @@ def test_auto_gen_use_saasbo(saasbo_config, tmp_path):
assert "SAASBO" in gs.name
else:
assert "FullyBayesian" in gs.name


def test_modular_botorch(gen_strat_modular_botorch_config, tmp_path):
controller = Controller(
config=gen_strat_modular_botorch_config,
wrapper=ScriptWrapper(config=gen_strat_modular_botorch_config, experiment_dir=tmp_path),
)
exp = get_experiment(
config=controller.config, runner=WrappedJobRunner(wrapper=controller.wrapper), wrapper=controller.wrapper
)
gs = get_generation_strategy(config=controller.config, experiment=exp)
cfg_botorch_modular = gen_strat_modular_botorch_config.orig_config["generation_strategy"]["steps"][-1]
step = gs._steps[-1]
assert step.model == Models.BOTORCH_MODULAR
mdl_kw = step.model_kwargs
assert mdl_kw["botorch_acqf_class"] == getattr(
botorch.acquisition, cfg_botorch_modular["model_kwargs"]["botorch_acqf_class"]
)
assert mdl_kw["acquisition_options"] == cfg_botorch_modular["model_kwargs"]["acquisition_options"]

assert mdl_kw["surrogate"].mll_class == getattr(
gpytorch.mlls, cfg_botorch_modular["model_kwargs"]["surrogate"]["mll_class"]
)
assert mdl_kw["surrogate"].botorch_model_class == getattr(
botorch.models, cfg_botorch_modular["model_kwargs"]["surrogate"]["botorch_model_class"]
)
assert mdl_kw["surrogate"].covar_module_class == getattr(
gpytorch.kernels, cfg_botorch_modular["model_kwargs"]["surrogate"]["covar_module_class"]
)
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def gen_strat1_config():
return BOAConfig.from_jsonlike(file=config_path)


@pytest.fixture
def gen_strat_modular_botorch_config():
config_path = TEST_DIR / f"scripts/other_langs/r_package_streamlined/config_modular_botorch.yaml"
return BOAConfig.from_jsonlike(file=config_path)


@pytest.fixture
def synth_config():
config_path = TEST_CONFIG_DIR / "test_config_synth.yaml"
Expand Down Expand Up @@ -233,3 +239,9 @@ def r_streamlined(tmp_path_factory, cd_to_root_and_back_session):
config_path = TEST_DIR / f"scripts/other_langs/r_package_streamlined/config.yaml"

yield cli_main(split_shell_command(f"--config-path {config_path} -td"), standalone_mode=False)


@pytest.fixture(scope="session")
def r_streamlined_botorch_modular(tmp_path_factory, cd_to_root_and_back_session):
config_path = TEST_DIR / f"scripts/other_langs/r_package_streamlined/config_modular_botorch.yaml"
return cli_main(split_shell_command(f"--config-path {config_path} -td"), standalone_mode=False)
7 changes: 5 additions & 2 deletions tests/integration_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ def test_calling_command_line_test_script_doesnt_error_out_and_produces_correct_

# parametrize the test to use the full version (all scripts) or the light version (only run_model.R)
# or parametrize the test to use the streamlined version (doesn't use trial_status.json, only use output.json)
# the botorch modular version is the same as the streamlined version, but also uses botorch modular
# which uses a custom kernel, acquisition function, mll and botorch model class
# (which can customize the GP process even more)
@pytest.mark.parametrize(
"r_scripts_run",
["r_full", "r_light", "r_streamlined"],
["r_full", "r_light", "r_streamlined", "r_streamlined_botorch_modular"],
)
@pytest.mark.skipif(not R_INSTALLED, reason="requires R to be installed")
def test_calling_command_line_r_test_scripts(r_scripts_run, request):
Expand All @@ -92,7 +95,7 @@ def test_calling_command_line_r_test_scripts(r_scripts_run, request):
assert "param_names" in data
assert "metric_properties" in data

if "r_streamlined" == r_scripts_run:
if r_scripts_run in ("r_streamlined", "r_streamlined_botorch_modular"):
with cd_and_cd_back(scheduler.wrapper.config_path.parent):

pre_num_trials = len(scheduler.experiment.trials)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
objective:
metrics:
- name: metric
scheduler:
n_trials: 15

parameters:
x0:
'bounds': [ 0, 1 ]
'type': 'range'
'value_type': 'float'
x1:
'bounds': [ 0, 1]
'type': 'range'
'value_type': 'float'
x2:
'bounds': [ 0, 1 ]
'type': 'range'
'value_type': 'float'
x3:
'bounds': [ 0, 1]
'type': 'range'
'value_type': 'float'
x4:
'bounds': [ 0, 1 ]
'type': 'range'
'value_type': 'float'
x5:
'bounds': [ 0, 1]
'type': 'range'
'value_type': 'float'

script_options:
# notice here that this is a shell command
# this is what BOA will do to launch your script
# it will also pass as a command line argument the current trial directory
# that is being parameterized

# This can either be a relative path or absolute path
# (by default when BOA launches from a config file
# it uses the config file directory as your working directory)
# here config.yaml and run_model.R are in the same directory
run_model: Rscript run_model.R
exp_name: "r_streamlined_botorch_modular"

generation_strategy:
steps:
- model: SOBOL
num_trials: 5
- model: BOTORCH_MODULAR
num_trials: -1 # No limitation on how many trials should be produced from this step
model_kwargs:
surrogate:
botorch_model_class: SingleTaskGP # BoTorch model class name

covar_module_class: RBFKernel # GPyTorch kernel class name
mll_class: LeaveOneOutPseudoLikelihood
botorch_acqf_class: qUpperConfidenceBound # BoTorch acquisition function class name
acquisition_options:
beta: 0.5

5 comments on commit 70876f5

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
__main__.py550%1–2, 7, 10–11
async_opt.py1051882%137, 181, 188, 191, 211–212, 218–219, 221–222, 224, 231–236, 240
ax_instantiation_utils.py47197%120
controller.py76396%73, 84, 176
definitions.py80100% 
instantiation_base.py45197%48
metaclasses.py72395%50–51, 57
plot.py12120%17, 19–20, 22–23, 26, 29, 36, 42–43, 46–47
plotting.py1412880%51–52, 59, 94–95, 251–253, 301–305, 345, 429–432, 434–441, 443, 448
registry.py260100% 
runner.py53492%54, 88–90
scheduler.py811482%38, 128–131, 138–139, 153, 160–161, 168–169, 264–265
storage.py1241587%83, 115–117, 177, 195, 273–281
template.py31583%52–56
utils.py992079%179, 193–194, 219, 229–233, 235–237, 239, 241, 245–250
config
   __main__.py00100% 
   config.py3073488%33, 211, 217, 221, 223–224, 226, 282, 409, 412, 420–421, 429, 603, 605, 607, 613–617, 619, 690, 723, 744, 751–752, 760, 773, 783–784, 796, 821, 824
   converters.py941386%19, 30, 49, 54–56, 89–90, 103, 105, 111, 119, 135
metrics
   metric_funcs.py34488%58, 80–81, 83
   metrics.py1011585%127, 304–305, 310–311, 314, 320, 331–333, 337–341
   modular_metric.py1382184%40–43, 45–52, 66, 137, 148, 172, 210, 265–266, 287–288
   synthetic_funcs.py39489%31, 35, 58, 65
scripts
   moo.py30196%44
   run_branin.py34197%56
   script_wrappers.py31293%57–58
   synth_func_cli.py210100% 
wrappers
   base_wrapper.py1953084%63–64, 89–92, 94, 102, 110, 115, 147, 156, 167, 215–216, 218, 267, 269, 271, 275, 330, 402, 416, 589, 591–594, 602, 612
   script_wrapper.py1191190%196, 207–208, 276, 279–281, 285, 326, 345, 350
   synthetic_wrapper.py16287%16, 28
   wrapper_utils.py150994%148–149, 247, 299, 423–425, 433–434
TOTAL223427687% 

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
__main__.py550%1–2, 7, 10–11
async_opt.py1051882%137, 181, 188, 191, 211–212, 218–219, 221–222, 224, 231–236, 240
ax_instantiation_utils.py47197%120
controller.py76396%73, 84, 176
definitions.py80100% 
instantiation_base.py45197%48
metaclasses.py72395%50–51, 57
plot.py12120%17, 19–20, 22–23, 26, 29, 36, 42–43, 46–47
plotting.py1412880%51–52, 59, 94–95, 251–253, 301–305, 345, 429–432, 434–441, 443, 448
registry.py260100% 
runner.py53492%54, 88–90
scheduler.py811482%38, 128–131, 138–139, 153, 160–161, 168–169, 264–265
storage.py1241587%83, 115–117, 177, 195, 273–281
template.py31583%52–56
utils.py992079%179, 193–194, 219, 229–233, 235–237, 239, 241, 245–250
config
   __main__.py00100% 
   config.py3073488%33, 211, 217, 221, 223–224, 226, 282, 409, 412, 420–421, 429, 603, 605, 607, 613–617, 619, 690, 723, 744, 751–752, 760, 773, 783–784, 796, 821, 824
   converters.py941386%19, 30, 49, 54–56, 89–90, 103, 105, 111, 119, 135
metrics
   metric_funcs.py34488%58, 80–81, 83
   metrics.py1011585%127, 304–305, 310–311, 314, 320, 331–333, 337–341
   modular_metric.py1382184%40–43, 45–52, 66, 137, 148, 172, 210, 265–266, 287–288
   synthetic_funcs.py39489%31, 35, 58, 65
scripts
   moo.py30196%44
   run_branin.py34197%56
   script_wrappers.py31293%57–58
   synth_func_cli.py210100% 
wrappers
   base_wrapper.py1953084%63–64, 89–92, 94, 102, 110, 115, 147, 156, 167, 215–216, 218, 267, 269, 271, 275, 330, 402, 416, 589, 591–594, 602, 612
   script_wrapper.py1191190%196, 207–208, 276, 279–281, 285, 326, 345, 350
   synthetic_wrapper.py16287%16, 28
   wrapper_utils.py150994%148–149, 247, 299, 423–425, 433–434
TOTAL223427687% 

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
__main__.py550%1–2, 7, 10–11
async_opt.py1051882%137, 181, 188, 193, 211–212, 218–219, 221–222, 224, 231–236, 240
ax_instantiation_utils.py47197%120
controller.py76396%73, 84, 176
definitions.py80100% 
instantiation_base.py45197%48
metaclasses.py72395%50–51, 57
plot.py12120%17, 19–20, 22–23, 26, 29, 36, 42–43, 46–47
plotting.py1412880%51–52, 59, 94–95, 251–253, 301–305, 345, 429–432, 434–441, 443, 448
registry.py260100% 
runner.py53492%54, 88–90
scheduler.py811482%38, 128–131, 138–139, 153, 160–161, 168–169, 264–265
storage.py1241587%83, 115–117, 177, 195, 273–281
template.py31583%52–56
utils.py992079%179, 193–194, 219, 229–233, 235–237, 239, 241, 245–250
config
   __main__.py00100% 
   config.py3073488%33, 211, 217, 221, 223–224, 226, 282, 409, 412, 420–421, 429, 603, 605, 607, 613–617, 619, 690, 723, 744, 751–752, 760, 773, 783–784, 796, 821, 824
   converters.py941386%19, 30, 49, 54–56, 89–90, 103, 105, 111, 119, 135
metrics
   metric_funcs.py34488%58, 80–81, 83
   metrics.py1011585%127, 304–305, 310–311, 314, 320, 331–333, 337–341
   modular_metric.py1382184%40–43, 45–52, 66, 137, 148, 172, 210, 265–266, 287–288
   synthetic_funcs.py39489%31, 35, 58, 65
scripts
   moo.py30196%44
   run_branin.py34197%56
   script_wrappers.py31293%57–58
   synth_func_cli.py210100% 
wrappers
   base_wrapper.py1953084%63–64, 89–92, 94, 102, 110, 115, 147, 156, 167, 215–216, 218, 267, 269, 271, 275, 330, 402, 416, 589, 591–594, 602, 612
   script_wrapper.py1191190%196, 207–208, 276, 279–281, 285, 326, 345, 350
   synthetic_wrapper.py16287%16, 28
   wrapper_utils.py150994%148–149, 247, 299, 423–425, 433–434
TOTAL223427687% 

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
__main__.py550%1–2, 7, 10–11
async_opt.py1051882%137, 181, 188, 193, 211–212, 218–219, 221–222, 224, 231–236, 240
ax_instantiation_utils.py470100% 
controller.py76396%73, 84, 176
definitions.py80100% 
instantiation_base.py45197%48
metaclasses.py72395%50–51, 57
plot.py12120%17, 19–20, 22–23, 26, 29, 36, 42–43, 46–47
plotting.py1412880%51–52, 59, 94–95, 251–253, 301–305, 345, 429–432, 434–441, 443, 448
registry.py260100% 
runner.py53492%54, 88–90
scheduler.py811482%38, 128–131, 138–139, 153, 160–161, 168–169, 264–265
storage.py1241587%83, 115–117, 177, 195, 273–281
template.py31583%52–56
utils.py992079%179, 193–194, 219, 229–233, 235–237, 239, 241, 245–250
config
   __main__.py00100% 
   config.py3073488%33, 211, 217, 221, 223–224, 226, 282, 409, 412, 420–421, 429, 603, 605, 607, 613–617, 619, 690, 723, 744, 751–752, 760, 773, 783–784, 796, 821, 824
   converters.py941386%19, 30, 49, 54–56, 89–90, 103, 105, 111, 119, 135
metrics
   metric_funcs.py34488%58, 80–81, 83
   metrics.py1011585%127, 304–305, 310–311, 314, 320, 331–333, 337–341
   modular_metric.py1382184%40–43, 45–52, 66, 137, 148, 172, 210, 265–266, 287–288
   synthetic_funcs.py39489%31, 35, 58, 65
scripts
   moo.py30196%44
   run_branin.py34197%56
   script_wrappers.py31293%57–58
   synth_func_cli.py210100% 
wrappers
   base_wrapper.py1953084%63–64, 89–92, 94, 102, 110, 115, 147, 156, 167, 215–216, 218, 267, 269, 271, 275, 330, 402, 416, 589, 591–594, 602, 612
   script_wrapper.py1191190%196, 207–208, 276, 279–281, 285, 326, 345, 350
   synthetic_wrapper.py16287%16, 28
   wrapper_utils.py150994%148–149, 247, 299, 423–425, 433–434
TOTAL223427587% 

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
__main__.py550%1–2, 7, 10–11
async_opt.py1051882%137, 181, 188, 193, 211–212, 218–219, 221–222, 224, 231–236, 240
ax_instantiation_utils.py470100% 
controller.py76396%73, 84, 176
definitions.py80100% 
instantiation_base.py45197%48
metaclasses.py72395%50–51, 57
plot.py12120%17, 19–20, 22–23, 26, 29, 36, 42–43, 46–47
plotting.py1412880%51–52, 59, 94–95, 251–253, 301–305, 345, 429–432, 434–441, 443, 448
registry.py260100% 
runner.py53492%54, 88–90
scheduler.py811482%38, 128–131, 138–139, 153, 160–161, 168–169, 264–265
storage.py1241587%83, 115–117, 177, 195, 273–281
template.py31583%52–56
utils.py992079%179, 193–194, 219, 229–233, 235–237, 239, 241, 245–250
config
   __main__.py00100% 
   config.py3073488%33, 211, 217, 221, 223–224, 226, 282, 409, 412, 420–421, 429, 603, 605, 607, 613–617, 619, 690, 723, 744, 751–752, 760, 773, 783–784, 796, 821, 824
   converters.py941386%19, 30, 49, 54–56, 89–90, 103, 105, 111, 119, 135
metrics
   metric_funcs.py34488%58, 80–81, 83
   metrics.py1011585%127, 304–305, 310–311, 314, 320, 331–333, 337–341
   modular_metric.py1382184%40–43, 45–52, 66, 137, 148, 172, 210, 265–266, 287–288
   synthetic_funcs.py39489%31, 35, 58, 65
scripts
   moo.py30196%44
   run_branin.py34197%56
   script_wrappers.py31293%57–58
   synth_func_cli.py210100% 
wrappers
   base_wrapper.py1953084%63–64, 89–92, 94, 102, 110, 115, 147, 156, 167, 215–216, 218, 267, 269, 271, 275, 330, 402, 416, 589, 591–594, 602, 612
   script_wrapper.py1191190%196, 207–208, 276, 279–281, 285, 326, 345, 350
   synthetic_wrapper.py16287%16, 28
   wrapper_utils.py150994%148–149, 247, 299, 423–425, 433–434
TOTAL223427587% 

Please sign in to comment.