1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <math.h> |
18 | #include <stddef.h> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #define for_ for |
25 | |
26 | #define CHECK(f) \ |
27 | do { \ |
28 | dnnl_status_t s = f; \ |
29 | if (s != dnnl_success) { \ |
30 | printf("[%s:%d] error: %s returns %d\n", __FILE__, __LINE__, #f, \ |
31 | s); \ |
32 | exit(2); \ |
33 | } \ |
34 | } while (0) |
35 | |
36 | #define CHECK_TRUE(expr) \ |
37 | do { \ |
38 | int e_ = expr; \ |
39 | if (!e_) { \ |
40 | printf("[%s:%d] %s failed\n", __FILE__, __LINE__, #expr); \ |
41 | exit(2); \ |
42 | } \ |
43 | } while (0) |
44 | |
45 | typedef float real_t; |
46 | |
47 | #define LENGTH_100 100 |
48 | |
49 | void test1() { |
50 | dnnl_engine_t engine; |
51 | CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0)); |
52 | |
53 | dnnl_dims_t dims = {LENGTH_100}; |
54 | real_t data[LENGTH_100]; |
55 | |
56 | dnnl_memory_desc_t md; |
57 | const_dnnl_memory_desc_t c_md_tmp; |
58 | dnnl_memory_t m; |
59 | |
60 | CHECK(dnnl_memory_desc_create_with_tag(&md, 1, dims, dnnl_f32, dnnl_x)); |
61 | CHECK(dnnl_memory_create(&m, md, engine, NULL)); |
62 | |
63 | void *req = NULL; |
64 | |
65 | CHECK(dnnl_memory_get_data_handle(m, &req)); |
66 | CHECK_TRUE(req == NULL); |
67 | |
68 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL |
69 | CHECK(dnnl_memory_set_data_handle(m, data)); |
70 | CHECK(dnnl_memory_get_data_handle(m, &req)); |
71 | CHECK_TRUE(req == data); |
72 | #endif |
73 | |
74 | CHECK_TRUE(dnnl_memory_desc_get_size(md) == LENGTH_100 * sizeof(data[0])); |
75 | |
76 | CHECK(dnnl_memory_get_memory_desc(m, &c_md_tmp)); |
77 | CHECK_TRUE(dnnl_memory_desc_equal(md, c_md_tmp)); |
78 | |
79 | CHECK(dnnl_memory_destroy(m)); |
80 | CHECK(dnnl_memory_desc_destroy(md)); |
81 | |
82 | CHECK(dnnl_engine_destroy(engine)); |
83 | } |
84 | |
85 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL |
86 | |
87 | static size_t product(dnnl_dim_t *arr, size_t size) { |
88 | size_t prod = 1; |
89 | for (size_t i = 0; i < size; ++i) |
90 | prod *= arr[i]; |
91 | return prod; |
92 | } |
93 | |
94 | void test2() { |
95 | /* AlexNet: c3 |
96 | * {2, 256, 13, 13} (x) {384, 256, 3, 3} -> {2, 384, 13, 13} |
97 | * pad: {1, 1} |
98 | * strides: {1, 1} |
99 | */ |
100 | |
101 | const dnnl_dim_t mb = 2; |
102 | const dnnl_dim_t groups = 2; |
103 | const int ndims = 4; |
104 | dnnl_dims_t c3_src_sizes = {mb, 256, 13, 13}; |
105 | dnnl_dims_t c3_weights_sizes = {groups, 384 / groups, 256 / groups, 3, 3}; |
106 | dnnl_dims_t c3_bias_sizes = {384}; |
107 | dnnl_dims_t strides = {1, 1}; |
108 | dnnl_dims_t dilation = {0, 0}; |
109 | dnnl_dims_t padding = {0, 0}; // set proper values |
110 | dnnl_dims_t c3_dst_sizes = {mb, 384, |
111 | (c3_src_sizes[2] + 2 * padding[0] - c3_weights_sizes[3]) |
112 | / strides[0] |
113 | + 1, |
114 | (c3_src_sizes[3] + 2 * padding[1] - c3_weights_sizes[4]) |
115 | / strides[1] |
116 | + 1}; |
117 | |
118 | real_t *src |
119 | = (real_t *)calloc(product(c3_src_sizes, ndims), sizeof(real_t)); |
120 | real_t *weights = (real_t *)calloc( |
121 | product(c3_weights_sizes, ndims + 1), sizeof(real_t)); |
122 | real_t *bias = (real_t *)calloc(product(c3_bias_sizes, 1), sizeof(real_t)); |
123 | real_t *dst |
124 | = (real_t *)calloc(product(c3_dst_sizes, ndims), sizeof(real_t)); |
125 | real_t *out_mem |
126 | = (real_t *)calloc(product(c3_dst_sizes, ndims), sizeof(real_t)); |
127 | CHECK_TRUE(src && weights && bias && dst && out_mem); |
128 | |
129 | for (dnnl_dim_t i = 0; i < c3_bias_sizes[0]; ++i) |
130 | bias[i] = i; |
131 | |
132 | dnnl_engine_t engine; |
133 | CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0)); |
134 | |
135 | dnnl_stream_t stream; |
136 | CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags)); |
137 | |
138 | /* first describe user data and create data descriptors for future |
139 | * convolution w/ the specified format -- we do not want to do a reorder */ |
140 | dnnl_memory_desc_t c3_src_md, c3_weights_md, c3_bias_md, c3_dst_md, out_md; |
141 | dnnl_memory_t c3_src, c3_weights, c3_bias, c3_dst, out; |
142 | |
143 | // src |
144 | { |
145 | CHECK(dnnl_memory_desc_create_with_tag( |
146 | &c3_src_md, 4, c3_src_sizes, dnnl_f32, dnnl_nChw8c)); |
147 | CHECK(dnnl_memory_create(&c3_src, c3_src_md, engine, src)); |
148 | } |
149 | |
150 | // weights |
151 | { |
152 | CHECK(dnnl_memory_desc_create_with_tag(&c3_weights_md, |
153 | 4 + (groups != 1), c3_weights_sizes + (groups == 1), dnnl_f32, |
154 | groups == 1 ? dnnl_OIhw8i8o : dnnl_gOIhw8i8o)); |
155 | CHECK(dnnl_memory_create(&c3_weights, c3_weights_md, engine, weights)); |
156 | } |
157 | |
158 | // bias |
159 | { |
160 | CHECK(dnnl_memory_desc_create_with_tag( |
161 | &c3_bias_md, 1, c3_bias_sizes, dnnl_f32, dnnl_x)); |
162 | CHECK(dnnl_memory_create(&c3_bias, c3_bias_md, engine, bias)); |
163 | } |
164 | |
165 | // c3_dst |
166 | { |
167 | CHECK(dnnl_memory_desc_create_with_tag( |
168 | &c3_dst_md, 4, c3_dst_sizes, dnnl_f32, dnnl_nChw8c)); |
169 | CHECK(dnnl_memory_create(&c3_dst, c3_dst_md, engine, dst)); |
170 | } |
171 | |
172 | // out |
173 | { |
174 | CHECK(dnnl_memory_desc_create_with_tag( |
175 | &out_md, 4, c3_dst_sizes, dnnl_f32, dnnl_nchw)); |
176 | CHECK(dnnl_memory_create(&out, out_md, engine, out_mem)); |
177 | } |
178 | |
179 | /* create a convolution primitive descriptor */ |
180 | dnnl_primitive_desc_t c3_pd; |
181 | dnnl_primitive_t c3; |
182 | |
183 | CHECK(dnnl_convolution_forward_primitive_desc_create(&c3_pd, engine, |
184 | dnnl_forward_training, dnnl_convolution_direct, c3_src_md, |
185 | c3_weights_md, c3_bias_md, c3_dst_md, strides, dilation, padding, |
186 | NULL, NULL)); |
187 | |
188 | CHECK_TRUE(dnnl_memory_desc_equal(c3_src_md, |
189 | dnnl_primitive_desc_query_md(c3_pd, dnnl_query_src_md, 0))); |
190 | CHECK_TRUE(dnnl_memory_desc_equal(c3_weights_md, |
191 | dnnl_primitive_desc_query_md(c3_pd, dnnl_query_weights_md, 0))); |
192 | CHECK_TRUE(dnnl_memory_desc_equal(c3_bias_md, |
193 | dnnl_primitive_desc_query_md(c3_pd, dnnl_query_weights_md, 1))); |
194 | CHECK_TRUE(dnnl_memory_desc_equal(c3_dst_md, |
195 | dnnl_primitive_desc_query_md(c3_pd, dnnl_query_dst_md, 0))); |
196 | |
197 | CHECK_TRUE(dnnl_memory_desc_equal(c3_src_md, |
198 | dnnl_primitive_desc_query_md( |
199 | c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_SRC))); |
200 | CHECK_TRUE(dnnl_memory_desc_equal(c3_weights_md, |
201 | dnnl_primitive_desc_query_md( |
202 | c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_WEIGHTS))); |
203 | CHECK_TRUE(dnnl_memory_desc_equal(c3_bias_md, |
204 | dnnl_primitive_desc_query_md( |
205 | c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_BIAS))); |
206 | CHECK_TRUE(dnnl_memory_desc_equal(c3_dst_md, |
207 | dnnl_primitive_desc_query_md( |
208 | c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_DST))); |
209 | |
210 | /* create a convolution and execute it */ |
211 | CHECK(dnnl_primitive_create(&c3, c3_pd)); |
212 | CHECK(dnnl_primitive_desc_destroy(c3_pd)); |
213 | |
214 | dnnl_exec_arg_t c3_args[4] = { |
215 | {DNNL_ARG_SRC, c3_src}, |
216 | {DNNL_ARG_WEIGHTS, c3_weights}, |
217 | {DNNL_ARG_BIAS, c3_bias}, |
218 | {DNNL_ARG_DST, c3_dst}, |
219 | }; |
220 | CHECK(dnnl_primitive_execute(c3, stream, 4, c3_args)); |
221 | CHECK(dnnl_primitive_destroy(c3)); |
222 | |
223 | /* create a reorder primitive descriptor */ |
224 | dnnl_primitive_desc_t r_pd; |
225 | CHECK(dnnl_reorder_primitive_desc_create( |
226 | &r_pd, c3_dst_md, engine, out_md, engine, NULL)); |
227 | |
228 | /* create a reorder and execute it */ |
229 | dnnl_primitive_t r; |
230 | CHECK(dnnl_primitive_create(&r, r_pd)); |
231 | CHECK(dnnl_primitive_desc_destroy(r_pd)); |
232 | |
233 | dnnl_exec_arg_t r_args[2] = { |
234 | {DNNL_ARG_FROM, c3_dst}, |
235 | {DNNL_ARG_TO, out}, |
236 | }; |
237 | CHECK(dnnl_primitive_execute(r, stream, 2, r_args)); |
238 | CHECK(dnnl_primitive_destroy(r)); |
239 | |
240 | CHECK(dnnl_stream_wait(stream)); |
241 | |
242 | /* clean-up */ |
243 | CHECK(dnnl_memory_destroy(c3_src)); |
244 | CHECK(dnnl_memory_destroy(c3_weights)); |
245 | CHECK(dnnl_memory_destroy(c3_bias)); |
246 | CHECK(dnnl_memory_destroy(c3_dst)); |
247 | CHECK(dnnl_memory_destroy(out)); |
248 | CHECK(dnnl_stream_destroy(stream)); |
249 | CHECK(dnnl_engine_destroy(engine)); |
250 | |
251 | CHECK(dnnl_memory_desc_destroy(c3_src_md)); |
252 | CHECK(dnnl_memory_desc_destroy(c3_weights_md)); |
253 | CHECK(dnnl_memory_desc_destroy(c3_bias_md)); |
254 | CHECK(dnnl_memory_desc_destroy(c3_dst_md)); |
255 | CHECK(dnnl_memory_desc_destroy(out_md)); |
256 | |
257 | const dnnl_dim_t N = c3_dst_sizes[0], C = c3_dst_sizes[1], |
258 | H = c3_dst_sizes[2], W = c3_dst_sizes[3]; |
259 | for_(dnnl_dim_t n = 0; n < N; ++n) |
260 | for_(dnnl_dim_t c = 0; c < C; ++c) |
261 | for_(dnnl_dim_t h = 0; h < H; ++h) |
262 | for (dnnl_dim_t w = 0; w < W; ++w) { |
263 | dnnl_dim_t off = ((n * C + c) * H + h) * W + w; |
264 | CHECK_TRUE(out_mem[off] == bias[c]); |
265 | } |
266 | |
267 | free(src); |
268 | free(weights); |
269 | free(bias); |
270 | free(dst); |
271 | free(out_mem); |
272 | } |
273 | |
274 | void test3() { |
275 | const dnnl_dim_t mb = 2; |
276 | const int ndims = 4; |
277 | dnnl_dims_t l2_data_sizes = {mb, 256, 13, 13}; |
278 | |
279 | real_t *src |
280 | = (real_t *)calloc(product(l2_data_sizes, ndims), sizeof(real_t)); |
281 | real_t *dst |
282 | = (real_t *)calloc(product(l2_data_sizes, ndims), sizeof(real_t)); |
283 | CHECK_TRUE(src && dst); |
284 | |
285 | for (size_t i = 0; i < product(l2_data_sizes, ndims); ++i) |
286 | src[i] = (i % 13) + 1; |
287 | |
288 | dnnl_engine_t engine; |
289 | CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0)); |
290 | |
291 | dnnl_stream_t stream; |
292 | CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags)); |
293 | |
294 | dnnl_memory_desc_t l2_data_md; |
295 | dnnl_memory_t l2_src, l2_dst; |
296 | |
297 | // src, dst |
298 | { |
299 | CHECK(dnnl_memory_desc_create_with_tag( |
300 | &l2_data_md, ndims, l2_data_sizes, dnnl_f32, dnnl_nchw)); |
301 | CHECK(dnnl_memory_create(&l2_src, l2_data_md, engine, src)); |
302 | CHECK(dnnl_memory_create(&l2_dst, l2_data_md, engine, dst)); |
303 | } |
304 | |
305 | /* create an lrn */ |
306 | dnnl_primitive_desc_t l2_pd; |
307 | dnnl_primitive_t l2; |
308 | |
309 | CHECK(dnnl_lrn_forward_primitive_desc_create(&l2_pd, engine, |
310 | dnnl_forward_inference, dnnl_lrn_across_channels, l2_data_md, |
311 | l2_data_md, 5, 1e-4, 0.75, 1.0, NULL)); |
312 | |
313 | CHECK_TRUE(dnnl_memory_desc_equal(l2_data_md, |
314 | dnnl_primitive_desc_query_md(l2_pd, dnnl_query_src_md, 0))); |
315 | CHECK_TRUE(dnnl_memory_desc_equal(l2_data_md, |
316 | dnnl_primitive_desc_query_md(l2_pd, dnnl_query_dst_md, 0))); |
317 | CHECK_TRUE(dnnl_primitive_desc_query_s32( |
318 | l2_pd, dnnl_query_num_of_inputs_s32, 0) |
319 | == 1); |
320 | CHECK_TRUE(dnnl_primitive_desc_query_s32( |
321 | l2_pd, dnnl_query_num_of_outputs_s32, 0) |
322 | == 1); |
323 | |
324 | CHECK(dnnl_primitive_create(&l2, l2_pd)); |
325 | CHECK(dnnl_primitive_desc_destroy(l2_pd)); |
326 | |
327 | dnnl_exec_arg_t l2_args[2] = { |
328 | {DNNL_ARG_SRC, l2_src}, |
329 | {DNNL_ARG_DST, l2_dst}, |
330 | }; |
331 | CHECK(dnnl_primitive_execute(l2, stream, 2, l2_args)); |
332 | CHECK(dnnl_primitive_destroy(l2)); |
333 | |
334 | CHECK(dnnl_stream_wait(stream)); |
335 | |
336 | /* clean-up */ |
337 | CHECK(dnnl_memory_destroy(l2_src)); |
338 | CHECK(dnnl_memory_destroy(l2_dst)); |
339 | CHECK(dnnl_stream_destroy(stream)); |
340 | CHECK(dnnl_engine_destroy(engine)); |
341 | |
342 | CHECK(dnnl_memory_desc_destroy(l2_data_md)); |
343 | |
344 | const dnnl_dim_t N = l2_data_sizes[0], C = l2_data_sizes[1], |
345 | H = l2_data_sizes[2], W = l2_data_sizes[3]; |
346 | for_(dnnl_dim_t n = 0; n < N; ++n) |
347 | for_(dnnl_dim_t c = 0; c < C; ++c) |
348 | for_(dnnl_dim_t h = 0; h < H; ++h) |
349 | for (dnnl_dim_t w = 0; w < W; ++w) { |
350 | size_t off = ((n * C + c) * H + h) * W + w; |
351 | real_t e = (off % 13) + 1; |
352 | real_t diff = (real_t)fabs(dst[off] - e); |
353 | if (diff / fabs(e) > 0.0125) printf("exp: %g, got: %g\n" , e, dst[off]); |
354 | CHECK_TRUE(diff / fabs(e) < 0.0125); |
355 | } |
356 | |
357 | free(src); |
358 | free(dst); |
359 | } |
360 | #endif |
361 | |
362 | int main() { |
363 | test1(); |
364 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL |
365 | test2(); |
366 | test3(); |
367 | #endif |
368 | return 0; |
369 | } |
370 | |