Skip to content

Commit

Permalink
Fix sampling for fp32 devices (#58)
Browse files Browse the repository at this point in the history
* replace bernoulli by uniform

* dispatch different approaches

* optimize dispatching

* fix missprint

* linting

---------

Co-authored-by: Dmitry Razdoburdin <>
  • Loading branch information
razdoburdin authored Jul 9, 2024
1 parent 888ff62 commit ba88551
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
58 changes: 42 additions & 16 deletions plugin/sycl/tree/hist_updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,23 +447,49 @@ void HistUpdater<GradientSumT>::InitSampling(
::sycl::buffer<uint64_t, 1> flag_buf(&num_samples, 1);
uint64_t seed = seed_;
seed_ += num_rows;
event = qu_.submit([&](::sycl::handler& cgh) {
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
[=](::sycl::item<1> pid) {
uint64_t i = pid.get_id(0);

// Create minstd_rand engine
oneapi::dpl::minstd_rand engine(seed, i);
oneapi::dpl::bernoulli_distribution coin_flip(subsample);

auto rnd = coin_flip(engine);
if (gpair_ptr[i].GetHess() >= 0.0f && rnd) {
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
row_idx[num_samples_ref++] = i;
}

/*
* oneDLP bernoulli_distribution implicitly uses double.
* In this case the device doesn't have fp64 support,
* we generate bernoulli distributed random values from uniform distribution
*/
if (has_fp64_support_) {
// Use oneDPL bernoulli_distribution for better perf
event = qu_.submit([&](::sycl::handler& cgh) {
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
[=](::sycl::item<1> pid) {
uint64_t i = pid.get_id(0);
// Create minstd_rand engine
oneapi::dpl::minstd_rand engine(seed, i);
oneapi::dpl::bernoulli_distribution coin_flip(subsample);
auto bernoulli_rnd = coin_flip(engine);

if (gpair_ptr[i].GetHess() >= 0.0f && bernoulli_rnd) {
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
row_idx[num_samples_ref++] = i;
}
});
});
});
} else {
// Use oneDPL uniform for better perf, as far as bernoulli_distribution uses fp64
event = qu_.submit([&](::sycl::handler& cgh) {
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(::sycl::range<1>(num_rows)),
[=](::sycl::item<1> pid) {
uint64_t i = pid.get_id(0);
oneapi::dpl::minstd_rand engine(seed, i);
oneapi::dpl::uniform_real_distribution<float> distr;
const float rnd = distr(engine);
const bool bernoulli_rnd = rnd < subsample ? 1 : 0;

if (gpair_ptr[i].GetHess() >= 0.0f && bernoulli_rnd) {
AtomicRef<uint64_t> num_samples_ref(flag_buf_acc[0]);
row_idx[num_samples_ref++] = i;
}
});
});
}
/* After calling a destructor for flag_buf, content will be copyed to num_samples */
}

Expand Down
2 changes: 2 additions & 0 deletions plugin/sycl/tree/hist_updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class HistUpdater {
if (param.max_depth > 0) {
snode_device_.Resize(&qu, 1u << (param.max_depth + 1));
}
has_fp64_support_ = qu_.get_device().has(::sycl::aspect::fp64);
const auto sub_group_sizes =
qu_.get_device().get_info<::sycl::info::device::sub_group_sizes>();
sub_group_size_ = sub_group_sizes.back();
Expand Down Expand Up @@ -211,6 +212,7 @@ class HistUpdater {

// --data fields--
const Context* ctx_;
bool has_fp64_support_;
size_t sub_group_size_;
const xgboost::tree::TrainParam& param_;
std::shared_ptr<xgboost::common::ColumnSampler> column_sampler_;
Expand Down

0 comments on commit ba88551

Please sign in to comment.