Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2H/N) (pytorch#3539)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#626

- Add `int32_t` indices support for the TBE CPU kernels

Reviewed By: sryap

Differential Revision: D67784746
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 16, 2025
1 parent b96e2e2 commit 7bd2d46
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"

#if FBGEMM_GPU_MEMCHECK
#define FBGEMM_MEM_CHECK_ONLY
#else
#define FBGEMM_MEM_CHECK_ONLY maybe_unused
#endif

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

namespace {
template <typename scalar_t, typename grad_t>
template <typename index_t, typename scalar_t, typename grad_t>
void split_embedding_backward_approx_cpu_kernel(
Tensor grad_output,
Tensor host_weights,
Expand All @@ -44,8 +50,11 @@ void split_embedding_backward_approx_cpu_kernel(
{{ args.split_cpu_kernel_args | join(", ") }}) {
auto grad_output_data = grad_output.accessor<grad_t, 2>();
auto host_weights_data = host_weights.accessor<scalar_t, 1>();
const auto indices_data = indices.accessor<int64_t, 1>();
const auto offsets_data = offsets.accessor<int64_t, 1>();

[[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "split_embedding_backward_approx_cpu_kernel";
const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1);
const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1);

// If indice_weights are not defined, then this accessor won't be used
auto indice_weights_data = indice_weights.defined()
? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
Expand Down Expand Up @@ -133,75 +142,84 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(
!indice_weights.defined() && static_cast<PoolingMode>(pooling_mode) == PoolingMode::SUM;

if (use_fbgemm) {
auto grad_stride = grad_output.size(1);
const float* grad_output_data = grad_output.data_ptr<float>();
float* host_weights_data = host_weights.data_ptr<float>();
const int64_t* indices_data = indices.data_ptr<int64_t>();
const int64_t* offsets_data = offsets.data_ptr<int64_t>();
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
float* momentum1_data = momentum1_host.data_ptr<float>();

at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
int t_begin = tb_begin / B;
int t_end = (tb_end + B - 1) / B;
for (const auto t : c10::irange(t_begin,t_end)) {
auto D_begin = D_offsets_data[t];
auto D = D_offsets_data[t + 1] - D_offsets_data[t];
auto table_begin = weights_offsets_data[t];
auto momentum_begin = momentum1_offsets_data[t];

int64_t hash_size;
int t_temp = t + 1;
do {
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
++t_temp;
} while (hash_size == 0);

int b_begin = (t == t_begin) ? tb_begin % B : 0;
int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;

auto kernel =
fbgemm::GenerateRowWiseSparseAdaGradFused<int64_t, int64_t, float>(
D,
/*prefetch=*/16,
/*use_offsets=*/true,
/*use_stochastic_round=*/true,
/*grad_stride=*/grad_stride);
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
bool success = kernel(
b_end - b_begin,
index_size,
hash_size,
reinterpret_cast<float*>(host_weights_data + table_begin),
reinterpret_cast<const float*>(
grad_output_data + b_begin * grad_stride + D_begin),
reinterpret_cast<float*>(momentum1_data + momentum_begin),
indices_data + *offsets_begin_ptr,
offsets_begin_ptr,
eps,
// fbgemm follows caffe2 convention of negative learning rate
-learning_rate);

if (!success) {
fbgemm_gpu::report_embedding_error(
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] {

auto grad_stride = grad_output.size(1);
const float* grad_output_data = grad_output.data_ptr<float>();
float* host_weights_data = host_weights.data_ptr<float>();

const auto* indices_data = indices.data_ptr<index_t>();
const auto* offsets_data = offsets.data_ptr<index_t>();

const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
float* momentum1_data = momentum1_host.data_ptr<float>();

at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
int t_begin = tb_begin / B;
int t_end = (tb_end + B - 1) / B;

for (const auto t : c10::irange(t_begin,t_end)) {
auto D_begin = D_offsets_data[t];
auto D = D_offsets_data[t + 1] - D_offsets_data[t];
auto table_begin = weights_offsets_data[t];
auto momentum_begin = momentum1_offsets_data[t];

int64_t hash_size;
int t_temp = t + 1;
do {
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
++t_temp;
} while (hash_size == 0);

int b_begin = (t == t_begin) ? tb_begin % B : 0;
int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;

auto kernel =
fbgemm::GenerateRowWiseSparseAdaGradFused<index_t, index_t, float>(
D,
/*prefetch=*/16,
/*use_offsets=*/true,
/*use_stochastic_round=*/true,
/*grad_stride=*/grad_stride);
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
bool success = kernel(
b_end - b_begin,
index_size,
hash_size,
reinterpret_cast<float*>(host_weights_data + table_begin),
reinterpret_cast<const float*>(
grad_output_data + b_begin * grad_stride + D_begin),
reinterpret_cast<float*>(momentum1_data + momentum_begin),
indices_data + *offsets_begin_ptr,
offsets_begin_ptr,
eps,
// fbgemm follows caffe2 convention of negative learning rate
-learning_rate);

if (!success) {
fbgemm_gpu::report_embedding_error(
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
}
}
}
}); // parallel_for
}); // parallel_for
}); // dispatch indices.scalar_type()

return;
} // use_fbgemm

{% endif %}

FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_cpu", [&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] {

FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_approx_cpu_kernel_2", [&] {
using grad_t = scalar_t;
FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(),
"split_embedding_backward_cpu_inner",
[&] {
split_embedding_backward_approx_cpu_kernel<scalar_t, grad_t>(

FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(), "split_embedding_backward_approx_cpu_kernel_3", [&] {
split_embedding_backward_approx_cpu_kernel<index_t, scalar_t, grad_t>(
grad_output,
host_weights,
weights_offsets_data,
Expand All @@ -220,7 +238,8 @@ for (const auto t : c10::irange(t_begin,t_end)) {
{% endif %}
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
}); // dispatch host_weights.scalar_type()
}); // dispatch grad_output.scalar_type()
}); // dispatch grad_output.scalar_type()
}); // dispatch indices.scalar_type()

return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
#include "fbgemm_gpu/utils/cpu_utils.h"
#include "fbgemm_gpu/utils/ops_utils.h"

#if FBGEMM_GPU_MEMCHECK
#define FBGEMM_MEM_CHECK_ONLY
#else
#define FBGEMM_MEM_CHECK_ONLY maybe_unused
#endif

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

Expand All @@ -40,7 +46,7 @@ struct half2float16<at::Half> {
} // namespace internal

namespace {
template <typename scalar_t, typename grad_t>
template <typename index_t, typename scalar_t, typename grad_t>
void split_embedding_backward_exact_cpu_kernel(
Tensor grad_output,
Tensor host_weights,
Expand Down Expand Up @@ -94,8 +100,8 @@ for (const auto t : c10::irange(num_tables)) {
::internal::csr2csc(
cscs[t],
B,
MAKE_TA_WITH_NAME(func_name, offsets, int64_t, 1),
MAKE_TA_WITH_NAME(func_name, indices, int64_t, 1),
MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1),
MAKE_TA_WITH_NAME(func_name, indices, index_t, 1),
MAKE_TA_WITH_NAME(func_name, indice_weights, weight_t, 1),
pooling_mode,
table_to_feature_offset + t,
Expand Down Expand Up @@ -196,19 +202,21 @@ for (const auto t : c10::irange(num_tables)) {
// TODO: to parallelize, we should easily identify segments belong to
// the same column.
at::acc_type<grad_t, true> grad_buffer[D];
for (const auto c : c10::irange(num_non_zero_columns)) {
for (const auto c : c10::irange(num_non_zero_columns)) {
int64_t idx = col_segment_indices[c];
if (c == 0 || col_segment_indices[c - 1] != idx) {
memset(grad_buffer, 0, D * sizeof(at::acc_type<grad_t, true>));
}
[[maybe_unused]] const int64_t embedding_begin = table_begin + idx * D;

for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) {
int D_offset = D_begin;
if (is_shared_table) {
D_offset += cscs[t].column_segment_ids[r] * D;
}
int b = cscs[t].row_indices[r];
for (const auto d : c10::irange(D)) {

for (const auto d : c10::irange(D)) {
if (cscs[t].weights != nullptr) {
grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d] *
cscs[t].weights[r];
Expand All @@ -225,7 +233,7 @@ for (const auto d : c10::irange(D)) {
} // for each table
}

template <typename scalar_t>
template <typename index_t, typename scalar_t>
void split_embedding_backward_exact_cpu_dense_kernel(
Tensor grad,
Tensor grad_output,
Expand All @@ -242,8 +250,10 @@ void split_embedding_backward_exact_cpu_dense_kernel(

auto grad_output_data = grad_output.accessor<scalar_t, 2>();

const auto indices_data = indices.accessor<int64_t, 1>();
const auto offsets_data = offsets.accessor<int64_t, 1>();
[[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "split_embedding_backward_exact_cpu_dense_kernel";

const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1);
const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1);
const auto indice_weights_data = indice_weights.defined()
?
// If indice_weights are not defined, then this accessor won't be
Expand Down Expand Up @@ -349,34 +359,41 @@ for (const auto d : c10::irange(D)) {

grad_output = grad_output.contiguous();


FBGEMM_DISPATCH_FLOAT_AND_HALF(
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"split_embedding_backward_exact_cpu_kernel_1", [&] {

FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(),
"split_embedding_backward_exact_cpu_outer", [&]() {
using grad_t = scalar_t;
"split_embedding_backward_exact_cpu_kernel_2", [&] {
using grad_t = scalar_t;

FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>(
grad_output,
host_weights,
weights_offsets_data,
D_offsets_data,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
num_tables,
B,
table_to_feature_offset,
{% if "momentum1_offsets" in args.split_function_arg_names %}
momentum1_offsets_data,
{% endif %}
{% if "momentum2_offsets" in args.split_function_arg_names %}
momentum2_offsets_data,
{% endif %}
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
});
host_weights.scalar_type(),
"split_embedding_backward_exact_cpu_kernel_3", [&] {

split_embedding_backward_exact_cpu_kernel<index_t, scalar_t, grad_t>(
grad_output,
host_weights,
weights_offsets_data,
D_offsets_data,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
num_tables,
B,
table_to_feature_offset,
{% if "momentum1_offsets" in args.split_function_arg_names %}
momentum1_offsets_data,
{% endif %}
{% if "momentum2_offsets" in args.split_function_arg_names %}
momentum2_offsets_data,
{% endif %}
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
});
});
});

return;
Expand All @@ -385,10 +402,15 @@ for (const auto d : c10::irange(D)) {

// When input is dense enough, avoid sorting and just treat as dense.
auto grad = zeros_like(host_weights, grad_output.dtype());
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"split_embedding_backward_exact_cpu_dense_kernel", [&] {

split_embedding_backward_exact_cpu_dense_kernel<scalar_t>(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(),
"split_embedding_backward_exact_cpu", [&] {

split_embedding_backward_exact_cpu_dense_kernel<index_t, scalar_t>(
grad,
grad_output,
weights_offsets_data,
Expand All @@ -400,7 +422,8 @@ for (const auto d : c10::irange(D)) {
num_tables,
B,
table_to_feature_offset);
}); // dispatch host_weights.scalar_type()
});
});

return grad;
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }},
output_dtype);

static auto op2 =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::split_embedding_codegen_grad_indice_weights_cpu", "")
Expand Down
Loading

0 comments on commit 7bd2d46

Please sign in to comment.