Skip to content

Commit

Permalink
fix embedding sparsity log bug of -1% density (microsoft#20420)
Browse files Browse the repository at this point in the history
### Description
When not checked valid embedding sparsity, the log print a wrong info of
"-1% density", this pr is to fix it.
  • Loading branch information
guyang3532 authored Apr 23, 2024
1 parent ed6f1ad commit ffb9c8d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -706,11 +706,12 @@ def _embedding_hook(module, args, output):
valid_token = torch.count_nonzero(ebd_input - module.padding_idx)
total_token = ebd_input.numel()
embed_density = float(valid_token) / float(total_token) * 100
if module not in self._runtime_inspector._embedding_module_to_padding_density_map:
self._logger.warning("Found Embedding module not in the map. %s", module)
return None

if embed_density < 90:
self._logger.info("Embedding sparsity-based optimization is ON for density: %.0f%%", embed_density)
if module not in self._runtime_inspector._embedding_module_to_padding_density_map:
self._logger.warning("Found Embedding module not in the map. %s", module)
return None
if self._runtime_inspector._embedding_module_to_padding_density_map[module][1] != -1:
self._logger.warning(
"Found duplicate Embedding module. %s",
Expand Down Expand Up @@ -794,6 +795,7 @@ def _enable_conditional_optimizations(
[
f"{v[0]}:{v[1]:.0f}%"
for v in self._runtime_inspector._embedding_module_to_padding_density_map.values()
if v[1] != -1
]
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5776,19 +5776,25 @@ def run_step(model, input, positions):
_ = run_step(ort_model, input, label)

found_embed_is_sparse = False
found_embed_is_dense = False
found_label_is_sparse = False
for record in caplog.records:
if "Label sparsity-based optimization is ON for" in record.getMessage():
found_label_is_sparse = True

if "Embedding sparsity-based optimization is OFF for" in record.getMessage():
found_embed_is_dense = True

if "Embedding sparsity-based optimization is ON for" in record.getMessage():
found_embed_is_sparse = True

if label_is_sparse:
assert found_label_is_sparse

if embed_is_sparse:
assert found_embed_is_sparse
assert found_embed_is_sparse and not found_embed_is_dense
else:
assert not found_embed_is_sparse and found_embed_is_dense


@pytest.mark.parametrize(
Expand Down

0 comments on commit ffb9c8d

Please sign in to comment.