Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable brgemm-based RNN for int8 on avx2 #2476

Merged
merged 3 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,7 @@ struct memory : public handle<dnnl_memory_t> {
abCd32c = dnnl_abCd32c,
abdEc16e = dnnl_abdEc16e,
abdEc32e = dnnl_abdEc32e,
abdEC16e4c = dnnl_abdEC16e4c,
abdEC32e2c = dnnl_abdEC32e2c,
abdEC32e4c = dnnl_abdEC32e4c,
abdCe16c = dnnl_abdCe16c,
Expand Down Expand Up @@ -1979,6 +1980,7 @@ struct memory : public handle<dnnl_memory_t> {
ldOi32o = abDc32d,
ldOI32o4i = abDC32d4c,
ldgOi16o = abdEc16e,
ldgOI16o4i = abdEC16e4c,
ldgOi32o = abdEc32e,
ldgOI32o2i = abdEC32e2c,
ldgOI32o4i = abdEC32e4c,
Expand Down
4 changes: 3 additions & 1 deletion include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
* Copyright 2016-2025 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -1045,6 +1045,7 @@ typedef enum {
dnnl_BAcd8a8b,
dnnl_BAcde8a8b,
dnnl_aCBdef8b8c,
dnnl_abdEC16e4c,

/// Just a sentinel, not real memory format tag. Must be changed after new
/// format tag is added.
Expand Down Expand Up @@ -1184,6 +1185,7 @@ typedef enum {
dnnl_ldIo32i = dnnl_abCd32c,
/// 6D RNN weights tensor
dnnl_ldgOi16o = dnnl_abdEc16e,
dnnl_ldgOI16o4i = dnnl_abdEC16e4c,
dnnl_ldgOi32o = dnnl_abdEc32e,
dnnl_ldgOI32o2i = dnnl_abdEC32e2c,
dnnl_ldgOI32o4i = dnnl_abdEC32e4c,
Expand Down
4 changes: 3 additions & 1 deletion src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
* Copyright 2016-2025 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -717,6 +717,7 @@ const format_tag_t abCde32c = dnnl_abCde32c;
const format_tag_t abCdef32c = dnnl_abCdef32c;
const format_tag_t abdEc16e = dnnl_abdEc16e;
const format_tag_t abdEc32e = dnnl_abdEc32e;
const format_tag_t abdEC16e4c = dnnl_abdEC16e4c;
const format_tag_t abdEC32e2c = dnnl_abdEC32e2c;
const format_tag_t abdEC32e4c = dnnl_abdEC32e4c;
const format_tag_t abdEC64e2c = dnnl_abdEC64e2c;
Expand Down Expand Up @@ -1483,6 +1484,7 @@ const format_tag_t ldOI32o4i = dnnl_ldOI32o4i;
const format_tag_t ldIo32i = dnnl_ldIo32i;
const format_tag_t ldgOi16o = dnnl_ldgOi16o;
const format_tag_t ldgOi32o = dnnl_ldgOi32o;
const format_tag_t ldgOI16o4i = dnnl_ldgOI16o4i;
const format_tag_t ldgOI32o2i = dnnl_ldgOI32o2i;
const format_tag_t ldgOI32o4i = dnnl_ldgOI32o4i;
const format_tag_t ldgOI64o2i = dnnl_ldgOI64o2i;
Expand Down
4 changes: 3 additions & 1 deletion src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2018-2025 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -376,6 +376,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_abDc32d) return "abDc32d";
if (v == dnnl_abDC32d4c) return "abDC32d4c";
if (v == dnnl_abdEc32e) return "abdEc32e";
if (v == dnnl_abdEC16e4c) return "abdEC16e4c";
if (v == dnnl_abdEC32e2c) return "abdEC32e2c";
if (v == dnnl_abdEC32e4c) return "abdEC32e4c";
if (v == dnnl_aBdefC16b4c) return "aBdefC16b4c";
Expand Down Expand Up @@ -1008,6 +1009,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_ldIo32i) return "ldIo32i";
if (v == dnnl_ldgOi16o) return "ldgOi16o";
if (v == dnnl_ldgOi32o) return "ldgOi32o";
if (v == dnnl_ldgOI16o4i) return "ldgOI16o4i";
if (v == dnnl_ldgOI32o2i) return "ldgOI32o2i";
if (v == dnnl_ldgOI32o4i) return "ldgOI32o4i";
if (v == dnnl_ldgOI64o2i) return "ldgOI64o2i";
Expand Down
3 changes: 2 additions & 1 deletion src/common/memory_desc_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
* Copyright 2016-2025 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -628,6 +628,7 @@ status_t memory_desc_wrapper::compute_blocking(
C(abCdef32c, {0, 1, 2, 3, 4, 5}, {32}, {2});
C(abdEc16e, {0, 1, 3, 4, 2}, {16}, {4});
C(abdEc32e, {0, 1, 3, 4, 2}, {32}, {4});
C(abdEC16e4c, {0, 1, 3, 4, 2}, {16, 4}, {4, 2});
C(abdEC32e2c, {0, 1, 3, 4, 2}, {32, 2}, {4, 2});
C(abdEC32e4c, {0, 1, 3, 4, 2}, {32, 4}, {4, 2});
C(abdEC64e2c, {0, 1, 3, 4, 2}, {64, 2}, {4, 2});
Expand Down
4 changes: 3 additions & 1 deletion src/common/tag_traits.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2018-2025 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -142,6 +142,7 @@ enum class inner_blk_t {
_16b4c,
_16c2b,
_16c4b,
_16e4c,
_24a2b,
_24a4b,
_24b2a,
Expand Down Expand Up @@ -860,6 +861,7 @@ DECL_TRAITS(abCde4c, _C, _4c, 5);
DECL_TRAITS(abCdef4c, _C, _4c, 6);
DECL_TRAITS(abdEc16e, _E, _16e, 5);
DECL_TRAITS(abdEc32e, _E, _32e, 5);
DECL_TRAITS(abdEC16e4c, _CE, _16e4c, 5);
DECL_TRAITS(abdEC32e2c, _CE, _32e2c, 5);
DECL_TRAITS(abdEC32e4c, _CE, _32e4c, 5);
DECL_TRAITS(abdEC64e2c, _CE, _64e2c, 5);
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/rnn/ref_rnn.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2018-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -414,7 +414,7 @@ _ref_rnn_common_t<aprop, src_type, weights_type, acc_type>::pd_t::init_brgemm(
VDISPATCH_RNN(
!(rnn_.is_signed_int8_conf() && !is_superset(isa, avx512_core_amx)),
VERBOSE_ISA_DT_MISMATCH);
VDISPATCH_RNN(!(rnn_.is_int8_conf() && !is_superset(isa, avx512_core_vnni)),
VDISPATCH_RNN(!(rnn_.is_int8_conf() && !is_superset(isa, avx2)),
VERBOSE_ISA_DT_MISMATCH);
VDISPATCH_RNN(!(rnn_.is_f32_conf() && !is_superset(isa, avx2)),
VERBOSE_ISA_DT_MISMATCH);
Expand Down
9 changes: 4 additions & 5 deletions src/cpu/rnn/rnn_reorders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,8 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {
format_tag_t otag, itag;

itag = id.matches_one_of_tag(ldigo, ldio);
otag = od.matches_one_of_tag(ldgOI64o4i, ldgOI32o4i, ldOI32o4i);
otag = od.matches_one_of_tag(
ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i);
if (itag != format_tag::undef && otag != format_tag::undef) {
_pd->itag_ = itag;
_pd->otag_ = otag;
Expand Down Expand Up @@ -855,15 +856,13 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {
return status::success;
}

const auto &blocked_d = dst_d;
const auto &pdims = blocked_d.padded_dims();

const int o_block = pd()->otag_ == ldgOI64o4i ? 64 : 32;
const int o_block = dst_d.blocking_desc().inner_blks[0];
static constexpr int i_block = 4;

dim_t L, D, I, G, O;
init_dims(L, D, I, G, O, src_d);

const auto &pdims = dst_d.padded_dims();
const dim_t pI = pdims[2];
const dim_t pO = (src_d.ndims() == 5) ? pdims[4] : pdims[3];
const dim_t IB = pI / i_block;
Expand Down
9 changes: 5 additions & 4 deletions src/cpu/rnn/rnn_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2023 Intel Corporation
* Copyright 2018-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -76,8 +76,8 @@ bool rnn_utils::is_ldoi(const memory_desc_wrapper &mdw) {
bool rnn_utils::is_ldigo_blocked(const memory_desc_wrapper &mdw) {
format_tag_t md_format_tag = mdw.matches_one_of_tag(format_tag::ldgOi32o,
format_tag::ldgOI32o2i, format_tag::ldgOI32o4i,
format_tag::ldgOI64o2i, format_tag::ldgOI64o4i,
format_tag::ldgOi16o);
format_tag::ldgOI16o4i, format_tag::ldgOI64o2i,
format_tag::ldgOI64o4i, format_tag::ldgOi16o);
return md_format_tag != format_tag::undef;
}

Expand Down Expand Up @@ -293,7 +293,8 @@ status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,
} else if (rnn.is_fwd) {
if (rnn.is_int8_conf())
tag = utils::map(n_block, format_tag::undef, 64,
format_tag::ldgOI64o4i, 32, ldgOI32o4i);
format_tag::ldgOI64o4i, 32, ldgOI32o4i, 16,
ldgOI16o4i);
else if (rnn.is_xf16_conf())
tag = utils::map(n_block, format_tag::undef, 64,
format_tag::ldgOI64o2i, 32, ldgOI32o2i);
Expand Down
15 changes: 8 additions & 7 deletions src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,14 @@ void set_isa_impl(brgemm_desc_t *brg) {
is_isa_ok(avx512_core_fp16), avx512_core_fp16);
}
} else if (brg->is_int8) {
brg->isa_impl = utils::map(true, isa_undef,
is_isa_ok(avx512_core_amx_fp16), avx512_core_amx_fp16,
is_isa_ok(avx512_core_amx), avx512_core_amx,
is_isa_ok(avx512_core_fp16), avx512_core_fp16,
is_isa_ok(avx512_core_vnni), avx512_core_vnni,
is_isa_ok(avx512_core), avx512_core, is_isa_ok(avx2_vnni_2),
avx2_vnni_2, is_isa_ok(avx2_vnni), avx2_vnni);
brg->isa_impl
= utils::map(true, isa_undef, is_isa_ok(avx512_core_amx_fp16),
avx512_core_amx_fp16, is_isa_ok(avx512_core_amx),
avx512_core_amx, is_isa_ok(avx512_core_fp16),
avx512_core_fp16, is_isa_ok(avx512_core_vnni),
avx512_core_vnni, is_isa_ok(avx512_core), avx512_core,
is_isa_ok(avx2_vnni_2), avx2_vnni_2,
is_isa_ok(avx2_vnni), avx2_vnni, is_isa_ok(avx2), avx2);
} else if (brg->is_fp8) {
brg->isa_impl = utils::map(true, isa_undef,
is_isa_ok(avx10_1_512_amx_fp16), avx10_1_512_amx_fp16);
Expand Down
9 changes: 8 additions & 1 deletion src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2899,7 +2899,14 @@ void jit_brgemm_kernel_t<Wmm>::generate() {

if (brg.is_int8 && !brg.has_int8_vnni) {
mov(reg_tmp_gpr.cvt16(), 0x1);
vpbroadcastw(int8_ones_words(), reg_tmp_gpr.cvt16());

if (is_superset(brg.isa_impl, avx512_core))
vpbroadcastw(int8_ones_words(), reg_tmp_gpr.cvt16());
else if (is_superset(brg.isa_impl, avx2)) {
movq(Xmm(int8_ones_words().getIdx()), reg_tmp_gpr);
vpbroadcastw(int8_ones_words(), Xmm(int8_ones_words().getIdx()));
} else
assert(!"unsupported isa");
}

if (brg.is_f16_b_non_amx_vnni()) {
Expand Down
10 changes: 9 additions & 1 deletion src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4936,7 +4936,15 @@ status_t create_brgemm_matmul_copy_b(
CHECK(safe_ptr_assign(copy_ker,
new jit_avx512_core_brgemm_matmul_copy_b_int8_t(conf)));
else {
assert(one_of(conf->isa, avx2_vnni, avx2_vnni_2));
// TODO: jit_avx2_vnni_brgemm_matmul_copy_b_int8_t can handle
// avx2 if no compensation is required. Consider enabling it
// for avx2 and renaming the kernel (drop "vnni" part).
const bool is_comp_required = conf->s8s8_compensation_required
|| conf->has_zero_point_a;
MAYBE_UNUSED(is_comp_required);
assert(one_of(conf->isa, avx2_vnni, avx2_vnni_2, avx2)
&& IMPLICATION(conf->isa == avx2, !is_comp_required));

CHECK(safe_ptr_assign(copy_ker,
new jit_avx2_vnni_brgemm_matmul_copy_b_int8_t(conf)));
}
Expand Down
5 changes: 3 additions & 2 deletions src/cpu/x64/rnn/rnn_brgemm_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -85,7 +85,8 @@ x64::cpu_isa_t brgemm_calc_isa(

if (rnn.is_cell_dt_int8()) {
return utils::map(true, x64::isa_undef, mayiuse(avx512_core_vnni),
avx512_core_vnni, mayiuse(avx512_core), avx512_core);
avx512_core_vnni, mayiuse(avx512_core), avx512_core,
mayiuse(avx2), avx2);
} else if (rnn.is_cell_dt_bf16()) {
return x64::avx512_core_bf16;
} else if (rnn.is_cell_dt_f16()) {
Expand Down
5 changes: 3 additions & 2 deletions tests/benchdnn/rnn/rnn.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2018-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -780,7 +780,8 @@ void skip_unimplemented_prb(const prb_t *prb_, res_t *res) {
}
#endif
// cpu backward only supports `any` or `abx` layouts for weights
if (IMPLICATION(prb.prop == dnnl_backward, prb.tag[1] != tag::abx)) {
if (prb.prop == dnnl_backward && prb.tag[1] != tag::abx
&& prb.tag[1] != tag::any) {
res->state = SKIPPED;
res->reason = skip_reason::case_not_supported;
return;
Expand Down
Loading