Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2H/N) (#3539)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#626

- Update benchmark test for `int32_t` Indicies

Reviewed By: sryap

Differential Revision: D67784746
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 9, 2025
1 parent 7cd05f6 commit b2e29a7
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 43 deletions.
2 changes: 2 additions & 0 deletions fbgemm_gpu/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def benchmark_requests(

if num_warmups > 0:
indices, offsets, weights = requests[0].unpack_3()
print(f"INDICES BENCHMARK {indices.dtype}")
print(f"OFFSETS BENCHMARK {offsets.dtype}")
for _ in range(num_warmups):
out = func(indices, offsets, weights)
if bwd_only:
Expand Down
13 changes: 9 additions & 4 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def cli() -> None:
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--dense", is_flag=True, default=False)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
@click.option("--indices-dtype", type=click.Choice(["32", "64"]), default="64")
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option("--export-trace", is_flag=True, default=False)
Expand Down Expand Up @@ -166,6 +167,7 @@ def device( # noqa C901
flush_gpu_cache_size_mb: int,
dense: bool,
output_dtype: SparseType,
indices_dtype: str,
requests_data_file: Optional[str],
tables: Optional[str],
export_trace: bool,
Expand All @@ -176,6 +178,9 @@ def device( # noqa C901
cache_load_factor: float,
) -> None:
assert not ssd or not dense, "--ssd cannot be used together with --dense"
indices_dtype_torch: torch.dtype = (
torch.int32 if int(indices_dtype) == 32 else torch.int64
)
np.random.seed(42)
torch.manual_seed(42)
B = batch_size
Expand Down Expand Up @@ -352,8 +357,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb.forward(
indices.long(),
offsets.long(),
indices.to(dtype=indices_dtype_torch),
offsets.to(dtype=indices_dtype_torch),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
Expand Down Expand Up @@ -384,8 +389,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb(
indices.long(),
offsets.long(),
indices.to(dtype=indices_dtype_torch),
offsets.to(dtype=indices_dtype_torch),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct half2float16<at::Half> {
} // namespace internal

namespace {
template <typename scalar_t, typename grad_t>
template <typename index_t, typename scalar_t, typename grad_t>
void split_embedding_backward_exact_cpu_kernel(
Tensor grad_output,
Tensor host_weights,
Expand Down Expand Up @@ -90,8 +90,8 @@ for (const auto t : c10::irange(num_tables)) {
::internal::csr2csc(
cscs[t],
B,
offsets.accessor<int64_t, 1>(),
indices.accessor<int64_t, 1>(),
offsets.accessor<index_t, 1>(),
indices.accessor<index_t, 1>(),
indice_weights.defined()
? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
: at::TensorAccessor<at::acc_type<scalar_t, true>, 1>(nullptr, nullptr, nullptr),
Expand Down Expand Up @@ -223,7 +223,7 @@ for (const auto d : c10::irange(D)) {
} // for each table
}

template <typename scalar_t>
template <typename index_t, typename scalar_t>
void split_embedding_backward_exact_cpu_dense_kernel(
Tensor grad,
Tensor grad_output,
Expand All @@ -240,8 +240,8 @@ void split_embedding_backward_exact_cpu_dense_kernel(

auto grad_output_data = grad_output.accessor<scalar_t, 2>();

const auto indices_data = indices.accessor<int64_t, 1>();
const auto offsets_data = offsets.accessor<int64_t, 1>();
const auto indices_data = indices.accessor<index_t, 1>();
const auto offsets_data = offsets.accessor<index_t, 1>();
const auto indice_weights_data = indice_weights.defined()
?
// If indice_weights are not defined, then this accessor won't be
Expand Down Expand Up @@ -347,34 +347,42 @@ for (const auto d : c10::irange(D)) {

grad_output = grad_output.contiguous();


FBGEMM_DISPATCH_FLOAT_AND_HALF(
FBGEMM_DISPATCH_INTEGRAL_TYPES(
indices.scalar_type(),
"split_embedding_backward_exact_cpu_kernel_1", [&] {
using index_t = scalar_t;

FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(),
"split_embedding_backward_exact_cpu_outer", [&]() {
using grad_t = scalar_t;
"split_embedding_backward_exact_cpu_kernel_2", [&] {
using grad_t = scalar_t;

FBGEMM_DISPATCH_FLOAT_AND_HALF(
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>(
grad_output,
host_weights,
weights_offsets_data,
D_offsets_data,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
num_tables,
B,
table_to_feature_offset,
{% if "momentum1_offsets" in args.split_function_arg_names %}
momentum1_offsets_data,
{% endif %}
{% if "momentum2_offsets" in args.split_function_arg_names %}
momentum2_offsets_data,
{% endif %}
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
});
host_weights.scalar_type(),
"split_embedding_backward_exact_cpu_kernel_3", [&] {

split_embedding_backward_exact_cpu_kernel<index_t, scalar_t, grad_t>(
grad_output,
host_weights,
weights_offsets_data,
D_offsets_data,
hash_size_cumsum,
indices,
offsets,
pooling_mode,
indice_weights,
num_tables,
B,
table_to_feature_offset,
{% if "momentum1_offsets" in args.split_function_arg_names %}
momentum1_offsets_data,
{% endif %}
{% if "momentum2_offsets" in args.split_function_arg_names %}
momentum2_offsets_data,
{% endif %}
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
});
});
});

return;
Expand All @@ -383,10 +391,16 @@ for (const auto d : c10::irange(D)) {

// When input is dense enough, avoid sorting and just treat as dense.
auto grad = zeros_like(host_weights, grad_output.dtype());
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
FBGEMM_DISPATCH_INTEGRAL_TYPES(
indices.scalar_type(),
"split_embedding_backward_exact_cpu_dense_kernel", [&] {
using index_t = scalar_t;

split_embedding_backward_exact_cpu_dense_kernel<scalar_t>(
FBGEMM_DISPATCH_FLOAT_AND_HALF(
grad_output.scalar_type(),
"split_embedding_backward_exact_cpu", [&] {

split_embedding_backward_exact_cpu_dense_kernel<index_t, scalar_t>(
grad,
grad_output,
weights_offsets_data,
Expand All @@ -398,7 +412,8 @@ for (const auto d : c10::irange(D)) {
num_tables,
B,
table_to_feature_offset);
}); // dispatch host_weights.scalar_type()
});
});

return grad;
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3315,8 +3315,10 @@ def prepare_inputs(
)

if force_cast_input_types:
# Force casting indices and offsets to long
(indices, offsets) = indices.long(), offsets.long()
# NOTE: Force offsets to have the same dtype as indices since the
# kernels assume same dtype. We might need to revisit the assumption
# of same dtypes in the future.
offsets = offsets.to(dtype=indices.dtype)

# Force casting per_sample_weights to float
if per_sample_weights is not None:
Expand Down Expand Up @@ -3681,7 +3683,11 @@ def forward(
offsets, batch_size_per_feature_per_rank
)

(indices, offsets) = indices.long(), offsets.long()
# NOTE: Force offsets to have the same dtype as indices since the
# kernels assume same dtype. We might need to revisit the assumption
# of same dtypes in the future.
offsets = offsets.to(dtype=indices.dtype)

# Force casting per_sample_weights to float
if per_sample_weights is not None:
per_sample_weights = per_sample_weights.float()
Expand Down

0 comments on commit b2e29a7

Please sign in to comment.