Skip to content

Commit

Permalink
fix llm shapes in quantize bench and add ldm shapes
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#689

As title

Differential Revision: D68594150
  • Loading branch information
mxz297 authored and facebook-github-bot committed Jan 24, 2025
1 parent 5754ce7 commit 559e94f
Showing 1 changed file with 49 additions and 9 deletions.
58 changes: 49 additions & 9 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,61 @@ def set_amd_env_vars() -> None:
os.environ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"] = "30"


def get_llama_shapes() -> List[Tuple[int, int, int]]:
def get_llama_shapes() -> List[Tuple[int, int, int, int]]:
# Helper function that returns a list of shapes relevant to llama.

llama_shapes = []
for M in [1, 16, 32, 64, 96, 128, 16384]:
# Add shapes for llama 70B
llama_shapes += [
(M, 1280, 8192),
(M, 8192, 1024),
(M, 7168, 8192),
(M, 8192, 3584),
(1, M, 1280, 8192),
(1, M, 8192, 1024),
(1, M, 7168, 8192),
(1, M, 8192, 3584),
]
# Add shapes for llama 405B
llama_shapes += [
(M, 13312, 6656),
(M, 13312, 16384),
(M, 16384, 6656),
(M, 16384, 16384),
(1, M, 13312, 6656),
(1, M, 13312, 16384),
(1, M, 16384, 6656),
(1, M, 16384, 16384),
]

return llama_shapes


def get_ldm_shapes() -> List[Tuple[int, int, int, int]]:
# Helper function that returns a list of shapes relevant to ldm.
return [
(1, 1536, 3584, 3584),
(1, 8192, 9728, 3584),
(1, 8192, 3584, 9728),
(1, 8192, 3584, 3584),
(1, 4096, 3584, 3584),
(1, 768, 3584, 3584),
(1, 4096, 9728, 3584),
(1, 4096, 3584, 9728),
(1, 7200, 3584, 3584),
(1, 7200, 9728, 3584),
(1, 7200, 3584, 9728),
(1, 3600, 3584, 3584),
(1, 3600, 9728, 3584),
(1, 3600, 3584, 9728),
(1, 1536, 4096, 4096),
(1, 3600, 4096, 4096),
(1, 3600, 11008, 4096),
(1, 3600, 4096, 11008),
(1, 4096, 4096, 4096),
(1, 4096, 11008, 4096),
(1, 4096, 4096, 11008),
(1, 32768, 128, 8192),
(1, 32768, 8192, 1024),
(1, 32768, 8192, 3072),
(1, 32768, 3072, 8192),
(1, 32768, 1024, 8192),
]


def benchmark_grouped(
quantize_ops: List[QuantizeOpBase],
b: List[int],
Expand Down Expand Up @@ -297,6 +329,8 @@ def main(args: Any):
B = [int(b) for b in args.B.strip().split(",")]
if args.use_llama_shapes:
MNK = get_llama_shapes()
elif args.use_ldm_shapes:
MNK = get_ldm_shapes()
else:
if args.M is None:
M = [1, 4, 8, 16, 32, 64, 128, 2048, 4096, 8192, 16384]
Expand Down Expand Up @@ -419,6 +453,12 @@ def invoke_main() -> None:
action="store_true",
help="If set, benchmark using fixed shapes relevant to llama workloads.",
)
parser.add_argument(
"--use_ldm_shapes",
default=False,
action="store_true",
help="If set, benchmark using fixed shapes relevant to ldm workloads.",
)

args = parser.parse_args()
main(args)

0 comments on commit 559e94f

Please sign in to comment.