Skip to content

Commit

Permalink
Enable fast FP8 GEMM for memory bound (resubmit) (#3608)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3608

X-link: facebookresearch/FBGEMM#686

This Diff (resubmit from D68193920) enables fast FP8 gemm for memory bound with adding TRT-LLM FP8 cuda gemm to fbgemm. In addition to the original kernel, this Diff extends the kernel to:

- Support pytorch operations
- Support cuda graph with handling scale as tensor
- Support smaller dim M for much faster compilation time
- Support benchmark/unittest

For decode attn linear shapes:
- When BS=1, TRT-LLM FP8 gemm brings 2x speedup compared to BF16, while FP8 cutlass gemm’s perf is similar to BF16
- When BS>4, TRT-LLM FP8 gemm does not bring perf gain
- This TRT-LLM kernel is based on tensorwise quantization not rowwise.

Note: As M>4 does not bring perf gain in our use cases, we only instantiate 4 template instances to reduce compilation time (10 mins -> 2.5 mins). If we would like to increase instances for larger M in the future, we could tradeoff acceptable compilation time or dedicate cuda file to each instance with compile in parallel

Reviewed By: q10, jwfromm

Differential Revision: D68568596

fbshipit-source-id: ba8b565a564533717deb29f9d701550d99a8c759
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Jan 23, 2025
1 parent 31d41dc commit 5754ce7
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 0 deletions.
34 changes: 34 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,40 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class FP8LiteGemm(QuantizeOpBase):
"""
FP8 lite matmul for memory bound.
"""

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
return torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale * w_scale)

@property
def name(self) -> str:
return "cuda_lite"

@property
def hip(self) -> bool:
# Need to add support for better quantize kernel.
# Also may have an issue with cuda graphs.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class TritonFP8RowwiseGemm(QuantizeOpBase):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cutlass/numeric_conversion.h>
#include <cub/cub.cuh>

namespace fbgemm_gpu {

#if CUDART_VERSION >= 12000

using SizeType32 = std::size_t;

struct Params {
void const* act;
void const* weight;
void const* alpha;
void* output;
SizeType32 m, n, k;

Params(
void const* _act,
void const* _weight,
void const* _alpha,
void* _output,
SizeType32 _m,
SizeType32 _n,
SizeType32 _k)
: act(_act),
weight(_weight),
alpha(_alpha),
output(_output),
m(_m),
n(_n),
k(_k) {}
};

template <
typename InputType,
typename OutputType,
SizeType32 TILE_M,
SizeType32 TILE_N,
SizeType32 BLOCK_SIZE>
__global__ void cudaCoreGemm(
InputType const* __restrict__ act,
InputType const* __restrict__ weight,
float const* alpha,
OutputType* __restrict__ output,
SizeType32 m,
SizeType32 n,
SizeType32 k) {
using VecType = int4;
static constexpr SizeType32 kStepK =
static_cast<SizeType32>(128 / (8 * sizeof(InputType)));
static constexpr SizeType32 kTileK = kStepK * BLOCK_SIZE;
auto tileIdM = static_cast<SizeType32>(blockIdx.x * TILE_M);
auto tileIdN = static_cast<SizeType32>(blockIdx.y * TILE_N);
auto tid = static_cast<SizeType32>(threadIdx.x);
float tile_a[kStepK], tile_w[TILE_N * kStepK];
float acc[TILE_M * TILE_N];

static_assert(kStepK % 4 == 0);
using CvtInputType = cutlass::float_e4m3_t;
using Converter = cutlass::NumericArrayConverter<float, CvtInputType, 4>;
using CvtSrcType = typename Converter::source_type;
using CvtResType = typename Converter::result_type;
static constexpr SizeType32 kCvtCount =
static_cast<SizeType32>(sizeof(VecType) / sizeof(CvtSrcType));

#pragma unroll
for (SizeType32 i = 0; i < TILE_M * TILE_N; ++i) {
acc[i] = 0;
}
act += tileIdM * k;
weight += tileIdN * k;
output += tileIdM * n + tileIdN;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif

for (SizeType32 idxK = tid * kStepK; idxK < k; idxK += kTileK) {
for (SizeType32 i = 0; i < TILE_N; ++i) {
auto tile_w_quantized =
reinterpret_cast<VecType const*>(weight + i * k + idxK)[0];
#pragma unroll
for (SizeType32 cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) {
reinterpret_cast<CvtResType*>(tile_w)[i * kCvtCount + cvtIdx] =
Converter::convert(
reinterpret_cast<CvtSrcType*>(&tile_w_quantized)[cvtIdx]);
}
}
#pragma unroll
for (SizeType32 i = 0; i < TILE_M; ++i) {
auto tile_a_quantized =
reinterpret_cast<VecType const*>(act + i * k + idxK)[0];
#pragma unroll
for (SizeType32 cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) {
reinterpret_cast<CvtResType*>(tile_a)[cvtIdx] = Converter::convert(
reinterpret_cast<CvtSrcType*>(&tile_a_quantized)[cvtIdx]);
}
#pragma unroll
for (SizeType32 j = 0; j < TILE_N; ++j) {
#pragma unroll
for (SizeType32 l = 0; l < kStepK; ++l) {
acc[i * TILE_N + j] =
fma(tile_a[l], tile_w[j * kStepK + l], acc[i * TILE_N + j]);
}
}
}
}

typedef cub::WarpReduce<float> WarpReduce;

static constexpr SizeType32 kWarpSize = 32;
static constexpr SizeType32 kWarpNum = BLOCK_SIZE / kWarpSize;
SizeType32 warpId = tid / kWarpSize, laneId = tid % kWarpSize;
__shared__ float shmem[TILE_M * TILE_N * kWarpNum];
__shared__ typename WarpReduce::TempStorage tempStorage[kWarpNum];
#pragma unroll
for (SizeType32 mi = 0; mi < TILE_M; ++mi) {
#pragma unroll
for (SizeType32 ni = 0; ni < TILE_N; ++ni) {
float val = WarpReduce(tempStorage[warpId]).Sum(acc[mi * TILE_N + ni]);
if (laneId == 0) {
shmem[mi * TILE_N + ni + warpId * TILE_M * TILE_N] = val;
}
}
}
__syncthreads();
for (SizeType32 ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) {
SizeType32 mid = ii / TILE_N, nid = ii % TILE_N;
float val = 0;
#pragma unroll
for (SizeType32 jj = 0; jj < kWarpNum; ++jj) {
val += shmem[jj * TILE_M * TILE_N + ii];
}
output[mid * n + nid] = static_cast<OutputType>(val * *alpha);
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

template <
typename InputType,
typename OutputType,
SizeType32 TILE_M,
SizeType32 TILE_N,
SizeType32 BLOCK_SIZE>
void cudaCoreGemmKernel(Params const& params, cudaStream_t stream) {
dim3 block(BLOCK_SIZE);
dim3 grid(params.m / TILE_M, params.n / TILE_N);

cudaCoreGemm<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>
<<<grid, block, 0, stream>>>(
reinterpret_cast<InputType const*>(params.act),
reinterpret_cast<InputType const*>(params.weight),
reinterpret_cast<float const*>(params.alpha),
reinterpret_cast<OutputType*>(params.output),
params.m,
params.n,
params.k);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <
typename InputType,
typename OutputType,
int TILE_M,
int TILE_N,
int BLOCK_SIZE>
bool cudaCoreGemmTemplateCaller(Params const& params, cudaStream_t stream) {
constexpr int cudaCoreGemmTemplateMaxM = 4;
if (params.m == TILE_M) {
cudaCoreGemmKernel<InputType, OutputType, TILE_M, TILE_N, BLOCK_SIZE>(
params, stream);
return true;
}
if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) {
return cudaCoreGemmTemplateCaller<
InputType,
OutputType,
TILE_M + 1,
TILE_N,
BLOCK_SIZE>(params, stream);
}
return false;
}

template <typename InputType, typename OutputType>
bool cudaCoreGemmLauncher(Params const& params, cudaStream_t stream) {
return cudaCoreGemmTemplateCaller<InputType, OutputType, 1, 2, 128>(
params, stream);
}

at::Tensor f8f8bf16_lite(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale) {
bool dispatched = true;
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
TORCH_CHECK(XQ.size(-1) == K);

if (M > 4) {
throw std::runtime_error("f8f8bf16_lite cannot run when M > 4");
} else if (N % 2 != 0) {
throw std::runtime_error("f8f8bf16_lite cannot run when N % 2 != 0");
} else if (K % 16 != 0) {
throw std::runtime_error("f8f8bf16_lite cannot run when K % 16 != 0");
}

auto out_sizes = XQ.sizes().vec();
out_sizes.back() = N;
at::Tensor Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16));

Params params{
XQ.data_ptr(),
WQ.data_ptr(),
scale.data_ptr(),
Y.data_ptr(),
(SizeType32)M,
(SizeType32)N,
(SizeType32)K};
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dispatched = cudaCoreGemmLauncher<cutlass::float_e4m3_t, __nv_bfloat16>(
params, stream);
if (!dispatched) {
throw std::runtime_error("f8f8bf16_lite cannot run");
}
return Y;
}

#else

at::Tensor f8f8bf16_lite(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
at::Tensor scale) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}

#endif

} // namespace fbgemm_gpu
12 changes: 12 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ at::Tensor f8f8bf16_tensorwise(
at::Tensor WQ,
double scale,
bool use_fast_accum = true);
at::Tensor f8f8bf16_lite(at::Tensor XQ, at::Tensor WQ, at::Tensor scale);
std::vector<at::Tensor> f8f8bf16_grouped(
at::TensorList XQ,
at::TensorList WQ,
Expand Down Expand Up @@ -187,6 +188,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor");
m.def(
"f8f8bf16_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] scale, Tensor? zero_start_index_M=None, bool use_fast_accum=True) -> Tensor[]");
m.def("f8f8bf16_lite(Tensor XQ, Tensor WQ, Tensor scale) -> Tensor");
m.def(
"bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor");
m.def(
Expand Down Expand Up @@ -268,6 +270,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
m.impl("f8f8bf16_grouped", f8f8bf16_grouped);
m.impl("f8f8bf16_lite", f8f8bf16_lite);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
Expand Down Expand Up @@ -295,6 +298,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
m.impl("f8f8bf16_grouped", f8f8bf16_grouped);
m.impl("f8f8bf16_lite", f8f8bf16_lite);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
Expand Down Expand Up @@ -415,6 +419,13 @@ at::Tensor f8f8bf16_tensorwise_meta(
return Y;
}

at::Tensor f8f8bf16_lite_meta(at::Tensor X, at::Tensor W, at::Tensor scale) {
const at::SymInt M = X.sym_size(0);
const at::SymInt N = W.sym_size(0);
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
return Y;
}

at::Tensor f8i4bf16_rowwise_meta(
at::Tensor XQ, // FP8
at::Tensor WQ, // INT4
Expand Down Expand Up @@ -533,6 +544,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta);
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta);
m.impl("f8f8bf16_grouped", f8f8bf16_grouped_meta);
m.impl("f8f8bf16_lite", f8f8bf16_lite_meta);
#endif
}

Expand Down
24 changes: 24 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,30 @@ def test_quantize_zero_input(self, K) -> None:
torch.testing.assert_close(w.shape, wq.shape)
torch.testing.assert_close(w_scale.shape, w_scale_ref.shape)

@unittest.skipIf(torch.version.hip, "Skip on AMD: fp8 lite op is yet suported.")
@settings(deadline=None)
@given(
M=st.sampled_from([1, 4]),
N=st.sampled_from([1024, 6144]),
K=st.sampled_from([512, 3584]),
CudaGraph=st.sampled_from([True, False]),
)
def test_fp8_lite_matmul(self, M: int, N: int, K: int, CudaGraph: bool) -> None:
x = torch.randn(size=(M, K), dtype=torch.bfloat16, device="cuda") * 0.1
w = torch.randn(size=(N, K), dtype=torch.bfloat16, device="cuda") * 0.01
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
if CudaGraph:
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
g.replay()
else:
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
zq_ref = (x @ w.T).to(torch.bfloat16)
torch.testing.assert_close(zq, zq_ref, atol=9.0e-2, rtol=9.0e-2)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5754ce7

Please sign in to comment.