1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_TAG_TRAITS_HPP |
18 | #define COMMON_TAG_TRAITS_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "utils.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | |
28 | enum class block_dim_t { |
29 | _, |
30 | _A, |
31 | _B, |
32 | _C, |
33 | _D, |
34 | _E, |
35 | _AB, |
36 | _BC, |
37 | _CD, |
38 | _CE, |
39 | }; |
40 | |
41 | enum class inner_blk_t { |
42 | _, |
43 | _4a, |
44 | _4b, |
45 | _4c, |
46 | _8a, |
47 | _8b, |
48 | _16a, |
49 | _16b, |
50 | _16c, |
51 | _32a, |
52 | _32b, |
53 | _32c, |
54 | _32d, |
55 | _32e, |
56 | _48b, |
57 | _48c, |
58 | _64b, |
59 | _64c, |
60 | |
61 | _4a4b, |
62 | _4b4a, |
63 | _4b4c, |
64 | _4c4b, |
65 | _8a8b, |
66 | _8b8a, |
67 | _8b8c, |
68 | _8c8b, |
69 | _16a16b, |
70 | _16a32b, |
71 | _16a48b, |
72 | _16a64b, |
73 | _16b64a, |
74 | _16b32a, |
75 | _16b16a, |
76 | _16b16c, |
77 | _16c16b, |
78 | _32a32b, |
79 | _16a2b, |
80 | _16a4b, |
81 | _16b2a, |
82 | _16b4a, |
83 | _16b2c, |
84 | _16b4c, |
85 | _16c2b, |
86 | _16c4b, |
87 | _32d4c, |
88 | _32e2c, |
89 | _32e4c, |
90 | _32b2a, |
91 | _32b4a, |
92 | _32c2b, |
93 | _32c4b, |
94 | _64e2c, |
95 | _64e4c, |
96 | _32c2e, |
97 | _48c2b, |
98 | _48c4b, |
99 | _48b2a, |
100 | _48b4a, |
101 | _64b2a, |
102 | _64b4a, |
103 | _64c2b, |
104 | _64c4b, |
105 | |
106 | _2c8b4c, |
107 | _8a16b2a, |
108 | _4b16a4b, |
109 | _4b32a4b, |
110 | _4b64a4b, |
111 | _2b8a4b, |
112 | _8b16a2b, |
113 | _8b32a2b, |
114 | _8b64a2b, |
115 | _8b16c2b, |
116 | _4c16b4c, |
117 | _8c16b2c, |
118 | _2b4c2b, |
119 | _2c4b2c, |
120 | _4b8c2b, |
121 | _4c8b2c, |
122 | |
123 | _16a16b2a, |
124 | _16a32b2a, |
125 | _16a48b2a, |
126 | _16a64b2a, |
127 | _16b16a2b, |
128 | _16b32a2b, |
129 | _16b48a2b, |
130 | _16b64a2b, |
131 | _16a16b4a, |
132 | _16a32b4a, |
133 | _16a48b4a, |
134 | _16a64b4a, |
135 | _16b16a4b, |
136 | _16b32a4b, |
137 | _16b48a4b, |
138 | _16b64a4b, |
139 | _16b16c2b, |
140 | _16c16b2c, |
141 | _16c16b4c, |
142 | _2a8b8a2b, |
143 | _2b8c8b2c, |
144 | _4a8b8a4b, |
145 | _4b8c8b4c, |
146 | _16c32b2c, |
147 | _16c48b2c, |
148 | _16c64b2c, |
149 | _16c32b4c, |
150 | _16c48b4c, |
151 | _16c64b4c, |
152 | _16b32c, |
153 | _16b48c, |
154 | _16b64c, |
155 | _16b32c2b, |
156 | _16b48c2b, |
157 | _16b64c2b, |
158 | _16b16c4b, |
159 | _16b32c4b, |
160 | _16b48c4b, |
161 | _16b64c4b, |
162 | }; |
163 | |
164 | /** returns the offset within the block for weights blocked over oc and ic */ |
165 | template <inner_blk_t f> |
166 | constexpr int AB_or_BC_blk_off(int x0, int x1) { |
167 | using ib = inner_blk_t; |
168 | static_assert( |
169 | utils::one_of(f, ib::_4a4b, ib::_4b4a, ib::_4b4c, ib::_4c4b, |
170 | ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, |
171 | ib::_16b64a, ib::_16b32a, ib::_16b16a, ib::_16b16c, |
172 | ib::_16c16b, ib::_32a32b, ib::_16a2b, ib::_16a4b, |
173 | ib::_16b2c, ib::_16b4c, ib::_2c8b4c, ib::_8a16b2a, |
174 | ib::_4b64a4b, ib::_4b32a4b, ib::_4b16a4b, ib::_2b8a4b, |
175 | ib::_8b64a2b, ib::_8b32a2b, ib::_8b16a2b, ib::_8b16c2b, |
176 | ib::_4c16b4c, ib::_8c16b2c, ib::_2b4c2b, ib::_2c4b2c, |
177 | ib::_4b8c2b, ib::_4c8b2c, ib::_16a32b, ib::_16a48b, |
178 | ib::_16a64b, ib::_16a16b2a, ib::_16a32b2a, ib::_16a48b2a, |
179 | ib::_16a64b2a, ib::_16a16b4a, ib::_16a32b4a, ib::_16a48b4a, |
180 | ib::_16a64b4a, ib::_16b16a2b, ib::_16b16a4b, ib::_16b16c2b, |
181 | ib::_16c16b2c, ib::_16c16b4c, ib::_2a8b8a2b, ib::_2b8c8b2c, |
182 | ib::_4a8b8a4b, ib::_4b8c8b4c, ib::_16b32a2b, ib::_16b48a2b, |
183 | ib::_16b64a2b, ib::_16b32a4b, ib::_16b48a4b, ib::_16b64a4b, |
184 | ib::_16c32b2c, ib::_16c48b2c, ib::_16c64b2c, ib::_16c32b4c, |
185 | ib::_16c48b4c, ib::_16c64b4c, ib::_16b32c, ib::_16b48c, |
186 | ib::_16b64c, ib::_16b32c2b, ib::_16b48c2b, ib::_16b64c2b, |
187 | ib::_16b16c4b, ib::_16b32c4b, ib::_16b48c4b, ib::_16b64c4b), |
188 | "unexpected inner_blk format" ); |
189 | |
190 | // clang-format off |
191 | return false ? 0 |
192 | : (f == ib::_4a4b || f == ib::_4b4c) ? 4 * x0 + x1 |
193 | : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0 |
194 | : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1 |
195 | : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0 |
196 | : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1 |
197 | : (f == ib::_16b64a) ? 64 * x1 + x0 |
198 | : (f == ib::_16b32a) ? 32 * x1 + x0 |
199 | : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0 |
200 | : (f == ib::_16a2b || f == ib::_16b2c) ? 2 * x0 + x1 |
201 | : (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1 |
202 | : (utils::one_of(f, ib::_32a32b, ib::_16a32b, ib::_16b32c)) ? 32 * x0 + x1 |
203 | : (utils::one_of(f, ib::_8a16b2a, ib::_8b16c2b, ib::_16a16b2a, ib::_16b16c2b)) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2 |
204 | : (utils::one_of(f, ib::_16a48b, ib::_16b48c)) ? x0 * 48 + x1 |
205 | : (utils::one_of(f, ib::_16a64b, ib::_16b64c)) ? x0 * 64 + x1 |
206 | : (utils::one_of(f, ib::_16a32b2a, ib::_16b32c2b)) ? (x0 / 2) * 64 + x1 * 2 + x0 % 2 |
207 | : (utils::one_of(f, ib::_16a48b2a, ib::_16b48c2b)) ? (x0 / 2) * 96 + x1 * 2 + x0 % 2 |
208 | : (utils::one_of(f, ib::_16a64b2a, ib::_16b64c2b)) ? (x0 / 2) * 128 + x1 * 2 + x0 % 2 |
209 | : (utils::one_of(f, ib::_16a16b4a, ib::_16b16c4b)) ? (x0 / 4) * 64 + x1 * 4 + x0 % 4 |
210 | : (utils::one_of(f, ib::_16a32b4a, ib::_16b32c4b)) ? (x0 / 4) * 128 + x1 * 4 + x0 % 4 |
211 | : (utils::one_of(f, ib::_16a48b4a, ib::_16b48c4b)) ? (x0 / 4) * 192 + x1 * 4 + x0 % 4 |
212 | : (utils::one_of(f, ib::_16a64b4a, ib::_16b64c4b)) ? (x0 / 4) * 256 + x1 * 4 + x0 % 4 |
213 | : (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4 |
214 | : (f == ib::_4b32a4b) ? (x1 / 4) * 128 + x0 * 4 + x1 % 4 |
215 | : (f == ib::_4b64a4b) ? (x1 / 4) * 256 + x0 * 4 + x1 % 4 |
216 | : (f == ib::_2b8a4b || f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4 |
217 | : (f == ib::_16b16a2b || f == ib::_16c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2 |
218 | : (f == ib::_16b16a4b || f == ib::_16c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4 |
219 | : (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2 |
220 | : (f == ib::_8b32a2b) ? (x1 / 2) * 64 + x0 * 2 + x1 % 2 |
221 | : (f == ib::_8b64a2b) ? (x1 / 2) * 128 + x0 * 2 + x1 % 2 |
222 | : (f == ib::_2b4c2b || f == ib::_2c4b2c) ? (x0 / 2) * 8 + x1 * 2 + x0 % 2 |
223 | : (f == ib::_4b8c2b || f == ib::_4c8b2c) ? (x0 / 2) * 16 + x1 * 2 + x0 % 2 |
224 | : (f == ib::_2a8b8a2b || f == ib::_2b8c8b2c) ? (x0 / 8) * 128 + (x1 / 2) * 16 + (x0 % 8) * 2 + x1 % 2 |
225 | : (f == ib::_4a8b8a4b || f == ib::_4b8c8b4c) ? (x0 / 8) * 256 + (x1 / 4) * 32 + (x0 % 8) * 4 + x1 % 4 |
226 | : (f == ib::_16b32a2b || f == ib::_16c32b2c) ? (x1 / 2) * 64 + x0 * 2 + x1 % 2 |
227 | : (f == ib::_16b48a2b || f == ib::_16c48b2c) ? (x1 / 2) * 96 + x0 * 2 + x1 % 2 |
228 | : (f == ib::_16b64a2b || f == ib::_16c64b2c) ? (x1 / 2) * 128 + x0 * 2 + x1 % 2 |
229 | : (f == ib::_16b32a4b || f == ib::_16c32b4c) ? (x1 / 4) * 128 + x0 * 4 + x1 % 4 |
230 | : (f == ib::_16b48a4b || f == ib::_16c48b4c) ? (x1 / 4) * 192 + x0 * 4 + x1 % 4 |
231 | : (f == ib::_16b64a4b || f == ib::_16c64b4c) ? (x1 / 4) * 256 + x0 * 4 + x1 % 4 |
232 | : INT_MIN; |
233 | // clang-format on |
234 | } |
235 | |
236 | template <inner_blk_t b> |
237 | struct inner_blk_traits { |
238 | using ib = inner_blk_t; |
239 | }; |
240 | |
241 | template <format_tag_t> |
242 | struct tag_traits { |
243 | // block_dim_t block_dims; |
244 | // inner_blk_t inner_blks; |
245 | // int ndims; |
246 | }; |
247 | |
248 | #define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \ |
249 | template <> \ |
250 | struct tag_traits<format_tag::_tag> { \ |
251 | static constexpr block_dim_t |
---|