Skip to content

Commit

Permalink
Added [start quote] hashed [end quote] cross features
Browse files Browse the repository at this point in the history
Signed-off-by: David Davó <david@ddavo.me>
  • Loading branch information
daviddavo committed Sep 24, 2024
1 parent d5e461e commit 9bd7d25
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions recommenders/models/wide_deep/wide_deep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class WideAndDeepHyperParams:
user_dim: int = 32
item_dim: int = 32
crossed_feat_dim: int = 1000
dnn_hidden_units: Tuple[int, ...] = (128, 128)
dnn_dropout: float = 0.0
dnn_additional_embeddings_sizes: dict[str, Tuple[int, int]] = field(default_factory=dict)
Expand Down Expand Up @@ -53,15 +54,12 @@ def __init__(

self.deep = nn.Sequential(*layers)

# P(Y=1|X) = W*wide + W'*a^(lf) + bias
# which is eq. to W"*cat(wide, a^(lf))+bias
wide_input = num_items # TODO: cross product
wide_output = num_items
# Cross product of users-items
exclusive_wide_input = hparams.crossed_feat_dim

print('wide_input:', wide_input, 'wide_output:', wide_output, 'total:', wide_input*wide_output)
print('wide_input:', wide_input, 'prev_output:', prev_output, 'total:', wide_input+prev_output)
self.head = nn.Sequential(
nn.Linear(wide_input+prev_output, wide_output),
# Output is binary score. 1 iif item-user pair is a good recommendation.
nn.Linear(exclusive_wide_input+prev_output, 1),
nn.Sigmoid(),
)

Expand All @@ -72,17 +70,16 @@ def forward(
continuous_features: Optional[torch.Tensor] = None,
) -> torch.Tensor:
users, items = interactions.T

all_embed = torch.cat([
self.users_emb(users), # Receives the indices
self.items_emb(items),
*[ emb(additional_embeddings[k]) for k, emb in self.additional_embs.items() ]
], dim=1)

# The cross-feature is really only the items because there is no
# impression data??
# https://datascience.stackexchange.com/a/58915/169220
cross_product = torch.zeros([items.numel(), self.n_items])
cross_product[torch.arange(items.numel()), items] = 1
# TODO: Use hashing to avoid problems with biased distributions
cross_product_idx = (users*self.n_items + items) % self.hparams.crossed_feat_dim
cross_product = nn.functional.one_hot(cross_product_idx, self.hparams.crossed_feat_dim)

if self.hparams.dnn_cont_features > 0:
deep_input = torch.cat([continuous_features, all_embed], dim=1)
Expand Down

0 comments on commit 9bd7d25

Please sign in to comment.