Skip to content

Commit

Permalink
Add support for GPTNeo for KV cache injection (#1720)
Browse files Browse the repository at this point in the history
* Add support for GPTNeo for kv cache injection

* Style

* Update configs.py

* Update configs.py

---------

Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
  • Loading branch information
mgoin and dbogunowicz authored Sep 14, 2023
1 parent c610973 commit d8f015b
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ class Config:
multiply_batch_by_num_att_heads=False,
)

# Reusing the CodeGen transforms because it happens to match what we need for GPTNeo
additional_transforms_gpt_neo = AdditionalTransformsCodeGen

GPT_NEO_CONFIG = KeyValueCacheConfig(
model_name="gpt_neo",
additional_transforms=additional_transforms_gpt_neo,
key_num_attention_heads="num_heads",
key_num_embedding_hidden_size="hidden_size",
transpose_value_input=(0, 2, 1, 3),
transpose_key_input=None,
multiply_batch_by_num_att_heads=False,
)


def get_kv_cache_config(
model_path: str,
Expand All @@ -147,6 +160,7 @@ def get_kv_cache_config(
BLOOM_CONFIG,
MPT_CONFIG,
LLAMA_CONFIG,
GPT_NEO_CONFIG,
],
) -> KeyValueCacheConfig:
"""
Expand Down

0 comments on commit d8f015b

Please sign in to comment.