1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef EXAMPLE_UTILS_H |
18 | #define EXAMPLE_UTILS_H |
19 | |
20 | #include <assert.h> |
21 | #include <stdbool.h> |
22 | #include <stdio.h> |
23 | #include <stdlib.h> |
24 | #include <string.h> |
25 | |
26 | #include "dnnl.h" |
27 | #include "dnnl_debug.h" |
28 | |
29 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL |
30 | #include "dnnl_ocl.h" |
31 | #endif |
32 | |
33 | #define COMPLAIN_DNNL_ERROR_AND_EXIT(what, status) \ |
34 | do { \ |
35 | printf("[%s:%d] `%s` returns oneDNN error: %s.\n", __FILE__, __LINE__, \ |
36 | what, dnnl_status2str(status)); \ |
37 | printf("Example failed.\n"); \ |
38 | exit(1); \ |
39 | } while (0) |
40 | |
41 | #define COMPLAIN_EXAMPLE_ERROR_AND_EXIT(complain_fmt, ...) \ |
42 | do { \ |
43 | printf("[%s:%d] Error in the example: " complain_fmt ".\n", __FILE__, \ |
44 | __LINE__, __VA_ARGS__); \ |
45 | printf("Example failed.\n"); \ |
46 | exit(2); \ |
47 | } while (0) |
48 | |
49 | static dnnl_engine_kind_t validate_engine_kind(dnnl_engine_kind_t akind) { |
50 | // Checking if a GPU exists on the machine |
51 | if (akind == dnnl_gpu) { |
52 | if (!dnnl_engine_get_count(dnnl_gpu)) { |
53 | printf("Application couldn't find GPU, please run with CPU " |
54 | "instead.\n" ); |
55 | exit(0); |
56 | } |
57 | } |
58 | return akind; |
59 | } |
60 | |
61 | #define CHECK(f) \ |
62 | do { \ |
63 | dnnl_status_t s_ = f; \ |
64 | if (s_ != dnnl_success) COMPLAIN_DNNL_ERROR_AND_EXIT(#f, s_); \ |
65 | } while (0) |
66 | |
67 | static inline dnnl_engine_kind_t parse_engine_kind(int argc, char **argv) { |
68 | // Returns default engine kind, i.e. CPU, if none given |
69 | if (argc == 1) { |
70 | return validate_engine_kind(dnnl_cpu); |
71 | } else if (argc == 2) { |
72 | // Checking the engine type, i.e. CPU or GPU |
73 | char *engine_kind_str = argv[1]; |
74 | if (!strcmp(engine_kind_str, "cpu" )) { |
75 | return validate_engine_kind(dnnl_cpu); |
76 | } else if (!strcmp(engine_kind_str, "gpu" )) { |
77 | return validate_engine_kind(dnnl_gpu); |
78 | } |
79 | } |
80 | |
81 | // If all above fails, the example should be run properly |
82 | COMPLAIN_EXAMPLE_ERROR_AND_EXIT( |
83 | "inappropriate engine kind.\n" |
84 | "Please run the example like this: %s [cpu|gpu]." , |
85 | argv[0]); |
86 | } |
87 | |
88 | static inline const char *engine_kind2str_upper(dnnl_engine_kind_t kind) { |
89 | if (kind == dnnl_cpu) return "CPU" ; |
90 | if (kind == dnnl_gpu) return "GPU" ; |
91 | return "<Unknown engine>" ; |
92 | } |
93 | |
94 | // Read from memory, write to handle |
95 | static inline void read_from_dnnl_memory(void *handle, dnnl_memory_t mem) { |
96 | dnnl_engine_t eng; |
97 | dnnl_engine_kind_t eng_kind; |
98 | const_dnnl_memory_desc_t md; |
99 | |
100 | if (!handle) COMPLAIN_EXAMPLE_ERROR_AND_EXIT("%s" , "handle is NULL." ); |
101 | |
102 | CHECK(dnnl_memory_get_engine(mem, &eng)); |
103 | CHECK(dnnl_engine_get_kind(eng, &eng_kind)); |
104 | CHECK(dnnl_memory_get_memory_desc(mem, &md)); |
105 | size_t bytes = dnnl_memory_desc_get_size(md); |
106 | |
107 | bool is_cpu_sycl |
108 | = (DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL && eng_kind == dnnl_cpu); |
109 | |
110 | if (eng_kind == dnnl_gpu || is_cpu_sycl) { |
111 | void *mapped_ptr = NULL; |
112 | CHECK(dnnl_memory_map_data(mem, &mapped_ptr)); |
113 | if (mapped_ptr) memcpy(handle, mapped_ptr, bytes); |
114 | CHECK(dnnl_memory_unmap_data(mem, mapped_ptr)); |
115 | return; |
116 | } |
117 | |
118 | if (eng_kind == dnnl_cpu) { |
119 | void *ptr = NULL; |
120 | CHECK(dnnl_memory_get_data_handle(mem, &ptr)); |
121 | if (ptr) { |
122 | for (size_t i = 0; i < bytes; ++i) { |
123 | ((char *)handle)[i] = ((char *)ptr)[i]; |
124 | } |
125 | } |
126 | return; |
127 | } |
128 | |
129 | assert(!"not expected" ); |
130 | } |
131 | |
132 | // Read from handle, write to memory |
133 | static inline void write_to_dnnl_memory(void *handle, dnnl_memory_t mem) { |
134 | dnnl_engine_t eng; |
135 | dnnl_engine_kind_t eng_kind; |
136 | const_dnnl_memory_desc_t md; |
137 | |
138 | if (!handle) COMPLAIN_EXAMPLE_ERROR_AND_EXIT("%s" , "handle is NULL." ); |
139 | |
140 | CHECK(dnnl_memory_get_engine(mem, &eng)); |
141 | CHECK(dnnl_engine_get_kind(eng, &eng_kind)); |
142 | CHECK(dnnl_memory_get_memory_desc(mem, &md)); |
143 | size_t bytes = dnnl_memory_desc_get_size(md); |
144 | |
145 | #ifdef DNNL_WITH_SYCL |
146 | bool is_cpu_sycl |
147 | = (DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL && eng_kind == dnnl_cpu); |
148 | bool is_gpu_sycl |
149 | = (DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL && eng_kind == dnnl_gpu); |
150 | if (is_cpu_sycl || is_gpu_sycl) { |
151 | void *mapped_ptr = NULL; |
152 | CHECK(dnnl_memory_map_data(mem, &mapped_ptr)); |
153 | if (mapped_ptr) { |
154 | for (size_t i = 0; i < bytes; ++i) { |
155 | ((char *)mapped_ptr)[i] = ((char *)handle)[i]; |
156 | } |
157 | } |
158 | CHECK(dnnl_memory_unmap_data(mem, mapped_ptr)); |
159 | return; |
160 | } |
161 | #endif |
162 | |
163 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL |
164 | if (eng_kind == dnnl_gpu) { |
165 | void *mapped_ptr = NULL; |
166 | CHECK(dnnl_memory_map_data(mem, &mapped_ptr)); |
167 | if (mapped_ptr) memcpy(mapped_ptr, handle, bytes); |
168 | CHECK(dnnl_memory_unmap_data(mem, mapped_ptr)); |
169 | return; |
170 | } |
171 | #endif |
172 | |
173 | if (eng_kind == dnnl_cpu) { |
174 | void *ptr = NULL; |
175 | CHECK(dnnl_memory_get_data_handle(mem, &ptr)); |
176 | if (ptr) { |
177 | for (size_t i = 0; i < bytes; ++i) { |
178 | ((char *)ptr)[i] = ((char *)handle)[i]; |
179 | } |
180 | } |
181 | return; |
182 | } |
183 | |
184 | assert(!"not expected" ); |
185 | } |
186 | |
187 | #endif |
188 | |