Skip to content

Commit

Permalink
Merge branch 'master' into hyeok9855/get_logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Nov 15, 2024
2 parents 9587f2a + 8259a21 commit 92301c4
Showing 1 changed file with 134 additions and 11 deletions.
145 changes: 134 additions & 11 deletions src/gfn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from gfn.states import DiscreteStates, States
from gfn.utils.distributions import UnsqueezedCategorical

REDUCTION_FXNS = {
"mean": torch.mean,
"sum": torch.sum,
"prod": torch.prod,
}


class GFNModule(ABC, nn.Module):
r"""Base class for modules mapping states distributions.
Expand Down Expand Up @@ -41,9 +47,11 @@ class GFNModule(ABC, nn.Module):
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
_output_dim_is_checked: Flag for tracking whether the output dimenions of
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
"""

def __init__(
Expand All @@ -52,7 +60,7 @@ def __init__(
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
) -> None:
"""Initalize the FunctionEstimator with an environment and a module.
"""Initialize the GFNModule with nn.Module and a preprocessor.
Args:
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
Expand Down Expand Up @@ -134,9 +142,82 @@ def to_probability_distribution(


class ScalarEstimator(GFNModule):
r"""Class for estimating scalars such as LogZ or state flow functions of DB/SubTB.
Training a GFlowNet requires sometimes requires the estimation of precise scalar
values, such as the partition function of flows on the DAG. This Estimator is
designed for those cases.
The function approximator used for `module` need not directly output a scalar. If
it does not, `reduction` will be used to aggregate the outputs of the module into
a single scalar.
Attributes:
preprocessor: Preprocessor object that transforms raw States objects to tensors
that can be used as input to the module. Optional, defaults to
`IdentityPreprocessor`.
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
reduction_function: String denoting the
"""

def __init__(
self,
module: nn.Module,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
reduction: str = "mean",
):
"""Initialize the GFNModule with a scalar output.
Args:
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor object.
is_backward: Flags estimators of probability distributions over parents.
reduction: str name of the one of the REDUCTION_FXNS keys: {}
""".format(
list(REDUCTION_FXNS.keys())
)
super().__init__(module, preprocessor, is_backward)
assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format(
REDUCTION_FXNS.keys()
)
self.reduction_fxn = REDUCTION_FXNS[reduction]

def expected_output_dim(self) -> int:
return 1

def forward(self, input: States | torch.Tensor) -> torch.Tensor:
"""Forward pass of the module.
Args:
input: The input to the module, as states or a tensor.
Returns the output of the module, as a tensor of shape (*batch_shape, output_dim).
"""
if isinstance(input, States):
input = self.preprocessor(input)

out = self.module(input)

# Ensures estimator outputs are always scalar.
if out.shape[-1] != 1:
out = self.reduction_fxn(out, -1)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True

return out


class DiscretePolicyEstimator(GFNModule):
r"""Container for forward and backward policy estimators for discrete environments.
Expand Down Expand Up @@ -290,14 +371,57 @@ def forward(self, states: States, conditioning: torch.Tensor) -> torch.Tensor:


class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
r"""Class for conditionally estimating scalars (LogZ, DB/SubTB state logF).
Training a GFlowNet requires sometimes requires the estimation of precise scalar
values, such as the partition function of flows on the DAG. In the case of a
conditional GFN, the logZ or logF estimate is also conditional. This Estimator is
designed for those cases.
The function approximator used for `final_module` need not directly output a scalar.
If it does not, `reduction` will be used to aggregate the outputs of the module into
a single scalar.
Attributes:
preprocessor: Preprocessor object that transforms raw States objects to tensors
that can be used as input to the module. Optional, defaults to
`IdentityPreprocessor`.
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
reduction_fxn: the selected torch reduction operation.
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
reduction_function: String denoting the
"""

def __init__(
self,
state_module: nn.Module,
conditioning_module: nn.Module,
final_module: nn.Module,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
reduction: str = "mean",
):
"""Initialize a conditional GFNModule with a scalar output.
Args:
state_module: The module to use for state representations. If the module is
a Tabular module (from `gfn.utils.modules`), then the environment
preprocessor needs to be an `EnumPreprocessor`.
conditioning_module: The module to use for conditioning representations.
final_module: The module to use for computing the final output.
preprocessor: Preprocessor object.
is_backward: Flags estimators of probability distributions over parents.
reduction: str name of the one of the REDUCTION_FXNS keys: {}
""".format(
list(REDUCTION_FXNS.keys())
)

super().__init__(
state_module,
conditioning_module,
Expand All @@ -306,6 +430,10 @@ def __init__(
preprocessor=preprocessor,
is_backward=is_backward,
)
assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format(
REDUCTION_FXNS.keys()
)
self.reduction_fxn = REDUCTION_FXNS[reduction]

def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
"""Forward pass of the module.
Expand All @@ -318,6 +446,10 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
"""
out = self._forward_trunk(states, conditioning)

# Ensures estimator outputs are always scalar.
if out.shape[-1] != 1:
out = self.reduction_fxn(out, -1)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True
Expand All @@ -333,13 +465,4 @@ def to_probability_distribution(
module_output: torch.Tensor,
**policy_kwargs: Any,
) -> Distribution:
"""Transform the output of the module into a probability distribution.
Args:
states: The states to use.
module_output: The output of the module as a tensor of shape (*batch_shape, output_dim).
**policy_kwargs: Keyword arguments to modify the distribution.
Returns a distribution object.
"""
raise NotImplementedError

0 comments on commit 92301c4

Please sign in to comment.