1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /// @example cnn_training_bf16.cpp |
18 | /// @copybrief cnn_training_bf16_cpp |
19 | /// > Annotated version: @ref cnn_training_bf16_cpp |
20 | /// |
21 | /// @page cnn_training_bf16_cpp CNN bf16 training example |
22 | /// This C++ API example demonstrates how to build an AlexNet model training |
23 | /// using the bfloat16 data type. |
24 | /// |
25 | /// The example implements a few layers from AlexNet model. |
26 | /// |
27 | /// @include cnn_training_bf16.cpp |
28 | |
29 | #include <cassert> |
30 | #include <cmath> |
31 | #include <iostream> |
32 | #include <stdexcept> |
33 | |
34 | #include "oneapi/dnnl/dnnl.hpp" |
35 | |
36 | #include "example_utils.hpp" |
37 | |
38 | using namespace dnnl; |
39 | |
40 | void simple_net(engine::kind engine_kind) { |
41 | using tag = memory::format_tag; |
42 | using dt = memory::data_type; |
43 | |
44 | auto eng = engine(engine_kind, 0); |
45 | stream s(eng); |
46 | |
47 | // Vector of primitives and their execute arguments |
48 | std::vector<primitive> net_fwd, net_bwd; |
49 | std::vector<std::unordered_map<int, memory>> net_fwd_args, net_bwd_args; |
50 | |
51 | const int batch = 32; |
52 | |
53 | // float data type is used for user data |
54 | std::vector<float> net_src(batch * 3 * 227 * 227); |
55 | |
56 | // initializing non-zero values for src |
57 | for (size_t i = 0; i < net_src.size(); ++i) |
58 | net_src[i] = sinf((float)i); |
59 | |
60 | // AlexNet: conv |
61 | // {batch, 3, 227, 227} (x) {96, 3, 11, 11} -> {batch, 96, 55, 55} |
62 | // strides: {4, 4} |
63 | |
64 | memory::dims conv_src_tz = {batch, 3, 227, 227}; |
65 | memory::dims conv_weights_tz = {96, 3, 11, 11}; |
66 | memory::dims conv_bias_tz = {96}; |
67 | memory::dims conv_dst_tz = {batch, 96, 55, 55}; |
68 | memory::dims conv_strides = {4, 4}; |
69 | memory::dims conv_padding = {0, 0}; |
70 | |
71 | // float data type is used for user data |
72 | std::vector<float> conv_weights(product(conv_weights_tz)); |
73 | std::vector<float> conv_bias(product(conv_bias_tz)); |
74 | |
75 | // initializing non-zero values for weights and bias |
76 | for (size_t i = 0; i < conv_weights.size(); ++i) |
77 | conv_weights[i] = sinf((float)i); |
78 | for (size_t i = 0; i < conv_bias.size(); ++i) |
79 | conv_bias[i] = sinf((float)i); |
80 | |
81 | // create memory for user data |
82 | auto conv_user_src_memory |
83 | = memory({{conv_src_tz}, dt::f32, tag::nchw}, eng); |
84 | write_to_dnnl_memory(net_src.data(), conv_user_src_memory); |
85 | |
86 | auto conv_user_weights_memory |
87 | = memory({{conv_weights_tz}, dt::f32, tag::oihw}, eng); |
88 | write_to_dnnl_memory(conv_weights.data(), conv_user_weights_memory); |
89 | |
90 | auto conv_user_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng); |
91 | write_to_dnnl_memory(conv_bias.data(), conv_user_bias_memory); |
92 | |
93 | // create memory descriptors for bfloat16 convolution data w/ no specified |
94 | // format tag(`any`) |
95 | // tag `any` lets a primitive(convolution in this case) |
96 | // chose the memory format preferred for best performance. |
97 | auto conv_src_md = memory::desc({conv_src_tz}, dt::bf16, tag::any); |
98 | auto conv_weights_md = memory::desc({conv_weights_tz}, dt::bf16, tag::any); |
99 | auto conv_dst_md = memory::desc({conv_dst_tz}, dt::bf16, tag::any); |
100 | // here bias data type is set to bf16. |
101 | // additionally, f32 data type is supported for bf16 convolution. |
102 | auto conv_bias_md = memory::desc({conv_bias_tz}, dt::bf16, tag::any); |
103 | |
104 | // create a convolution primitive descriptor |
105 | |
106 | // check if bf16 convolution is supported |
107 | try { |
108 | convolution_forward::primitive_desc(eng, prop_kind::forward, |
109 | algorithm::convolution_direct, conv_src_md, conv_weights_md, |
110 | conv_bias_md, conv_dst_md, conv_strides, conv_padding, |
111 | conv_padding); |
112 | } catch (error &e) { |
113 | if (e.status == dnnl_unimplemented) |
114 | throw example_allows_unimplemented { |
115 | "No bf16 convolution implementation is available for this " |
116 | "platform.\n" |
117 | "Please refer to the developer guide for details." }; |
118 | |
119 | // on any other error just re-throw |
120 | throw; |
121 | } |
122 | |
123 | auto conv_pd = convolution_forward::primitive_desc(eng, prop_kind::forward, |
124 | algorithm::convolution_direct, conv_src_md, conv_weights_md, |
125 | conv_bias_md, conv_dst_md, conv_strides, conv_padding, |
126 | conv_padding); |
127 | |
128 | // create reorder primitives between user input and conv src if needed |
129 | auto conv_src_memory = conv_user_src_memory; |
130 | if (conv_pd.src_desc() != conv_user_src_memory.get_desc()) { |
131 | conv_src_memory = memory(conv_pd.src_desc(), eng); |
132 | net_fwd.push_back(reorder(conv_user_src_memory, conv_src_memory)); |
133 | net_fwd_args.push_back({{DNNL_ARG_FROM, conv_user_src_memory}, |
134 | {DNNL_ARG_TO, conv_src_memory}}); |
135 | } |
136 | |
137 | auto conv_weights_memory = conv_user_weights_memory; |
138 | if (conv_pd.weights_desc() != conv_user_weights_memory.get_desc()) { |
139 | conv_weights_memory = memory(conv_pd.weights_desc(), eng); |
140 | net_fwd.push_back( |
141 | reorder(conv_user_weights_memory, conv_weights_memory)); |
142 | net_fwd_args.push_back({{DNNL_ARG_FROM, conv_user_weights_memory}, |
143 | {DNNL_ARG_TO, conv_weights_memory}}); |
144 | } |
145 | |
146 | // convert bias from f32 to bf16 as convolution descriptor is created with |
147 | // bias data type as bf16. |
148 | auto conv_bias_memory = conv_user_bias_memory; |
149 | if (conv_pd.bias_desc() != conv_user_bias_memory.get_desc()) { |
150 | conv_bias_memory = memory(conv_pd.bias_desc(), eng); |
151 | net_fwd.push_back(reorder(conv_user_bias_memory, conv_bias_memory)); |
152 | net_fwd_args.push_back({{DNNL_ARG_FROM, conv_user_bias_memory}, |
153 | {DNNL_ARG_TO, conv_bias_memory}}); |
154 | } |
155 | |
156 | // create memory for conv dst |
157 | auto conv_dst_memory = memory(conv_pd.dst_desc(), eng); |
158 | |
159 | // finally create a convolution primitive |
160 | net_fwd.push_back(convolution_forward(conv_pd)); |
161 | net_fwd_args.push_back({{DNNL_ARG_SRC, conv_src_memory}, |
162 | {DNNL_ARG_WEIGHTS, conv_weights_memory}, |
163 | {DNNL_ARG_BIAS, conv_bias_memory}, |
164 | {DNNL_ARG_DST, conv_dst_memory}}); |
165 | |
166 | // AlexNet: relu |
167 | // {batch, 96, 55, 55} -> {batch, 96, 55, 55} |
168 | memory::dims relu_data_tz = {batch, 96, 55, 55}; |
169 | const float negative_slope = 0.0f; |
170 | |
171 | // create relu primitive desc |
172 | // keep memory format tag of source same as the format tag of convolution |
173 | // output in order to avoid reorder |
174 | auto relu_pd = eltwise_forward::primitive_desc(eng, prop_kind::forward, |
175 | algorithm::eltwise_relu, conv_pd.dst_desc(), conv_pd.dst_desc(), |
176 | negative_slope); |
177 | |
178 | // create relu dst memory |
179 | auto relu_dst_memory = memory(relu_pd.dst_desc(), eng); |
180 | |
181 | // finally create a relu primitive |
182 | net_fwd.push_back(eltwise_forward(relu_pd)); |
183 | net_fwd_args.push_back( |
184 | {{DNNL_ARG_SRC, conv_dst_memory}, {DNNL_ARG_DST, relu_dst_memory}}); |
185 | |
186 | // AlexNet: lrn |
187 | // {batch, 96, 55, 55} -> {batch, 96, 55, 55} |
188 | // local size: 5 |
189 | // alpha: 0.0001 |
190 | // beta: 0.75 |
191 | // k: 1.0 |
192 | memory::dims lrn_data_tz = {batch, 96, 55, 55}; |
193 | const uint32_t local_size = 5; |
194 | const float alpha = 0.0001f; |
195 | const float beta = 0.75f; |
196 | const float k = 1.0f; |
197 | |
198 | // create a lrn primitive descriptor |
199 | auto lrn_pd = lrn_forward::primitive_desc(eng, prop_kind::forward, |
200 | algorithm::lrn_across_channels, relu_pd.dst_desc(), |
201 | relu_pd.dst_desc(), local_size, alpha, beta, k); |
202 | |
203 | // create lrn dst memory |
204 | auto lrn_dst_memory = memory(lrn_pd.dst_desc(), eng); |
205 | |
206 | // create workspace only in training and only for forward primitive |
207 | // query lrn_pd for workspace, this memory will be shared with forward lrn |
208 | auto lrn_workspace_memory = memory(lrn_pd.workspace_desc(), eng); |
209 | |
210 | // finally create a lrn primitive |
211 | net_fwd.push_back(lrn_forward(lrn_pd)); |
212 | net_fwd_args.push_back( |
213 | {{DNNL_ARG_SRC, relu_dst_memory}, {DNNL_ARG_DST, lrn_dst_memory}, |
214 | {DNNL_ARG_WORKSPACE, lrn_workspace_memory}}); |
215 | |
216 | // AlexNet: pool |
217 | // {batch, 96, 55, 55} -> {batch, 96, 27, 27} |
218 | // kernel: {3, 3} |
219 | // strides: {2, 2} |
220 | |
221 | memory::dims pool_dst_tz = {batch, 96, 27, 27}; |
222 | memory::dims pool_kernel = {3, 3}; |
223 | memory::dims pool_strides = {2, 2}; |
224 | memory::dims pool_dilation = {0, 0}; |
225 | memory::dims pool_padding = {0, 0}; |
226 | |
227 | // create memory for pool dst data in user format |
228 | auto pool_user_dst_memory |
229 | = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng); |
230 | |
231 | // create pool dst memory descriptor in format any for bfloat16 data type |
232 | auto pool_dst_md = memory::desc({pool_dst_tz}, dt::bf16, tag::any); |
233 | |
234 | // create a pooling primitive descriptor |
235 | auto pool_pd = pooling_forward::primitive_desc(eng, prop_kind::forward, |
236 | algorithm::pooling_max, lrn_dst_memory.get_desc(), pool_dst_md, |
237 | pool_strides, pool_kernel, pool_dilation, pool_padding, |
238 | pool_padding); |
239 | |
240 | // create pooling workspace memory if training |
241 | auto pool_workspace_memory = memory(pool_pd.workspace_desc(), eng); |
242 | |
243 | // create a pooling primitive |
244 | net_fwd.push_back(pooling_forward(pool_pd)); |
245 | // leave DST unknown for now (see the next reorder) |
246 | net_fwd_args.push_back({{DNNL_ARG_SRC, lrn_dst_memory}, |
247 | // delay putting DST until reorder (if needed) |
248 | {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); |
249 | |
250 | // create reorder primitive between pool dst and user dst format |
251 | // if needed |
252 | auto pool_dst_memory = pool_user_dst_memory; |
253 | if (pool_pd.dst_desc() != pool_user_dst_memory.get_desc()) { |
254 | pool_dst_memory = memory(pool_pd.dst_desc(), eng); |
255 | net_fwd_args.back().insert({DNNL_ARG_DST, pool_dst_memory}); |
256 | |
257 | net_fwd.push_back(reorder(pool_dst_memory, pool_user_dst_memory)); |
258 | net_fwd_args.push_back({{DNNL_ARG_FROM, pool_dst_memory}, |
259 | {DNNL_ARG_TO, pool_user_dst_memory}}); |
260 | } else { |
261 | net_fwd_args.back().insert({DNNL_ARG_DST, pool_dst_memory}); |
262 | } |
263 | |
264 | //----------------------------------------------------------------------- |
265 | //----------------- Backward Stream ------------------------------------- |
266 | // ... user diff_data in float data type ... |
267 | std::vector<float> net_diff_dst(batch * 96 * 27 * 27); |
268 | for (size_t i = 0; i < net_diff_dst.size(); ++i) |
269 | net_diff_dst[i] = sinf((float)i); |
270 | |
271 | // create memory for user diff dst data stored in float data type |
272 | auto pool_user_diff_dst_memory |
273 | = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng); |
274 | write_to_dnnl_memory(net_diff_dst.data(), pool_user_diff_dst_memory); |
275 | |
276 | // Backward pooling |
277 | // create memory descriptors for pooling |
278 | auto pool_diff_src_md = memory::desc({lrn_data_tz}, dt::bf16, tag::any); |
279 | auto pool_diff_dst_md = memory::desc({pool_dst_tz}, dt::bf16, tag::any); |
280 | |
281 | // backward primitive descriptor needs to hint forward descriptor |
282 | auto pool_bwd_pd = pooling_backward::primitive_desc(eng, |
283 | algorithm::pooling_max, pool_diff_src_md, pool_diff_dst_md, |
284 | pool_strides, pool_kernel, pool_dilation, pool_padding, |
285 | pool_padding, pool_pd); |
286 | |
287 | // create reorder primitive between user diff dst and pool diff dst |
288 | // if required |
289 | auto pool_diff_dst_memory = pool_user_diff_dst_memory; |
290 | if (pool_dst_memory.get_desc() != pool_user_diff_dst_memory.get_desc()) { |
291 | pool_diff_dst_memory = memory(pool_dst_memory.get_desc(), eng); |
292 | net_bwd.push_back( |
293 | reorder(pool_user_diff_dst_memory, pool_diff_dst_memory)); |
294 | net_bwd_args.push_back({{DNNL_ARG_FROM, pool_user_diff_dst_memory}, |
295 | {DNNL_ARG_TO, pool_diff_dst_memory}}); |
296 | } |
297 | |
298 | // create memory for pool diff src |
299 | auto pool_diff_src_memory = memory(pool_bwd_pd.diff_src_desc(), eng); |
300 | |
301 | // finally create backward pooling primitive |
302 | net_bwd.push_back(pooling_backward(pool_bwd_pd)); |
303 | net_bwd_args.push_back({{DNNL_ARG_DIFF_DST, pool_diff_dst_memory}, |
304 | {DNNL_ARG_DIFF_SRC, pool_diff_src_memory}, |
305 | {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); |
306 | |
307 | // Backward lrn |
308 | auto lrn_diff_dst_md = memory::desc({lrn_data_tz}, dt::bf16, tag::any); |
309 | const auto &lrn_diff_src_md = lrn_diff_dst_md; |
310 | |
311 | // create backward lrn primitive descriptor |
312 | auto lrn_bwd_pd = lrn_backward::primitive_desc(eng, |
313 | algorithm::lrn_across_channels, lrn_diff_src_md, lrn_diff_dst_md, |
314 | lrn_pd.src_desc(), local_size, alpha, beta, k, lrn_pd); |
315 | |
316 | // create reorder primitive between pool diff src and lrn diff dst |
317 | // if required |
318 | auto lrn_diff_dst_memory = pool_diff_src_memory; |
319 | if (lrn_diff_dst_memory.get_desc() != lrn_bwd_pd.diff_dst_desc()) { |
320 | lrn_diff_dst_memory = memory(lrn_bwd_pd.diff_dst_desc(), eng); |
321 | net_bwd.push_back(reorder(pool_diff_src_memory, lrn_diff_dst_memory)); |
322 | net_bwd_args.push_back({{DNNL_ARG_FROM, pool_diff_src_memory}, |
323 | {DNNL_ARG_TO, lrn_diff_dst_memory}}); |
324 | } |
325 | |
326 | // create memory for lrn diff src |
327 | auto lrn_diff_src_memory = memory(lrn_bwd_pd.diff_src_desc(), eng); |
328 | |
329 | // finally create a lrn backward primitive |
330 | // backward lrn needs src: relu dst in this topology |
331 | net_bwd.push_back(lrn_backward(lrn_bwd_pd)); |
332 | net_bwd_args.push_back({{DNNL_ARG_SRC, relu_dst_memory}, |
333 | {DNNL_ARG_DIFF_DST, lrn_diff_dst_memory}, |
334 | {DNNL_ARG_DIFF_SRC, lrn_diff_src_memory}, |
335 | {DNNL_ARG_WORKSPACE, lrn_workspace_memory}}); |
336 | |
337 | // Backward relu |
338 | auto relu_diff_src_md = memory::desc({relu_data_tz}, dt::bf16, tag::any); |
339 | auto relu_diff_dst_md = memory::desc({relu_data_tz}, dt::bf16, tag::any); |
340 | auto relu_src_md = conv_pd.dst_desc(); |
341 | |
342 | // create backward relu primitive_descriptor |
343 | auto relu_bwd_pd = eltwise_backward::primitive_desc(eng, |
344 | algorithm::eltwise_relu, relu_diff_src_md, relu_diff_dst_md, |
345 | relu_src_md, negative_slope, relu_pd); |
346 | |
347 | // create reorder primitive between lrn diff src and relu diff dst |
348 | // if required |
349 | auto relu_diff_dst_memory = lrn_diff_src_memory; |
350 | if (relu_diff_dst_memory.get_desc() != relu_bwd_pd.diff_dst_desc()) { |
351 | relu_diff_dst_memory = memory(relu_bwd_pd.diff_dst_desc(), eng); |
352 | net_bwd.push_back(reorder(lrn_diff_src_memory, relu_diff_dst_memory)); |
353 | net_bwd_args.push_back({{DNNL_ARG_FROM, lrn_diff_src_memory}, |
354 | {DNNL_ARG_TO, relu_diff_dst_memory}}); |
355 | } |
356 | |
357 | // create memory for relu diff src |
358 | auto relu_diff_src_memory = memory(relu_bwd_pd.diff_src_desc(), eng); |
359 | |
360 | // finally create a backward relu primitive |
361 | net_bwd.push_back(eltwise_backward(relu_bwd_pd)); |
362 | net_bwd_args.push_back({{DNNL_ARG_SRC, conv_dst_memory}, |
363 | {DNNL_ARG_DIFF_DST, relu_diff_dst_memory}, |
364 | {DNNL_ARG_DIFF_SRC, relu_diff_src_memory}}); |
365 | |
366 | // Backward convolution with respect to weights |
367 | // create user format diff weights and diff bias memory for float data type |
368 | |
369 | auto conv_user_diff_weights_memory |
370 | = memory({{conv_weights_tz}, dt::f32, tag::nchw}, eng); |
371 | auto conv_diff_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng); |
372 | |
373 | // create memory descriptors for bfloat16 convolution data |
374 | auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::bf16, tag::any); |
375 | auto conv_diff_weights_md |
376 | = memory::desc({conv_weights_tz}, dt::bf16, tag::any); |
377 | auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::bf16, tag::any); |
378 | |
379 | // use diff bias provided by the user |
380 | auto conv_diff_bias_md = conv_diff_bias_memory.get_desc(); |
381 | |
382 | // create backward convolution primitive descriptor |
383 | auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc(eng, |
384 | algorithm::convolution_direct, conv_bwd_src_md, |
385 | conv_diff_weights_md, conv_diff_bias_md, conv_diff_dst_md, |
386 | conv_strides, conv_padding, conv_padding, conv_pd); |
387 | |
388 | // for best performance convolution backward might chose |
389 | // different memory format for src and diff_dst |
390 | // than the memory formats preferred by forward convolution |
391 | // for src and dst respectively |
392 | // create reorder primitives for src from forward convolution to the |
393 | // format chosen by backward convolution |
394 | auto conv_bwd_src_memory = conv_src_memory; |
395 | if (conv_bwd_weights_pd.src_desc() != conv_src_memory.get_desc()) { |
396 | conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_desc(), eng); |
397 | net_bwd.push_back(reorder(conv_src_memory, conv_bwd_src_memory)); |
398 | net_bwd_args.push_back({{DNNL_ARG_FROM, conv_src_memory}, |
399 | {DNNL_ARG_TO, conv_bwd_src_memory}}); |
400 | } |
401 | |
402 | // create reorder primitives for diff_dst between diff_src from relu_bwd |
403 | // and format preferred by conv_diff_weights |
404 | auto conv_diff_dst_memory = relu_diff_src_memory; |
405 | if (conv_bwd_weights_pd.diff_dst_desc() |
406 | != relu_diff_src_memory.get_desc()) { |
407 | conv_diff_dst_memory = memory(conv_bwd_weights_pd.diff_dst_desc(), eng); |
408 | net_bwd.push_back(reorder(relu_diff_src_memory, conv_diff_dst_memory)); |
409 | net_bwd_args.push_back({{DNNL_ARG_FROM, relu_diff_src_memory}, |
410 | {DNNL_ARG_TO, conv_diff_dst_memory}}); |
411 | } |
412 | |
413 | // create backward convolution primitive |
414 | net_bwd.push_back(convolution_backward_weights(conv_bwd_weights_pd)); |
415 | net_bwd_args.push_back({{DNNL_ARG_SRC, conv_bwd_src_memory}, |
416 | {DNNL_ARG_DIFF_DST, conv_diff_dst_memory}, |
417 | // delay putting DIFF_WEIGHTS until reorder (if needed) |
418 | {DNNL_ARG_DIFF_BIAS, conv_diff_bias_memory}}); |
419 | |
420 | // create reorder primitives between conv diff weights and user diff weights |
421 | // if needed |
422 | auto conv_diff_weights_memory = conv_user_diff_weights_memory; |
423 | if (conv_bwd_weights_pd.diff_weights_desc() |
424 | != conv_user_diff_weights_memory.get_desc()) { |
425 | conv_diff_weights_memory |
426 | = memory(conv_bwd_weights_pd.diff_weights_desc(), eng); |
427 | net_bwd_args.back().insert( |
428 | {DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory}); |
429 | |
430 | net_bwd.push_back(reorder( |
431 | conv_diff_weights_memory, conv_user_diff_weights_memory)); |
432 | net_bwd_args.push_back({{DNNL_ARG_FROM, conv_diff_weights_memory}, |
433 | {DNNL_ARG_TO, conv_user_diff_weights_memory}}); |
434 | } else { |
435 | net_bwd_args.back().insert( |
436 | {DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory}); |
437 | } |
438 | |
439 | // didn't we forget anything? |
440 | assert(net_fwd.size() == net_fwd_args.size() && "something is missing" ); |
441 | assert(net_bwd.size() == net_bwd_args.size() && "something is missing" ); |
442 | |
443 | int n_iter = 1; // number of iterations for training |
444 | // execute |
445 | while (n_iter) { |
446 | // forward |
447 | for (size_t i = 0; i < net_fwd.size(); ++i) |
448 | net_fwd.at(i).execute(s, net_fwd_args.at(i)); |
449 | |
450 | // update net_diff_dst |
451 | // auto net_output = pool_user_dst_memory.get_data_handle(); |
452 | // ..user updates net_diff_dst using net_output... |
453 | // some user defined func update_diff_dst(net_diff_dst.data(), |
454 | // net_output) |
455 | |
456 | for (size_t i = 0; i < net_bwd.size(); ++i) |
457 | net_bwd.at(i).execute(s, net_bwd_args.at(i)); |
458 | // update weights and bias using diff weights and bias |
459 | // |
460 | // auto net_diff_weights |
461 | // = conv_user_diff_weights_memory.get_data_handle(); |
462 | // auto net_diff_bias = conv_diff_bias_memory.get_data_handle(); |
463 | // |
464 | // ...user updates weights and bias using diff weights and bias... |
465 | // |
466 | // some user defined func update_weights(conv_weights.data(), |
467 | // conv_bias.data(), net_diff_weights, net_diff_bias); |
468 | |
469 | --n_iter; |
470 | } |
471 | |
472 | s.wait(); |
473 | } |
474 | |
475 | int main(int argc, char **argv) { |
476 | return handle_example_errors(simple_net, parse_engine_kind(argc, argv)); |
477 | } |
478 | |