Skip to content

Commit

Permalink
Changed function signatures to comply with updated tensorflow require…
Browse files Browse the repository at this point in the history
…ments. -Subramanian Iyer <subramanian.iyer@paramount.com>
  • Loading branch information
siyerp authored Aug 15, 2024
1 parent 0296e95 commit 551ec31
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions recommenders/models/sasrec/ssept.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def call(self, x, training):
# --- ATTENTION BLOCKS ---
seq_attention = seq_embeddings # (b, s, h1 + h2)

seq_attention = self.encoder(seq_attention, training, mask)
seq_attention = self.encoder(seq_attention, training=training, mask=mask)
seq_attention = self.layer_normalization(seq_attention) # (b, s, h1+h2)

# --- PREDICTION LAYER ---
Expand Down Expand Up @@ -197,7 +197,7 @@ def predict(self, inputs):

seq_embeddings *= mask
seq_attention = seq_embeddings
seq_attention = self.encoder(seq_attention, training, mask)
seq_attention = self.encoder(seq_attention, training=training, mask=mask)
seq_attention = self.layer_normalization(seq_attention) # (b, s, h1+h2)
seq_emb = tf.reshape(
seq_attention,
Expand Down

0 comments on commit 551ec31

Please sign in to comment.