Skip to content

Commit

Permalink
Merge pull request #217 from GFNOrg/remove_unused_method
Browse files Browse the repository at this point in the history
small fixes to seeding and hypergrid tests, docstring improvements to resolve confusion of `get_terminating_states_indices` purpose
  • Loading branch information
josephdviviano authored Nov 15, 2024
2 parents b5d909e + 3041fb2 commit 6f132a8
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
return states_indices

def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Returns the indices of the terminating states.
"""Get the indices of the terminating states in the canonical ordering from the submitted states.
Args:
states: DiscreteStates object representing the states.
Expand Down
5 changes: 3 additions & 2 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def update_masks(self, states: type[DiscreteStates]) -> None:
"""Update the masks based on the current states."""
# Not allowed to take any action beyond the environment height, but
# allow early termination.
# TODO: do we need to handle the conditional case here?
states.set_nonexit_action_masks(
states.tensor == self.height - 1,
allow_exit=True,
Expand Down Expand Up @@ -174,9 +175,9 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
return indices

def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Get the indices of the terminating states in the canonical ordering.
"""Get the indices of the terminating states in the canonical ordering from the submitted states.
Returns the indices of the terminating states in the canonical ordering as a tensor of shape `batch_shape`.
Canonical ordering is returned as a tensor of shape `batch_shape`.
"""
return self.get_states_indices(states)

Expand Down
3 changes: 3 additions & 0 deletions src/gfn/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:
np.random.seed(seed)
torch.manual_seed(seed)

if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)

# These are only set when we care about reproducibility over performance.
if not performance_mode:
torch.backends.cudnn.deterministic = True
Expand Down
8 changes: 4 additions & 4 deletions tutorials/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ class BoxArgs(CommonArgs):
@pytest.mark.parametrize("ndim", [2, 4])
@pytest.mark.parametrize("height", [8, 16])
def test_hypergrid(ndim: int, height: int):
n_trajectories = 32000 if ndim == 2 else 16000
n_trajectories = 64000 # if ndim == 2 else 16000
args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories)
final_l1_dist = train_hypergrid_main(args)
if ndim == 2 and height == 8:
assert np.isclose(final_l1_dist, 8.78e-4, atol=1e-3)
elif ndim == 2 and height == 16:
assert np.isclose(final_l1_dist, 4.56e-4, atol=1e-4)
assert np.isclose(final_l1_dist, 2.62e-4, atol=1e-3)
elif ndim == 4 and height == 8:
assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-4)
assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-3)
elif ndim == 4 and height == 16:
assert np.isclose(final_l1_dist, 2.45e-5, atol=1e-5)
assert np.isclose(final_l1_dist, 6.89e-6, atol=1e-5)


@pytest.mark.parametrize("ndim", [2, 4])
Expand Down

0 comments on commit 6f132a8

Please sign in to comment.