From bbc39fed6806ec21f37dca899a29163fdd50d90e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 8 Mar 2024 00:15:58 -0500 Subject: [PATCH] [KVCache] Update FlashInfer PackedFunc names This PR updates the FlashInfer names given https://github.com/apache/tvm/pull/16692 has been merged. --- python/mlc_chat/nn/kv_cache.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py index f63e74d855..91b9ee1899 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_chat/nn/kv_cache.py @@ -259,15 +259,15 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals # pylint: disable=line-too-long # fmt: off bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), - rx.extern("paged_kv_cache.attention_kernel_prefill"), - rx.extern("paged_kv_cache.attention_kernel_decode"), + rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache"), + rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), - rx.extern("paged_kv_cache.attention_kernel_prefill_begin_forward"), - rx.extern("paged_kv_cache.attention_kernel_prefill_end_forward"), - rx.extern("paged_kv_cache.attention_kernel_decode_begin_forward"), - rx.extern("paged_kv_cache.attention_kernel_decode_end_forward"), + rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache_begin_forward"), + rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache_end_forward"), + rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward"), + rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"), rx.extern("flashinfer.merge_state_in_place"), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), bb.add_func(llama_inplace_rope(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, target, rotary_dim), "tir_qk_rotary_inplace"),