Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CK MoE: cherry-pick #1808 #3609

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct fused_moe_args {
ck_tile::index_t block_m; // block_m, used to devide the input
ck_tile::index_t hidden_size; // k
ck_tile::index_t
intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
intermediate_size; // n / TP, for Gate. and Up, Down is also this value
ck_tile::index_t num_tokens; // input number of tokens for current iteration
ck_tile::index_t num_experts; // number of groups
ck_tile::index_t topk; // need this?
Expand All @@ -47,7 +47,8 @@ struct fused_moe_traits {
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int block_m;
int gate_only;
int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ struct fused_moegemm_traits {
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int block_m;
int gate_only;
int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,78 +3,86 @@

#include "fused_moe.hpp"

float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
{
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
float fused_moe(
fused_moe_traits t,
fused_moe_args a,
const ck_tile::stream_config& s) {
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};

auto o_data_bytes = [&]() {
if(t.prec_o == "fp32")
return 4;
else if(t.prec_o == "fp16" || t.prec_o == "bf16")
return 2;
else if(t.prec_o == "int8" || t.prec_o == "fp8")
return 1;
return 1;
}();
auto o_data_bytes = [&]() {
if (t.prec_o == "fp32")
return 4;
else if (t.prec_o == "fp16" || t.prec_o == "bf16")
return 2;
else if (t.prec_o == "int8" || t.prec_o == "fp8")
return 1;
return 1;
}();

auto t0 = fused_moesorting_trait{"int32", "fp32"};
auto a0 = fused_moesorting_args{
a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights;
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
a.o_ptr, // void* p_moe_buf;
a.num_tokens, // index_t tokens;
a.block_m, // index_t unit_size;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
};
auto t0 = fused_moesorting_trait{"int32", "fp32"};
auto a0 = fused_moesorting_args{
a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights;
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
a.o_ptr, // void* p_moe_buf;
a.num_tokens, // index_t tokens;
a.block_m, // index_t unit_size;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
};

auto t1 = fused_moegemm_traits{t.prec_i,
t.prec_w,
t.prec_o,
t.prec_st,
t.prec_sw,
t.prec_sq,
t.prec_kw,
t.block_m,
t.gate_only,
t.fused_quant};
auto a1 = fused_moegemm_args{
a.a_ptr, // const void* a_ptr;
a.a_scale_ptr, // const void* a_scale_ptr;
a.g_ptr, // const void* g_ptr;
a.d_ptr, // const void* d_ptr;
a.g_scale_ptr, // const void* g_scale_ptr;
a.d_scale_ptr, // const void* d_scale_ptr;
a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr;
a.o_ptr, // void* o_ptr;
a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr;
a.sorted_weight_ptr, // const void* sorted_weight_ptr;
a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr;
a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr;
a.hidden_size, // index_t hidden_size;
a.intermediate_size, // index_t intermediate_size;
a.num_tokens, // index_t num_tokens;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
a.stride_token // index_t stride_token;
};
auto t1 = fused_moegemm_traits{
t.prec_i,
t.prec_w,
t.prec_o,
t.prec_st,
t.prec_sw,
t.prec_sq,
t.prec_kw,
t.block_m,
t.activation,
t.gate_only,
t.fused_quant};
auto a1 = fused_moegemm_args{
a.a_ptr, // const void* a_ptr;
a.a_scale_ptr, // const void* a_scale_ptr;
a.g_ptr, // const void* g_ptr;
a.d_ptr, // const void* d_ptr;
a.g_scale_ptr, // const void* g_scale_ptr;
a.d_scale_ptr, // const void* d_scale_ptr;
a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr;
a.o_ptr, // void* o_ptr;
a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr;
a.sorted_weight_ptr, // const void* sorted_weight_ptr;
a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr;
a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr;
a.hidden_size, // index_t hidden_size;
a.intermediate_size, // index_t intermediate_size;
a.num_tokens, // index_t num_tokens;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
a.stride_token // index_t stride_token;
};

float r0 = -1;
float r1 = -1;
float r0 = -1;
float r1 = -1;

float r = ck_tile::launch_kernel(
s,
[=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); },
[=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); });
float r = ck_tile::launch_kernel(
s,
[=, &r0](const ck_tile::stream_config&) {
r0 = fused_moesorting(t0, a0, s_sub);
},
[=, &r1](const ck_tile::stream_config&) {
r1 = fused_moegemm(t1, a1, s_sub);
});

// keep unsupported case return negative
if(r0 < 0 || r1 < 0)
return -1;
// keep unsupported case return negative
if (r0 < 0 || r1 < 0)
return -1;

return r;
return r;
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,67 @@ float fused_moegemm(
// clang-format off
float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
// clang-format on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) {
typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>;

constexpr auto get_activation_ = []() {
if constexpr (Ts_::Activation == 0) {
return ck_tile::element_wise::FastGeluAsm{};
} else
return ck_tile::element_wise::Silu{};
};
using f_act_ = ck_tile::remove_cvref_t<decltype(get_activation_())>;

using f_problem = ck_tile::FusedMoeGemmPipelineProblem<
typename Ts_::ADataType,
typename Ts_::GDataType,
Expand All @@ -35,7 +44,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) {
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
f_act_, // TODO: hardcoded
f_shape,
f_traits>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ template <
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
typename WarpPerBlock_,
typename WarpTile_, // seq<*,*,*>, used to select mfma
ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu
ck_tile::index_t GateOnly_ = 0,
ck_tile::index_t FusedQuant_ = 0>
struct fmoe_ // traits, ugly name, only used for internal
Expand Down Expand Up @@ -55,10 +56,11 @@ struct fmoe_ // traits, ugly name, only used for internal
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;

using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;

static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu
static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
};
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@

// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);

template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);

template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);

template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,19 @@

// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);

template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);

template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);

template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);

// clang-format on
Loading