1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_GEMM_UTILS_HPP |
18 | #define COMMON_GEMM_UTILS_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/nstl.hpp" |
24 | #include "common/primitive_desc_iterator.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | |
30 | static inline status_t check_gemm_input(char transa, char transb, int m, int n, |
31 | int k, int lda, int ldb, int ldc, float alpha, float beta) { |
32 | using namespace status; |
33 | bool consistency = true && utils::one_of(transa, 'T', 't', 'N', 'n') |
34 | && utils::one_of(transb, 'T', 't', 'N', 'n') && m >= 0 && n >= 0 |
35 | && k >= 0; |
36 | if (!consistency) return invalid_arguments; |
37 | bool isTransA = utils::one_of(transa, 'T', 't'); |
38 | bool isTransB = utils::one_of(transb, 'T', 't'); |
39 | int nrowA = isTransA ? k : m; |
40 | int nrowB = isTransB ? n : k; |
41 | consistency = true && lda >= nstl::max(1, nrowA) |
42 | && ldb >= nstl::max(1, nrowB) && ldc >= nstl::max(1, m); |
43 | if (!consistency) return invalid_arguments; |
44 | |
45 | return success; |
46 | } |
47 | |
48 | static inline status_t check_gemm_x8x8s32_input(char offsetc, char transa, |
49 | char transb, int m, int n, int k, int lda, int ldb, int ldc, |
50 | float alpha, float beta) { |
51 | using namespace status; |
52 | if (!utils::one_of(offsetc, 'F', 'f', 'C', 'c', 'R', 'r')) |
53 | return invalid_arguments; |
54 | return check_gemm_input( |
55 | transa, transb, m, n, k, lda, ldb, ldc, alpha, beta); |
56 | } |
57 | |
58 | // This function makes a 2d tensor from an nd tensor. |
59 | // the 2d tensor just collapes dims[1...ndims-1] from the nd tensor |
60 | // The only reason we do not use reshape here is that we want to allow |
61 | // fusing blocked dimensions and padded dimensions. |
62 | static inline void init_2d_desc(memory_desc_t *md_2d, |
63 | const memory_desc_t *md_nd, bool transpose_dims = false) { |
64 | auto p_dims = md_nd->padded_dims; |
65 | auto blk = md_nd->format_desc.blocking; |
66 | auto strides = blk.strides; |
67 | |
68 | // we assume that the innermost dimension always has stride 1 |
69 | assert(IMPLICATION(blk.inner_nblks == 0, |
70 | utils::array_min(strides, md_nd->ndims) == 1)); |
71 | |
72 | // TODO: add checks to see if the memory descriptor can be 2d-fied |
73 | // TODO: change signature to specifiy at which dimension shall we 2d-fy (currently 1st) |
74 | auto p_dim1 = utils::array_product(p_dims + 1, md_nd->ndims - 1); |
75 | auto stride1 = blk.inner_nblks == 0 |
76 | ? utils::array_min(strides + 1, md_nd->ndims - 1) |
77 | : 1; |
78 | |
79 | if (transpose_dims) { |
80 | dnnl_dims_t dims_2d = {p_dim1, p_dims[0]}; |
81 | dnnl_dims_t strides_2d = {stride1, strides[0]}; |
82 | memory_desc_init_by_strides( |
83 | *md_2d, 2, dims_2d, md_nd->data_type, strides_2d); |
84 | } else { |
85 | dnnl_dims_t dims_2d = {p_dims[0], p_dim1}; |
86 | dnnl_dims_t strides_2d = {strides[0], stride1}; |
87 | memory_desc_init_by_strides( |
88 | *md_2d, 2, dims_2d, md_nd->data_type, strides_2d); |
89 | } |
90 | } |
91 | |
92 | static inline void create_2d_desc(memory_desc_t *md_2d, int d0, int d1, |
93 | data_type_t dt, transpose_t trans, int ld) { |
94 | dnnl_dims_t dims_2d = {d0, d1}; |
95 | if (trans == transpose::notrans) { |
96 | dnnl_dims_t strides_2d = {ld, 1}; |
97 | memory_desc_init_by_strides(*md_2d, 2, dims_2d, dt, strides_2d); |
98 | } else { |
99 | dnnl_dims_t strides_2d = {1, ld}; |
100 | memory_desc_init_by_strides(*md_2d, 2, dims_2d, dt, strides_2d); |
101 | } |
102 | } |
103 | |
104 | static inline gemm_desc_t create_gemm_desc(const memory_desc_t *a_md, |
105 | const memory_desc_t *b_md, const memory_desc_t *c_md, |
106 | const memory_desc_t *bias_md, data_type_t acc_dt, engine_t *engine, |
107 | sum_ab_t sum_ab = sum_ab::sum_none, |
108 | data_type_t sum_ab_dt = data_type::undef) { |
109 | auto gemm_desc = gemm_desc_t(); |
110 | gemm_desc.primitive_kind = primitive_kind::gemm; |
111 | gemm_desc.a_desc = *a_md; |
112 | gemm_desc.b_desc = *b_md; |
113 | gemm_desc.c_desc = *c_md; |
114 | gemm_desc.bias_desc = *bias_md; |
115 | gemm_desc.acc_type = acc_dt; |
116 | gemm_desc.sum_ab = sum_ab; |
117 | gemm_desc.sum_ab_type = sum_ab_dt; |
118 | // Downgrade accumulation type for f16 if allowed. |
119 | if (engine->mayiuse_f16_accumulator_with_f16() |
120 | && utils::everyone_is( |
121 | data_type::f16, a_md->data_type, b_md->data_type)) { |
122 | gemm_desc.acc_type = data_type::f16; |
123 | } |
124 | return gemm_desc; |
125 | } |
126 | |
127 | static inline status_t create_gemm_pd( |
128 | std::shared_ptr<primitive_desc_t> &gemm_pd_, engine_t *engine, |
129 | const memory_desc_t *a_md, const memory_desc_t *b_md, |
130 | const memory_desc_t *c_md, const memory_desc_t *bias_md, |
131 | data_type_t acc_dt, const primitive_attr_t *attr, bool skip_ref = false, |
132 | sum_ab_t sum_ab = sum_ab::sum_none, |
133 | data_type_t sum_ab_dt = data_type::undef) { |
134 | auto gemm_desc = create_gemm_desc( |
135 | a_md, b_md, c_md, bias_md, acc_dt, engine, sum_ab, sum_ab_dt); |
136 | |
137 | primitive_attr_t gemm_attr = *attr; |
138 | |
139 | primitive_desc_iterator_t it( |
140 | engine, (op_desc_t *)&gemm_desc, &gemm_attr, nullptr); |
141 | |
142 | gemm_pd_ = *(++it); |
143 | if (!gemm_pd_) return status::unimplemented; |
144 | if (skip_ref && strstr(gemm_pd_.get()->name(), "ref" ) != NULL) |
145 | return status::unimplemented; |
146 | |
147 | return status::success; |
148 | } |
149 | |
150 | static inline bool is_md_gemm_compatible_plain_format( |
151 | const memory_desc_t *md, bool is_dst = false) { |
152 | |
153 | if (md->format_kind != format_kind::blocked) return false; |
154 | |
155 | auto &blk_desc = md->format_desc.blocking; |
156 | |
157 | if (blk_desc.inner_nblks != 0) return false; |
158 | |
159 | return (blk_desc.strides[md->ndims - 1] == 1) |
160 | || (!is_dst && blk_desc.strides[md->ndims - 2] == 1); |
161 | } |
162 | |
163 | } // namespace impl |
164 | } // namespace dnnl |
165 | |
166 | #endif |
167 | |