From 0c1cb4aea195118b22111a7f84b8b936106eec84 Mon Sep 17 00:00:00 2001 From: xuxinzen Date: Wed, 22 Jan 2025 14:05:07 -0500 Subject: [PATCH] Backport: cpu: x64: matmul: fixes correctness issue about tags (#2464) --- src/cpu/x64/matmul/brgemm_matmul_utils.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 7734afb99f6..a58ce38a66e 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -1658,7 +1658,9 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, : 0; bgmmc.LDB = bm_conf_utils.get_actual_LDB(); - bgmmc.LDD = dst_d.blocking_desc().strides[bgmmc.ndims - 2]; + bgmmc.LDD = dst_d.ndims() == 2 && dst_d.count_non_unit_dims(1) + ? bgmmc.N + : dst_d.blocking_desc().strides[bgmmc.ndims - 2]; bgmmc.LDC = bgmmc.use_buffer_c && bgmmc.nthr_k <= 1 ? bgmmc.N_blk * (bgmmc.is_runtime_N ? bgmmc.N_chunk_size : 1) : bgmmc.LDD; @@ -1852,9 +1854,13 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc, const dim_t src_stride = src_d.matches_tag(acbd) ? bgmmc.A_strides[1] : bgmmc.A_strides[0]; + const dim_t copy_A_src_stride = src_d.matches_tag(dabc) + && bgmmc.K * bgmmc.batch + == src_d.blocking_desc().strides[0] + ? src_d.blocking_desc().strides[0] + : src_d.blocking_desc().strides[0] * bgmmc.K; bgmmc.copy_A_src_stride - = nstl::min(src_d.blocking_desc().strides[0], - src_stride / factor) + = nstl::min(copy_A_src_stride, src_stride / factor) * factor; }