Skip to content

Commit

Permalink
separate logprob calculations from GFlowNet objective by making them …
Browse files Browse the repository at this point in the history
…into util functions
  • Loading branch information
hyeok9855 committed Nov 14, 2024
1 parent 4a9f112 commit 519f5b4
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 151 deletions.
106 changes: 4 additions & 102 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
from gfn.modules import GFNModule
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.common import has_log_probs
from gfn.utils.handlers import (
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)
from gfn.utils.prob_calculations import get_traj_pfs_and_pbs

TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
Expand Down Expand Up @@ -145,6 +141,7 @@ def get_pfs_and_pbs(
trajectories: Trajectories to evaluate.
fill_value: Value to use for invalid states (i.e. $s_f$ that is added to
shorter trajectories).
recalculate_all_logprobs: Whether to re-evaluate all logprobs.
Returns: A tuple of float tensors of shape (max_length, n_trajectories) containing
the log_pf and log_pb for each action in each trajectory. The first one can be None.
Expand All @@ -153,103 +150,9 @@ def get_pfs_and_pbs(
ValueError: if the trajectories are backward.
AssertionError: when actions and states dimensions mismatch.
"""
# fill value is the value used for invalid states (sink state usually)
if trajectories.is_backward:
raise ValueError("Backward trajectories are not supported")

valid_states = trajectories.states[~trajectories.states.is_sink_state]
valid_actions = trajectories.actions[~trajectories.actions.is_dummy]

# uncomment next line for debugging
# assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy)

if valid_states.batch_shape != tuple(valid_actions.batch_shape):
raise AssertionError("Something wrong happening with log_pf evaluations")

if has_log_probs(trajectories) and not recalculate_all_logprobs:
log_pf_trajectories = trajectories.log_probs
else:
if (
trajectories.estimator_outputs is not None
and not recalculate_all_logprobs
):
estimator_outputs = trajectories.estimator_outputs[
~trajectories.actions.is_dummy
]
else:
if trajectories.conditioning is not None:
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state]

# Here, we pass all valid states, i.e., non-sink states.
with has_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states, masked_cond)
else:
# Here, we pass all valid states, i.e., non-sink states.
with no_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states)

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = self.pf.to_probability_distribution(
valid_states, estimator_outputs
).log_prob(
valid_actions.tensor
) # Using the actions sampled off-policy.
log_pf_trajectories = torch.full_like(
trajectories.actions.tensor[..., 0],
fill_value=fill_value,
dtype=torch.float,
)
log_pf_trajectories[~trajectories.actions.is_dummy] = valid_log_pf_actions

non_initial_valid_states = valid_states[~valid_states.is_initial_state]
non_exit_valid_actions = valid_actions[~valid_actions.is_exit]

# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
if trajectories.conditioning is not None:
# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state][~valid_states.is_initial_state]

# Pass all valid states, i.e., non-sink states, except the initial state.
with has_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states, masked_cond)
else:
# Pass all valid states, i.e., non-sink states, except the initial state.
with no_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states)

valid_log_pb_actions = self.pb.to_probability_distribution(
non_initial_valid_states, estimator_outputs
).log_prob(non_exit_valid_actions.tensor)

log_pb_trajectories = torch.full_like(
trajectories.actions.tensor[..., 0],
fill_value=fill_value,
dtype=torch.float,
return get_traj_pfs_and_pbs(
self.pf, self.pb, trajectories, fill_value, recalculate_all_logprobs
)
log_pb_trajectories_slice = torch.full_like(
valid_actions.tensor[..., 0], fill_value=fill_value, dtype=torch.float
)
log_pb_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions
log_pb_trajectories[~trajectories.actions.is_dummy] = log_pb_trajectories_slice

assert log_pf_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
assert log_pb_trajectories.shape == (
trajectories.max_length,
trajectories.n_trajectories,
)
return log_pf_trajectories, log_pb_trajectories

def get_trajectories_scores(
self,
Expand All @@ -265,7 +168,6 @@ def get_trajectories_scores(
Returns: A tuple of float tensors of shape (n_trajectories,)
containing the total log_pf, total log_pb, and the total
log-likelihood of the trajectories.
"""
log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs(
trajectories, recalculate_all_logprobs=recalculate_all_logprobs
Expand Down
67 changes: 21 additions & 46 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)
from gfn.utils.prob_calculations import get_trans_pfs_and_pbs


def check_compatibility(states, actions, transitions):
Expand Down Expand Up @@ -78,6 +79,13 @@ def logF_parameters(self):
)
)

def get_pfs_and_pbs(
self, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
return get_trans_pfs_and_pbs(
self.pf, self.pb, transitions, recalculate_all_logprobs
)

def get_scores(
self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand All @@ -101,70 +109,39 @@ def get_scores(
"""
if transitions.is_backward:
raise ValueError("Backward transitions are not supported")

states = transitions.states
actions = transitions.actions

# uncomment next line for debugging
# assert transitions.states.is_sink_state.equal(transitions.actions.is_dummy)
check_compatibility(states, actions, transitions)

if has_log_probs(transitions) and not recalculate_all_logprobs:
valid_log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions, with optional conditioning.
# TODO: Inefficient duplication in case of tempered policy
# The Transitions container should then have some
# estimator_outputs attribute as well, to avoid duplication here ?
# See (#156).
if transitions.conditioning is not None:
with has_conditioning_exception_handler("pf", self.pf):
module_output = self.pf(states, transitions.conditioning)
else:
with no_conditioning_exception_handler("pf", self.pf):
module_output = self.pf(states)

valid_log_pf_actions = self.pf.to_probability_distribution(
states, module_output
).log_prob(actions.tensor)
log_pf_actions, log_pb_actions = self.get_pfs_and_pbs(
transitions, recalculate_all_logprobs
)

# LogF is potentially a conditional computation.
if transitions.conditioning is not None:
with has_conditioning_exception_handler("logF", self.logF):
valid_log_F_s = self.logF(states, transitions.conditioning).squeeze(-1)
log_F_s = self.logF(states, transitions.conditioning).squeeze(-1)
else:
with no_conditioning_exception_handler("logF", self.logF):
valid_log_F_s = self.logF(states).squeeze(-1)
log_F_s = self.logF(states).squeeze(-1)

if self.forward_looking:
log_rewards = env.log_reward(states) # TODO: RM unsqueeze(-1) ?
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
valid_log_F_s = valid_log_F_s + log_rewards
log_F_s = log_F_s + log_rewards

preds = valid_log_pf_actions + valid_log_F_s
targets = torch.zeros_like(preds)
preds = log_pf_actions + log_F_s

# uncomment next line for debugging
# assert transitions.next_states.is_sink_state.equal(transitions.is_done)

# automatically removes invalid transitions (i.e. s_f -> s_f)
valid_next_states = transitions.next_states[~transitions.is_done]
non_exit_actions = actions[~actions.is_exit]

# Evaluate the log PB of the actions, with optional conditioning.
if transitions.conditioning is not None:
with has_conditioning_exception_handler("pb", self.pb):
module_output = self.pb(
valid_next_states, transitions.conditioning[~transitions.is_done]
)
else:
with no_conditioning_exception_handler("pb", self.pb):
module_output = self.pb(valid_next_states)

valid_log_pb_actions = self.pb.to_probability_distribution(
valid_next_states, module_output
).log_prob(non_exit_actions.tensor)

valid_transitions_is_done = transitions.is_done[
~transitions.states.is_sink_state
]
Expand All @@ -179,23 +156,21 @@ def get_scores(
with no_conditioning_exception_handler("logF", self.logF):
valid_log_F_s_next = self.logF(valid_next_states).squeeze(-1)

targets[~valid_transitions_is_done] = valid_log_pb_actions
log_pb_actions = targets.clone()
targets[~valid_transitions_is_done] += valid_log_F_s_next
log_F_s_next = torch.zeros_like(log_pb_actions)
log_F_s_next[~valid_transitions_is_done] += valid_log_F_s_next
assert transitions.log_rewards is not None
valid_transitions_log_rewards = transitions.log_rewards[
~transitions.states.is_sink_state
]
targets[valid_transitions_is_done] = valid_transitions_log_rewards[
log_F_s_next[valid_transitions_is_done] = valid_transitions_log_rewards[
valid_transitions_is_done
]
targets = log_pb_actions + log_F_s_next

scores = preds - targets

assert valid_log_pf_actions.shape == (transitions.n_transitions,)
assert log_pb_actions.shape == (transitions.n_transitions,)
assert scores.shape == (transitions.n_transitions,)
return valid_log_pf_actions, log_pb_actions, scores
return log_pf_actions, log_pb_actions, scores

def loss(self, env: Env, transitions: Transitions) -> torch.Tensor:
"""Detailed balance loss.
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _forward_trunk(

return out

def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
def forward(self, states: States, conditioning: torch.Tensor) -> torch.Tensor:
"""Forward pass of the module.
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def sample_actions(
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
**policy_kwargs: Any,
) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]:
) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]:
"""Samples actions from the given states.
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/utils/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class UnsqueezedCategorical(Categorical):
"""Samples froma categorical distribution with an unsqueezed final dimension.
"""Samples from a categorical distribution with an unsqueezed final dimension.
Samples are unsqueezed to be of shape (batch_size, 1) instead of (batch_size,).
Expand Down
Loading

0 comments on commit 519f5b4

Please sign in to comment.