-
Notifications
You must be signed in to change notification settings - Fork 528
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
amd fp8 rowwise gemm prefill shape tuning (#3607)
Summary: X-link: facebookresearch/FBGEMM#685 Pull Request resolved: #3607 This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases. Consider input [M, K] and weight [N, K]. For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M. For each combination of N and K, there is offline tuning for many M, looking like: ``` 5280, 8192, 3584, 0.318272, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5312, 8192, 3584, 0.322179, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5344, 8192, 3584, 0.320632, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5440, 8192, 3584, 0.341432, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5472, 8192, 3584, 0.3436, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5536, 8192, 3584, 0.341703, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5568, 8192, 3584, 0.342054, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5632, 8192, 3584, 0.347904, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5664, 8192, 3584, 0.345129, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 ``` A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic. The full tuning log is parsed and converted into a std::map for range based lookup. One key question here is which instance to use right at the range where the best instance has changed. For example: ``` 5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 ``` Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407? The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values. Reviewed By: jwfromm Differential Revision: D68521662 fbshipit-source-id: e59a8634678a77e4e4d5c2110dbe5d92febc3ad8
- Loading branch information
1 parent
5f3adca
commit 74490d6
Showing
32 changed files
with
1,279 additions
and
473 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
...els/fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "fp8_rowwise_common.h" | ||
|
||
at::Tensor | ||
fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1( | ||
at::Tensor XQ, | ||
at::Tensor WQ, | ||
at::Tensor x_scale, | ||
at::Tensor w_scale, | ||
at::Tensor Y) { | ||
using DeviceGemmInstance = DeviceGemmHelper< | ||
128, | ||
16, | ||
32, | ||
256, | ||
16, | ||
16, | ||
1, | ||
1, | ||
S<16, 8, 1>, | ||
S<16, 8, 1>, | ||
S<1, 16, 1, 8>, | ||
S<4, 4, 1>, | ||
1, | ||
1, | ||
ck::BlockGemmPipelineScheduler::Intrawave, | ||
ck::BlockGemmPipelineVersion::v1, | ||
ck::tensor_operation::device::GemmSpecialization::Default>; | ||
// Run kernel instance. | ||
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y); | ||
} | ||
|
39 changes: 39 additions & 0 deletions
39
...els/fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "fp8_rowwise_common.h" | ||
|
||
at::Tensor | ||
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2( | ||
at::Tensor XQ, | ||
at::Tensor WQ, | ||
at::Tensor x_scale, | ||
at::Tensor w_scale, | ||
at::Tensor Y) { | ||
using DeviceGemmInstance = DeviceGemmHelper< | ||
128, | ||
16, | ||
32, | ||
512, | ||
16, | ||
16, | ||
1, | ||
1, | ||
S<32, 4, 1>, | ||
S<32, 4, 1>, | ||
S<1, 16, 1, 8>, | ||
S<4, 4, 1>, | ||
1, | ||
1, | ||
ck::BlockGemmPipelineScheduler::Interwave, | ||
ck::BlockGemmPipelineVersion::v2, | ||
ck::tensor_operation::device::GemmSpecialization::Default>; | ||
// Run kernel instance. | ||
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.