Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that we can run reduce_by_key with const inputs #1528

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion thrust/testing/zip_iterator_reduce_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct TestZipIteratorReduceByKey
ASSERT_EQUAL(h_data4, d_data4);
ASSERT_EQUAL(h_data5, d_data5);
}

// The tests below get miscompiled on Tesla hw for 8b types

#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
Expand Down Expand Up @@ -118,6 +118,51 @@ struct TestZipIteratorReduceByKey
ASSERT_EQUAL(h_data5, d_data5);
ASSERT_EQUAL(h_data6, d_data6);
}

// const inputs, see #1527
{
host_vector<float> h_data3(n, 0.0f);
host_vector<T> h_data4(n, 0);
host_vector<T> h_data5(n, 0);
host_vector<float> h_data6(n, 0.0f);
device_vector<float> d_data3(n, 0.0f);
device_vector<T> d_data4(n, 0);
device_vector<T> d_data5(n, 0);
device_vector<float> d_data6(n, 0.0f);

// run on host
const T* h_begin1 = thrust::raw_pointer_cast(h_data1.data());
const T* h_begin2 = thrust::raw_pointer_cast(h_data2.data());
const float* h_begin3 = thrust::raw_pointer_cast(h_data3.data());
T* h_begin4 = thrust::raw_pointer_cast(h_data4.data());
T* h_begin5 = thrust::raw_pointer_cast(h_data5.data());
float* h_begin6 = thrust::raw_pointer_cast(h_data6.data());
thrust::reduce_by_key(thrust::host,
thrust::make_zip_iterator(thrust::make_tuple(h_begin1, h_begin2)),
thrust::make_zip_iterator(thrust::make_tuple(h_begin1, h_begin2)) + n,
h_begin3,
thrust::make_zip_iterator(thrust::make_tuple(h_begin4, h_begin5)),
h_begin6);

// run on device
const T* d_begin1 = thrust::raw_pointer_cast(d_data1.data());
const T* d_begin2 = thrust::raw_pointer_cast(d_data2.data());
const float* d_begin3 = thrust::raw_pointer_cast(d_data3.data());
T* d_begin4 = thrust::raw_pointer_cast(d_data4.data());
T* d_begin5 = thrust::raw_pointer_cast(d_data5.data());
float* d_begin6 = thrust::raw_pointer_cast(d_data6.data());
thrust::reduce_by_key(thrust::device,
thrust::make_zip_iterator(thrust::make_tuple(d_begin1, d_begin2)),
thrust::make_zip_iterator(thrust::make_tuple(d_begin1, d_begin2)) + n,
d_begin3,
thrust::make_zip_iterator(thrust::make_tuple(d_begin4, d_begin5)),
d_begin6);

ASSERT_EQUAL(h_data3, d_data3);
ASSERT_EQUAL(h_data4, d_data4);
ASSERT_EQUAL(h_data5, d_data5);
ASSERT_EQUAL(h_data6, d_data6);
}
}
};
VariableUnitTest<TestZipIteratorReduceByKey, UnsignedIntegralTypes> TestZipIteratorReduceByKeyInstance;
Expand Down
4 changes: 2 additions & 2 deletions thrust/thrust/system/cuda/detail/reduce_by_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ namespace __reduce_by_key {
}

key_type tile_pred_key = (threadIdx.x == 0)
? keys_load_it[tile_offset - 1]
? key_type(keys_load_it[tile_offset - 1])
: key_type();

sync_threadblock();
Expand Down Expand Up @@ -1057,7 +1057,7 @@ namespace __reduce_by_key {
status = cuda_cub::synchronize(policy);
cuda_cub::throw_on_error(status, "reduce_by_key: failed to synchronize");

int num_runs_out = cuda_cub::get_value(policy, d_num_runs_out);
const auto num_runs_out = cuda_cub::get_value(policy, d_num_runs_out);

return thrust::make_pair(
keys_output + num_runs_out,
Expand Down
Loading