diff --git a/moco/builder.py b/moco/builder.py index 7952a981c..9d1d4dd29 100644 --- a/moco/builder.py +++ b/moco/builder.py @@ -60,7 +60,7 @@ def _momentum_update_key_encoder(self): for param_q, param_k in zip( self.encoder_q.parameters(), self.encoder_k.parameters() ): - param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) + param_k.mul_(self.m).add_(param_q.mul(1. - self.m)) @torch.no_grad() def _dequeue_and_enqueue(self, keys):