Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hhou435 authored Nov 21, 2023
1 parent 19e1216 commit b6bca11
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion uer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def pooling(memory_bank, seg, pooling_type):
features = torch.sum(memory_bank, dim=1)
features = torch.div(features, torch.sum(seg, dim=1))
elif pooling_type == "last":
features = memory_bank[torch.arange(memory_bank.shape[0]), torch.squeeze(torch.sum(seg, dim=1).type(torch.int64) - 1), :]
features = memory_bank[torch.arange(memory_bank.shape[0]), torch.squeeze(torch.sum(seg!=0, dim=1).type(torch.int64) - 1), :]
elif pooling_type == "max":
features = torch.max(memory_bank + (seg - 1) * sys.maxsize, dim=1)[0]
else:
Expand Down

0 comments on commit b6bca11

Please sign in to comment.