Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for int32_t indices in TBE training (2H/N) #3539

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"

#if FBGEMM_GPU_MEMCHECK
#define FBGEMM_MEM_CHECK_ONLY
#else
#define FBGEMM_MEM_CHECK_ONLY maybe_unused
#endif

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

namespace {
template <typename scalar_t, typename grad_t>
template <typename index_t, typename scalar_t, typename grad_t>
void split_embedding_backward_approx_cpu_kernel(
Tensor grad_output,
Tensor host_weights,
Expand All @@ -44,8 +50,11 @@ void split_embedding_backward_approx_cpu_kernel(
{{ args.split_cpu_kernel_args | join(", ") }}) {
auto grad_output_data = grad_output.accessor<grad_t, 2>();
auto host_weights_data = host_weights.accessor<scalar_t, 1>();
const auto indices_data = indices.accessor<int64_t, 1>();
const auto offsets_data = offsets.accessor<int64_t, 1>();

[[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "split_embedding_backward_approx_cpu_kernel";
const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1);
const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1);

// If indice_weights are not defined, then this accessor won't be used
auto indice_weights_data = indice_weights.defined()
? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
Expand Down Expand Up @@ -133,75 +142,84 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(
!indice_weights.defined() && static_cast<PoolingMode>(pooling_mode) == PoolingMode::SUM;

if (use_fbgemm) {
auto grad_stride = grad_output.size(1);
const float* grad_output_data = grad_output.data_ptr<float>();
float* host_weights_data = host_weights.data_ptr<float>();
const int64_t* indices_data = indices.data_ptr<int64_t>();
const int64_t* offsets_data = offsets.data_ptr<int64_t>();
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
float* momentum1_data = momentum1_host.data_ptr<float>();

at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
int t_begin = tb_begin / B;
int t_end = (tb_end + B - 1) / B;
for (const auto t : c10::irange(t_begin,t_end)) {
auto D_begin = D_offsets_data[t];
auto D = D_offsets_data[t + 1] - D_offsets_data[t];
auto table_begin = weights_offsets_data[t];
auto momentum_begin = momentum1_offsets_data[t];

int64_t hash_size;
int t_temp = t + 1;
do {
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
++t_temp;
} while (hash_size == 0);

int b_begin = (t == t_begin) ? tb_begin % B : 0;
int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;

auto kernel =
fbgemm::GenerateRowWiseSparseAdaGradFused<int64_t, int64_t, float>(
D,
/*prefetch=*/16,
/*use_offsets=*/true,
/*use_stochastic_round=*/true,
/*grad_stride=*/grad_stride);
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
bool success = kernel(
b_end - b_begin,
index_size,
hash_size,
reinterpret_cast<float*>(host_weights_data + table_begin),
reinterpret_cast<const float*>(
grad_output_data + b_begin * grad_stride + D_begin),
reinterpret_cast<float*>(momentum1_data + momentum_begin),
indices_data + *offsets_begin_ptr,
offsets_begin_ptr,
eps,
// fbgemm follows caffe2 convention of negative learning rate
-learning_rate);

if (!success) {
fbgemm_gpu::report_embedding_error(
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] {

auto grad_stride = grad_output.size(1);
const float* grad_output_data = grad_output.data_ptr<float>();
float* host_weights_data = host_weights.data_ptr<float>();

const auto* indices_data = indices.data_ptr<index_t>();
const auto* offsets_data = offsets.data_ptr<index_t>();

const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
float* momentum1_data = momentum1_host.data_ptr<float>();

at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
int t_begin = tb_begin / B;
int t_end = (tb_end + B - 1) / B;

for (const auto t : c10::irange(t_begin,t_end)) {
auto D_begin = D_offsets_data[t];
auto D = D_offsets_data[t + 1] - D_offsets_data[t];
auto table_begin = weights_offsets_data[t];
auto momentum_begin = momentum1_offsets_data[t];

int64_t hash_size;
int t_temp = t + 1;
do {
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
++t_temp;
} while (hash_size == 0);

int b_begin = (t == t_begin) ? tb_begin % B : 0;
int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;

auto kernel =
fbgemm::GenerateRowWiseSparseAdaGradFused<index_t, index_t, float>(
D,
/*prefetch=*/16,
/*use_offsets=*/true,
/*use_stochastic_round=*/true,
/*grad_stride=*/grad_stride);
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
bool success = kernel(
b_end - b_begin,
index_size,
hash_size,
reinterpret_cast<float*>(host_weights_data + table_begin),
reinterpret_cast<const float*>(
grad_output_data + b_begin * grad_stride + D_begin),
reinterpret_cast<float*>(momentum1_data + momentum_begin),
indices_data + *offsets_begin_ptr,
offsets_begin_ptr,
eps,
// fbgemm follows caffe2 convention of negative learning rate
-learning_rate);

if (!success) {
fbgemm_gpu::report_embedding_error(
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
}
}
}
}); // parallel_for
}); // parallel_for
}); // dispatch indices.scalar_type()

return;
} // use_fbgemm

{% endif %}

FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_cpu", [&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] {

FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_approx_cpu_kernel_2", [&] {
using grad_t = scalar_t;
FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(),
"split_embedding_backward_cpu_inner",
[&] {
split_embedding_backward_approx_cpu_kernel<scalar_t, grad_t>(

FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(), "split_embedding_backward_approx_cpu_kernel_3", [&] {
split_embedding_backward_approx_cpu_kernel<index_t, scalar_t, grad_t>(
grad_output,
host_weights,
weights_offsets_data,
Expand All @@ -220,7 +238,8 @@ for (const auto t : c10::irange(t_begin,t_end)) {
{% endif %}
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
}); // dispatch host_weights.scalar_type()
}); // dispatch grad_output.scalar_type()
}); // dispatch grad_output.scalar_type()
}); // dispatch indices.scalar_type()

return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
#include "fbgemm_gpu/utils/cpu_utils.h"
#include "fbgemm_gpu/utils/ops_utils.h"

#if FBGEMM_GPU_MEMCHECK
#define FBGEMM_MEM_CHECK_ONLY
#else
#define FBGEMM_MEM_CHECK_ONLY maybe_unused
#endif

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

Expand All @@ -40,7 +46,7 @@ struct half2float16<at::Half> {
} // namespace internal

namespace {
template <typename scalar_t, typename grad_t>
template <typename index_t, typename scalar_t, typename grad_t>
void split_embedding_backward_exact_cpu_kernel(
Tensor grad_output,
Tensor host_weights,
Expand Down Expand Up @@ -94,8 +100,8 @@ for (const auto t : c10::irange(num_tables)) {
::internal::csr2csc(
cscs[t],
B,
MAKE_TA_WITH_NAME(func_name, offsets, int64_t, 1),
MAKE_TA_WITH_NAME(func_name, indices, int64_t, 1),
MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1),
MAKE_TA_WITH_NAME(func_name, indices, index_t, 1),
MAKE_TA_WITH_NAME(func_name, indice_weights, weight_t, 1),
pooling_mode,
table_to_feature_offset + t,
Expand Down Expand Up @@ -196,19 +202,21 @@ for (const auto t : c10::irange(num_tables)) {
// TODO: to parallelize, we should easily identify segments belong to
// the same column.
at::acc_type<grad_t, true> grad_buffer[D];
for (const auto c : c10::irange(num_non_zero_columns)) {
for (const auto c : c10::irange(num_non_zero_columns)) {
int64_t idx = col_segment_indices[c];
if (c == 0 || col_segment_indices[c - 1] != idx) {
memset(grad_buffer, 0, D * sizeof(at::acc_type<grad_t, true>));
}
[[maybe_unused]] const int64_t embedding_begin = table_begin + idx * D;

for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) {
int D_offset = D_begin;
if (is_shared_table) {
D_offset += cscs[t].column_segment_ids[r] * D;
}
int b = cscs[t].row_indices[r];
for (const auto d : c10::irange(D)) {

for (const auto d : c10::irange(D)) {
if (cscs[t].weights != nullptr) {
grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d] *
cscs[t].weights[r];
Expand All @@ -225,7 +233,7 @@ for (const auto d : c10::irange(D)) {
} // for each table
}

template <typename scalar_t>
template <typename index_t, typename scalar_t>
void split_embedding_backward_exact_cpu_dense_kernel(
Tensor grad,
Tensor grad_output,
Expand All @@ -242,8 +250,10 @@ void split_embedding_backward_exact_cpu_dense_kernel(

auto grad_output_data = grad_output.accessor<scalar_t, 2>();

const auto indices_data = indices.accessor<int64_t, 1>();
const auto offsets_data = offsets.accessor<int64_t, 1>();
[[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "split_embedding_backward_exact_cpu_dense_kernel";

const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1);
const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1);
const auto indice_weights_data = indice_weights.defined()
?
// If indice_weights are not defined, then this accessor won't be
Expand Down Expand Up @@ -349,34 +359,41 @@ for (const auto d : c10::irange(D)) {

grad_output = grad_output.contiguous();


FBGEMM_DISPATCH_FLOAT_AND_HALF(
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"split_embedding_backward_exact_cpu_kernel_1", [&] {

FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(),
"split_embedding_backward_exact_cpu_outer", [&]() {
using grad_t = scalar_t;
"split_embedding_backward_exact_cpu_kernel_2", [&] {
using grad_t = scalar_t;

FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>(
grad_output,
host_weights,
weights_offsets_data,
D_offsets_data,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
num_tables,
B,
table_to_feature_offset,
{% if "momentum1_offsets" in args.split_function_arg_names %}
momentum1_offsets_data,
{% endif %}
{% if "momentum2_offsets" in args.split_function_arg_names %}
momentum2_offsets_data,
{% endif %}
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
});
host_weights.scalar_type(),
"split_embedding_backward_exact_cpu_kernel_3", [&] {

split_embedding_backward_exact_cpu_kernel<index_t, scalar_t, grad_t>(
grad_output,
host_weights,
weights_offsets_data,
D_offsets_data,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
num_tables,
B,
table_to_feature_offset,
{% if "momentum1_offsets" in args.split_function_arg_names %}
momentum1_offsets_data,
{% endif %}
{% if "momentum2_offsets" in args.split_function_arg_names %}
momentum2_offsets_data,
{% endif %}
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
});
});
});

return;
Expand All @@ -385,10 +402,15 @@ for (const auto d : c10::irange(D)) {

// When input is dense enough, avoid sorting and just treat as dense.
auto grad = zeros_like(host_weights, grad_output.dtype());
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"split_embedding_backward_exact_cpu_dense_kernel", [&] {

split_embedding_backward_exact_cpu_dense_kernel<scalar_t>(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(),
"split_embedding_backward_exact_cpu", [&] {

split_embedding_backward_exact_cpu_dense_kernel<index_t, scalar_t>(
grad,
grad_output,
weights_offsets_data,
Expand All @@ -400,7 +422,8 @@ for (const auto d : c10::irange(D)) {
num_tables,
B,
table_to_feature_offset);
}); // dispatch host_weights.scalar_type()
});
});

return grad;
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }},
output_dtype);

static auto op2 =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::split_embedding_codegen_grad_indice_weights_cpu", "")
Expand Down
Loading
Loading