diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 2227944b8653..f6857a9dceae 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -159,7 +159,7 @@ class AttentionKVCacheObj : public KVStateObj { * This function is supposed to be invoked after calling BeginForward. * \return The in-sequence query positions, in shape `(total_length,)`. */ - virtual NDArray GetQueryPositions() const = 0; + virtual NDArray GetQueryPositions() = 0; /************** Debug Helpers **************/ diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 0c64800cec2d..9c3ee5d427c2 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -838,10 +838,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - NDArray GetQueryPositions() const final { - CHECK(!dirty_aux_data_device_) - << "The auxiliary arrays are not synchronized to device. Please call " - "`BeginForward` to synchronize before calling `GetQueryPositions`."; + NDArray GetQueryPositions() final { + // Sync the copy stream and the compute stream. + ComputeStreamWaitForCopyStream(); + // The auxiliary data structure on device must have been synchronized. + ICHECK(!dirty_aux_data_device_); return q_rope_position_map_view_; };