Skip to content

Commit

Permalink
Backport: cpu: x64: matmul: fixes correctness issue about tags (#2464)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuxinzen authored Jan 22, 2025
1 parent 87d098e commit 0c1cb4a
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,9 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
: 0;

bgmmc.LDB = bm_conf_utils.get_actual_LDB();
bgmmc.LDD = dst_d.blocking_desc().strides[bgmmc.ndims - 2];
bgmmc.LDD = dst_d.ndims() == 2 && dst_d.count_non_unit_dims(1)
? bgmmc.N
: dst_d.blocking_desc().strides[bgmmc.ndims - 2];
bgmmc.LDC = bgmmc.use_buffer_c && bgmmc.nthr_k <= 1
? bgmmc.N_blk * (bgmmc.is_runtime_N ? bgmmc.N_chunk_size : 1)
: bgmmc.LDD;
Expand Down Expand Up @@ -1852,9 +1854,13 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc,
const dim_t src_stride = src_d.matches_tag(acbd)
? bgmmc.A_strides[1]
: bgmmc.A_strides[0];
const dim_t copy_A_src_stride = src_d.matches_tag(dabc)
&& bgmmc.K * bgmmc.batch
== src_d.blocking_desc().strides[0]
? src_d.blocking_desc().strides[0]
: src_d.blocking_desc().strides[0] * bgmmc.K;
bgmmc.copy_A_src_stride
= nstl::min(src_d.blocking_desc().strides[0],
src_stride / factor)
= nstl::min(copy_A_src_stride, src_stride / factor)
* factor;
}

Expand Down

0 comments on commit 0c1cb4a

Please sign in to comment.