Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2H/N)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#626

- Update benchmark test for `int32_t` Indicies

Reviewed By: sryap

Differential Revision: D67784746
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 6, 2025
1 parent 5544329 commit 1fab876
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def cli() -> None:
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--dense", is_flag=True, default=False)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
@click.option("--indices-dtype", type=click.Choice(["32", "64"]), default="64")
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option("--export-trace", is_flag=True, default=False)
Expand Down Expand Up @@ -166,6 +167,7 @@ def device( # noqa C901
flush_gpu_cache_size_mb: int,
dense: bool,
output_dtype: SparseType,
indices_dtype: int,
requests_data_file: Optional[str],
tables: Optional[str],
export_trace: bool,
Expand All @@ -176,6 +178,9 @@ def device( # noqa C901
cache_load_factor: float,
) -> None:
assert not ssd or not dense, "--ssd cannot be used together with --dense"
indices_dtype_torch: torch.dtype = (
torch.int32 if indices_dtype == 32 else torch.int64
)
np.random.seed(42)
torch.manual_seed(42)
B = batch_size
Expand Down Expand Up @@ -352,8 +357,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb.forward(
indices.long(),
offsets.long(),
indices.to(dtype=indices_dtype_torch),
offsets.to(dtype=indices_dtype_torch),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
Expand Down Expand Up @@ -384,8 +389,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb(
indices.long(),
offsets.long(),
indices.to(dtype=indices_dtype_torch),
offsets.to(dtype=indices_dtype_torch),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
Expand Down

0 comments on commit 1fab876

Please sign in to comment.