Skip to content

Commit

Permalink
[Fix][Builtin] Fix "GetQueryPosition" of PagedKVCache (#16746)
Browse files Browse the repository at this point in the history
Since #16692 introduced the copy stream separation, the function
`GetQueryPositions` also needs to eagerly call sync to work
properly. This PR fixes the previous wrong behavior.
  • Loading branch information
MasterJH5574 authored Mar 20, 2024
1 parent 48cedc7 commit a9436b8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class AttentionKVCacheObj : public KVStateObj {
* This function is supposed to be invoked after calling BeginForward.
* \return The in-sequence query positions, in shape `(total_length,)`.
*/
virtual NDArray GetQueryPositions() const = 0;
virtual NDArray GetQueryPositions() = 0;

/************** Debug Helpers **************/

Expand Down
9 changes: 5 additions & 4 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -838,10 +838,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

NDArray GetQueryPositions() const final {
CHECK(!dirty_aux_data_device_)
<< "The auxiliary arrays are not synchronized to device. Please call "
"`BeginForward` to synchronize before calling `GetQueryPositions`.";
NDArray GetQueryPositions() final {
// Sync the copy stream and the compute stream.
ComputeStreamWaitForCopyStream();
// The auxiliary data structure on device must have been synchronized.
ICHECK(!dirty_aux_data_device_);
return q_rope_position_map_view_;
};

Expand Down

0 comments on commit a9436b8

Please sign in to comment.