forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
: basic tbe input dump framework (pytorch#3593)
Summary: Plugin capability to dump TBE input and no-ops in OSS Reviewed By: damianr99 Differential Revision: D68446857
- Loading branch information
1 parent
b858408
commit a12f4ed
Showing
2 changed files
with
94 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import abc | ||
|
||
from dataclasses import dataclass | ||
from typing import List, Optional | ||
|
||
from torch import Tensor | ||
|
||
|
||
class TBEInputDump(abc.ABC): | ||
""" | ||
Interface for dump TBE input data out, actual implementation may store the data to files | ||
""" | ||
|
||
@abc.abstractmethod | ||
def should_dump(self, step: int) -> bool: | ||
""" | ||
To check if the dump should be triggered at this step | ||
Args: | ||
step: the current step | ||
Returns: | ||
True if the dump should be triggered, otherwise False | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def run( | ||
self, | ||
indices: Tensor, | ||
offsets: Tensor, | ||
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, | ||
) -> None: | ||
""" | ||
To run the tbe input dump, and this is called for every batch that needs to be dumped | ||
Args: | ||
indices: A 1D-tensor that contains indices to be looked up | ||
from all embedding table. | ||
offsets: A 1D-tensor that conatins offsets of indices. | ||
batch_size_per_feature_per_rank: An optional 2D-tensor that contains batch sizes for every rank and | ||
every feature. this is needed to support VBE. | ||
""" | ||
pass | ||
|
||
|
||
@dataclass(frozen=True) | ||
class TBEInputDumpConfig: | ||
""" | ||
Configuration for TBEInputDump | ||
""" | ||
|
||
# first batch to start dump, -1 means no dump | ||
monitored_batch_start: int = -1 | ||
# total batch to dump | ||
monitored_total_batch: int = 0 | ||
|
||
def create_tbe_input_dump( | ||
self, | ||
table_names: List[str], | ||
table_heights: List[int], | ||
tbe_uuid: str, | ||
feature_table_map: List[int], | ||
) -> Optional[TBEInputDump]: | ||
assert ( | ||
self.monitored_batch_start == -1 | ||
), "Cannot specify monitored_batch_start without an actual implementation of tbe dump" | ||
return None |