diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 87d9437be..15f1b448c 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -49,6 +49,12 @@ generate_vbe_metadata, is_torchdynamo_compiling, ) +from fbgemm_gpu.tbe_input_multiplexer import ( + TBEInfo, + TBEInputInfo, + TBEInputMultiplexer, + TBEInputMultiplexerConfig, +) from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc @@ -647,6 +653,7 @@ def __init__( # noqa C901 global_weight_decay: Optional[GlobalWeightDecayDefinition] = None, uvm_host_mapped: bool = False, extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None, + tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() @@ -820,6 +827,23 @@ def __init__( # noqa C901 self.feature_table_map: List[int] = ( feature_table_map if feature_table_map is not None else list(range(T_)) ) + + self.tbe_input_multiplexer: Optional[TBEInputMultiplexer] = ( + tbe_input_multiplexer_config.create_tbe_input_multiplexer( + tbe_info=TBEInfo( + table_names=( + table_names + if table_names + else [f"table-{i}" for i in range(len(embedding_specs))] + ), + table_heights=rows, + tbe_uuid=self.uuid, + feature_table_map=self.feature_table_map, + ) + ) + if tbe_input_multiplexer_config is not None + else None + ) T = len(self.feature_table_map) assert T_ <= T table_has_feature = [False] * T_ @@ -1789,6 +1813,15 @@ def forward( # noqa: C901 self._report_io_size_count("fwd_input", indices) self._report_tbe_mem_usage() + if self.tbe_input_multiplexer is not None: + tbe_input_multiplexer: TBEInputMultiplexer = self.tbe_input_multiplexer + if tbe_input_multiplexer.should_run(self.step): + tbe_input_multiplexer.run( + tbe_input_info=TBEInputInfo( + indices, offsets, batch_size_per_feature_per_rank + ) + ) + if len(self.timesteps_prefetched) == 0: # In forward, we don't enable multi-pass prefetch as we want the process # to be as fast as possible and memory usage doesn't matter (will be recycled diff --git a/fbgemm_gpu/fbgemm_gpu/tbe_input_multiplexer.py b/fbgemm_gpu/fbgemm_gpu/tbe_input_multiplexer.py new file mode 100644 index 000000000..0b126fe3f --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/tbe_input_multiplexer.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import abc + +from dataclasses import dataclass +from typing import List, Optional + +from torch import Tensor + + +@dataclass(frozen=True) +class TBEInfo: + """ + contains selective TBE info used for multiplexing. For more info, check https://fburl.com/code/ljnd6j65 + + Args: + table_names: table names within the tbe + table_heights: table heights (hashsize) + tbe_uuid: a unique identifier for the TBE + feature_table_map: feature to table map + """ + + table_names: List[str] + table_heights: List[int] + tbe_uuid: str + feature_table_map: List[int] + + +@dataclass(frozen=True) +class TBEInputInfo: + """ + indices: A 1D-tensor that contains indices to be looked up + from all embedding table. + offsets: A 1D-tensor that conatins offsets of indices. + batch_size_per_feature_per_rank: An optional 2D-tensor that contains batch sizes for every rank and + every feature. this is needed to support VBE. + """ + + indices: Tensor + offsets: Tensor + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None + + +class TBEInputMultiplexer(abc.ABC): + """ + Interface for multiplex TBE input data out, actual implementation may store the data to files + """ + + @abc.abstractmethod + def should_run(self, step: int) -> bool: + """ + To check if should run at this step + Args: + step: the current step + Returns: + True if should run, otherwise False + """ + pass + + @abc.abstractmethod + def run( + self, + tbe_input_info: TBEInputInfo, + ) -> None: + """ + To run the tbe input multiplex, and this is called for every batch that needs to be dumped + Args: + tbe_input_info: tbe input info that contains all the necessary info for further processing + """ + pass + + +@dataclass(frozen=True) +class TBEInputMultiplexerConfig: + """ + Configuration for TBEInputMultiplexer + """ + + # first batch to start run, -1 means no run + start_batch: int = -1 + # total batch to multiplex + total_batch: int = 0 + + def create_tbe_input_multiplexer( + self, + tbe_info: TBEInfo, + ) -> Optional[TBEInputMultiplexer]: + assert ( + self.start_batch == -1 + ), "Cannot specify monitor_start_batch without an actual implementation." + return None