Skip to content

Commit

Permalink
Back out "Enable fast FP8 GEMM for memory bound"
Browse files Browse the repository at this point in the history
Summary:
Original commit changeset: fbf34e283e94

Original Phabricator Diff: D68193920

Differential Revision: D68351266
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Jan 17, 2025
1 parent 379db5f commit d5d375b
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 333 deletions.
34 changes: 0 additions & 34 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,40 +716,6 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class FP8LiteGemm(QuantizeOpBase):
"""
FP8 lite matmul for memory bound.
"""

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
return torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale * w_scale)

@property
def name(self) -> str:
return "cuda_lite"

@property
def hip(self) -> bool:
# Need to add support for better quantize kernel.
# Also may have an issue with cuda graphs.
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class TritonFP8RowwiseGemm(QuantizeOpBase):
"""
Expand Down

This file was deleted.

12 changes: 0 additions & 12 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ at::Tensor f8f8bf16_tensorwise(
at::Tensor WQ,
double scale,
bool use_fast_accum = true);
at::Tensor f8f8bf16_lite(at::Tensor XQ, at::Tensor WQ, at::Tensor scale);
std::vector<at::Tensor> f8f8bf16_grouped(
at::TensorList XQ,
at::TensorList WQ,
Expand Down Expand Up @@ -188,7 +187,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor");
m.def(
"f8f8bf16_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] scale, Tensor? zero_start_index_M=None, bool use_fast_accum=True) -> Tensor[]");
m.def("f8f8bf16_lite(Tensor XQ, Tensor WQ, Tensor scale) -> Tensor");
m.def(
"bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor");
m.def(
Expand Down Expand Up @@ -270,7 +268,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
m.impl("f8f8bf16_grouped", f8f8bf16_grouped);
m.impl("f8f8bf16_lite", f8f8bf16_lite);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
Expand Down Expand Up @@ -298,7 +295,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
m.impl("f8f8bf16", f8f8bf16);
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
m.impl("f8f8bf16_grouped", f8f8bf16_grouped);
m.impl("f8f8bf16_lite", f8f8bf16_lite);
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
Expand Down Expand Up @@ -419,13 +415,6 @@ at::Tensor f8f8bf16_tensorwise_meta(
return Y;
}

at::Tensor f8f8bf16_lite_meta(at::Tensor X, at::Tensor W, at::Tensor scale) {
const at::SymInt M = X.sym_size(0);
const at::SymInt N = W.sym_size(0);
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
return Y;
}

at::Tensor f8i4bf16_rowwise_meta(
at::Tensor XQ, // FP8
at::Tensor WQ, // INT4
Expand Down Expand Up @@ -544,7 +533,6 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta);
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta);
m.impl("f8f8bf16_grouped", f8f8bf16_grouped_meta);
m.impl("f8f8bf16_lite", f8f8bf16_lite_meta);
#endif
}

Expand Down
24 changes: 0 additions & 24 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,30 +1110,6 @@ def test_quantize_zero_input(self, K) -> None:
torch.testing.assert_close(w.shape, wq.shape)
torch.testing.assert_close(w_scale.shape, w_scale_ref.shape)

@unittest.skipIf(torch.version.hip, "Skip on AMD: fp8 lite op is yet suported.")
@settings(deadline=None)
@given(
M=st.sampled_from([1, 5, 16]),
N=st.sampled_from([1024, 6144]),
K=st.sampled_from([512, 3584]),
CudaGraph=st.sampled_from([True, False]),
)
def test_fp8_lite_matmul(self, M: int, N: int, K: int, CudaGraph: bool) -> None:
x = torch.randn(size=(M, K), dtype=torch.bfloat16, device="cuda") * 0.1
w = torch.randn(size=(N, K), dtype=torch.bfloat16, device="cuda") * 0.01
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
if CudaGraph:
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
g.replay()
else:
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
zq_ref = (x @ w.T).to(torch.bfloat16)
torch.testing.assert_close(zq, zq_ref, atol=9.0e-2, rtol=9.0e-2)


if __name__ == "__main__":
unittest.main()

0 comments on commit d5d375b

Please sign in to comment.