1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_GEMM_TYPES_HPP
18#define COMMON_GEMM_TYPES_HPP
19
20#include <assert.h>
21#include "common/c_types_map.hpp"
22#include "common/memory_desc.hpp"
23
24namespace dnnl {
25namespace impl {
26
27enum transpose_t { dnnl_notrans, dnnl_trans };
28
29namespace transpose {
30const transpose_t notrans = dnnl_notrans;
31const transpose_t trans = dnnl_trans;
32} // namespace transpose
33
34enum offsetc_t { dnnl_fixed, dnnl_column, dnnl_row };
35
36namespace offsetc {
37const offsetc_t fixed = dnnl_fixed;
38const offsetc_t column = dnnl_column;
39const offsetc_t row = dnnl_row;
40} // namespace offsetc
41
42enum sum_ab_t { dnnl_sum_a_row, dnnl_sum_b_col, dnnl_sum_none };
43namespace sum_ab {
44const sum_ab_t sum_a_row = dnnl_sum_a_row;
45const sum_ab_t sum_b_col = dnnl_sum_b_col;
46const sum_ab_t sum_none = dnnl_sum_none;
47} // namespace sum_ab
48
49// A descriptor for a matrix multiplication (gemm) operation. To make the
50// interface consistent, the descriptor represent the GEMM operation in row
51// major.
52struct gemm_desc_t {
53 // The kind of primitive. Used for self identifying the primitive
54 // descriptor. Must be #dnnl_gemm.
55 dnnl_primitive_kind_t primitive_kind;
56 memory_desc_t a_desc;
57 memory_desc_t b_desc;
58 memory_desc_t c_desc;
59 memory_desc_t bias_desc;
60 // Type for accumulating A*B.
61 dnnl_data_type_t acc_type;
62 // Sum across k dimension in either A or B tensor
63 // and output to sum_ab tensor.
64 sum_ab_t sum_ab;
65 dnnl_data_type_t sum_ab_type;
66
67 // These accessors are to be used by the GEMM implementation. Because the
68 // GEMM implementation currently assumes column major. These accessors
69 // return data in column major fashion.
70
71 inline bool is_batched() const { return c_desc.ndims >= 3; }
72
73 // Simplified accessors that comply to GEMM API
74 transpose_t get_trans(memory_desc_t md) const {
75 return md.format_desc.blocking.strides[md.ndims - 1] != 1
76 ? transpose::trans
77 : transpose::notrans;
78 }
79 transpose_t transa() const { return get_trans(b_desc); };
80 transpose_t transb() const { return get_trans(a_desc); };
81 transpose_t trans_bias() const { return get_trans(bias_desc); }
82
83 dnnl_dim_t batch() const {
84 // if ndims < 3, it should return 1
85 int64_t batch = 1;
86 for (int i = 0; i < c_desc.ndims - 2; ++i) {
87 if (c_desc.dims[i] == DNNL_RUNTIME_DIM_VAL)
88 return DNNL_RUNTIME_DIM_VAL;
89 batch *= c_desc.dims[i];
90 }
91 return batch;
92 }
93
94 // Number of rows of C.
95 dnnl_dim_t m() const { return c_desc.dims[c_desc.ndims - 1]; }
96 // Number of columns of C.
97 dnnl_dim_t n() const { return c_desc.dims[c_desc.ndims - 2]; }
98 // Size of inner dimension shared between A and B.
99 dnnl_dim_t k() const { return a_desc.dims[a_desc.ndims - 1]; }
100
101 static dnnl_dim_t get_stride(const memory_desc_t &md, int dim = 0) {
102 return (dim >= md.ndims - 2 || md.dims[dim] == 1)
103 ? 0
104 : md.format_desc.blocking.strides[dim];
105 }
106
107 /** Stride between 2 matrices A in a batch. */
108 dnnl_dim_t stride_a(int dim = 0) const { return get_stride(b_desc, dim); };
109 /** Stride between 2 matrices B in a batch. */
110 dnnl_dim_t stride_b(int dim = 0) const { return get_stride(a_desc, dim); };
111 /** Stride between 2 matrices C in a batch. */
112 dnnl_dim_t stride_c(int dim = 0) const { return get_stride(c_desc, dim); };
113
114 // This assumes that one of the dimensions has strides 1
115 static dnnl_dim_t get_ld(const memory_desc_t &md) {
116 auto strides = md.format_desc.blocking.strides;
117 assert(strides[md.ndims - 1] == 1 || strides[md.ndims - 2] == 1);
118 return strides[md.ndims - 1] != 1 ? strides[md.ndims - 1]
119 : strides[md.ndims - 2];
120 }
121 // Leading dimension of A.
122 dnnl_dim_t lda() const { return get_ld(b_desc); }
123 // Leading dimension of B.
124 dnnl_dim_t ldb() const { return get_ld(a_desc); }
125 // Leading dimension of C.
126 dnnl_dim_t ldc() const { return get_ld(c_desc); }
127 /** Leading dimension of bias. */
128 dnnl_dim_t ld_bias() const { return get_ld(bias_desc); }
129
130 // Type of matrix A.
131 dnnl_data_type_t a_type() const { return b_desc.data_type; }
132 // Type of matrix B.
133 dnnl_data_type_t b_type() const { return a_desc.data_type; }
134 // Type of matrix C.
135 dnnl_data_type_t c_type() const { return c_desc.data_type; }
136 // Type of bias.
137 dnnl_data_type_t bias_type() const { return bias_desc.data_type; }
138 // Type of bias.
139 int bias_mask() const {
140 assert(bias_desc.ndims <= 3);
141 int mask = 0;
142 // TODO: update the mask for batched dimension if we start
143 // supporting more batch dimensions
144 if (is_batched()) mask |= (bias_desc.dims[0] > 1) ? 1 << 0 : 0;
145
146 // because the bias mask is in row major, we have to convert
147 // to col major here by swapping two last dimensions
148 int m_idx = is_batched();
149 mask |= (bias_desc.dims[m_idx] > 1) ? 1 << (bias_desc.ndims - m_idx)
150 : 0;
151 mask |= (bias_desc.dims[m_idx + 1] > 1)
152 ? 1 << (bias_desc.ndims - (m_idx + 1))
153 : 0;
154 return mask;
155 }
156};
157
158} // namespace impl
159} // namespace dnnl
160
161#endif // COMMON_GEMM_TYPES_HPP
162