Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2K/N) (#3583)
Browse files Browse the repository at this point in the history
Summary:

- Update TBE benchmark test to support `int32_t` indicies

Differential Revision: D68296454
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 16, 2025
1 parent 3449377 commit 8442ccf
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
13 changes: 9 additions & 4 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def cli() -> None:
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--dense", is_flag=True, default=False)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
@click.option("--indices-dtype", type=click.Choice(["32", "64"]), default="64")
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option("--export-trace", is_flag=True, default=False)
Expand Down Expand Up @@ -166,6 +167,7 @@ def device( # noqa C901
flush_gpu_cache_size_mb: int,
dense: bool,
output_dtype: SparseType,
indices_dtype: str,
requests_data_file: Optional[str],
tables: Optional[str],
export_trace: bool,
Expand All @@ -176,6 +178,9 @@ def device( # noqa C901
cache_load_factor: float,
) -> None:
assert not ssd or not dense, "--ssd cannot be used together with --dense"
indices_dtype_torch: torch.dtype = (
torch.int32 if int(indices_dtype) == 32 else torch.int64
)
np.random.seed(42)
torch.manual_seed(42)
B = batch_size
Expand Down Expand Up @@ -352,8 +357,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb.forward(
indices.long(),
offsets.long(),
indices.to(dtype=indices_dtype_torch),
offsets.to(dtype=indices_dtype_torch),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
Expand Down Expand Up @@ -384,8 +389,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb(
indices.long(),
offsets.long(),
indices.to(dtype=indices_dtype_torch),
offsets.to(dtype=indices_dtype_torch),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3375,8 +3375,10 @@ def prepare_inputs(
)

if force_cast_input_types:
# Force casting indices and offsets to long
(indices, offsets) = indices.long(), offsets.long()
# NOTE: Force offsets to have the same dtype as indices since the
# kernels assume same dtype. We might need to revisit the assumption
# of same dtypes in the future.
offsets = offsets.to(dtype=indices.dtype)

# Force casting per_sample_weights to float
if per_sample_weights is not None:
Expand Down Expand Up @@ -3741,7 +3743,11 @@ def forward(
offsets, batch_size_per_feature_per_rank
)

(indices, offsets) = indices.long(), offsets.long()
# NOTE: Force offsets to have the same dtype as indices since the
# kernels assume same dtype. We might need to revisit the assumption
# of same dtypes in the future.
offsets = offsets.to(dtype=indices.dtype)

# Force casting per_sample_weights to float
if per_sample_weights is not None:
per_sample_weights = per_sample_weights.float()
Expand Down

0 comments on commit 8442ccf

Please sign in to comment.