1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_GEMM_UTILS_HPP
18#define COMMON_GEMM_UTILS_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "common/c_types_map.hpp"
23#include "common/nstl.hpp"
24#include "common/primitive_desc_iterator.hpp"
25#include "common/utils.hpp"
26
27namespace dnnl {
28namespace impl {
29
30static inline status_t check_gemm_input(char transa, char transb, int m, int n,
31 int k, int lda, int ldb, int ldc, float alpha, float beta) {
32 using namespace status;
33 bool consistency = true && utils::one_of(transa, 'T', 't', 'N', 'n')
34 && utils::one_of(transb, 'T', 't', 'N', 'n') && m >= 0 && n >= 0
35 && k >= 0;
36 if (!consistency) return invalid_arguments;
37 bool isTransA = utils::one_of(transa, 'T', 't');
38 bool isTransB = utils::one_of(transb, 'T', 't');
39 int nrowA = isTransA ? k : m;
40 int nrowB = isTransB ? n : k;
41 consistency = true && lda >= nstl::max(1, nrowA)
42 && ldb >= nstl::max(1, nrowB) && ldc >= nstl::max(1, m);
43 if (!consistency) return invalid_arguments;
44
45 return success;
46}
47
48static inline status_t check_gemm_x8x8s32_input(char offsetc, char transa,
49 char transb, int m, int n, int k, int lda, int ldb, int ldc,
50 float alpha, float beta) {
51 using namespace status;
52 if (!utils::one_of(offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
53 return invalid_arguments;
54 return check_gemm_input(
55 transa, transb, m, n, k, lda, ldb, ldc, alpha, beta);
56}
57
58// This function makes a 2d tensor from an nd tensor.
59// the 2d tensor just collapes dims[1...ndims-1] from the nd tensor
60// The only reason we do not use reshape here is that we want to allow
61// fusing blocked dimensions and padded dimensions.
62static inline void init_2d_desc(memory_desc_t *md_2d,
63 const memory_desc_t *md_nd, bool transpose_dims = false) {
64 auto p_dims = md_nd->padded_dims;
65 auto blk = md_nd->format_desc.blocking;
66 auto strides = blk.strides;
67
68 // we assume that the innermost dimension always has stride 1
69 assert(IMPLICATION(blk.inner_nblks == 0,
70 utils::array_min(strides, md_nd->ndims) == 1));
71
72 // TODO: add checks to see if the memory descriptor can be 2d-fied
73 // TODO: change signature to specifiy at which dimension shall we 2d-fy (currently 1st)
74 auto p_dim1 = utils::array_product(p_dims + 1, md_nd->ndims - 1);
75 auto stride1 = blk.inner_nblks == 0
76 ? utils::array_min(strides + 1, md_nd->ndims - 1)
77 : 1;
78
79 if (transpose_dims) {
80 dnnl_dims_t dims_2d = {p_dim1, p_dims[0]};
81 dnnl_dims_t strides_2d = {stride1, strides[0]};
82 memory_desc_init_by_strides(
83 *md_2d, 2, dims_2d, md_nd->data_type, strides_2d);
84 } else {
85 dnnl_dims_t dims_2d = {p_dims[0], p_dim1};
86 dnnl_dims_t strides_2d = {strides[0], stride1};
87 memory_desc_init_by_strides(
88 *md_2d, 2, dims_2d, md_nd->data_type, strides_2d);
89 }
90}
91
92static inline void create_2d_desc(memory_desc_t *md_2d, int d0, int d1,
93 data_type_t dt, transpose_t trans, int ld) {
94 dnnl_dims_t dims_2d = {d0, d1};
95 if (trans == transpose::notrans) {
96 dnnl_dims_t strides_2d = {ld, 1};
97 memory_desc_init_by_strides(*md_2d, 2, dims_2d, dt, strides_2d);
98 } else {
99 dnnl_dims_t strides_2d = {1, ld};
100 memory_desc_init_by_strides(*md_2d, 2, dims_2d, dt, strides_2d);
101 }
102}
103
104static inline gemm_desc_t create_gemm_desc(const memory_desc_t *a_md,
105 const memory_desc_t *b_md, const memory_desc_t *c_md,
106 const memory_desc_t *bias_md, data_type_t acc_dt, engine_t *engine,
107 sum_ab_t sum_ab = sum_ab::sum_none,
108 data_type_t sum_ab_dt = data_type::undef) {
109 auto gemm_desc = gemm_desc_t();
110 gemm_desc.primitive_kind = primitive_kind::gemm;
111 gemm_desc.a_desc = *a_md;
112 gemm_desc.b_desc = *b_md;
113 gemm_desc.c_desc = *c_md;
114 gemm_desc.bias_desc = *bias_md;
115 gemm_desc.acc_type = acc_dt;
116 gemm_desc.sum_ab = sum_ab;
117 gemm_desc.sum_ab_type = sum_ab_dt;
118 // Downgrade accumulation type for f16 if allowed.
119 if (engine->mayiuse_f16_accumulator_with_f16()
120 && utils::everyone_is(
121 data_type::f16, a_md->data_type, b_md->data_type)) {
122 gemm_desc.acc_type = data_type::f16;
123 }
124 return gemm_desc;
125}
126
127static inline status_t create_gemm_pd(
128 std::shared_ptr<primitive_desc_t> &gemm_pd_, engine_t *engine,
129 const memory_desc_t *a_md, const memory_desc_t *b_md,
130 const memory_desc_t *c_md, const memory_desc_t *bias_md,
131 data_type_t acc_dt, const primitive_attr_t *attr, bool skip_ref = false,
132 sum_ab_t sum_ab = sum_ab::sum_none,
133 data_type_t sum_ab_dt = data_type::undef) {
134 auto gemm_desc = create_gemm_desc(
135 a_md, b_md, c_md, bias_md, acc_dt, engine, sum_ab, sum_ab_dt);
136
137 primitive_attr_t gemm_attr = *attr;
138
139 primitive_desc_iterator_t it(
140 engine, (op_desc_t *)&gemm_desc, &gemm_attr, nullptr);
141
142 gemm_pd_ = *(++it);
143 if (!gemm_pd_) return status::unimplemented;
144 if (skip_ref && strstr(gemm_pd_.get()->name(), "ref") != NULL)
145 return status::unimplemented;
146
147 return status::success;
148}
149
150static inline bool is_md_gemm_compatible_plain_format(
151 const memory_desc_t *md, bool is_dst = false) {
152
153 if (md->format_kind != format_kind::blocked) return false;
154
155 auto &blk_desc = md->format_desc.blocking;
156
157 if (blk_desc.inner_nblks != 0) return false;
158
159 return (blk_desc.strides[md->ndims - 1] == 1)
160 || (!is_dst && blk_desc.strides[md->ndims - 2] == 1);
161}
162
163} // namespace impl
164} // namespace dnnl
165
166#endif
167