diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 87d9437be..7d489579a 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -25,7 +25,6 @@ # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers - from fbgemm_gpu.config import FeatureGate, FeatureGateName from fbgemm_gpu.runtime_monitor import ( AsyncSeriesTimer, @@ -49,6 +48,7 @@ generate_vbe_metadata, is_torchdynamo_compiling, ) +from fbgemm_gpu.tbe_input_dump import TBEInputDump, TBEInputDumpConfig from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc @@ -647,6 +647,7 @@ def __init__( # noqa C901 global_weight_decay: Optional[GlobalWeightDecayDefinition] = None, uvm_host_mapped: bool = False, extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None, + tbe_input_dump_config: Optional[TBEInputDumpConfig] = None, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() @@ -820,6 +821,21 @@ def __init__( # noqa C901 self.feature_table_map: List[int] = ( feature_table_map if feature_table_map is not None else list(range(T_)) ) + + self.tbe_input_dump: Optional[TBEInputDump] = ( + tbe_input_dump_config.create_tbe_input_dump( + table_names=( + table_names + if table_names + else [f"table-{i}" for i in range(len(embedding_specs))] + ), + table_heights=rows, + tbe_uuid=self.uuid, + feature_table_map=self.feature_table_map, + ) + if tbe_input_dump_config is not None + else None + ) T = len(self.feature_table_map) assert T_ <= T table_has_feature = [False] * T_ @@ -1789,6 +1805,11 @@ def forward( # noqa: C901 self._report_io_size_count("fwd_input", indices) self._report_tbe_mem_usage() + if self.tbe_input_dump is not None: + tbe_input_dump: TBEInputDump = self.tbe_input_dump + if tbe_input_dump.should_dump(self.step): + tbe_input_dump.run(indices, offsets, batch_size_per_feature_per_rank) + if len(self.timesteps_prefetched) == 0: # In forward, we don't enable multi-pass prefetch as we want the process # to be as fast as possible and memory usage doesn't matter (will be recycled diff --git a/fbgemm_gpu/fbgemm_gpu/tbe_input_dump.py b/fbgemm_gpu/fbgemm_gpu/tbe_input_dump.py new file mode 100644 index 000000000..037b3af77 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/tbe_input_dump.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc + +from dataclasses import dataclass +from typing import List, Optional + +from torch import Tensor + + +class TBEInputDump(abc.ABC): + """ + Interface for dump TBE input data out, actual implementation may store the data to files + """ + + @abc.abstractmethod + def should_dump(self, step: int) -> bool: + """ + To check if the dump should be triggered at this step + Args: + step: the current step + Returns: + True if the dump should be triggered, otherwise False + """ + pass + + @abc.abstractmethod + def run( + self, + indices: Tensor, + offsets: Tensor, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, + ) -> None: + """ + To run the tbe input dump, and this is called for every batch that needs to be dumped + Args: + indices: A 1D-tensor that contains indices to be looked up + from all embedding table. + offsets: A 1D-tensor that conatins offsets of indices. + batch_size_per_feature_per_rank: An optional 2D-tensor that contains batch sizes for every rank and + every feature. this is needed to support VBE. + """ + pass + + +@dataclass(frozen=True) +class TBEInputDumpConfig: + """ + Configuration for TBEInputDump + """ + + # first batch to start dump, -1 means no dump + monitored_batch_start: int = -1 + # total batch to dump + monitored_total_batch: int = 0 + + def create_tbe_input_dump( + self, + table_names: List[str], + table_heights: List[int], + tbe_uuid: str, + feature_table_map: List[int], + ) -> Optional[TBEInputDump]: + assert ( + self.monitored_batch_start == -1 + ), "Cannot specify monitored_batch_start without an actual implementation of tbe dump" + return None