diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe.hpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe.hpp index 47dd00a55..8a4bcac81 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe.hpp @@ -28,7 +28,7 @@ struct fused_moe_args { ck_tile::index_t block_m; // block_m, used to devide the input ck_tile::index_t hidden_size; // k ck_tile::index_t - intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2 + intermediate_size; // n / TP, for Gate. and Up, Down is also this value ck_tile::index_t num_tokens; // input number of tokens for current iteration ck_tile::index_t num_experts; // number of groups ck_tile::index_t topk; // need this? @@ -47,7 +47,8 @@ struct fused_moe_traits { std::string prec_sq; // smooth quant scale std::string prec_kw; // topk-weight data type int block_m; - int gate_only; + int activation; // 0:gelu, 1:silu + int gate_only; // 0:g1u0, 1:g1u1 int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moegemm.hpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moegemm.hpp index a5dd50da0..1ee871db4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moegemm.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moegemm.hpp @@ -99,7 +99,8 @@ struct fused_moegemm_traits { std::string prec_sq; // smooth quant scale std::string prec_kw; // topk-weight data type int block_m; - int gate_only; + int activation; // 0:gelu, 1:silu + int gate_only; // 0:g1u0, 1:g1u1 int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moe_api.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moe_api.hip index bfc0ce409..d19599295 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moe_api.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moe_api.hip @@ -3,78 +3,86 @@ #include "fused_moe.hpp" -float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s) -{ - auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1}; +float fused_moe( + fused_moe_traits t, + fused_moe_args a, + const ck_tile::stream_config& s) { + auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1}; - auto o_data_bytes = [&]() { - if(t.prec_o == "fp32") - return 4; - else if(t.prec_o == "fp16" || t.prec_o == "bf16") - return 2; - else if(t.prec_o == "int8" || t.prec_o == "fp8") - return 1; - return 1; - }(); + auto o_data_bytes = [&]() { + if (t.prec_o == "fp32") + return 4; + else if (t.prec_o == "fp16" || t.prec_o == "bf16") + return 2; + else if (t.prec_o == "int8" || t.prec_o == "fp8") + return 1; + return 1; + }(); - auto t0 = fused_moesorting_trait{"int32", "fp32"}; - auto a0 = fused_moesorting_args{ - a.topk_ids_ptr, // const void* p_topk_ids; - a.topk_weight_ptr, // const void* p_weights; - a.sorted_token_ids_ptr, // void* p_sorted_token_ids; - a.sorted_weight_ptr, // void* p_sorted_weights; - a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; - a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; - a.o_ptr, // void* p_moe_buf; - a.num_tokens, // index_t tokens; - a.block_m, // index_t unit_size; - a.num_experts, // index_t num_experts; - a.topk, // index_t topk; - a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes; - }; + auto t0 = fused_moesorting_trait{"int32", "fp32"}; + auto a0 = fused_moesorting_args{ + a.topk_ids_ptr, // const void* p_topk_ids; + a.topk_weight_ptr, // const void* p_weights; + a.sorted_token_ids_ptr, // void* p_sorted_token_ids; + a.sorted_weight_ptr, // void* p_sorted_weights; + a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; + a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; + a.o_ptr, // void* p_moe_buf; + a.num_tokens, // index_t tokens; + a.block_m, // index_t unit_size; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; + a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes; + }; - auto t1 = fused_moegemm_traits{t.prec_i, - t.prec_w, - t.prec_o, - t.prec_st, - t.prec_sw, - t.prec_sq, - t.prec_kw, - t.block_m, - t.gate_only, - t.fused_quant}; - auto a1 = fused_moegemm_args{ - a.a_ptr, // const void* a_ptr; - a.a_scale_ptr, // const void* a_scale_ptr; - a.g_ptr, // const void* g_ptr; - a.d_ptr, // const void* d_ptr; - a.g_scale_ptr, // const void* g_scale_ptr; - a.d_scale_ptr, // const void* d_scale_ptr; - a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr; - a.o_ptr, // void* o_ptr; - a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr; - a.sorted_weight_ptr, // const void* sorted_weight_ptr; - a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr; - a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr; - a.hidden_size, // index_t hidden_size; - a.intermediate_size, // index_t intermediate_size; - a.num_tokens, // index_t num_tokens; - a.num_experts, // index_t num_experts; - a.topk, // index_t topk; - a.stride_token // index_t stride_token; - }; + auto t1 = fused_moegemm_traits{ + t.prec_i, + t.prec_w, + t.prec_o, + t.prec_st, + t.prec_sw, + t.prec_sq, + t.prec_kw, + t.block_m, + t.activation, + t.gate_only, + t.fused_quant}; + auto a1 = fused_moegemm_args{ + a.a_ptr, // const void* a_ptr; + a.a_scale_ptr, // const void* a_scale_ptr; + a.g_ptr, // const void* g_ptr; + a.d_ptr, // const void* d_ptr; + a.g_scale_ptr, // const void* g_scale_ptr; + a.d_scale_ptr, // const void* d_scale_ptr; + a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr; + a.o_ptr, // void* o_ptr; + a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr; + a.sorted_weight_ptr, // const void* sorted_weight_ptr; + a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr; + a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr; + a.hidden_size, // index_t hidden_size; + a.intermediate_size, // index_t intermediate_size; + a.num_tokens, // index_t num_tokens; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; + a.stride_token // index_t stride_token; + }; - float r0 = -1; - float r1 = -1; + float r0 = -1; + float r1 = -1; - float r = ck_tile::launch_kernel( - s, - [=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); }, - [=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); }); + float r = ck_tile::launch_kernel( + s, + [=, &r0](const ck_tile::stream_config&) { + r0 = fused_moesorting(t0, a0, s_sub); + }, + [=, &r1](const ck_tile::stream_config&) { + r1 = fused_moegemm(t1, a1, s_sub); + }); - // keep unsupported case return negative - if(r0 < 0 || r1 < 0) - return -1; + // keep unsupported case return negative + if (r0 < 0 || r1 < 0) + return -1; - return r; + return r; } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api.hip index ed5857705..be1e64f2d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api.hip @@ -20,15 +20,67 @@ float fused_moegemm( // clang-format off float r = -1; if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && - t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0) { - using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0) + { + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0) + { + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0) + { + constexpr ck_tile::index_t act_ = 0; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1) + { + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1) + { + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1) + { + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 1; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; r = fused_moegemm_(s, a); } else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && - t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1) { - using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; + constexpr ck_tile::index_t act_ = 1; + constexpr ck_tile::index_t go_ = 0; + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>; r = fused_moegemm_(s, a); } // clang-format on diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api_internal.hpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api_internal.hpp index b67929341..8bf5571e0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api_internal.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api_internal.hpp @@ -23,6 +23,15 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) { typename Ts_::BlockTile_1, typename Ts_::WarpPerBlock_0, typename Ts_::WarpTile_0>; + + constexpr auto get_activation_ = []() { + if constexpr (Ts_::Activation == 0) { + return ck_tile::element_wise::FastGeluAsm{}; + } else + return ck_tile::element_wise::Silu{}; + }; + using f_act_ = ck_tile::remove_cvref_t; + using f_problem = ck_tile::FusedMoeGemmPipelineProblem< typename Ts_::ADataType, typename Ts_::GDataType, @@ -35,7 +44,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) { typename Ts_::YSmoothScaleDataType, typename Ts_::TopkWeightDataType, typename Ts_::IndexDataType, - ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded + f_act_, // TODO: hardcoded f_shape, f_traits>; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api_traits.hpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api_traits.hpp index a6fbe97e1..2c4d71ca4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api_traits.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_api_traits.hpp @@ -18,6 +18,7 @@ template < typename BlockTIle_, // seq typename WarpPerBlock_, typename WarpTile_, // seq<*,*,*>, used to select mfma + ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu ck_tile::index_t GateOnly_ = 0, ck_tile::index_t FusedQuant_ = 0> struct fmoe_ // traits, ugly name, only used for internal @@ -55,10 +56,11 @@ struct fmoe_ // traits, ugly name, only used for internal using WarpPerBlock_0 = ck_tile::remove_cvref_t; using WarpTile_0 = ck_tile::remove_cvref_t; - using BlockTile_1 = ck_tile::sequence; + using BlockTile_1 = ck_tile::sequence; using WarpPerBlock_1 = ck_tile::remove_cvref_t; using WarpTile_1 = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu static constexpr ck_tile::index_t GateOnly = GateOnly_; static constexpr ck_tile::index_t FusedQuant = FusedQuant_; }; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_bf16_m32.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_bf16_m32.hip index 57e4e845b..cf59492b1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_bf16_m32.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_bf16_m32.hip @@ -8,7 +8,18 @@ // clang-format off template float fused_moegemm_< - fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0> + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0> >(const ck_tile::stream_config& s, fused_moegemm_args a); +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); // clang-format on diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_fp16_m32.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_fp16_m32.hip index a68fea939..b86caeccf 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_fp16_m32.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moegemm_fp16_m32.hip @@ -8,7 +8,19 @@ // clang-format off template float fused_moegemm_< - fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0> + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0> >(const ck_tile::stream_config& s, fused_moegemm_args a); // clang-format on