1/*******************************************************************************
2* Copyright 2020 Intel Corporation
3* Copyright 2022 Arm Ltd. and affiliates
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#ifndef CPU_MATMUL_UTILS_HPP
19#define CPU_MATMUL_UTILS_HPP
20
21#include "common/memory_desc_wrapper.hpp"
22#include "common/utils.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace cpu {
27
28namespace matmul {
29
30struct matmul_helper_t {
31 using mdw_t = const memory_desc_wrapper;
32
33 matmul_helper_t(mdw_t &src_md, mdw_t &weights_md, mdw_t &dst_md)
34 : src_md_(src_md), weights_md_(weights_md), dst_md_(dst_md) {}
35
36 int ndims() const { return dst_md_.ndims(); }
37 bool batched() const { return ndims() > 2; }
38
39 dim_t batch() const {
40 return utils::array_product(dst_md_.dims(), ndims() - 2);
41 };
42 dim_t src_batch() const {
43 return utils::array_product(src_md_.dims(), ndims() - 2);
44 };
45 dim_t wei_batch() const {
46 return utils::array_product(weights_md_.dims(), ndims() - 2);
47 };
48
49 dim_t M() const { return dst_md_.dims()[ndims() - 2]; }
50 dim_t N() const { return dst_md_.dims()[ndims() - 1]; }
51 dim_t K() const { return src_md_.dims()[ndims() - 1]; }
52
53 char transA() const {
54 const auto &strides = &src_md_.blocking_desc().strides[ndims() - 2];
55 return (strides[1] == 1 && src_md_.dims()[ndims() - 2] > 1) ? 'N' : 'T';
56 }
57
58 char transB() const {
59 const auto &strides = &weights_md_.blocking_desc().strides[ndims() - 2];
60 return (strides[1] == 1 && weights_md_.dims()[ndims() - 2] > 1) ? 'N'
61 : 'T';
62 }
63
64 dim_t lda() const {
65 const auto &strides = &src_md_.blocking_desc().strides[ndims() - 2];
66 return strides[transA() == 'N' ? 0 : 1];
67 }
68
69 dim_t ldb() const {
70 const auto &strides = &weights_md_.blocking_desc().strides[ndims() - 2];
71 return strides[transB() == 'N' ? 0 : 1];
72 }
73
74 dim_t ldc() const { return dst_md_.blocking_desc().strides[ndims() - 2]; }
75
76 // TODO similar optimization is also possible for wei batch fusion.
77 bool can_fuse_src_batch_dims() const {
78 /* Note:
79 We can fuse src batch dims so that a single GeMM can be used iff
80 1. src is not transposed
81 2. wei batch dims are all 1's
82 3. The strides in batch dims are trivial (allowing permutations).
83 4. src and dst layout are identical. Example:
84 src layout : {batch dim_idx permutations}xMxK
85 dst layout : {identical batch dim_idx perm}xMxN;
86
87 For example,
88 src_layout : aXdXcXbXmXk
89 wei_layout: 1X1X1X1xkxn or 1X1X1X1xnxk
90 dst_layout : aXdXcXbXmXn
91
92 A single GeMM call can be used instead with m = a*d*c*b*m
93 */
94 // Note 1:
95 if (transA() == 'T') return false;
96
97 const int n_dims = ndims();
98 const int batch_ndims = n_dims - 2;
99 if (batch_ndims == 0) return true;
100
101 // Note 2:
102 if (utils::array_product(weights_md_.dims(), batch_ndims) != 1)
103 return false;
104
105 // determine batch dims layout
106 dims_t src_strides;
107 utils::array_copy(
108 src_strides, src_md_.blocking_desc().strides, batch_ndims);
109
110 // compute ou_dims. It is required to get correct perm
111 dims_t blocks = {0};
112 src_md_.compute_blocks(blocks);
113 dims_t ou_dims;
114 for (int i = 0; i < batch_ndims; ++i)
115 ou_dims[i] = src_md_.padded_dims()[i] / blocks[i];
116
117 dims_t perm;
118 for (int i = 0; i < batch_ndims; ++i)
119 perm[i] = i;
120
121 // permute batch dim idx by sorting based on strides.
122 utils::simultaneous_sort(src_strides, ou_dims, perm, batch_ndims,
123 [](stride_t a, stride_t b) { return a - b; });
124
125 dim_t src_stride = M() * lda();
126 dim_t dst_stride = M() * ldc();
127
128 // Note 3-4:
129 for (int i = 0; i < batch_ndims; ++i) {
130 const int dim_idx = perm[i];
131 if (src_md_.blocking_desc().strides[dim_idx] != src_stride
132 || dst_md_.blocking_desc().strides[dim_idx] != dst_stride)
133 return false;
134 src_stride = src_stride * src_md_.dims()[dim_idx];
135 dst_stride = dst_stride * dst_md_.dims()[dim_idx];
136 }
137
138 return true;
139 }
140
141private:
142 mdw_t src_md_;
143 mdw_t weights_md_;
144 mdw_t dst_md_;
145};
146
147} // namespace matmul
148} // namespace cpu
149} // namespace impl
150} // namespace dnnl
151#endif
152