Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring ideas for log_rewards #200

Open
younik opened this issue Oct 15, 2024 · 6 comments
Open

Refactoring ideas for log_rewards #200

younik opened this issue Oct 15, 2024 · 6 comments
Assignees

Comments

@younik
Copy link
Collaborator

younik commented Oct 15, 2024

Computing log_rewards requires access to the environment. However, Transitions, Trajectories and States provide log_rewards, making a complicated dependency among these class.

I propose two solutions:

  1. We drop log_rewards in Transitions, Trajectories and States. I suppose log_rewards is only needed in GFlowNet classes, we can compute it directly there. The only exception is PrioritizedReplayBuffer, which we can add a scoring function attribute or a score for each added object.
    This solution has the drawback of removing caching mechanism (is log_rewards computed multiple time not the same object? Is it a heavy computation?)
  2. We provide log_rewards at the initialization of Transitions, Trajectories and States without accepting None. This is problematic for States as env.log_reward. work on states, making it a chicken-and-egg problem.
@hyeok9855
Copy link
Collaborator

hyeok9855 commented Oct 18, 2024

IMHO, solution 2 seems more reasonable to me. I think the possible issue can be resolved by removing log_reward from State with further modifications (I've quickly checked, and it seems not very tricky, e.g., create another subclass of Container that includes the State and log_reward).

@josephdviviano
Copy link
Collaborator

josephdviviano commented Oct 18, 2024 via email

@younik
Copy link
Collaborator Author

younik commented Nov 1, 2024

I investigate it further, and it seems log_rewards are never computed inside the Tranistionsand Trajectories function, because log_rewards is never None.
In fact, at initialization, we do this (it was introduced here):

self._log_rewards = (
log_rewards
if log_rewards is not None
else torch.full(size=(0,), fill_value=0, dtype=torch.float)
)

which ensure log_rewards is always not None.

So, computation here is never triggered:

def log_rewards(self) -> torch.Tensor | None:
"""Returns the log rewards of the trajectories as a tensor of shape (n_trajectories,)."""
if self._log_rewards is not None:
assert self._log_rewards.shape == (self.n_trajectories,)
return self._log_rewards
if self.is_backward:
return None
try:
return self.env.log_reward(self.last_states)
except NotImplementedError:
return torch.log(self.env.reward(self.last_states))

I checked that this is the case in this commit (tests run correctly): https://github.com/younik/torchgfn/tree/test-log-rewards-comp

This allows to easily do the solution 2, and straightly remove the env dependency. It also allows for a bunch of code cleaning (in some places we check if log_rewards is None).
@josephdviviano

@josephdviviano
Copy link
Collaborator

josephdviviano commented Nov 5, 2024

Hi @younik I think the easiest fix is to replace line 157 with the appropriate check (checks whether _log_rewards is empty). But I also think we need to ensure that line 163 either never needs to be called (i.e., is updated externally only) OR has a path to being called (i.e., the Transitions object carries a state which determines that the log rewards need to be updated).

I'm curious what you think?

@josephdviviano josephdviviano self-assigned this Nov 5, 2024
@younik
Copy link
Collaborator Author

younik commented Nov 5, 2024

Hi @younik I think the easiest fix is to replace line 157 with the appropriate check (checks whether _log_rewards is empty). But I also think we need to ensure that line 163 either never needs to be called (i.e., is updated externally only) OR has a path to being called (i.e., the Transitions object carries a state which determines that the log rewards need to be updated).

I'm curious what you think?

I believe the semantics of empty should be" the trajectory is empty", and it shouldn't happen that we have n states with an empty log reward tensor. To indicate something must be computed, it is better to use None.

However, it looks like we don't need log_rewards computation inside the Trajectories(and maybe inside Transitions).
For maintainability, it is better to prune everything that is not used because it affects code readability.
Of course, we must ensure the user is using it properly, but I believe this lines does it already:

158   assert self._log_rewards.shape == (self.n_trajectories,) 

@josephdviviano
Copy link
Collaborator

Yes, I agree with this. Sorry for the lag in my reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants