Skip to content

Commit

Permalink
use trajectory/transitions instead of traj/trans
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Nov 14, 2024
1 parent 519f5b4 commit 9587f2a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from gfn.modules import GFNModule
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.prob_calculations import get_traj_pfs_and_pbs
from gfn.utils.prob_calculations import get_trajectory_pfs_and_pbs

TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_pfs_and_pbs(
ValueError: if the trajectories are backward.
AssertionError: when actions and states dimensions mismatch.
"""
return get_traj_pfs_and_pbs(
return get_trajectory_pfs_and_pbs(
self.pf, self.pb, trajectories, fill_value, recalculate_all_logprobs
)

Expand Down
4 changes: 2 additions & 2 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)
from gfn.utils.prob_calculations import get_trans_pfs_and_pbs
from gfn.utils.prob_calculations import get_transition_pfs_and_pbs


def check_compatibility(states, actions, transitions):
Expand Down Expand Up @@ -82,7 +82,7 @@ def logF_parameters(self):
def get_pfs_and_pbs(
self, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
return get_trans_pfs_and_pbs(
return get_transition_pfs_and_pbs(
self.pf, self.pb, transitions, recalculate_all_logprobs
)

Expand Down
20 changes: 10 additions & 10 deletions src/gfn/utils/prob_calculations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def check_cond_forward(
#########################


def get_traj_pfs_and_pbs(
def get_trajectory_pfs_and_pbs(
pf: GFNModule,
pb: GFNModule,
trajectories: Trajectories,
Expand All @@ -45,13 +45,13 @@ def get_traj_pfs_and_pbs(
# uncomment next line for debugging
# assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy)

log_pf_trajectories = get_traj_pfs(
log_pf_trajectories = get_trajectory_pfs(
pf,
trajectories,
fill_value=fill_value,
recalculate_all_logprobs=recalculate_all_logprobs,
)
log_pb_trajectories = get_traj_pbs(pb, trajectories, fill_value=fill_value)
log_pb_trajectories = get_trajectory_pbs(pb, trajectories, fill_value=fill_value)

assert log_pf_trajectories.shape == (
trajectories.max_length,
Expand All @@ -65,7 +65,7 @@ def get_traj_pfs_and_pbs(
return log_pf_trajectories, log_pb_trajectories


def get_traj_pfs(
def get_trajectory_pfs(
pf: GFNModule,
trajectories: Trajectories,
fill_value: float = 0.0,
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_traj_pfs(
return log_pf_trajectories


def get_traj_pbs(
def get_trajectory_pbs(
pb: GFNModule, trajectories: Trajectories, fill_value: float = 0.0
) -> torch.Tensor:
# Note the different mask for valid states and actions compared to the pf case.
Expand Down Expand Up @@ -160,7 +160,7 @@ def get_traj_pbs(
########################


def get_trans_pfs_and_pbs(
def get_transition_pfs_and_pbs(
pf: GFNModule,
pb: GFNModule,
transitions: Transitions,
Expand All @@ -169,16 +169,16 @@ def get_trans_pfs_and_pbs(
if transitions.is_backward:
raise ValueError("Backward transitions are not supported")

log_pf_transitions = get_trans_pfs(pf, transitions, recalculate_all_logprobs)
log_pb_transitions = get_trans_pbs(pb, transitions)
log_pf_transitions = get_transition_pfs(pf, transitions, recalculate_all_logprobs)
log_pb_transitions = get_transition_pbs(pb, transitions)

assert log_pf_transitions.shape == (transitions.n_transitions,)
assert log_pb_transitions.shape == (transitions.n_transitions,)

return log_pf_transitions, log_pb_transitions


def get_trans_pfs(
def get_transition_pfs(
pf: GFNModule, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> torch.Tensor:
states = transitions.states
Expand All @@ -203,7 +203,7 @@ def get_trans_pfs(
return log_pf_actions


def get_trans_pbs(pb: GFNModule, transitions: Transitions) -> torch.Tensor:
def get_transition_pbs(pb: GFNModule, transitions: Transitions) -> torch.Tensor:
# automatically removes invalid transitions (i.e. s_f -> s_f)
valid_next_states = transitions.next_states[~transitions.is_done]
non_exit_actions = transitions.actions[~transitions.actions.is_exit]
Expand Down

0 comments on commit 9587f2a

Please sign in to comment.