1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef EXAMPLE_UTILS_H
18#define EXAMPLE_UTILS_H
19
20#include <assert.h>
21#include <stdbool.h>
22#include <stdio.h>
23#include <stdlib.h>
24#include <string.h>
25
26#include "dnnl.h"
27#include "dnnl_debug.h"
28
29#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
30#include "dnnl_ocl.h"
31#endif
32
33#define COMPLAIN_DNNL_ERROR_AND_EXIT(what, status) \
34 do { \
35 printf("[%s:%d] `%s` returns oneDNN error: %s.\n", __FILE__, __LINE__, \
36 what, dnnl_status2str(status)); \
37 printf("Example failed.\n"); \
38 exit(1); \
39 } while (0)
40
41#define COMPLAIN_EXAMPLE_ERROR_AND_EXIT(complain_fmt, ...) \
42 do { \
43 printf("[%s:%d] Error in the example: " complain_fmt ".\n", __FILE__, \
44 __LINE__, __VA_ARGS__); \
45 printf("Example failed.\n"); \
46 exit(2); \
47 } while (0)
48
49static dnnl_engine_kind_t validate_engine_kind(dnnl_engine_kind_t akind) {
50 // Checking if a GPU exists on the machine
51 if (akind == dnnl_gpu) {
52 if (!dnnl_engine_get_count(dnnl_gpu)) {
53 printf("Application couldn't find GPU, please run with CPU "
54 "instead.\n");
55 exit(0);
56 }
57 }
58 return akind;
59}
60
61#define CHECK(f) \
62 do { \
63 dnnl_status_t s_ = f; \
64 if (s_ != dnnl_success) COMPLAIN_DNNL_ERROR_AND_EXIT(#f, s_); \
65 } while (0)
66
67static inline dnnl_engine_kind_t parse_engine_kind(int argc, char **argv) {
68 // Returns default engine kind, i.e. CPU, if none given
69 if (argc == 1) {
70 return validate_engine_kind(dnnl_cpu);
71 } else if (argc == 2) {
72 // Checking the engine type, i.e. CPU or GPU
73 char *engine_kind_str = argv[1];
74 if (!strcmp(engine_kind_str, "cpu")) {
75 return validate_engine_kind(dnnl_cpu);
76 } else if (!strcmp(engine_kind_str, "gpu")) {
77 return validate_engine_kind(dnnl_gpu);
78 }
79 }
80
81 // If all above fails, the example should be run properly
82 COMPLAIN_EXAMPLE_ERROR_AND_EXIT(
83 "inappropriate engine kind.\n"
84 "Please run the example like this: %s [cpu|gpu].",
85 argv[0]);
86}
87
88static inline const char *engine_kind2str_upper(dnnl_engine_kind_t kind) {
89 if (kind == dnnl_cpu) return "CPU";
90 if (kind == dnnl_gpu) return "GPU";
91 return "<Unknown engine>";
92}
93
94// Read from memory, write to handle
95static inline void read_from_dnnl_memory(void *handle, dnnl_memory_t mem) {
96 dnnl_engine_t eng;
97 dnnl_engine_kind_t eng_kind;
98 const_dnnl_memory_desc_t md;
99
100 if (!handle) COMPLAIN_EXAMPLE_ERROR_AND_EXIT("%s", "handle is NULL.");
101
102 CHECK(dnnl_memory_get_engine(mem, &eng));
103 CHECK(dnnl_engine_get_kind(eng, &eng_kind));
104 CHECK(dnnl_memory_get_memory_desc(mem, &md));
105 size_t bytes = dnnl_memory_desc_get_size(md);
106
107 bool is_cpu_sycl
108 = (DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL && eng_kind == dnnl_cpu);
109
110 if (eng_kind == dnnl_gpu || is_cpu_sycl) {
111 void *mapped_ptr = NULL;
112 CHECK(dnnl_memory_map_data(mem, &mapped_ptr));
113 if (mapped_ptr) memcpy(handle, mapped_ptr, bytes);
114 CHECK(dnnl_memory_unmap_data(mem, mapped_ptr));
115 return;
116 }
117
118 if (eng_kind == dnnl_cpu) {
119 void *ptr = NULL;
120 CHECK(dnnl_memory_get_data_handle(mem, &ptr));
121 if (ptr) {
122 for (size_t i = 0; i < bytes; ++i) {
123 ((char *)handle)[i] = ((char *)ptr)[i];
124 }
125 }
126 return;
127 }
128
129 assert(!"not expected");
130}
131
132// Read from handle, write to memory
133static inline void write_to_dnnl_memory(void *handle, dnnl_memory_t mem) {
134 dnnl_engine_t eng;
135 dnnl_engine_kind_t eng_kind;
136 const_dnnl_memory_desc_t md;
137
138 if (!handle) COMPLAIN_EXAMPLE_ERROR_AND_EXIT("%s", "handle is NULL.");
139
140 CHECK(dnnl_memory_get_engine(mem, &eng));
141 CHECK(dnnl_engine_get_kind(eng, &eng_kind));
142 CHECK(dnnl_memory_get_memory_desc(mem, &md));
143 size_t bytes = dnnl_memory_desc_get_size(md);
144
145#ifdef DNNL_WITH_SYCL
146 bool is_cpu_sycl
147 = (DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL && eng_kind == dnnl_cpu);
148 bool is_gpu_sycl
149 = (DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL && eng_kind == dnnl_gpu);
150 if (is_cpu_sycl || is_gpu_sycl) {
151 void *mapped_ptr = NULL;
152 CHECK(dnnl_memory_map_data(mem, &mapped_ptr));
153 if (mapped_ptr) {
154 for (size_t i = 0; i < bytes; ++i) {
155 ((char *)mapped_ptr)[i] = ((char *)handle)[i];
156 }
157 }
158 CHECK(dnnl_memory_unmap_data(mem, mapped_ptr));
159 return;
160 }
161#endif
162
163#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
164 if (eng_kind == dnnl_gpu) {
165 void *mapped_ptr = NULL;
166 CHECK(dnnl_memory_map_data(mem, &mapped_ptr));
167 if (mapped_ptr) memcpy(mapped_ptr, handle, bytes);
168 CHECK(dnnl_memory_unmap_data(mem, mapped_ptr));
169 return;
170 }
171#endif
172
173 if (eng_kind == dnnl_cpu) {
174 void *ptr = NULL;
175 CHECK(dnnl_memory_get_data_handle(mem, &ptr));
176 if (ptr) {
177 for (size_t i = 0; i < bytes; ++i) {
178 ((char *)ptr)[i] = ((char *)handle)[i];
179 }
180 }
181 return;
182 }
183
184 assert(!"not expected");
185}
186
187#endif
188