Skip to content

Commit

Permalink
Port oss f16_fast_gemv into fbcode
Browse files Browse the repository at this point in the history
Summary:
This diff content includes:
1. Port OSS FastGEMV `fp16` kernel into fbcode and expose to python as a step 1 - `torch.ops.fbgemm.f16_fast_gemv`
https://github.com/wangsiping97/FastGEMV/blob/1fdff6f74aade033c02727a419afd6a4b4bfbc3f/fast_gemv.cu#L14
2. Add `fp16_oss_fast_gemv` to quantize ops benchmark script
3. Add two simple tests for custom op`torch.ops.fbgemm.f16_fast_gemv` to test
     - `torch.compile()` able
     -  correctness

**Next step:**
Need fp8 mixed precision support for fast gemv kernel which is what we want

Differential Revision: D68470488
  • Loading branch information
YUNQIUGUO authored and facebook-github-bot committed Jan 23, 2025
1 parent 5754ce7 commit 7e680f3
Show file tree
Hide file tree
Showing 7 changed files with 560 additions and 0 deletions.
32 changes: 32 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,38 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class FP16OSSFastGemv(QuantizeOpBase):
"""
FP16 oss fast gemv.
"""

def quantize(self, x, w):
return x, w

def compute(self, x, w):
out = torch.ops.fbgemm.f16_fast_gemv(x, w)
return out

def quantize_and_compute(self, x, w):
# dummy quantize
x, w = self.quantize(x, w)
return self.compute(x, w)

@property
def name(self) -> str:
return "fp16_oss_fast_gemv"

@property
def hip(self) -> bool:
# This implementation is specific to cublas.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class FP8CublasRowwiseGemm(QuantizeOpBase):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAGuard.h>

#include "include/fast_gemv.cuh"

namespace fbgemm_gpu {

#if CUDART_VERSION >= 12000

at::Tensor f16_fast_gemv(at::Tensor X, at::Tensor W) {
// note: oss fast gemv implementation accepts vector shape as (size, 1) i.e.
// (K, M)
// X: K x M
// W: N x K
auto m = X.size(1);
auto n = W.size(0);
auto k = W.size(1);

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(W.is_cuda() && W.is_contiguous());

auto block_dim_x = k / 8;
auto block_dim_y = MAX_THREADS_PER_BLOCK / block_dim_x;
dim3 block_dim(block_dim_x, block_dim_y);
dim3 grid_dim(1, n / block_dim_y);
unsigned int num_per_thread = k / block_dim_x;

auto Y = at::empty({n, m}, X.options().dtype(at::kHalf));

gemv_fp16<<<grid_dim, block_dim>>>(
(half*)W.data_ptr(), // mat
(half*)X.data_ptr(), // vec
(half*)Y.data_ptr(), // res
k,
num_per_thread);

return Y;
}

#else

at::Tensor f16_fast_gemv(at::Tensor X, at::Tensor W) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
#endif

} // namespace fbgemm_gpu
Loading

0 comments on commit 7e680f3

Please sign in to comment.