1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | #include "src/cpu/platform.hpp" |
23 | |
24 | namespace dnnl { |
25 | |
26 | // short names for brevity |
27 | using data_type = memory::data_type; |
28 | using tag = memory::format_tag; |
29 | |
30 | class attr_quantization_test_t : public ::testing::Test { |
31 | protected: |
32 | engine eng = get_test_engine(); |
33 | void SetUp() override {} |
34 | |
35 | static primitive_attr gen_attr_with_scales() { |
36 | primitive_attr attr; |
37 | attr.set_scales_mask(DNNL_ARG_SRC, 0); |
38 | attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); |
39 | attr.set_scales_mask(DNNL_ARG_DST, 0); |
40 | return attr; |
41 | } |
42 | |
43 | static primitive_attr gen_attr_with_scales(int arg, int mask = 0) { |
44 | primitive_attr attr; |
45 | attr.set_scales_mask(arg, mask); |
46 | return attr; |
47 | } |
48 | |
49 | static primitive_attr gen_attr_with_zp(int arg, int mask = 0) { |
50 | primitive_attr attr; |
51 | attr.set_zero_points_mask(arg, mask); |
52 | return attr; |
53 | } |
54 | |
55 | template <typename F> |
56 | static void check_status(const F &f, dnnl_status_t status) { |
57 | catch_expected_failures(f, status != dnnl_success, status, false); |
58 | } |
59 | }; |
60 | #define CHECK_STATUs(status, ...) check_status([&]() { __VA_ARGS__; }, status) |
61 | #define CHECK_STATUS(status, ...) CHECK_STATUs(status, __VA_ARGS__) |
62 | |
63 | #define CHECK_OK(...) CHECK_STATUS(dnnl_success, __VA_ARGS__) |
64 | #define CHECK_INVALID(...) CHECK_STATUS(dnnl_invalid_arguments, __VA_ARGS__) |
65 | #define CHECK_UNIMPL(...) CHECK_STATUS(dnnl_unimplemented, __VA_ARGS__) |
66 | |
67 | // TODO: replace primitive descriptor creation with iterator fetching |
68 | // to test all possible implementations |
69 | |
70 | TEST_F(attr_quantization_test_t, TestBNorm) { |
71 | for (auto dt : {data_type::f32, data_type::s8}) { |
72 | // no s8 -> s8 batch norm on GPU yet |
73 | if (get_test_engine_kind() == engine::kind::gpu && dt == data_type::s8) |
74 | continue; |
75 | |
76 | memory::desc md {{1, 16, 3, 3}, dt, tag::abcd}; |
77 | normalization_flags flags = normalization_flags::use_global_stats; |
78 | CHECK_OK(batch_normalization_forward::primitive_desc( |
79 | eng, prop_kind::forward_inference, md, md, 0.1f, flags)); |
80 | CHECK_UNIMPL(batch_normalization_forward::primitive_desc(eng, |
81 | prop_kind::forward_inference, md, md, 0.1f, flags, |
82 | gen_attr_with_scales())); |
83 | |
84 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_BIAS, |
85 | DNNL_ARG_MEAN, DNNL_ARG_VARIANCE, DNNL_ARG_DST}) { |
86 | CHECK_UNIMPL(batch_normalization_forward::primitive_desc(eng, |
87 | prop_kind::forward_inference, md, md, 0.1f, flags, |
88 | gen_attr_with_zp(arg))); |
89 | } |
90 | } |
91 | } |
92 | |
93 | TEST_F(attr_quantization_test_t, TestBinary) { |
94 | memory::desc md {{1, 16, 3, 3}, data_type::f32, tag::abcd}; |
95 | CHECK_OK(binary::primitive_desc(eng, algorithm::binary_add, md, md, md)); |
96 | |
97 | for (auto arg : {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1, DNNL_ARG_DST}) { |
98 | CHECK_OK(binary::primitive_desc(eng, algorithm::binary_add, md, md, md, |
99 | gen_attr_with_scales(arg))); |
100 | CHECK_UNIMPL(binary::primitive_desc( |
101 | eng, algorithm::binary_add, md, md, md, gen_attr_with_zp(arg))); |
102 | } |
103 | } |
104 | |
105 | TEST_F(attr_quantization_test_t, TestConcat) { |
106 | memory::desc md {{1, 16, 3, 3}, data_type::s8, tag::abcd}; |
107 | CHECK_OK(concat::primitive_desc(eng, 1, {md, md})); |
108 | |
109 | for (auto arg : |
110 | {DNNL_ARG_MULTIPLE_SRC, DNNL_ARG_MULTIPLE_SRC + 1, DNNL_ARG_DST}) { |
111 | CHECK_OK(concat::primitive_desc( |
112 | eng, 1, {md, md}, gen_attr_with_scales(arg))); |
113 | CHECK_UNIMPL(concat::primitive_desc( |
114 | eng, 1, {md, md}, gen_attr_with_zp(arg))); |
115 | } |
116 | } |
117 | |
118 | TEST_F(attr_quantization_test_t, TestConv) { |
119 | // Datatype u8 is not supported in the Nvidia backend |
120 | SKIP_IF_CUDA(true, "Unsupported datatype for CUDA" ); |
121 | memory::desc src_md {{1, 16, 7, 7}, data_type::u8, tag::any}; |
122 | memory::desc wei_md {{32, 16, 3, 3}, data_type::s8, tag::any}; |
123 | memory::desc dst_md {{1, 32, 7, 7}, data_type::s32, tag::any}; |
124 | |
125 | CHECK_OK(convolution_forward::primitive_desc(eng, prop_kind::forward, |
126 | algorithm::convolution_direct, src_md, wei_md, dst_md, {1, 1}, |
127 | {1, 1}, {1, 1})); |
128 | CHECK_OK(convolution_forward::primitive_desc(eng, prop_kind::forward, |
129 | algorithm::convolution_direct, src_md, wei_md, dst_md, {1, 1}, |
130 | {1, 1}, {1, 1}, gen_attr_with_scales())); |
131 | |
132 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { |
133 | if ((src_md.get_data_type() == data_type::s8 |
134 | || src_md.get_data_type() == data_type::u8) |
135 | && (arg == DNNL_ARG_SRC || arg == DNNL_ARG_DST)) { |
136 | CHECK_OK(convolution_forward::primitive_desc(eng, |
137 | prop_kind::forward, algorithm::convolution_direct, src_md, |
138 | wei_md, dst_md, {1, 1}, {1, 1}, {1, 1}, |
139 | gen_attr_with_zp(arg))); |
140 | } else { |
141 | CHECK_UNIMPL(convolution_forward::primitive_desc(eng, |
142 | prop_kind::forward, algorithm::convolution_direct, src_md, |
143 | wei_md, dst_md, {1, 1}, {1, 1}, {1, 1}, |
144 | gen_attr_with_zp(arg))); |
145 | } |
146 | } |
147 | } |
148 | |
149 | TEST_F(attr_quantization_test_t, TestDeconv) { |
150 | memory::desc src_md {{1, 16, 7, 7}, data_type::f32, tag::any}; |
151 | memory::desc wei_md {{32, 16, 3, 3}, data_type::f32, tag::any}; |
152 | memory::desc dst_md {{1, 32, 7, 7}, data_type::f32, tag::any}; |
153 | CHECK_OK(deconvolution_forward::primitive_desc(eng, prop_kind::forward, |
154 | algorithm::deconvolution_direct, src_md, wei_md, dst_md, {1, 1}, |
155 | {1, 1}, {1, 1}, gen_attr_with_scales())); |
156 | |
157 | for (auto arg : |
158 | {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_BIAS, DNNL_ARG_DST}) { |
159 | CHECK_UNIMPL(deconvolution_forward::primitive_desc(eng, |
160 | prop_kind::forward, algorithm::deconvolution_direct, src_md, |
161 | wei_md, dst_md, {1, 1}, {1, 1}, {1, 1}, gen_attr_with_zp(arg))); |
162 | } |
163 | } |
164 | |
165 | TEST_F(attr_quantization_test_t, TestEltwise) { |
166 | for (auto dt : {data_type::f32, data_type::s8}) { |
167 | memory::desc md {{1, 16, 3, 3}, dt, tag::abcd}; |
168 | |
169 | CHECK_OK(eltwise_forward::primitive_desc( |
170 | eng, prop_kind::forward, algorithm::eltwise_relu, md, md, 0.f)); |
171 | |
172 | CHECK_UNIMPL(eltwise_forward::primitive_desc(eng, prop_kind::forward, |
173 | algorithm::eltwise_relu, md, md, 0.f, gen_attr_with_scales())); |
174 | |
175 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { |
176 | CHECK_UNIMPL(eltwise_forward::primitive_desc(eng, |
177 | prop_kind::forward, algorithm::eltwise_relu, md, md, 0.f, |
178 | gen_attr_with_zp(arg))); |
179 | } |
180 | } |
181 | } |
182 | |
183 | TEST_F(attr_quantization_test_t, TestInnerProduct) { |
184 | // Datatype u8 is not supported in the Nvidia backend |
185 | SKIP_IF_CUDA(true, "Unsupported datatype for CUDA" ); |
186 | memory::desc src_md {{1, 16, 7, 7}, data_type::u8, tag::any}; |
187 | memory::desc wei_md {{32, 16, 7, 7}, data_type::s8, tag::any}; |
188 | memory::desc dst_md {{1, 32}, data_type::s32, tag::any}; |
189 | CHECK_OK(inner_product_forward::primitive_desc( |
190 | eng, prop_kind::forward, src_md, wei_md, dst_md)); |
191 | CHECK_OK(inner_product_forward::primitive_desc(eng, prop_kind::forward, |
192 | src_md, wei_md, dst_md, gen_attr_with_scales())); |
193 | |
194 | for (auto arg : |
195 | {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_BIAS, DNNL_ARG_DST}) { |
196 | CHECK_UNIMPL( |
197 | inner_product_forward::primitive_desc(eng, prop_kind::forward, |
198 | src_md, wei_md, dst_md, gen_attr_with_zp(arg))); |
199 | } |
200 | } |
201 | |
202 | TEST_F(attr_quantization_test_t, TestLNorm) { |
203 | SKIP_IF_CUDA(true, "Layer normalization primitive not supported for CUDA" ); |
204 | |
205 | memory::desc md {{1, 16, 16}, data_type::s8, tag::abc}; |
206 | memory::desc stat_md {{1, 16}, data_type::f32, tag::ab}; |
207 | normalization_flags flags = normalization_flags::use_global_stats; |
208 | |
209 | if (get_test_engine_kind() == engine::kind::gpu) { |
210 | CHECK_UNIMPL(layer_normalization_forward::primitive_desc(eng, |
211 | prop_kind::forward_inference, md, md, stat_md, 0.1f, flags)); |
212 | CHECK_UNIMPL(layer_normalization_forward::primitive_desc(eng, |
213 | prop_kind::forward_inference, md, md, stat_md, 0.1f, flags, |
214 | gen_attr_with_scales())); |
215 | } else { |
216 | CHECK_OK(layer_normalization_forward::primitive_desc(eng, |
217 | prop_kind::forward_inference, md, md, stat_md, 0.1f, flags)); |
218 | CHECK_OK(layer_normalization_forward::primitive_desc(eng, |
219 | prop_kind::forward_inference, md, md, stat_md, 0.1f, flags, |
220 | gen_attr_with_scales())); |
221 | } |
222 | |
223 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_MEAN, DNNL_ARG_VARIANCE, |
224 | DNNL_ARG_WEIGHTS, DNNL_ARG_BIAS, DNNL_ARG_DST}) { |
225 | CHECK_UNIMPL(layer_normalization_forward::primitive_desc(eng, |
226 | prop_kind::forward_inference, md, md, stat_md, 0.1f, flags, |
227 | gen_attr_with_zp(arg))); |
228 | } |
229 | } |
230 | |
231 | TEST_F(attr_quantization_test_t, TestLRN) { |
232 | for (auto dt : {data_type::f32}) { |
233 | memory::desc md {{1, 16, 3, 3}, dt, tag::abcd}; |
234 | CHECK_OK(lrn_forward::primitive_desc(eng, prop_kind::forward_inference, |
235 | algorithm::lrn_across_channels, md, md, 5, 1.f, 0.75f, 1.0f)); |
236 | CHECK_UNIMPL(lrn_forward::primitive_desc(eng, |
237 | prop_kind::forward_inference, algorithm::lrn_across_channels, |
238 | md, md, 5, 1.f, 0.75f, 1.0f, gen_attr_with_scales())); |
239 | |
240 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { |
241 | CHECK_UNIMPL(lrn_forward::primitive_desc(eng, |
242 | prop_kind::forward_inference, |
243 | algorithm::lrn_across_channels, md, md, 5, 1.f, 0.75f, 1.0f, |
244 | gen_attr_with_zp(arg))); |
245 | } |
246 | } |
247 | } |
248 | |
249 | CPU_TEST_F(attr_quantization_test_t, TestMatmul) { |
250 | for (auto a_dt : {data_type::f32, data_type::u8}) { |
251 | const data_type b_dt |
252 | = a_dt == data_type::f32 ? data_type::f32 : data_type::s8; |
253 | |
254 | memory::desc a_md {{10, 3}, a_dt, tag::ab}; |
255 | memory::desc b_md {{3, 20}, b_dt, tag::ba}; |
256 | memory::desc c_md {{10, 20}, data_type::f32, tag::ab}; |
257 | |
258 | CHECK_OK(matmul::primitive_desc(eng, a_md, b_md, c_md)); |
259 | CHECK_OK(matmul::primitive_desc( |
260 | eng, a_md, b_md, c_md, gen_attr_with_scales())); |
261 | |
262 | for (auto arg : |
263 | {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_BIAS, DNNL_ARG_DST}) { |
264 | if ((a_dt != data_type::u8 && a_dt != data_type::s8) |
265 | || arg == DNNL_ARG_BIAS) { |
266 | CHECK_UNIMPL(matmul::primitive_desc( |
267 | eng, a_md, b_md, c_md, gen_attr_with_zp(arg))); |
268 | } else { |
269 | CHECK_OK(matmul::primitive_desc( |
270 | eng, a_md, b_md, c_md, gen_attr_with_zp(arg))); |
271 | } |
272 | } |
273 | } |
274 | } |
275 | |
276 | TEST_F(attr_quantization_test_t, TestPool) { |
277 | memory::desc src_md {{1, 16, 8, 8}, data_type::s8, tag::abcd}; |
278 | memory::desc dst_md {{1, 16, 4, 4}, data_type::s8, tag::abcd}; |
279 | |
280 | CHECK_OK(pooling_forward::primitive_desc(eng, prop_kind::forward_inference, |
281 | algorithm::pooling_max, src_md, dst_md, {2, 2}, {2, 2}, {0, 0}, |
282 | {0, 0}, {0, 0})); |
283 | CHECK_UNIMPL( |
284 | pooling_forward::primitive_desc(eng, prop_kind::forward_inference, |
285 | algorithm::pooling_max, src_md, dst_md, {2, 2}, {2, 2}, |
286 | {0, 0}, {0, 0}, {0, 0}, gen_attr_with_scales())); |
287 | |
288 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { |
289 | CHECK_UNIMPL(pooling_forward::primitive_desc(eng, |
290 | prop_kind::forward_inference, algorithm::pooling_max, src_md, |
291 | dst_md, {2, 2}, {2, 2}, {0, 0}, {0, 0}, {0, 0}, |
292 | gen_attr_with_zp(arg))); |
293 | } |
294 | } |
295 | |
296 | TEST_F(attr_quantization_test_t, TestPReLU) { |
297 | SKIP_IF_CUDA(true, "Unsupported primitive not supported for CUDA" ); |
298 | memory::desc data_md {{1, 16, 3, 3}, data_type::f32, tag::abcd}; |
299 | memory::desc weights_md {{1, 16, 3, 3}, data_type::f32, tag::abcd}; |
300 | |
301 | CHECK_OK(prelu_forward::primitive_desc( |
302 | eng, prop_kind::forward, data_md, weights_md, data_md)); |
303 | |
304 | CHECK_UNIMPL(prelu_forward::primitive_desc(eng, prop_kind::forward, data_md, |
305 | weights_md, data_md, gen_attr_with_scales())); |
306 | |
307 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { |
308 | CHECK_UNIMPL(prelu_forward::primitive_desc(eng, prop_kind::forward, |
309 | data_md, weights_md, data_md, gen_attr_with_zp(arg))); |
310 | } |
311 | } |
312 | |
313 | CPU_TEST_F(attr_quantization_test_t, TestReorder) { |
314 | memory::desc src_md {{1, 16, 8, 8}, data_type::s8, tag::abcd}; |
315 | memory::desc dst_md {{1, 16, 8, 8}, data_type::s8, tag::acdb}; |
316 | CHECK_OK(reorder::primitive_desc(eng, src_md, eng, dst_md)); |
317 | |
318 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { |
319 | CHECK_OK(reorder::primitive_desc( |
320 | eng, src_md, eng, dst_md, gen_attr_with_scales())); |
321 | CHECK_OK(reorder::primitive_desc( |
322 | eng, src_md, eng, dst_md, gen_attr_with_zp(arg))); |
323 | } |
324 | } |
325 | |
326 | TEST_F(attr_quantization_test_t, TestRNN) { |
327 | SKIP_IF_CUDA(true, "RNN primitive not supported for CUDA" ); |
328 | // Int8 RNN relies on packed API solely which is available only for X64. |
329 | #if !DNNL_X64 |
330 | return; |
331 | #endif |
332 | // XXX: Threadpool doesn't work correctly with packed API which is the only |
333 | // working mechanism for int8 computations. Disable it for now. |
334 | SKIP_IF(DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL, |
335 | "Threadpool does not have working packed API" ); |
336 | |
337 | memory::dim n = 1, t = 1, l = 10, c = 8, g = 4, d = 1; |
338 | memory::desc src_layer_md {{t, n, c}, data_type::u8, tag::tnc}; |
339 | memory::desc src_iter_md {{l, d, n, c}, data_type::u8, tag::ldnc}; |
340 | memory::desc src_iter_c_md {{l, d, n, c}, data_type::f32, tag::ldnc}; |
341 | memory::desc wei_layer_md {{l, d, c, g, c}, data_type::s8, tag::any}; |
342 | memory::desc wei_iter_md {{l, d, c, g, c}, data_type::s8, tag::any}; |
343 | memory::desc bia_md {{l, d, g, c}, data_type::f32, tag::ldgo}; |
344 | memory::desc dst_layer_md {{t, n, c}, data_type::u8, tag::tnc}; |
345 | memory::desc dst_iter_md {{l, d, n, c}, data_type::u8, tag::ldnc}; |
346 | memory::desc dst_iter_c_md {{l, d, n, c}, data_type::f32, tag::ldnc}; |
347 | |
348 | for_(auto is_runtime_data_scale : {true, false}) |
349 | for_(auto is_runtime_data_shift : {true, false}) |
350 | for_(auto is_runtime_weights_scale : {true, false}) |
351 | { |
352 | primitive_attr attr; |
353 | attr.set_rnn_data_qparams( |
354 | is_runtime_data_scale ? DNNL_RUNTIME_F32_VAL : 2.f, |
355 | is_runtime_data_shift ? DNNL_RUNTIME_F32_VAL : 2.f); |
356 | attr.set_rnn_weights_qparams( |
357 | 0, {is_runtime_weights_scale ? DNNL_RUNTIME_F32_VAL : 2.f}); |
358 | bool rt = is_runtime_data_scale || is_runtime_data_shift |
359 | || is_runtime_weights_scale; |
360 | CHECK_STATUS(rt ? dnnl_unimplemented : dnnl_success, |
361 | lstm_forward::primitive_desc(eng, prop_kind::forward_inference, |
362 | rnn_direction::unidirectional_left2right, src_layer_md, |
363 | src_iter_md, src_iter_c_md, wei_layer_md, wei_iter_md, |
364 | bia_md, dst_layer_md, dst_iter_md, dst_iter_c_md, |
365 | attr)); |
366 | } |
367 | |
368 | for (auto arg : {DNNL_ARG_SRC_LAYER, DNNL_ARG_SRC_ITER, DNNL_ARG_SRC_ITER_C, |
369 | DNNL_ARG_WEIGHTS_LAYER, DNNL_ARG_WEIGHTS_ITER, DNNL_ARG_BIAS, |
370 | DNNL_ARG_DST_LAYER, DNNL_ARG_DST_ITER, DNNL_ARG_DST_ITER_C}) { |
371 | CHECK_UNIMPL( |
372 | lstm_forward::primitive_desc(eng, prop_kind::forward_inference, |
373 | rnn_direction::unidirectional_left2right, src_layer_md, |
374 | src_iter_md, src_iter_c_md, wei_layer_md, wei_iter_md, |
375 | bia_md, dst_layer_md, dst_iter_md, dst_iter_c_md, |
376 | gen_attr_with_zp(arg))); |
377 | } |
378 | } |
379 | |
380 | TEST_F(attr_quantization_test_t, TestShuffle) { |
381 | SKIP_IF_CUDA(true, "Shuffle primitive not supported for CUDA" ); |
382 | memory::desc md {{1, 16, 3, 3}, data_type::f32, tag::abcd}; |
383 | |
384 | CHECK_OK(shuffle_forward::primitive_desc pd( |
385 | eng, prop_kind::forward, md, md, 1, 4)); |
386 | CHECK_UNIMPL(shuffle_forward::primitive_desc pd( |
387 | eng, prop_kind::forward, md, md, 1, 4, gen_attr_with_scales())); |
388 | |
389 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { |
390 | CHECK_UNIMPL(shuffle_forward::primitive_desc pd( |
391 | eng, prop_kind::forward, md, md, 1, 4, gen_attr_with_zp(arg))); |
392 | } |
393 | } |
394 | |
395 | TEST_F(attr_quantization_test_t, TestSoftmax) { |
396 | SKIP_IF_CUDA(true, "Unsupported datatype for CUDA" ); |
397 | SKIP_IF_HIP(true, "Unsupported datatype for HIP" ); |
398 | |
399 | memory::desc md {{2, 16}, data_type::u8, tag::ab}; |
400 | |
401 | CHECK_OK(softmax_forward::primitive_desc( |
402 | eng, prop_kind::forward, algorithm::softmax_accurate, md, md, 1)); |
403 | |
404 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { |
405 | CHECK_OK(softmax_forward::primitive_desc(eng, prop_kind::forward, |
406 | algorithm::softmax_accurate, md, md, 1, |
407 | gen_attr_with_scales(arg))); |
408 | CHECK_UNIMPL(softmax_forward::primitive_desc(eng, prop_kind::forward, |
409 | algorithm::softmax_accurate, md, md, 1, gen_attr_with_zp(arg))); |
410 | } |
411 | } |
412 | |
413 | TEST_F(attr_quantization_test_t, TestSum) { |
414 | memory::desc md {{1, 16, 3, 3}, data_type::s8, tag::abcd}; |
415 | CHECK_OK(sum::primitive_desc(eng, {1.f, 1.f}, {md, md})); |
416 | CHECK_UNIMPL(sum::primitive_desc( |
417 | eng, {1.f, 1.f}, {md, md}, gen_attr_with_scales())); |
418 | |
419 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { |
420 | CHECK_UNIMPL(sum::primitive_desc( |
421 | eng, {1.f, 1.f}, {md, md}, gen_attr_with_zp(arg))); |
422 | } |
423 | } |
424 | |
425 | } // namespace dnnl |
426 | |