Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2H/N) (pytorch#3539)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#626

- Update benchmark test for `int32_t` Indicies

Reviewed By: sryap

Differential Revision: D67784746
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 16, 2025
1 parent dcfa428 commit 2c37eac
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 132 deletions.
14 changes: 10 additions & 4 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def cli() -> None:
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--dense", is_flag=True, default=False)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
@click.option("--indices-dtype", type=click.Choice(["32", "64"]), default="64")
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option("--export-trace", is_flag=True, default=False)
Expand Down Expand Up @@ -166,6 +167,7 @@ def device( # noqa C901
flush_gpu_cache_size_mb: int,
dense: bool,
output_dtype: SparseType,
indices_dtype: str,
requests_data_file: Optional[str],
tables: Optional[str],
export_trace: bool,
Expand All @@ -176,6 +178,9 @@ def device( # noqa C901
cache_load_factor: float,
) -> None:
assert not ssd or not dense, "--ssd cannot be used together with --dense"
indices_dtype_torch: torch.dtype = (
torch.int32 if int(indices_dtype) == 32 else torch.int64
)
np.random.seed(42)
torch.manual_seed(42)
B = batch_size
Expand Down Expand Up @@ -272,6 +277,7 @@ def device( # noqa C901
**common_split_args,
)
else:
print(f"IS CUDA AVAILABLE? {torch.cuda.is_available()}")
tbe_type: str = "split"
emb = SplitTableBatchedEmbeddingBagsCodegen(
[
Expand Down Expand Up @@ -352,8 +358,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb.forward(
indices.long(),
offsets.long(),
indices.to(dtype=indices_dtype_torch),
offsets.to(dtype=indices_dtype_torch),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
Expand Down Expand Up @@ -384,8 +390,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb(
indices.long(),
offsets.long(),
indices.to(dtype=indices_dtype_torch),
offsets.to(dtype=indices_dtype_torch),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"

#if FBGEMM_GPU_MEMCHECK
#define FBGEMM_MEM_CHECK_ONLY
#else
#define FBGEMM_MEM_CHECK_ONLY maybe_unused
#endif

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

namespace {
template <typename scalar_t, typename grad_t>
template <typename index_t, typename scalar_t, typename grad_t>
void split_embedding_backward_approx_cpu_kernel(
Tensor grad_output,
Tensor host_weights,
Expand All @@ -44,8 +50,11 @@ void split_embedding_backward_approx_cpu_kernel(
{{ args.split_cpu_kernel_args | join(", ") }}) {
auto grad_output_data = grad_output.accessor<grad_t, 2>();
auto host_weights_data = host_weights.accessor<scalar_t, 1>();
const auto indices_data = indices.accessor<int64_t, 1>();
const auto offsets_data = offsets.accessor<int64_t, 1>();

[[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "split_embedding_backward_approx_cpu_kernel";
const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1);
const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1);

// If indice_weights are not defined, then this accessor won't be used
auto indice_weights_data = indice_weights.defined()
? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
Expand Down Expand Up @@ -133,75 +142,87 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(
!indice_weights.defined() && static_cast<PoolingMode>(pooling_mode) == PoolingMode::SUM;

if (use_fbgemm) {
auto grad_stride = grad_output.size(1);
const float* grad_output_data = grad_output.data_ptr<float>();
float* host_weights_data = host_weights.data_ptr<float>();
const int64_t* indices_data = indices.data_ptr<int64_t>();
const int64_t* offsets_data = offsets.data_ptr<int64_t>();
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
float* momentum1_data = momentum1_host.data_ptr<float>();

at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
int t_begin = tb_begin / B;
int t_end = (tb_end + B - 1) / B;
for (const auto t : c10::irange(t_begin,t_end)) {
auto D_begin = D_offsets_data[t];
auto D = D_offsets_data[t + 1] - D_offsets_data[t];
auto table_begin = weights_offsets_data[t];
auto momentum_begin = momentum1_offsets_data[t];

int64_t hash_size;
int t_temp = t + 1;
do {
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
++t_temp;
} while (hash_size == 0);

int b_begin = (t == t_begin) ? tb_begin % B : 0;
int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;

auto kernel =
fbgemm::GenerateRowWiseSparseAdaGradFused<int64_t, int64_t, float>(
D,
/*prefetch=*/16,
/*use_offsets=*/true,
/*use_stochastic_round=*/true,
/*grad_stride=*/grad_stride);
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
bool success = kernel(
b_end - b_begin,
index_size,
hash_size,
reinterpret_cast<float*>(host_weights_data + table_begin),
reinterpret_cast<const float*>(
grad_output_data + b_begin * grad_stride + D_begin),
reinterpret_cast<float*>(momentum1_data + momentum_begin),
indices_data + *offsets_begin_ptr,
offsets_begin_ptr,
eps,
// fbgemm follows caffe2 convention of negative learning rate
-learning_rate);

if (!success) {
fbgemm_gpu::report_embedding_error(
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] {

auto grad_stride = grad_output.size(1);
const float* grad_output_data = grad_output.data_ptr<float>();
float* host_weights_data = host_weights.data_ptr<float>();

const auto* indices_data = indices.data_ptr<index_t>();
const auto* offsets_data = offsets.data_ptr<index_t>();

// const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1);
// const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1);

const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
float* momentum1_data = momentum1_host.data_ptr<float>();

at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
int t_begin = tb_begin / B;
int t_end = (tb_end + B - 1) / B;

for (const auto t : c10::irange(t_begin,t_end)) {
auto D_begin = D_offsets_data[t];
auto D = D_offsets_data[t + 1] - D_offsets_data[t];
auto table_begin = weights_offsets_data[t];
auto momentum_begin = momentum1_offsets_data[t];

int64_t hash_size;
int t_temp = t + 1;
do {
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
++t_temp;
} while (hash_size == 0);

int b_begin = (t == t_begin) ? tb_begin % B : 0;
int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;

auto kernel =
fbgemm::GenerateRowWiseSparseAdaGradFused<index_t, index_t, float>(
D,
/*prefetch=*/16,
/*use_offsets=*/true,
/*use_stochastic_round=*/true,
/*grad_stride=*/grad_stride);
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
bool success = kernel(
b_end - b_begin,
index_size,
hash_size,
reinterpret_cast<float*>(host_weights_data + table_begin),
reinterpret_cast<const float*>(
grad_output_data + b_begin * grad_stride + D_begin),
reinterpret_cast<float*>(momentum1_data + momentum_begin),
indices_data + *offsets_begin_ptr,
offsets_begin_ptr,
eps,
// fbgemm follows caffe2 convention of negative learning rate
-learning_rate);

if (!success) {
fbgemm_gpu::report_embedding_error(
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
}
}
}
}); // parallel_for
}); // parallel_for
}); // dispatch indices.scalar_type()

return;
} // use_fbgemm

{% endif %}

FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_cpu", [&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] {

FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_approx_cpu_kernel_2", [&] {
using grad_t = scalar_t;
FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(),
"split_embedding_backward_cpu_inner",
[&] {
split_embedding_backward_approx_cpu_kernel<scalar_t, grad_t>(

FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(), "split_embedding_backward_approx_cpu_kernel_3", [&] {
split_embedding_backward_approx_cpu_kernel<index_t, scalar_t, grad_t>(
grad_output,
host_weights,
weights_offsets_data,
Expand All @@ -220,7 +241,8 @@ for (const auto t : c10::irange(t_begin,t_end)) {
{% endif %}
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
}); // dispatch host_weights.scalar_type()
}); // dispatch grad_output.scalar_type()
}); // dispatch grad_output.scalar_type()
}); // dispatch indices.scalar_type()

return;
}
Expand Down
Loading

0 comments on commit 2c37eac

Please sign in to comment.