Skip to content

Commit

Permalink
Make "device" attribute of contextual models a property instead (#1611)
Browse files Browse the repository at this point in the history
Summary:
## Motivation

As of cornellius-gp/gpytorch#2234, the parent class of BoTorch kernels now has a property "device." This means that if a subclass tries to set `self.device`, it will error. This is why the BoTorch CI is currently breaking: https://github.com/pytorch/botorch/actions/runs/3841992968/jobs/6542850176

Pull Request resolved: #1611

Test Plan: Tests should pass

Reviewed By: saitcakmak, Balandat

Differential Revision: D42354199

Pulled By: esantorella

fbshipit-source-id: c53e5b508dd75f4116870cd30ab90d11cd3eb573
  • Loading branch information
docusaurus-bot authored and facebook-github-bot committed Jan 5, 2023
1 parent 2998cfe commit b730a5f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 5 additions & 1 deletion botorch/models/kernels/contextual_lcea.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.decomposition = decomposition
self.batch_shape = batch_shape
self.train_embedding = train_embedding
self.device = device
self._device = device

num_param = len(next(iter(decomposition.values())))
self.context_list = list(decomposition.keys())
Expand Down Expand Up @@ -128,6 +128,10 @@ def __init__(
)
self.register_constraint("raw_outputscale_list", Positive())

@property
def device(self) -> Optional[torch.device]:
return self._device

@property
def outputscale_list(self) -> Tensor:
return self.raw_outputscale_list_constraint.transform(self.raw_outputscale_list)
Expand Down
6 changes: 5 additions & 1 deletion botorch/models/kernels/contextual_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(

super().__init__(batch_shape=batch_shape)
self.decomposition = decomposition
self.device = device
self._device = device

num_param = len(next(iter(decomposition.values())))
for active_parameters in decomposition.values():
Expand Down Expand Up @@ -86,6 +86,10 @@ def __init__(
)
self.kernel_dict = ModuleDict(self.kernel_dict)

@property
def device(self) -> Optional[torch.device]:
return self._device

def forward(
self,
x1: Tensor,
Expand Down

0 comments on commit b730a5f

Please sign in to comment.