Skip to content

Commit

Permalink
perf: optimize TileShape configuration for bf16 and mixed
Browse files Browse the repository at this point in the history
- Change TileShape from 128x128x128 to 128x256x64
- Optimize bf16 and mixed format kernels
- Add cooperative kernel by default for mixed kernels
  • Loading branch information
MatrixAssembler committed Jan 20, 2025
1 parent 2d025dc commit dd31dda
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,20 @@ __global__ void set_dynamic_kernel_args_kernel(
GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
// Pass dummy configs to get Stride structure
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA* stride_input_A_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA*>(stride_buf);
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB* stride_input_B_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB*>(stride_buf + stride_size);
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput* stride_output_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput*>(stride_buf + (stride_size * 2));

output_args_ptr[group_index] =
Expand All @@ -167,15 +167,15 @@ __global__ void set_dynamic_kernel_args_kernel(
zero_start_index_M[group_index], N, K);
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{},
{zero_start_index_M[group_index], K, 1});
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{},
{N, K, 1});
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{},
{zero_start_index_M[group_index], N, 1});
}
}
Expand Down Expand Up @@ -212,20 +212,20 @@ __global__ void set_static_kernel_args_kernel(
GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape*>(
problem_shape_buf);
// Pass dummy configs to get Stride structure
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA* stride_input_A_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputA*>(stride_buf);
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB* stride_input_B_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideInputB*>(stride_buf + stride_size);
GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput* stride_output_ptr = reinterpret_cast<
GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::
StrideOutput*>(stride_buf + (stride_size * 2));

output_args_ptr[group_index] = reinterpret_cast<int64_t>(output_data);
Expand All @@ -237,15 +237,15 @@ __global__ void set_static_kernel_args_kernel(
GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape(M, N, K);
stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{},
{M, K, 1});
stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{},
{N, K, 1});
stride_output_ptr[group_index] = cutlass::make_cute_packed_stride(
typename GroupedGemmBF16Args::
GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{},
GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{},
{M, N, 1});
}
}
Expand Down Expand Up @@ -470,10 +470,10 @@ std::vector<at::Tensor> dispatch_bf16_grouped_kernel(
return bf16bf16bf16_grouped_impl<64, 128, 128, 2, 1, 1, true>(
x_group, w_group, output_tensor, zero_start_index_M);
} else if (kernel == KernelMode::Large) {
return bf16bf16bf16_grouped_impl<128, 128, 128, 2, 1, 1, true>(
return bf16bf16bf16_grouped_impl<128, 256, 64, 2, 1, 1, false>(
x_group, w_group, output_tensor, zero_start_index_M);
} else {
return bf16bf16bf16_grouped_impl<128, 128, 128, 1, 2, 1, true>(
return bf16bf16bf16_grouped_impl<128, 256, 64, 2, 1, 1, false>(
x_group, w_group, output_tensor, zero_start_index_M);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,18 @@ at::Tensor bf16i4bf16_rowwise_impl(
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using CooperativeSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
Expand Down Expand Up @@ -231,18 +236,18 @@ at::Tensor dispatch_bf16i4bf16_rowwise_kernel(
} else if (kernel == KernelMode::Large) {
return bf16i4bf16_rowwise_impl<
128,
128,
128,
256,
64,
2,
1,
1,
true,
false,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
} else {
return bf16i4bf16_rowwise_impl<
128,
128,
128,
256,
64,
2,
1,
1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,18 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using CooperativeSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;

using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
Expand Down Expand Up @@ -235,17 +240,17 @@ at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel(
} else if (kernel == KernelMode::Large) {
return bf16i4bf16_rowwise_batched_impl<
128,
128,
256,
64,
2,
1,
1,
true,
false,
WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp);
} else {
return bf16i4bf16_rowwise_batched_impl<
128,
128,
256,
64,
2,
1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,18 @@ at::Tensor f8i4bf16_rowwise_impl(
cute::Int<TBS_K>>; // Shape of the
// threadblocks in a
// cluster
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput;
using CooperativeSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
using PongSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using CooperativeEpilogueSchedule =
cutlass::epilogue::TmaWarpSpecializedCooperative;
using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using MainLoopSchedule =
cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;

// Implement rowwise scaling epilogue for x
using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
Expand Down Expand Up @@ -254,19 +259,19 @@ at::Tensor dispatch_f8i4bf16_rowwise_kernel(
} else if (kernel == KernelMode::Large) {
return f8i4bf16_rowwise_impl<
128,
128,
128,
256,
64,
2,
1,
1,
true,
false,
InputDType,
WEIGHT_SCALE_DTYPE>(XQ, WQ, x_scale, w_scale, w_zp);
} else {
return f8i4bf16_rowwise_impl<
128,
128,
128,
256,
64,
2,
1,
1,
Expand Down

0 comments on commit dd31dda

Please sign in to comment.