Replies: 6 comments 3 replies
-
OA uses https://github.com/CarperAI/trlx/ for RL. I heard that Carper is experimenting with DPO internally and that it might become part of trlx in the future. |
Beta Was this translation helpful? Give feedback.
-
It seems that the DPO loss is almost a drop-in replacement of the RM loss of OA (obviously we will have to change the reward model with a regular model). class DPOLoss(nn.Module):
"""DPO loss function."""
def __init__(self, reduction="mean", beta=0.001):
super().__init__()
self.reduction = reduction
self.beta = beta
def forward(self, logits, logits_ref, labels, cu_lengths=None):
# if cu_lengths is None, assume that all examples belong to the same conversation
if cu_lengths is None:
cu_lengths = [0, logits.size(0)]
device = logits.device
losses = []
rewards = []
for start, end in zip(cu_lengths[:-1], cu_lengths[1:]):
pairs = torch.combinations(torch.arange(end - start, device=device), 2)
pos_ids, neg_ids = pairs[:, 0], pairs[:, 1]
# compute logprob of pos and neg examples
pos_logits = logits[start + pos_ids]
neg_logits = logits[start + neg_ids]
pos_logprob = F.log_softmax(pos_logits, dim=-1)
neg_logprob = F.log_softmax(neg_logits, dim=-1)
pos_logprob = torch.gather(pos_logprob, 2, labels[pos_ids].unsqueeze(-1))
neg_logprob = torch.gather(neg_logprob, 2, labels[neg_ids].unsqueeze(-1))
# we need to compute the logprob of the reference examples
pos_logits_ref = logits_ref[start + pos_ids]
neg_logits_ref = logits_ref[start + neg_ids]
pos_logprob_ref = F.log_softmax(pos_logits_ref, dim=-1)
neg_logprob_ref = F.log_softmax(neg_logits_ref, dim=-1)
pos_logprob_ref = torch.gather(pos_logprob_ref, 2, labels[pos_ids].unsqueeze(-1))
neg_logprob_ref = torch.gather(neg_logprob_ref, 2, labels[neg_ids].unsqueeze(-1))
# compute loss and reward
pi_logratios = pos_logprob.mean() - neg_logprob.mean()
ref_logratios = pos_logprob_ref.mean() - neg_logprob_ref.mean()
loss = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios))
reward = self.beta * (pi_logratios - ref_logratios)
losses.append(loss)
rewards.append(reward)
return sum(losses)/len(losses), sum(rewards)/len(rewards)
class DPOTrainer(Trainer):
"""
DPOTrainer class for training a model with the DPO algorithm.
"""
def __init__(
self,
model=None,
args=None,
sampler=None,
train_collate_fn=None,
**kwargs,
):
super().__init__(model, args, **kwargs)
self.train_collate_fn = train_collate_fn
self.sampler = sampler
self.loss_fct = DPOLoss()
def compute_loss(self, model, inputs, return_logits=False):
"""
The important part of the DPO Trainer.
"""
batch, cu_lens = inputs
# compute logits
logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits
# compute logits for reference model
with torch.no_grad():
logits_ref = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits
loss, reward = self.loss_fct(logits, logits_ref, batch["input_ids"], cu_lens)
print(loss)
return (loss, logits, reward) if return_logits else loss |
Beta Was this translation helpful? Give feedback.
-
I'd like to make a retrained of the OA llama 30B (SFT) but currently the flash_att lib is buggy so I can't build the model_training docker image :( |
Beta Was this translation helpful? Give feedback.
-
@Forbu @andreaskoepf You might be interested to know that we've just released a reference implementation of DPO: https://github.com/eric-mitchell/direct-preference-optimization Happy to answer any questions you may have! |
Beta Was this translation helpful? Give feedback.
-
@eric-mitchell Nice job ! I also saw some work done in trl repository : huggingface/trl#416, which is compatible with the Trainer class from huggingface (for easier integration) |
Beta Was this translation helpful? Give feedback.
-
And also I have a quick question for you @eric-mitchell : |
Beta Was this translation helpful? Give feedback.
-
A new way of training in RLHF directly with using a reward model : https://arxiv.org/pdf/2305.18290.pdf
I wonder if this is possible to use it instead of the usual RM + PPO setting. This is ML proposition.
Beta Was this translation helpful? Give feedback.
All reactions