Skip to content

Commit

Permalink
: basic tbe input dump framework (#3593)
Browse files Browse the repository at this point in the history
Summary:

Plugin capability to dump TBE input and no-ops in OSS

Reviewed By: levythu

Differential Revision: D68446857
  • Loading branch information
Sihui Han authored and facebook-github-bot committed Jan 24, 2025
1 parent 1aff241 commit 5a344c1
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
generate_vbe_metadata,
is_torchdynamo_compiling,
)
from fbgemm_gpu.tbe_input_multiplexer import (
TBEInfo,
TBEInputInfo,
TBEInputMultiplexer,
TBEInputMultiplexerConfig,
)

from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc

Expand Down Expand Up @@ -647,6 +653,7 @@ def __init__( # noqa C901
global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
uvm_host_mapped: bool = False,
extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand Down Expand Up @@ -820,6 +827,23 @@ def __init__( # noqa C901
self.feature_table_map: List[int] = (
feature_table_map if feature_table_map is not None else list(range(T_))
)

self.tbe_input_multiplexer: Optional[TBEInputMultiplexer] = (
tbe_input_multiplexer_config.create_tbe_input_multiplexer(
tbe_info=TBEInfo(
table_names=(
table_names
if table_names
else [f"table-{i}" for i in range(len(embedding_specs))]
),
table_heights=rows,
tbe_uuid=self.uuid,
feature_table_map=self.feature_table_map,
)
)
if tbe_input_multiplexer_config is not None
else None
)
T = len(self.feature_table_map)
assert T_ <= T
table_has_feature = [False] * T_
Expand Down Expand Up @@ -1789,6 +1813,15 @@ def forward( # noqa: C901
self._report_io_size_count("fwd_input", indices)
self._report_tbe_mem_usage()

if self.tbe_input_multiplexer is not None:
tbe_input_multiplexer: TBEInputMultiplexer = self.tbe_input_multiplexer
if tbe_input_multiplexer.should_run(self.step):
tbe_input_multiplexer.run(
tbe_input_info=TBEInputInfo(
indices, offsets, batch_size_per_feature_per_rank
)
)

if len(self.timesteps_prefetched) == 0:
# In forward, we don't enable multi-pass prefetch as we want the process
# to be as fast as possible and memory usage doesn't matter (will be recycled
Expand Down
96 changes: 96 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe_input_multiplexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import abc

from dataclasses import dataclass
from typing import List, Optional

from torch import Tensor


@dataclass(frozen=True)
class TBEInfo:
"""
contains selective TBE info used for multiplexing. For more info, check https://fburl.com/code/ljnd6j65
Args:
table_names: table names within the tbe
table_heights: table heights (hashsize)
tbe_uuid: a unique identifier for the TBE
feature_table_map: feature to table map
"""

table_names: List[str]
table_heights: List[int]
tbe_uuid: str
feature_table_map: List[int]


@dataclass(frozen=True)
class TBEInputInfo:
"""
indices: A 1D-tensor that contains indices to be looked up
from all embedding table.
offsets: A 1D-tensor that conatins offsets of indices.
batch_size_per_feature_per_rank: An optional 2D-tensor that contains batch sizes for every rank and
every feature. this is needed to support VBE.
"""

indices: Tensor
offsets: Tensor
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None


class TBEInputMultiplexer(abc.ABC):
"""
Interface for multiplex TBE input data out, actual implementation may store the data to files
"""

@abc.abstractmethod
def should_run(self, step: int) -> bool:
"""
To check if should run at this step
Args:
step: the current step
Returns:
True if should run, otherwise False
"""
pass

@abc.abstractmethod
def run(
self,
tbe_input_info: TBEInputInfo,
) -> None:
"""
To run the tbe input multiplex, and this is called for every batch that needs to be dumped
Args:
tbe_input_info: tbe input info that contains all the necessary info for further processing
"""
pass


@dataclass(frozen=True)
class TBEInputMultiplexerConfig:
"""
Configuration for TBEInputMultiplexer
"""

# first batch to start run, -1 means no run
start_batch: int = -1
# total batch to multiplex
total_batch: int = 0

def create_tbe_input_multiplexer(
self,
tbe_info: TBEInfo,
) -> Optional[TBEInputMultiplexer]:
assert (
self.start_batch == -1
), "Cannot specify monitor_start_batch without an actual implementation."
return None

0 comments on commit 5a344c1

Please sign in to comment.