Skip to content

Commit

Permalink
[Runtime] PagedKVCache execute data copy on a separate stream (apache…
Browse files Browse the repository at this point in the history
…#16692)

This PR enhances PagedKVCache with the copy stream separation.
In detail, for CUDA and ROCm backend, we create a standalone copy
stream for the copy of auxiliary data structure from CPU to GPU.
Furthermore, we move the copy from BeginForward to Attention,
which means it's no longer eagerly executed, instead, becoming
lazily executed when Attention computation is needed.

By making these changes, we are able to overlap the auxiliary
data copy time (on the copy stream) with the model forward
computation that happens before the first Attention. As a result,
we can hide some of the copy latency.

This PR also bumps the version of FlashInfer for the copy stream
support.
  • Loading branch information
MasterJH5574 authored and thaisacs committed Apr 3, 2024
1 parent 484c6c1 commit d50b364
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 57 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/flashinfer
Submodule flashinfer updated 53 files
+1 −1 .github/workflows/release_wheel.yml
+1 −1 .release-please-manifest.json
+26 −0 CHANGELOG.md
+5 −6 README.md
+2 −0 cmake/config.cmake
+25 −6 cmake/modules/FindThrust.cmake
+3 −5 docs/conf.py
+28 −7 docs/installation.rst
+2 −6 include/flashinfer.cuh
+25 −0 include/flashinfer/attention.cuh
+3 −3 include/flashinfer/attention/cascade.cuh
+136 −107 include/flashinfer/attention/decode.cuh
+7 −6 include/flashinfer/attention/handler.cuh
+471 −274 include/flashinfer/attention/prefill.cuh
+2 −2 include/flashinfer/attention/state.cuh
+47 −42 include/flashinfer/attention/wrapper.cuh
+25 −2 include/flashinfer/mma.cuh
+6 −6 include/flashinfer/page.cuh
+6 −8 include/flashinfer/permuted_smem.cuh
+22 −12 include/flashinfer/pos_enc.cuh
+63 −68 include/flashinfer/utils.cuh
+9 −0 include/flashinfer/vec_dtypes.cuh
+116 −50 python/csrc/batch_decode.cu
+30 −30 python/csrc/batch_prefill.cu
+42 −39 python/csrc/flashinfer_decl.h
+19 −13 python/csrc/flashinfer_ops.h
+43 −1 python/csrc/pytorch_extension_utils.h
+29 −12 python/csrc/single_decode.cu
+8 −6 python/csrc/single_prefill.cu
+26 −13 python/flashinfer/cascade.py
+49 −34 python/flashinfer/decode.py
+1 −1 python/flashinfer/page.py
+140 −40 python/flashinfer/prefill.py
+12 −10 python/flashinfer/utils.py
+36 −15 python/setup.py
+123 −0 python/tests/alibi_reference.py
+78 −0 python/tests/test_alibi.py
+22 −12 python/tests/test_batch_decode_kernels.py
+28 −9 python/tests/test_batch_prefill_kernels.py
+5 −3 python/tests/test_shared_prefix_kernels.py
+16 −14 src/bench_batch_decode.cu
+20 −18 src/bench_cascade.cu
+16 −11 src/bench_single_decode.cu
+9 −7 src/bench_single_prefill.cu
+6 −6 src/cpu_reference.h
+24 −18 src/test_batch_decode.cu
+32 −29 src/test_batch_prefill.cu
+21 −19 src/test_cascade.cu
+16 −13 src/test_page.cu
+22 −13 src/test_single_decode.cu
+23 −23 src/test_single_prefill.cu
+117 −104 src/tvm_wrapper.cu
+1 −1 version.txt
161 changes: 105 additions & 56 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
//-------------------------------------------
/*!
* \brief A boolean flag indicating if the auxiliary arrays are dirty.
* If it is dirty, an explicit "SyncAuxArrayToDevice" should be invoked.
* If it is dirty, an explicit "ComputeStreamWaitForCopyStream" should be invoked.
*/
bool dirty_aux_data_device_ = false;
/*! \brief The batch size of the current round of forwarding. */
Expand Down Expand Up @@ -285,6 +285,20 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
NDArray merged_attn_scores_device_;
std::vector<NDArray> temp_attn_workspace_;

//-------------------------------------------
// Below are the auxiliary data structure on CPU.
// We make them class members to avoid repetitive allocation time in BeginForward.
//-------------------------------------------
std::vector<std::vector<int32_t>> qo_indptr_on_depths_host_;
std::vector<std::vector<int32_t>> page_indptr_on_depths_host_;
std::vector<std::vector<int32_t>> page_indices_on_depths_host_;
std::vector<std::vector<int32_t>> last_page_len_on_depths_host_;
std::vector<std::vector<int32_t>> k_rope_pos_offset_on_depths_host_;
std::vector<int32_t> k_ragged_rope_pos_offset_host_;
std::vector<int32_t> q_rope_position_map_host_;
std::vector<int32_t> append_position_map_host_;
std::vector<int32_t> cur_append_lengths_indptr_host_;

//-------------------------------------------
// For efficient memory management, the actual sizes of the arrays
// above are over allocated.
Expand Down Expand Up @@ -328,6 +342,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
std::vector<bool> use_decode_kernel_;
/*! \brief Whether the attention request is a decode request, set in BeginForwardFunction. */
bool is_decode_request_;
/*! \brief The device this PagedKVCache runs on. */
DLDevice device_;
/*! \brief The device stream for the default computation operations. */
TVMStreamHandle compute_stream_ = nullptr;
/*! \brief The device stream for copying auxiliary data structure to GPU. */
TVMStreamHandle copy_stream_ = nullptr;

public:
/*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */
Expand Down Expand Up @@ -370,7 +390,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
f_merge_inplace_(std::move(f_merge_inplace)),
f_split_rotary_(std::move(f_split_rotary)),
f_rotary_inplace_(std::move(f_rotary_inplace)),
f_debug_get_kv_(std::move(f_debug_get_kv)) {
f_debug_get_kv_(std::move(f_debug_get_kv)),
device_(device) {
pages_.reserve(num_layers);
for (int i = 0; i < num_layers; ++i) {
pages_.push_back(
Expand Down Expand Up @@ -417,6 +438,22 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) {
free_page_ids_.push_back(page_id);
}

// The compute stream is the default stream.
// If the device is CUDA/ROCm, we create a standalone copy stream, in
// purpose to hide the latency of auxiliary stream copy.
compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);
if (device.device_type == DLDeviceType::kDLCUDA ||
device.device_type == DLDeviceType::kDLROCM) {
copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);
}
}

~PagedAttentionKVCacheObj() {
// Free the copy stream if defined.
if (copy_stream_ != nullptr) {
DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_);
}
}

/*! \brief Reset the KV cache. */
Expand Down Expand Up @@ -522,16 +559,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {

// - Collect sequence/block/page information for attention.
std::vector<const Sequence*> sequences;
std::vector<int32_t> k_ragged_rope_pos_offset;
is_decode_request_ = true;
sequences.reserve(cur_batch_size_);
k_ragged_rope_pos_offset.reserve(cur_batch_size_);
k_ragged_rope_pos_offset_host_.resize(cur_batch_size_);
for (int i = 0; i < cur_batch_size_; ++i) {
auto it = seq_map_.find(seq_ids[i]);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i]
<< "\" cannot be found in KV cache.";
sequences.push_back(&it->second);
k_ragged_rope_pos_offset.push_back(it->second.seq_length);
k_ragged_rope_pos_offset_host_[i] = it->second.seq_length;
it->second.seq_length += append_lengths[i];
if (append_lengths[i] != 1) {
is_decode_request_ = false;
Expand Down Expand Up @@ -561,18 +597,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

std::vector<std::vector<int32_t>> qo_indptr_on_depths;
std::vector<std::vector<int32_t>> page_indptr_on_depths;
std::vector<std::vector<int32_t>> page_indices_on_depths;
std::vector<std::vector<int32_t>> last_page_len_on_depths;
std::vector<std::vector<int32_t>> k_rope_pos_offset_on_depths;
qo_indptr_on_depths_host_.resize(num_depths_);
page_indptr_on_depths_host_.resize(num_depths_);
page_indices_on_depths_host_.resize(num_depths_);
last_page_len_on_depths_host_.resize(num_depths_);
k_rope_pos_offset_on_depths_host_.resize(num_depths_);

for (int d = 0; d < num_depths_; ++d) {
std::vector<int32_t> qo_indptr_h{0};
std::vector<int32_t> page_indptr_h{0};
std::vector<int32_t> page_indices_h;
std::vector<int32_t> last_page_len_h;
std::vector<int32_t> k_rope_pos_offset_h;
std::vector<int32_t>& qo_indptr_h = qo_indptr_on_depths_host_[d];
std::vector<int32_t>& page_indptr_h = page_indptr_on_depths_host_[d];
std::vector<int32_t>& page_indices_h = page_indices_on_depths_host_[d];
std::vector<int32_t>& last_page_len_h = last_page_len_on_depths_host_[d];
std::vector<int32_t>& k_rope_pos_offset_h = k_rope_pos_offset_on_depths_host_[d];
qo_indptr_h.clear();
page_indptr_h.clear();
page_indices_h.clear();
last_page_len_h.clear();
k_rope_pos_offset_h.clear();
qo_indptr_h.push_back(0);
page_indptr_h.push_back(0);
for (const auto& [block_id, chunk_append_length] : chunked_block_ids_arr[d]) {
qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length);
if (block_id == -1) {
Expand All @@ -588,11 +631,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
k_rope_pos_offset_h.push_back(block.start_pos);
}
}
qo_indptr_on_depths.push_back(qo_indptr_h);
page_indptr_on_depths.push_back(page_indptr_h);
page_indices_on_depths.push_back(page_indices_h);
last_page_len_on_depths.push_back(last_page_len_h);
k_rope_pos_offset_on_depths.push_back(k_rope_pos_offset_h);
}

if (!append_before_attn_) {
Expand All @@ -606,38 +644,25 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {

// Map each the token position in the input batch to the position
// in the global KV cache. The mapping is used in when appending k/v values.
std::vector<int32_t> q_rope_position_map;
std::vector<int32_t> append_position_map;
q_rope_position_map_host_.clear();
append_position_map_host_.clear();
for (int i = 0; i < cur_batch_size_; ++i) {
int64_t append_length = append_lengths[i];
const Block& block = global_block_pool_[sequences[i]->last_block_idx];
for (int64_t pos = 0; pos < append_length; ++pos) {
int64_t pos_in_block = block.seq_length - append_length + pos;
q_rope_position_map.push_back(sequences[i]->seq_length - append_length + pos);
append_position_map.push_back(block.page_ids[pos_in_block / page_size_] * page_size_ +
pos_in_block % page_size_);
q_rope_position_map_host_.push_back(sequences[i]->seq_length - append_length + pos);
append_position_map_host_.push_back(block.page_ids[pos_in_block / page_size_] * page_size_ +
pos_in_block % page_size_);
}
}

// - Sync NDArrays to GPU.
SyncAuxArrayToDevice(std::move(qo_indptr_on_depths), std::move(page_indptr_on_depths),
std::move(page_indices_on_depths), std::move(last_page_len_on_depths),
std::move(k_rope_pos_offset_on_depths),
std::move(k_ragged_rope_pos_offset), std::move(q_rope_position_map),
std::move(append_position_map));

// NOTE(Zihao): This logic is problematic ATM because we need a unique split per depth
KernelBeginForward();
}

void EndForward() final {
if (!f_attention_prefill_end_forward_.defined() || !f_attention_decode_end_forward_.defined() ||
!f_attention_prefill_ragged_end_forward_.defined()) {
return;
}
// Mark the dirty flag as true, so that BeginForward is required
// to be invoked before the next round of model forward.
dirty_aux_data_device_ = true;
f_attention_prefill_ragged_end_forward_.value()();
for (int d = 0; d < num_depths_; ++d) {
f_attention_prefill_end_forward_.value()(d);
Expand Down Expand Up @@ -681,10 +706,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
total_seq_length += cur_append_lengths_[seq_id];
}
CHECK_EQ(total_seq_length, q_data->shape[0]);
// Sync the copy stream and the compute stream.
ComputeStreamWaitForCopyStream();
// The auxiliary data structure on device must have been synchronized.
CHECK(!dirty_aux_data_device_)
<< "The auxiliary arrays are not synchronized to device. Please call "
"`BeginForward` to synchronize before calling `Attention`.";
ICHECK(!dirty_aux_data_device_);

if (rope_mode_ == RoPEMode::kNormal) {
// Apply rotary embedding to q/k data.
Expand Down Expand Up @@ -726,10 +751,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
total_seq_length += cur_append_lengths_[seq_id];
}
CHECK_EQ(total_seq_length, qkv_data->shape[0]);
// Sync the copy stream and the compute stream.
ComputeStreamWaitForCopyStream();
// The auxiliary data structure on device must have been synchronized.
CHECK(!dirty_aux_data_device_)
<< "The auxiliary arrays are not synchronized to device. Please call "
"`BeginForward` to synchronize before calling `Attention`.";
ICHECK(!dirty_aux_data_device_);

NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length, num_qo_heads_, head_dim_},
qkv_data->dtype);
Expand Down Expand Up @@ -965,11 +990,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
f_attention_decode_begin_forward_.value()(
/*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_view_[0],
last_page_len_on_depths_view_[0], num_qo_heads_, num_kv_heads_, head_dim_, page_size_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline);
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
} else {
f_attention_prefill_ragged_begin_forward_.value()(
temp_attn_workspace_[0], cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_,
num_kv_heads_);
num_kv_heads_, head_dim_, copy_stream_);
for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
continue;
Expand All @@ -978,11 +1003,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
f_attention_decode_begin_forward_.value()(
d, temp_attn_workspace_[d + 1], page_indptr_on_depths_view_[d],
last_page_len_on_depths_view_[d], num_qo_heads_, num_kv_heads_, head_dim_, page_size_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline);
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
} else {
f_attention_prefill_begin_forward_.value()(
/*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_view_[d],
last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_);
last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_, head_dim_,
copy_stream_);
}
}
}
Expand Down Expand Up @@ -1041,6 +1067,28 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

/*! \brief Synchronize the copy stream and the compute stream. */
void ComputeStreamWaitForCopyStream() {
if (!dirty_aux_data_device_) {
// If the auxiliary data is already synced, return and no need to sync again.
return;
}
// - Sync NDArrays to GPU.
SyncAuxArrayToDevice(qo_indptr_on_depths_host_, page_indptr_on_depths_host_,
page_indices_on_depths_host_, last_page_len_on_depths_host_,
k_rope_pos_offset_on_depths_host_, k_ragged_rope_pos_offset_host_,
q_rope_position_map_host_, append_position_map_host_);
KernelBeginForward();
// - Clear the dirty flag.
dirty_aux_data_device_ = false;
// - If there is no particular copy stream, no action is needed.
if (copy_stream_ == nullptr) {
return;
}
// - Sync two streams.
DeviceAPI::Get(device_)->SyncStreamFromTo(device_, copy_stream_, compute_stream_);
}

/*!
* \brief Synchronize auxiliary arrays to device.
* \note This method resets the dirty flag to false, and needs to be
Expand All @@ -1061,15 +1109,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
ICHECK_EQ(last_page_len_on_depths.size(), num_depths_);
int64_t total_append_length = 0;
int num_sequences = cur_append_lengths_.size();
std::vector<int32_t> cur_append_lengths_indptr{0};
for (int i = 0; i < static_cast<int>(cur_append_lengths_.size()); ++i) {
cur_append_lengths_indptr.push_back(cur_append_lengths_indptr.back() +
cur_append_lengths_[i]);
cur_append_lengths_indptr_host_.resize(num_sequences + 1);
cur_append_lengths_indptr_host_[0] = 0;
for (int i = 0; i < num_sequences; ++i) {
cur_append_lengths_indptr_host_[i + 1] =
cur_append_lengths_indptr_host_[i] + cur_append_lengths_[i];
}
total_append_length = cur_append_lengths_indptr.back();
total_append_length = cur_append_lengths_indptr_host_.back();
ICHECK_EQ(total_append_length, append_position_map.size());

auto fcopy_from_vec = [](NDArray array, int32_t* vec_data) {
auto fcopy_from_vec = [copy_stream = this->copy_stream_](NDArray array, int32_t* vec_data) {
DLTensor copy_dst = *array.operator->();
DLTensor copy_src;
copy_src.data = vec_data;
Expand All @@ -1079,7 +1128,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
copy_src.shape = array->shape;
copy_src.strides = nullptr;
copy_src.byte_offset = 0;
NDArray::CopyFromTo(&copy_src, &copy_dst);
NDArray::CopyFromTo(&copy_src, &copy_dst, copy_stream);
};

// 1. qo_indptr_on_depths
Expand Down Expand Up @@ -1126,7 +1175,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// 6. cur_append_lengths_indptr
cur_append_length_indptr_view_ =
cur_append_length_indptr_device_.CreateView({num_sequences + 1}, dtype_aux_);
fcopy_from_vec(cur_append_length_indptr_view_, cur_append_lengths_indptr.data());
fcopy_from_vec(cur_append_length_indptr_view_, cur_append_lengths_indptr_host_.data());

// 7. k_ragged_rope_pos_offset
ICHECK_EQ(k_ragged_rope_pos_offset.size(), num_sequences);
Expand Down

0 comments on commit d50b364

Please sign in to comment.