Skip to content

Commit

Permalink
amd fp8 rowwise gemm prefill shape tuning (#3607)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#685

Pull Request resolved: #3607

This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases.  Consider input [M, K] and weight [N, K].

For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M.

For each combination of N and K, there is offline tuning for many M, looking like:

```
5280, 8192, 3584, 0.318272, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5312, 8192, 3584, 0.322179, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5344, 8192, 3584, 0.320632, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5440, 8192, 3584, 0.341432, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5472, 8192, 3584, 0.3436, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5536, 8192, 3584, 0.341703, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5568, 8192, 3584, 0.342054, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5632, 8192, 3584, 0.347904, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5664, 8192, 3584, 0.345129, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
```

A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic.

The full tuning log is parsed and converted into a std::map for range based lookup.

One key question here is which instance to use right at the range where the best instance has changed. For example:

```
5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
```

Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407?  The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values.

Reviewed By: jwfromm

Differential Revision: D68521662

fbshipit-source-id: e59a8634678a77e4e4d5c2110dbe5d92febc3ad8
  • Loading branch information
mxz297 authored and facebook-github-bot committed Jan 24, 2025
1 parent 5f3adca commit 74490d6
Show file tree
Hide file tree
Showing 32 changed files with 1,279 additions and 473 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@ namespace fbgemm_gpu {
using RowwiseKernel = std::function<
at::Tensor(at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor)>;

using NKLookupTableType = std::map<int, RowwiseKernel>;

// Define a custom hash function for std::tuple<int, int, int>
struct IntTupleHash {
size_t operator()(const std::tuple<int, int>& t) const {
auto hash1 = std::hash<int>{}(std::get<0>(t));
auto hash2 = std::hash<int>{}(std::get<1>(t));
return hash1 ^ hash2;
}
size_t operator()(const std::tuple<int, int, int>& t) const {
auto hash1 = std::hash<int>{}(std::get<0>(t));
auto hash2 = std::hash<int>{}(std::get<1>(t));
Expand All @@ -38,24 +45,6 @@ struct IntTupleHash {
// For certain high priority shapes, we directly map to the best kernel rather
// than use heuristics.
static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTupleHash> rowwise_lookup_dispatch = {
// Support for decode for [1024, 5120]
{{16, 1024, 5120},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
{{32, 1024, 5120},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
{{64, 1024, 5120},
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
{{128, 1024, 5120},
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
// Support for decode for [5120, 1024]
{{16, 5120, 1024},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{32, 5120, 1024},
fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2},
{{64, 5120, 1024},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{128, 5120, 1024},
fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2},
// LLama 70B Decode shapes.
// Support for decode across batch sizes for [1280, 8192]
{{16, 1280, 8192},
Expand All @@ -75,40 +64,6 @@ static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTup
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{128, 8192, 1024},
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
// Support for decode across batch sizes for [7168, 8192]
{{16, 7168, 8192},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
{{32, 7168, 8192},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{64, 7168, 8192},
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
{{128, 7168, 8192},
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{1024, 7168, 8192},
fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{2048, 7168, 8192},
fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{4096, 7168, 8192},
fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{8192, 7168, 8192},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
// Support for decode across batch sizes for [8192, 3584]
{{16, 8192, 3584},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
{{32, 8192, 3584},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{{64, 8192, 3584},
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
{{128, 8192, 3584},
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{1024, 8192, 3584},
fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{2048, 8192, 3584},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{4096, 8192, 3584},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{8192, 8192, 3584},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
// Llama 405B Decode Shapes.
// Support for decode across batch sizes for [13312, 6656].
{{16, 13312, 6656},
Expand Down Expand Up @@ -218,6 +173,119 @@ static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTup
{{32768, 1024, 8192},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}};

static const std::map<int, RowwiseKernel> N_7168_K_8192_dispatch_table = {
{ 8, fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_interwave_v2},
{ 32, fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2},
{ 64, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 128, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 320, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 512, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 576, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 640, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 768, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 1024, fp8_rowwise_256x256x96x128_16x16_8x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 1280, fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 1536, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 2048, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 2304, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 2560, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 3328, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 4096, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 4672, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 4864, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 5376, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 6144, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 7168, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 8192, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
};

static const std::map<int, RowwiseKernel> N_8192_K_3584_dispatch_table = {
{ 8, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2},
{ 16, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
{ 32, fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
{ 64, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 128, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 192, fp8_rowwise_256x64x96x256_16x16_2x3_16x16x1_16x16x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 256, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 384, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 512, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 640, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 896, fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 1024, fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 1280, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 1792, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 2048, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 2304, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 2368, fp8_rowwise_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 2816, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 3584, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 4256, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 4864, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 5376, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 6272, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 7168, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 7424, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 8192, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
};

static const std::map<int, RowwiseKernel> N_1024_K_5120_dispatch_table = {
{ 32, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
{ 64, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2_2},
{ 128, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
{ 192, fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
{ 608, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 1216, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 2432, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 3456, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 4864, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 5472, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 6368, fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 6912, fp8_rowwise_256x256x96x128_16x16_8x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 8192, fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
};

static const std::map<int, RowwiseKernel> N_5120_K_1024_dispatch_table = {
{ 16, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
{ 32, fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1},
{ 64, fp8_rowwise_128x32x16x256_16x16_1x1_16x8x1_16x8x1_1x32x1x4_4x4x1_1x1_interwave_v1},
{ 96, fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
{ 192, fp8_rowwise_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 256, fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 320, fp8_rowwise_256x64x96x256_16x16_2x3_16x16x1_16x16x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 448, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 640, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 896, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 1152, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 1408, fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 1920, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 2304, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 2816, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 3360, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 3840, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 4864, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 5632, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 6720, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 7680, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 8192, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
};

static const std::unordered_map<std::tuple<int, int>, NKLookupTableType, IntTupleHash> NK_lookup_table = {
{{7168, 8192}, N_7168_K_8192_dispatch_table},
{{8192, 3584}, N_8192_K_3584_dispatch_table},
{{1024, 5120}, N_1024_K_5120_dispatch_table},
{{5120, 1024}, N_5120_K_1024_dispatch_table}
};

RowwiseKernel rowwise_nk_lookup(int M, const NKLookupTableType& table) {
auto it = table.lower_bound(M);
if (it != table.end()) {
return it->second;
} else {
--it;
return it->second;
}
}

RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) {
// Apply shape heuristics to find a suitable kernel implementation.

Expand Down Expand Up @@ -281,6 +349,11 @@ RowwiseKernel rowwise_dispatch(int M, int N, int K) {
// If we found an optimal kernel, use it.
if (it != rowwise_lookup_dispatch.end()) {
return it->second;
} else {
auto nk_lookup_it = NK_lookup_table.find({N,K});
if (nk_lookup_it != NK_lookup_table.end()){
return rowwise_nk_lookup(M, nk_lookup_it->second);
}
}
// Otherwise, use heuristics.
return rowwise_heuristic_dispatch(M, N, K);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,25 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// The smallest kernel we have available. Works well for memory bound shapes.

// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 16 != 0) || (N % 32 != 0) || (K % 128 != 0);

if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
128,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
128,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);

}
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
128,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
256,
16,
16,
1,
1,
S<16, 8, 1>,
S<16, 8, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
512,
16,
16,
1,
1,
S<32, 4, 1>,
S<32, 4, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Loading

0 comments on commit 74490d6

Please sign in to comment.