1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * Copyright 2022 Arm Ltd. and affiliates |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #ifndef CPU_REF_DECONVOLUTION_HPP |
19 | #define CPU_REF_DECONVOLUTION_HPP |
20 | |
21 | #include <assert.h> |
22 | #include <string.h> |
23 | |
24 | #include "common/c_types_map.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/primitive_desc_iterator.hpp" |
27 | #include "common/stream.hpp" |
28 | #include "common/type_helpers.hpp" |
29 | #include "common/utils.hpp" |
30 | |
31 | #include "cpu/primitive_attr_postops.hpp" |
32 | |
33 | #include "cpu/cpu_convolution_pd.hpp" |
34 | #include "cpu/cpu_deconvolution_pd.hpp" |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace cpu { |
39 | |
40 | static status_t weights_axes_permutation( |
41 | memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) { |
42 | int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation |
43 | for (int d = 0; d < DNNL_MAX_NDIMS; ++d) |
44 | perm[d] = d; |
45 | nstl::swap(perm[0 + with_groups], perm[1 + with_groups]); |
46 | |
47 | return memory_desc_permute_axes(*o_md, *i_md, perm); |
48 | } |
49 | |
50 | static status_t conv_descr_create(const deconvolution_desc_t *dd, |
51 | convolution_desc_t *cd, const memory_desc_t *bias_md = nullptr, |
52 | data_type_t src_dt = data_type::undef) { |
53 | using namespace prop_kind; |
54 | alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct |
55 | ? alg_kind::convolution_direct |
56 | : alg_kind::convolution_winograd; |
57 | |
58 | const memory_desc_t *src_md, *dst_md, *d_weights_d; |
59 | memory_desc_t src_md_patched; |
60 | prop_kind_t prop_kind; |
61 | |
62 | if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) { |
63 | prop_kind = backward_data; |
64 | assert(src_dt != data_type::undef); |
65 | memory_desc_init_by_md_and_dt(src_md_patched, dd->dst_desc, src_dt); |
66 | src_md = &src_md_patched; |
67 | dst_md = &dd->src_desc; |
68 | d_weights_d = &dd->weights_desc; |
69 | } else if (dd->prop_kind == backward_data) { |
70 | assert(src_dt == data_type::undef); |
71 | prop_kind = forward_training; |
72 | src_md = &dd->diff_dst_desc; |
73 | dst_md = &dd->diff_src_desc; |
74 | d_weights_d = &dd->weights_desc; |
75 | } else { |
76 | assert(src_dt == data_type::undef); |
77 | prop_kind = dd->prop_kind; |
78 | src_md = &dd->diff_dst_desc; |
79 | dst_md = &dd->src_desc; |
80 | d_weights_d = &dd->diff_weights_desc; |
81 | } |
82 | |
83 | /* create weights desc for convolution */ |
84 | memory_desc_t c_weights_d; |
85 | const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; |
86 | CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups)); |
87 | |
88 | return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, |
89 | bias_md, dst_md, dd->strides, dd->dilates, dd->padding[0], |
90 | dd->padding[1]); |
91 | } |
92 | |
93 | struct ref_deconvolution_fwd_t : public primitive_t { |
94 | struct pd_t : public cpu_deconvolution_fwd_pd_t { |
95 | pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, |
96 | const deconvolution_fwd_pd_t *hint_fwd_pd) |
97 | : cpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
98 | |
99 | pd_t(const pd_t &other) |
100 | : cpu_deconvolution_fwd_pd_t(other) |
101 | , conv_pd_(other.conv_pd_->clone()) |
102 | , conv_supports_bias_(other.conv_supports_bias_) |
103 | , dst_tag_(other.dst_tag_) {} |
104 | |
105 | ~pd_t() = default; |
106 | |
107 | DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t); |
108 | |
109 | status_t init_convolution(engine_t *engine) { |
110 | using namespace format_tag; |
111 | using namespace data_type; |
112 | |
113 | // Create empty attributes for bwd_d conv to pick up the fastest |
114 | // impl available and apply post-ops and/or bias update later in |
115 | // this impl via simple loop. |
116 | primitive_attr_t conv_attr; |
117 | |
118 | convolution_desc_t cd; |
119 | // When no attributes were requested, try to find a bwd_d conv impl |
120 | // which supports bias update in-place, if requested, in requested |
121 | // dst_dt. If appropriate conv impl was not found, enforce f32 |
122 | // diff_src for conv for correct result. If attributes are |
123 | // requested, enforce conv impl to return f32 output no matter what. |
124 | if (attr()->has_default_values()) { |
125 | CHECK(conv_descr_create( |
126 | desc(), &cd, weights_md(1), dst_md()->data_type)); |
127 | primitive_desc_iterator_t it( |
128 | engine, (op_desc_t *)&cd, &conv_attr, nullptr); |
129 | if (!it.is_initialized()) return status::out_of_memory; |
130 | |
131 | while (++it != it.end()) { |
132 | conv_pd_ = *it; |
133 | if (with_bias()) { |
134 | conv_supports_bias_ = utils::downcast< |
135 | cpu_convolution_bwd_data_pd_t *>(conv_pd_.get()) |
136 | ->support_bias(); |
137 | if (!conv_supports_bias_) continue; |
138 | } |
139 | bool ok = conv_pd_->weights_md()->extra.flags == 0; |
140 | if (ok) return status::success; |
141 | } |
142 | } |
143 | |
144 | // Intermediate f32 buffer is supported only for given condition. |
145 | if (!attr()->has_default_values() || with_bias()) { |
146 | // Enforce f32 dt for diff src and work with f32 output for bias |
147 | // update or post ops after conv execution. |
148 | CHECK(conv_descr_create(desc(), &cd, nullptr, data_type::f32)); |
149 | primitive_desc_iterator_t it( |
150 | engine, (op_desc_t *)&cd, &conv_attr, nullptr); |
151 | if (!it.is_initialized()) return status::out_of_memory; |
152 | |
153 | while (++it != it.end()) { |
154 | conv_pd_ = *it; |
155 | bool ok = conv_pd_->weights_md()->extra.flags == 0; |
156 | if (ok) return status::success; |
157 | } |
158 | } |
159 | return status::unimplemented; |
160 | } |
161 | |
162 | status_t init(engine_t *engine) { |
163 | using namespace format_tag; |
164 | using smask_t = primitive_attr_t::skip_mask_t; |
165 | |
166 | const bool ok = is_fwd() |
167 | && utils::one_of(desc()->alg_kind, |
168 | alg_kind::deconvolution_direct, |
169 | alg_kind::deconvolution_winograd) |
170 | && attr()->has_default_values(smask_t::scales_runtime |
171 | | smask_t::post_ops | smask_t::zero_points_runtime) |
172 | && scales_mask_ok() && post_ops_ok() && zero_points_ok(); |
173 | if (!ok) return status::unimplemented; |
174 | |
175 | CHECK(init_convolution(engine)); |
176 | |
177 | if (weights_md_.format_kind == format_kind::any) |
178 | CHECK(weights_axes_permutation( |
179 | &weights_md_, conv_pd_->weights_md(), with_groups())); |
180 | if (src_md_.format_kind == format_kind::any) |
181 | src_md_ = *conv_pd_->diff_dst_md(); |
182 | if (dst_md_.format_kind == format_kind::any) { |
183 | // re-apply dt manually since it could be changed due to bias |
184 | const auto dst_dt = dst_md_.data_type; |
185 | memory_desc_init_by_md_and_dt( |
186 | dst_md_, *conv_pd_->diff_src_md(), dst_dt); |
187 | } |
188 | if (bias_md_.format_kind == format_kind::any) |
189 | CHECK(memory_desc_init_by_tag(bias_md_, x)); |
190 | |
191 | dst_tag_ = memory_desc_matches_one_of_tag(dst_md_, |
192 | utils::pick(ndims() - 3, ncw, nchw, ncdhw), |
193 | utils::pick(ndims() - 3, nwc, nhwc, ndhwc), |
194 | utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), |
195 | utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); |
196 | |
197 | init_scratchpad(); |
198 | return attr_.set_default_formats(dst_md(0)); |
199 | } |
200 | |
201 | std::shared_ptr<primitive_desc_t> conv_pd_; |
202 | bool conv_supports_bias_ = false; |
203 | format_tag_t dst_tag_; |
204 | |
205 | private: |
206 | void init_scratchpad() { |
207 | using namespace memory_tracking::names; |
208 | auto scratchpad = scratchpad_registry().registrar(); |
209 | scratchpad.book(key_nested, conv_pd_->scratchpad_registry()); |
210 | |
211 | // This scratchpad is required for intermediate f32 conv output |
212 | // since original memory can be of smaller size and will cause |
213 | // out of boundary access. |
214 | if ((with_bias() && !conv_supports_bias_) |
215 | || !attr()->has_default_values()) { |
216 | const memory_desc_wrapper diff_src_d(conv_pd_->diff_src_md()); |
217 | assert(diff_src_d.data_type_size() == sizeof(float)); |
218 | scratchpad.book(key_deconv_bias, diff_src_d.nelems(true), |
219 | diff_src_d.data_type_size()); |
220 | } |
221 | // This scratchpad is required to stash original dst memory for sum |
222 | // post-op. It will be overwritten by conv execution and will not |
223 | // be available to get the correct result. |
224 | const memory_desc_wrapper dst_d(dst_md()); |
225 | if (attr()->post_ops_.find(primitive_kind::sum) != -1) |
226 | scratchpad.book(key_deconv_sum, dst_d.nelems(true), |
227 | dst_d.data_type_size()); |
228 | |
229 | if (!attr()->zero_points_.has_default_values(DNNL_ARG_SRC)) { |
230 | scratchpad.book<int32_t>(key_deconv_zp, OC() * G()); |
231 | } |
232 | } |
233 | |
234 | bool scales_mask_ok() const { |
235 | using namespace data_type; |
236 | const std::vector<int> supported_args |
237 | = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; |
238 | bool ok = attr()->scales_.has_default_values(supported_args); |
239 | for (int arg : supported_args) { |
240 | const auto &mask = attr()->scales_.get(arg).mask_; |
241 | if (arg == DNNL_ARG_WEIGHTS) |
242 | ok = ok && (mask == 0 || mask == (1 << (int)with_groups())); |
243 | else |
244 | ok = ok && (mask == 0); |
245 | } |
246 | return ok; |
247 | } |
248 | |
249 | bool post_ops_ok() const { |
250 | return attr()->post_ops_.find(primitive_kind::convolution) == -1; |
251 | } |
252 | |
253 | bool zero_points_ok() const { |
254 | using namespace data_type; |
255 | int mask_src = 0, mask_dst = 0; |
256 | attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src); |
257 | attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst); |
258 | |
259 | return IMPLICATION(!utils::one_of(src_md()->data_type, s8, u8), |
260 | attr()->zero_points_.has_default_values()) |
261 | && attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) |
262 | && (mask_src == 0 || mask_src == 1 << 1) |
263 | && (mask_dst == 0 || mask_dst == 1 << 1); |
264 | } |
265 | }; |
266 | |
267 | ref_deconvolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
268 | |
269 | status_t init(engine_t *engine) override { |
270 | CHECK(pd()->conv_pd_->create_primitive(conv_p_, engine)); |
271 | |
272 | ref_post_ops |
273 | = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_); |
274 | if (!ref_post_ops) return status::out_of_memory; |
275 | return status::success; |
276 | } |
277 | |
278 | status_t execute(const exec_ctx_t &ctx) const override; |
279 | |
280 | private: |
281 | void compute_fwd_bias_common(const exec_ctx_t &ctx, void *dst, |
282 | const float *conv_output, bool non_default_attr) const; |
283 | |
284 | void compute_fwd_bias_ncdhw(const exec_ctx_t &ctx, void *dst, |
285 | const float *conv_output, bool non_default_attr) const; |
286 | |
287 | void compute_fwd_bias_ndhwc(const exec_ctx_t &ctx, void *dst, |
288 | const float *conv_output, bool non_default_attr) const; |
289 | |
290 | template <dim_t blk_size> |
291 | void compute_fwd_bias_nCdhwXc(const exec_ctx_t &ctx, void *dst, |
292 | const float *conv_output, bool non_default_attr) const; |
293 | |
294 | status_t compute_oscale(const exec_ctx_t &ctx, float *dst) const; |
295 | |
296 | void compute_fwd_bias(const exec_ctx_t &ctx, void *dst, |
297 | const float *conv_output, bool non_default_attr) const; |
298 | |
299 | status_t compute_ref_attrs(const exec_ctx_t &ctx, const float *conv_output, |
300 | void *original_dst) const; |
301 | |
302 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
303 | std::shared_ptr<primitive_t> conv_p_; |
304 | std::unique_ptr<ref_post_ops_t> ref_post_ops; |
305 | }; |
306 | |
307 | struct ref_deconvolution_bwd_data_t : public primitive_t { |
308 | struct pd_t : public cpu_deconvolution_bwd_data_pd_t { |
309 | pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, |
310 | const deconvolution_fwd_pd_t *hint_fwd_pd) |
311 | : cpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {} |
312 | |
313 | pd_t(const pd_t &other) |
314 | : cpu_deconvolution_bwd_data_pd_t(other) |
315 | , conv_pd_(other.conv_pd_->clone()) {} |
316 | |
317 | ~pd_t() = default; |
318 | |
319 | DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t); |
320 | |
321 | status_t init_convolution(engine_t *engine) { |
322 | using namespace types; |
323 | |
324 | convolution_desc_t cd; |
325 | status_t status = conv_descr_create(desc(), &cd); |
326 | if (status != status::success) return status; |
327 | primitive_attr_t conv_attr(*attr()); |
328 | if (!conv_attr.is_initialized()) return status::out_of_memory; |
329 | |
330 | primitive_desc_iterator_t it( |
331 | engine, (op_desc_t *)&cd, &conv_attr, nullptr); |
332 | if (!it.is_initialized()) return status::out_of_memory; |
333 | while (++it != it.end()) { |
334 | conv_pd_ = *it; |
335 | if (conv_pd_->weights_md()->extra.flags == 0) |
336 | return status::success; |
337 | } |
338 | |
339 | return status::unimplemented; |
340 | } |
341 | |
342 | status_t init(engine_t *engine) { |
343 | using namespace data_type; |
344 | auto dsrc_type = desc()->diff_src_desc.data_type; |
345 | auto wei_type = desc()->weights_desc.data_type; |
346 | auto ddst_type = desc()->diff_dst_desc.data_type; |
347 | bool ok = true && desc()->prop_kind == prop_kind::backward_data |
348 | && utils::one_of(wei_type, f32, bf16, f16) |
349 | && ddst_type == wei_type |
350 | && utils::one_of(dsrc_type, wei_type, f32) |
351 | && utils::one_of(desc()->alg_kind, |
352 | alg_kind::deconvolution_direct, |
353 | alg_kind::deconvolution_winograd) |
354 | && attr()->has_default_values(); |
355 | |
356 | if (ok) { |
357 | CHECK(init_convolution(engine)); |
358 | if (weights_md_.format_kind == format_kind::any) |
359 | CHECK(weights_axes_permutation(&weights_md_, |
360 | conv_pd_->weights_md(), with_groups())); |
361 | if (diff_src_md_.format_kind == format_kind::any) |
362 | diff_src_md_ = *conv_pd_->dst_md(); |
363 | if (diff_dst_md_.format_kind == format_kind::any) |
364 | diff_dst_md_ = *conv_pd_->src_md(); |
365 | init_scratchpad(); |
366 | return status::success; |
367 | } |
368 | |
369 | return status::unimplemented; |
370 | } |
371 | |
372 | std::shared_ptr<primitive_desc_t> conv_pd_; |
373 | |
374 | private: |
375 | void init_scratchpad() { |
376 | auto scratchpad = scratchpad_registry().registrar(); |
377 | scratchpad.book(memory_tracking::names::key_nested, |
378 | conv_pd_->scratchpad_registry()); |
379 | } |
380 | }; |
381 | |
382 | typedef typename prec_traits<data_type::f32>::type data_t; |
383 | |
384 | ref_deconvolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} |
385 | |
386 | status_t init(engine_t *engine) override { |
387 | return pd()->conv_pd_->create_primitive(conv_p_, engine); |
388 | } |
389 | |
390 | #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL |
391 | status_t create_resource( |
392 | engine_t *engine, resource_mapper_t &mapper) const override { |
393 | CHECK(conv_p_->create_resource(engine, mapper)); |
394 | return status::success; |
395 | } |
396 | #endif |
397 | |
398 | status_t execute(const exec_ctx_t &ctx) const override; |
399 | |
400 | private: |
401 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
402 | std::shared_ptr<primitive_t> conv_p_; |
403 | }; |
404 | |
405 | struct ref_deconvolution_bwd_weights_t : public primitive_t { |
406 | struct pd_t : public cpu_deconvolution_bwd_weights_pd_t { |
407 | pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, |
408 | const deconvolution_fwd_pd_t *hint_fwd_pd) |
409 | : cpu_deconvolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {} |
410 | |
411 | pd_t(const pd_t &other) |
412 | : cpu_deconvolution_bwd_weights_pd_t(other) |
413 | , conv_pd_(other.conv_pd_->clone()) |
414 | , dst_tag_(other.dst_tag_) {} |
415 | |
416 | ~pd_t() = default; |
417 | |
418 | DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t); |
419 | |
420 | status_t init_convolution(engine_t *engine) { |
421 | using namespace types; |
422 | using namespace format_tag; |
423 | |
424 | convolution_desc_t cd; |
425 | status_t status = conv_descr_create(desc(), &cd); |
426 | if (status != status::success) return status; |
427 | primitive_attr_t conv_attr(*attr()); |
428 | if (!conv_attr.is_initialized()) return status::out_of_memory; |
429 | |
430 | primitive_desc_iterator_t it( |
431 | engine, (op_desc_t *)&cd, &conv_attr, nullptr); |
432 | if (!it.is_initialized()) return status::out_of_memory; |
433 | while (++it != it.end()) { |
434 | conv_pd_ = *it; |
435 | bool bf16_ref_deconv_supports_bias = IMPLICATION(with_bias() |
436 | && desc()->src_desc.data_type |
437 | == data_type::bf16, |
438 | memory_desc_matches_one_of_tag(*conv_pd_->src_md(), |
439 | utils::pick(ndims() - 3, ncw, nchw, ncdhw), |
440 | utils::pick(ndims() - 3, nwc, nhwc, ndhwc), |
441 | utils::pick(ndims() - 3, nCw16c, nChw16c, |
442 | nCdhw16c))); |
443 | if (conv_pd_->diff_weights_md()->extra.flags == 0 |
444 | && bf16_ref_deconv_supports_bias) { |
445 | return status::success; |
446 | } |
447 | } |
448 | return status::unimplemented; |
449 | } |
450 | |
451 | status_t init(engine_t *engine) { |
452 | using namespace format_tag; |
453 | using namespace data_type; |
454 | auto src_type = desc()->src_desc.data_type; |
455 | auto dwei_type = desc()->diff_weights_desc.data_type; |
456 | auto ddst_type = desc()->diff_dst_desc.data_type; |
457 | bool ok = true && desc()->prop_kind == prop_kind::backward_weights |
458 | && utils::one_of(src_type, f32, bf16, f16) |
459 | && ddst_type == src_type |
460 | && utils::one_of(dwei_type, src_type, f32) |
461 | && utils::one_of(desc()->alg_kind, |
462 | alg_kind::deconvolution_direct, |
463 | alg_kind::deconvolution_winograd) |
464 | && attr()->has_default_values(); |
465 | |
466 | if (ok) { |
467 | CHECK(init_convolution(engine)); |
468 | if (diff_weights_md_.format_kind == format_kind::any) |
469 | CHECK(weights_axes_permutation(&diff_weights_md_, |
470 | conv_pd_->diff_weights_md(), with_groups())); |
471 | if (src_md_.format_kind == format_kind::any) |
472 | src_md_ = *conv_pd_->diff_dst_md(); |
473 | if (diff_dst_md_.format_kind == format_kind::any) |
474 | diff_dst_md_ = *conv_pd_->src_md(); |
475 | if (diff_bias_md_.format_kind == format_kind::any) |
476 | CHECK(memory_desc_init_by_tag(diff_bias_md_, x)); |
477 | |
478 | dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_, |
479 | utils::pick(ndims() - 3, ncw, nchw, ncdhw), |
480 | utils::pick(ndims() - 3, nwc, nhwc, ndhwc), |
481 | utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), |
482 | utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); |
483 | init_scratchpad(); |
484 | return status::success; |
485 | } |
486 | |
487 | return status::unimplemented; |
488 | } |
489 | |
490 | std::shared_ptr<primitive_desc_t> conv_pd_; |
491 | format_tag_t dst_tag_; |
492 | |
493 | private: |
494 | void init_scratchpad() { |
495 | auto scratchpad = scratchpad_registry().registrar(); |
496 | scratchpad.book(memory_tracking::names::key_nested, |
497 | conv_pd_->scratchpad_registry()); |
498 | } |
499 | }; |
500 | |
501 | ref_deconvolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} |
502 | |
503 | status_t init(engine_t *engine) override { |
504 | return pd()->conv_pd_->create_primitive(conv_p_, engine); |
505 | } |
506 | |
507 | status_t execute(const exec_ctx_t &ctx) const override; |
508 | |
509 | private: |
510 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
511 | void compute_bwd_bias(float *diff_bias, const float *diff_dst) const; |
512 | |
513 | template <data_type_t dbia_type, data_type_t ddst_type> |
514 | void compute_bwd_bias_ncdhw( |
515 | typename prec_traits<dbia_type>::type *diff_bias, |
516 | const typename prec_traits<ddst_type>::type *diff_dst) const; |
517 | |
518 | template <data_type_t dbia_type, data_type_t ddst_type> |
519 | void compute_bwd_bias_ndhwc( |
520 | typename prec_traits<dbia_type>::type *diff_bias, |
521 | const typename prec_traits<ddst_type>::type *diff_dst) const; |
522 | |
523 | template <data_type_t dbia_type, data_type_t ddst_type, dim_t blksize> |
524 | void compute_bwd_bias_nCdhwXc( |
525 | typename prec_traits<dbia_type>::type *diff_bias, |
526 | const typename prec_traits<ddst_type>::type *diff_dst) const; |
527 | |
528 | template <data_type_t dbia_type, data_type_t ddst_type> |
529 | void compute_bias(const exec_ctx_t &ctx) const; |
530 | std::shared_ptr<primitive_t> conv_p_; |
531 | }; |
532 | |
533 | } // namespace cpu |
534 | } // namespace impl |
535 | } // namespace dnnl |
536 | |
537 | #endif |
538 | |
539 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
540 | |