diff --git a/src/sparseml/exporters/transforms/kv_cache/configs.py b/src/sparseml/exporters/transforms/kv_cache/configs.py index 2e23d3b3eec..ff9189b1c41 100644 --- a/src/sparseml/exporters/transforms/kv_cache/configs.py +++ b/src/sparseml/exporters/transforms/kv_cache/configs.py @@ -138,6 +138,19 @@ class Config: multiply_batch_by_num_att_heads=False, ) +# Reusing the CodeGen transforms because it happens to match what we need for GPTNeo +additional_transforms_gpt_neo = AdditionalTransformsCodeGen + +GPT_NEO_CONFIG = KeyValueCacheConfig( + model_name="gpt_neo", + additional_transforms=additional_transforms_gpt_neo, + key_num_attention_heads="num_heads", + key_num_embedding_hidden_size="hidden_size", + transpose_value_input=(0, 2, 1, 3), + transpose_key_input=None, + multiply_batch_by_num_att_heads=False, +) + def get_kv_cache_config( model_path: str, @@ -147,6 +160,7 @@ def get_kv_cache_config( BLOOM_CONFIG, MPT_CONFIG, LLAMA_CONFIG, + GPT_NEO_CONFIG, ], ) -> KeyValueCacheConfig: """