1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_GEMM_TYPES_HPP |
18 | #define COMMON_GEMM_TYPES_HPP |
19 | |
20 | #include <assert.h> |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/memory_desc.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | |
27 | enum transpose_t { dnnl_notrans, dnnl_trans }; |
28 | |
29 | namespace transpose { |
30 | const transpose_t notrans = dnnl_notrans; |
31 | const transpose_t trans = dnnl_trans; |
32 | } // namespace transpose |
33 | |
34 | enum offsetc_t { dnnl_fixed, dnnl_column, dnnl_row }; |
35 | |
36 | namespace offsetc { |
37 | const offsetc_t fixed = dnnl_fixed; |
38 | const offsetc_t column = dnnl_column; |
39 | const offsetc_t row = dnnl_row; |
40 | } // namespace offsetc |
41 | |
42 | enum sum_ab_t { dnnl_sum_a_row, dnnl_sum_b_col, dnnl_sum_none }; |
43 | namespace sum_ab { |
44 | const sum_ab_t sum_a_row = dnnl_sum_a_row; |
45 | const sum_ab_t sum_b_col = dnnl_sum_b_col; |
46 | const sum_ab_t sum_none = dnnl_sum_none; |
47 | } // namespace sum_ab |
48 | |
49 | // A descriptor for a matrix multiplication (gemm) operation. To make the |
50 | // interface consistent, the descriptor represent the GEMM operation in row |
51 | // major. |
52 | struct gemm_desc_t { |
53 | // The kind of primitive. Used for self identifying the primitive |
54 | // descriptor. Must be #dnnl_gemm. |
55 | dnnl_primitive_kind_t primitive_kind; |
56 | memory_desc_t a_desc; |
57 | memory_desc_t b_desc; |
58 | memory_desc_t c_desc; |
59 | memory_desc_t bias_desc; |
60 | // Type for accumulating A*B. |
61 | dnnl_data_type_t acc_type; |
62 | // Sum across k dimension in either A or B tensor |
63 | // and output to sum_ab tensor. |
64 | sum_ab_t sum_ab; |
65 | dnnl_data_type_t sum_ab_type; |
66 | |
67 | // These accessors are to be used by the GEMM implementation. Because the |
68 | // GEMM implementation currently assumes column major. These accessors |
69 | // return data in column major fashion. |
70 | |
71 | inline bool is_batched() const { return c_desc.ndims >= 3; } |
72 | |
73 | // Simplified accessors that comply to GEMM API |
74 | transpose_t get_trans(memory_desc_t md) const { |
75 | return md.format_desc.blocking.strides[md.ndims - 1] != 1 |
76 | ? transpose::trans |
77 | : transpose::notrans; |
78 | } |
79 | transpose_t transa() const { return get_trans(b_desc); }; |
80 | transpose_t transb() const { return get_trans(a_desc); }; |
81 | transpose_t trans_bias() const { return get_trans(bias_desc); } |
82 | |
83 | dnnl_dim_t batch() const { |
84 | // if ndims < 3, it should return 1 |
85 | int64_t batch = 1; |
86 | for (int i = 0; i < c_desc.ndims - 2; ++i) { |
87 | if (c_desc.dims[i] == DNNL_RUNTIME_DIM_VAL) |
88 | return DNNL_RUNTIME_DIM_VAL; |
89 | batch *= c_desc.dims[i]; |
90 | } |
91 | return batch; |
92 | } |
93 | |
94 | // Number of rows of C. |
95 | dnnl_dim_t m() const { return c_desc.dims[c_desc.ndims - 1]; } |
96 | // Number of columns of C. |
97 | dnnl_dim_t n() const { return c_desc.dims[c_desc.ndims - 2]; } |
98 | // Size of inner dimension shared between A and B. |
99 | dnnl_dim_t k() const { return a_desc.dims[a_desc.ndims - 1]; } |
100 | |
101 | static dnnl_dim_t get_stride(const memory_desc_t &md, int dim = 0) { |
102 | return (dim >= md.ndims - 2 || md.dims[dim] == 1) |
103 | ? 0 |
104 | : md.format_desc.blocking.strides[dim]; |
105 | } |
106 | |
107 | /** Stride between 2 matrices A in a batch. */ |
108 | dnnl_dim_t stride_a(int dim = 0) const { return get_stride(b_desc, dim); }; |
109 | /** Stride between 2 matrices B in a batch. */ |
110 | dnnl_dim_t stride_b(int dim = 0) const { return get_stride(a_desc, dim); }; |
111 | /** Stride between 2 matrices C in a batch. */ |
112 | dnnl_dim_t stride_c(int dim = 0) const { return get_stride(c_desc, dim); }; |
113 | |
114 | // This assumes that one of the dimensions has strides 1 |
115 | static dnnl_dim_t get_ld(const memory_desc_t &md) { |
116 | auto strides = md.format_desc.blocking.strides; |
117 | assert(strides[md.ndims - 1] == 1 || strides[md.ndims - 2] == 1); |
118 | return strides[md.ndims - 1] != 1 ? strides[md.ndims - 1] |
119 | : strides[md.ndims - 2]; |
120 | } |
121 | // Leading dimension of A. |
122 | dnnl_dim_t lda() const { return get_ld(b_desc); } |
123 | // Leading dimension of B. |
124 | dnnl_dim_t ldb() const { return get_ld(a_desc); } |
125 | // Leading dimension of C. |
126 | dnnl_dim_t ldc() const { return get_ld(c_desc); } |
127 | /** Leading dimension of bias. */ |
128 | dnnl_dim_t ld_bias() const { return get_ld(bias_desc); } |
129 | |
130 | // Type of matrix A. |
131 | dnnl_data_type_t a_type() const { return b_desc.data_type; } |
132 | // Type of matrix B. |
133 | dnnl_data_type_t b_type() const { return a_desc.data_type; } |
134 | // Type of matrix C. |
135 | dnnl_data_type_t c_type() const { return c_desc.data_type; } |
136 | // Type of bias. |
137 | dnnl_data_type_t bias_type() const { return bias_desc.data_type; } |
138 | // Type of bias. |
139 | int bias_mask() const { |
140 | assert(bias_desc.ndims <= 3); |
141 | int mask = 0; |
142 | // TODO: update the mask for batched dimension if we start |
143 | // supporting more batch dimensions |
144 | if (is_batched()) mask |= (bias_desc.dims[0] > 1) ? 1 << 0 : 0; |
145 | |
146 | // because the bias mask is in row major, we have to convert |
147 | // to col major here by swapping two last dimensions |
148 | int m_idx = is_batched(); |
149 | mask |= (bias_desc.dims[m_idx] > 1) ? 1 << (bias_desc.ndims - m_idx) |
150 | : 0; |
151 | mask |= (bias_desc.dims[m_idx + 1] > 1) |
152 | ? 1 << (bias_desc.ndims - (m_idx + 1)) |
153 | : 0; |
154 | return mask; |
155 | } |
156 | }; |
157 | |
158 | } // namespace impl |
159 | } // namespace dnnl |
160 | |
161 | #endif // COMMON_GEMM_TYPES_HPP |
162 | |