Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

: basic tbe input dump framework #3593

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
generate_vbe_metadata,
is_torchdynamo_compiling,
)
from fbgemm_gpu.tbe_input_multiplexer import (
TBEInfo,
TBEInputInfo,
TBEInputMultiplexer,
TBEInputMultiplexerConfig,
)

from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc

Expand Down Expand Up @@ -647,6 +653,7 @@ def __init__( # noqa C901
global_weight_decay: Optional[GlobalWeightDecayDefinition] = None,
uvm_host_mapped: bool = False,
extra_optimizer_config: Optional[UserEnabledConfigDefinition] = None,
tbe_input_multiplexer_config: Optional[TBEInputMultiplexerConfig] = None,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand Down Expand Up @@ -820,6 +827,23 @@ def __init__( # noqa C901
self.feature_table_map: List[int] = (
feature_table_map if feature_table_map is not None else list(range(T_))
)

self.tbe_input_multiplexer: Optional[TBEInputMultiplexer] = (
tbe_input_multiplexer_config.create_tbe_input_multiplexer(
tbe_info=TBEInfo(
table_names=(
table_names
if table_names
else [f"table-{i}" for i in range(len(embedding_specs))]
),
table_heights=rows,
tbe_uuid=self.uuid,
feature_table_map=self.feature_table_map,
)
)
if tbe_input_multiplexer_config is not None
else None
)
T = len(self.feature_table_map)
assert T_ <= T
table_has_feature = [False] * T_
Expand Down Expand Up @@ -1789,6 +1813,15 @@ def forward( # noqa: C901
self._report_io_size_count("fwd_input", indices)
self._report_tbe_mem_usage()

if self.tbe_input_multiplexer is not None:
tbe_input_multiplexer: TBEInputMultiplexer = self.tbe_input_multiplexer
if tbe_input_multiplexer.should_run(self.step):
tbe_input_multiplexer.run(
tbe_input_info=TBEInputInfo(
indices, offsets, batch_size_per_feature_per_rank
)
)

if len(self.timesteps_prefetched) == 0:
# In forward, we don't enable multi-pass prefetch as we want the process
# to be as fast as possible and memory usage doesn't matter (will be recycled
Expand Down
96 changes: 96 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe_input_multiplexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import abc

from dataclasses import dataclass
from typing import List, Optional

from torch import Tensor


@dataclass(frozen=True)
class TBEInfo:
"""
contains selective TBE info used for multiplexing. For more info, check https://fburl.com/code/ljnd6j65
Args:
table_names: table names within the tbe
table_heights: table heights (hashsize)
tbe_uuid: a unique identifier for the TBE
feature_table_map: feature to table map
"""

table_names: List[str]
table_heights: List[int]
tbe_uuid: str
feature_table_map: List[int]


@dataclass(frozen=True)
class TBEInputInfo:
"""
indices: A 1D-tensor that contains indices to be looked up
from all embedding table.
offsets: A 1D-tensor that conatins offsets of indices.
batch_size_per_feature_per_rank: An optional 2D-tensor that contains batch sizes for every rank and
every feature. this is needed to support VBE.
"""

indices: Tensor
offsets: Tensor
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None


class TBEInputMultiplexer(abc.ABC):
"""
Interface for multiplex TBE input data out, actual implementation may store the data to files
"""

@abc.abstractmethod
def should_run(self, step: int) -> bool:
"""
To check if should run at this step
Args:
step: the current step
Returns:
True if should run, otherwise False
"""
pass

@abc.abstractmethod
def run(
self,
tbe_input_info: TBEInputInfo,
) -> None:
"""
To run the tbe input multiplex, and this is called for every batch that needs to be dumped
Args:
tbe_input_info: tbe input info that contains all the necessary info for further processing
"""
pass


@dataclass(frozen=True)
class TBEInputMultiplexerConfig:
"""
Configuration for TBEInputMultiplexer
"""

# first batch to start run, -1 means no run
start_batch: int = -1
# total batch to multiplex
total_batch: int = 0

def create_tbe_input_multiplexer(
self,
tbe_info: TBEInfo,
) -> Optional[TBEInputMultiplexer]:
assert (
self.start_batch == -1
), "Cannot specify monitor_start_batch without an actual implementation."
return None