1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <math.h>
18#include <stddef.h>
19#include <stdio.h>
20#include <stdlib.h>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#define for_ for
25
26#define CHECK(f) \
27 do { \
28 dnnl_status_t s = f; \
29 if (s != dnnl_success) { \
30 printf("[%s:%d] error: %s returns %d\n", __FILE__, __LINE__, #f, \
31 s); \
32 exit(2); \
33 } \
34 } while (0)
35
36#define CHECK_TRUE(expr) \
37 do { \
38 int e_ = expr; \
39 if (!e_) { \
40 printf("[%s:%d] %s failed\n", __FILE__, __LINE__, #expr); \
41 exit(2); \
42 } \
43 } while (0)
44
45typedef float real_t;
46
47#define LENGTH_100 100
48
49void test1() {
50 dnnl_engine_t engine;
51 CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0));
52
53 dnnl_dims_t dims = {LENGTH_100};
54 real_t data[LENGTH_100];
55
56 dnnl_memory_desc_t md;
57 const_dnnl_memory_desc_t c_md_tmp;
58 dnnl_memory_t m;
59
60 CHECK(dnnl_memory_desc_create_with_tag(&md, 1, dims, dnnl_f32, dnnl_x));
61 CHECK(dnnl_memory_create(&m, md, engine, NULL));
62
63 void *req = NULL;
64
65 CHECK(dnnl_memory_get_data_handle(m, &req));
66 CHECK_TRUE(req == NULL);
67
68#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL
69 CHECK(dnnl_memory_set_data_handle(m, data));
70 CHECK(dnnl_memory_get_data_handle(m, &req));
71 CHECK_TRUE(req == data);
72#endif
73
74 CHECK_TRUE(dnnl_memory_desc_get_size(md) == LENGTH_100 * sizeof(data[0]));
75
76 CHECK(dnnl_memory_get_memory_desc(m, &c_md_tmp));
77 CHECK_TRUE(dnnl_memory_desc_equal(md, c_md_tmp));
78
79 CHECK(dnnl_memory_destroy(m));
80 CHECK(dnnl_memory_desc_destroy(md));
81
82 CHECK(dnnl_engine_destroy(engine));
83}
84
85#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL
86
87static size_t product(dnnl_dim_t *arr, size_t size) {
88 size_t prod = 1;
89 for (size_t i = 0; i < size; ++i)
90 prod *= arr[i];
91 return prod;
92}
93
94void test2() {
95 /* AlexNet: c3
96 * {2, 256, 13, 13} (x) {384, 256, 3, 3} -> {2, 384, 13, 13}
97 * pad: {1, 1}
98 * strides: {1, 1}
99 */
100
101 const dnnl_dim_t mb = 2;
102 const dnnl_dim_t groups = 2;
103 const int ndims = 4;
104 dnnl_dims_t c3_src_sizes = {mb, 256, 13, 13};
105 dnnl_dims_t c3_weights_sizes = {groups, 384 / groups, 256 / groups, 3, 3};
106 dnnl_dims_t c3_bias_sizes = {384};
107 dnnl_dims_t strides = {1, 1};
108 dnnl_dims_t dilation = {0, 0};
109 dnnl_dims_t padding = {0, 0}; // set proper values
110 dnnl_dims_t c3_dst_sizes = {mb, 384,
111 (c3_src_sizes[2] + 2 * padding[0] - c3_weights_sizes[3])
112 / strides[0]
113 + 1,
114 (c3_src_sizes[3] + 2 * padding[1] - c3_weights_sizes[4])
115 / strides[1]
116 + 1};
117
118 real_t *src
119 = (real_t *)calloc(product(c3_src_sizes, ndims), sizeof(real_t));
120 real_t *weights = (real_t *)calloc(
121 product(c3_weights_sizes, ndims + 1), sizeof(real_t));
122 real_t *bias = (real_t *)calloc(product(c3_bias_sizes, 1), sizeof(real_t));
123 real_t *dst
124 = (real_t *)calloc(product(c3_dst_sizes, ndims), sizeof(real_t));
125 real_t *out_mem
126 = (real_t *)calloc(product(c3_dst_sizes, ndims), sizeof(real_t));
127 CHECK_TRUE(src && weights && bias && dst && out_mem);
128
129 for (dnnl_dim_t i = 0; i < c3_bias_sizes[0]; ++i)
130 bias[i] = i;
131
132 dnnl_engine_t engine;
133 CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0));
134
135 dnnl_stream_t stream;
136 CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags));
137
138 /* first describe user data and create data descriptors for future
139 * convolution w/ the specified format -- we do not want to do a reorder */
140 dnnl_memory_desc_t c3_src_md, c3_weights_md, c3_bias_md, c3_dst_md, out_md;
141 dnnl_memory_t c3_src, c3_weights, c3_bias, c3_dst, out;
142
143 // src
144 {
145 CHECK(dnnl_memory_desc_create_with_tag(
146 &c3_src_md, 4, c3_src_sizes, dnnl_f32, dnnl_nChw8c));
147 CHECK(dnnl_memory_create(&c3_src, c3_src_md, engine, src));
148 }
149
150 // weights
151 {
152 CHECK(dnnl_memory_desc_create_with_tag(&c3_weights_md,
153 4 + (groups != 1), c3_weights_sizes + (groups == 1), dnnl_f32,
154 groups == 1 ? dnnl_OIhw8i8o : dnnl_gOIhw8i8o));
155 CHECK(dnnl_memory_create(&c3_weights, c3_weights_md, engine, weights));
156 }
157
158 // bias
159 {
160 CHECK(dnnl_memory_desc_create_with_tag(
161 &c3_bias_md, 1, c3_bias_sizes, dnnl_f32, dnnl_x));
162 CHECK(dnnl_memory_create(&c3_bias, c3_bias_md, engine, bias));
163 }
164
165 // c3_dst
166 {
167 CHECK(dnnl_memory_desc_create_with_tag(
168 &c3_dst_md, 4, c3_dst_sizes, dnnl_f32, dnnl_nChw8c));
169 CHECK(dnnl_memory_create(&c3_dst, c3_dst_md, engine, dst));
170 }
171
172 // out
173 {
174 CHECK(dnnl_memory_desc_create_with_tag(
175 &out_md, 4, c3_dst_sizes, dnnl_f32, dnnl_nchw));
176 CHECK(dnnl_memory_create(&out, out_md, engine, out_mem));
177 }
178
179 /* create a convolution primitive descriptor */
180 dnnl_primitive_desc_t c3_pd;
181 dnnl_primitive_t c3;
182
183 CHECK(dnnl_convolution_forward_primitive_desc_create(&c3_pd, engine,
184 dnnl_forward_training, dnnl_convolution_direct, c3_src_md,
185 c3_weights_md, c3_bias_md, c3_dst_md, strides, dilation, padding,
186 NULL, NULL));
187
188 CHECK_TRUE(dnnl_memory_desc_equal(c3_src_md,
189 dnnl_primitive_desc_query_md(c3_pd, dnnl_query_src_md, 0)));
190 CHECK_TRUE(dnnl_memory_desc_equal(c3_weights_md,
191 dnnl_primitive_desc_query_md(c3_pd, dnnl_query_weights_md, 0)));
192 CHECK_TRUE(dnnl_memory_desc_equal(c3_bias_md,
193 dnnl_primitive_desc_query_md(c3_pd, dnnl_query_weights_md, 1)));
194 CHECK_TRUE(dnnl_memory_desc_equal(c3_dst_md,
195 dnnl_primitive_desc_query_md(c3_pd, dnnl_query_dst_md, 0)));
196
197 CHECK_TRUE(dnnl_memory_desc_equal(c3_src_md,
198 dnnl_primitive_desc_query_md(
199 c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_SRC)));
200 CHECK_TRUE(dnnl_memory_desc_equal(c3_weights_md,
201 dnnl_primitive_desc_query_md(
202 c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_WEIGHTS)));
203 CHECK_TRUE(dnnl_memory_desc_equal(c3_bias_md,
204 dnnl_primitive_desc_query_md(
205 c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_BIAS)));
206 CHECK_TRUE(dnnl_memory_desc_equal(c3_dst_md,
207 dnnl_primitive_desc_query_md(
208 c3_pd, dnnl_query_exec_arg_md, DNNL_ARG_DST)));
209
210 /* create a convolution and execute it */
211 CHECK(dnnl_primitive_create(&c3, c3_pd));
212 CHECK(dnnl_primitive_desc_destroy(c3_pd));
213
214 dnnl_exec_arg_t c3_args[4] = {
215 {DNNL_ARG_SRC, c3_src},
216 {DNNL_ARG_WEIGHTS, c3_weights},
217 {DNNL_ARG_BIAS, c3_bias},
218 {DNNL_ARG_DST, c3_dst},
219 };
220 CHECK(dnnl_primitive_execute(c3, stream, 4, c3_args));
221 CHECK(dnnl_primitive_destroy(c3));
222
223 /* create a reorder primitive descriptor */
224 dnnl_primitive_desc_t r_pd;
225 CHECK(dnnl_reorder_primitive_desc_create(
226 &r_pd, c3_dst_md, engine, out_md, engine, NULL));
227
228 /* create a reorder and execute it */
229 dnnl_primitive_t r;
230 CHECK(dnnl_primitive_create(&r, r_pd));
231 CHECK(dnnl_primitive_desc_destroy(r_pd));
232
233 dnnl_exec_arg_t r_args[2] = {
234 {DNNL_ARG_FROM, c3_dst},
235 {DNNL_ARG_TO, out},
236 };
237 CHECK(dnnl_primitive_execute(r, stream, 2, r_args));
238 CHECK(dnnl_primitive_destroy(r));
239
240 CHECK(dnnl_stream_wait(stream));
241
242 /* clean-up */
243 CHECK(dnnl_memory_destroy(c3_src));
244 CHECK(dnnl_memory_destroy(c3_weights));
245 CHECK(dnnl_memory_destroy(c3_bias));
246 CHECK(dnnl_memory_destroy(c3_dst));
247 CHECK(dnnl_memory_destroy(out));
248 CHECK(dnnl_stream_destroy(stream));
249 CHECK(dnnl_engine_destroy(engine));
250
251 CHECK(dnnl_memory_desc_destroy(c3_src_md));
252 CHECK(dnnl_memory_desc_destroy(c3_weights_md));
253 CHECK(dnnl_memory_desc_destroy(c3_bias_md));
254 CHECK(dnnl_memory_desc_destroy(c3_dst_md));
255 CHECK(dnnl_memory_desc_destroy(out_md));
256
257 const dnnl_dim_t N = c3_dst_sizes[0], C = c3_dst_sizes[1],
258 H = c3_dst_sizes[2], W = c3_dst_sizes[3];
259 for_(dnnl_dim_t n = 0; n < N; ++n)
260 for_(dnnl_dim_t c = 0; c < C; ++c)
261 for_(dnnl_dim_t h = 0; h < H; ++h)
262 for (dnnl_dim_t w = 0; w < W; ++w) {
263 dnnl_dim_t off = ((n * C + c) * H + h) * W + w;
264 CHECK_TRUE(out_mem[off] == bias[c]);
265 }
266
267 free(src);
268 free(weights);
269 free(bias);
270 free(dst);
271 free(out_mem);
272}
273
274void test3() {
275 const dnnl_dim_t mb = 2;
276 const int ndims = 4;
277 dnnl_dims_t l2_data_sizes = {mb, 256, 13, 13};
278
279 real_t *src
280 = (real_t *)calloc(product(l2_data_sizes, ndims), sizeof(real_t));
281 real_t *dst
282 = (real_t *)calloc(product(l2_data_sizes, ndims), sizeof(real_t));
283 CHECK_TRUE(src && dst);
284
285 for (size_t i = 0; i < product(l2_data_sizes, ndims); ++i)
286 src[i] = (i % 13) + 1;
287
288 dnnl_engine_t engine;
289 CHECK(dnnl_engine_create(&engine, dnnl_cpu, 0));
290
291 dnnl_stream_t stream;
292 CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags));
293
294 dnnl_memory_desc_t l2_data_md;
295 dnnl_memory_t l2_src, l2_dst;
296
297 // src, dst
298 {
299 CHECK(dnnl_memory_desc_create_with_tag(
300 &l2_data_md, ndims, l2_data_sizes, dnnl_f32, dnnl_nchw));
301 CHECK(dnnl_memory_create(&l2_src, l2_data_md, engine, src));
302 CHECK(dnnl_memory_create(&l2_dst, l2_data_md, engine, dst));
303 }
304
305 /* create an lrn */
306 dnnl_primitive_desc_t l2_pd;
307 dnnl_primitive_t l2;
308
309 CHECK(dnnl_lrn_forward_primitive_desc_create(&l2_pd, engine,
310 dnnl_forward_inference, dnnl_lrn_across_channels, l2_data_md,
311 l2_data_md, 5, 1e-4, 0.75, 1.0, NULL));
312
313 CHECK_TRUE(dnnl_memory_desc_equal(l2_data_md,
314 dnnl_primitive_desc_query_md(l2_pd, dnnl_query_src_md, 0)));
315 CHECK_TRUE(dnnl_memory_desc_equal(l2_data_md,
316 dnnl_primitive_desc_query_md(l2_pd, dnnl_query_dst_md, 0)));
317 CHECK_TRUE(dnnl_primitive_desc_query_s32(
318 l2_pd, dnnl_query_num_of_inputs_s32, 0)
319 == 1);
320 CHECK_TRUE(dnnl_primitive_desc_query_s32(
321 l2_pd, dnnl_query_num_of_outputs_s32, 0)
322 == 1);
323
324 CHECK(dnnl_primitive_create(&l2, l2_pd));
325 CHECK(dnnl_primitive_desc_destroy(l2_pd));
326
327 dnnl_exec_arg_t l2_args[2] = {
328 {DNNL_ARG_SRC, l2_src},
329 {DNNL_ARG_DST, l2_dst},
330 };
331 CHECK(dnnl_primitive_execute(l2, stream, 2, l2_args));
332 CHECK(dnnl_primitive_destroy(l2));
333
334 CHECK(dnnl_stream_wait(stream));
335
336 /* clean-up */
337 CHECK(dnnl_memory_destroy(l2_src));
338 CHECK(dnnl_memory_destroy(l2_dst));
339 CHECK(dnnl_stream_destroy(stream));
340 CHECK(dnnl_engine_destroy(engine));
341
342 CHECK(dnnl_memory_desc_destroy(l2_data_md));
343
344 const dnnl_dim_t N = l2_data_sizes[0], C = l2_data_sizes[1],
345 H = l2_data_sizes[2], W = l2_data_sizes[3];
346 for_(dnnl_dim_t n = 0; n < N; ++n)
347 for_(dnnl_dim_t c = 0; c < C; ++c)
348 for_(dnnl_dim_t h = 0; h < H; ++h)
349 for (dnnl_dim_t w = 0; w < W; ++w) {
350 size_t off = ((n * C + c) * H + h) * W + w;
351 real_t e = (off % 13) + 1;
352 real_t diff = (real_t)fabs(dst[off] - e);
353 if (diff / fabs(e) > 0.0125) printf("exp: %g, got: %g\n", e, dst[off]);
354 CHECK_TRUE(diff / fabs(e) < 0.0125);
355 }
356
357 free(src);
358 free(dst);
359}
360#endif
361
362int main() {
363 test1();
364#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_SYCL
365 test2();
366 test3();
367#endif
368 return 0;
369}
370