1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REF_PRELU_HPP
18#define CPU_REF_PRELU_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/memory_tracking.hpp"
25#include "common/primitive.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/platform.hpp"
30
31#include "common/broadcast_strategy.hpp"
32#include "cpu/cpu_prelu_pd.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37
38namespace prelu {
39void set_reduction_buffers(
40 const dim_t work_amount, dim_t &group_size, dim_t &buf_size);
41dim_t get_scalar_scratchpad_offset(const std::size_t ithr,
42 const std::size_t nthr, const dim_t work_amount);
43} // namespace prelu
44
45using byte = unsigned char;
46
47struct ref_prelu_fwd_t : public primitive_t {
48 struct pd_t : public cpu_prelu_fwd_pd_t {
49 using cpu_prelu_fwd_pd_t::cpu_prelu_fwd_pd_t;
50
51 DECLARE_COMMON_PD_T("ref:any", ref_prelu_fwd_t);
52
53 status_t init(engine_t *engine) {
54 using namespace data_type;
55 bool ok = is_fwd() && src_md(0)->data_type == dst_md(0)->data_type
56 && platform::has_data_type_support(src_md(0)->data_type)
57 && platform::has_data_type_support(weights_md(0)->data_type)
58 && attr()->has_default_values() && set_default_formats()
59 && memory_desc_wrapper(src_md())
60 == memory_desc_wrapper(dst_md());
61 if (!ok) return status::unimplemented;
62
63 return status::success;
64 }
65 };
66
67 ref_prelu_fwd_t(const pd_t *apd) : primitive_t(apd) {}
68
69 status_t execute(const exec_ctx_t &ctx) const override {
70 return execute_forward(ctx);
71 }
72
73private:
74 status_t execute_forward(const exec_ctx_t &ctx) const;
75 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
76};
77
78struct ref_prelu_bwd_t : public primitive_t {
79 struct pd_t : public cpu_prelu_bwd_pd_t {
80 using cpu_prelu_bwd_pd_t::cpu_prelu_bwd_pd_t;
81
82 DECLARE_COMMON_PD_T("ref:any", ref_prelu_bwd_t);
83
84 status_t init(engine_t *engine) {
85 using namespace data_type;
86 bool ok = !is_fwd()
87 && diff_src_md(0)->data_type == src_md(0)->data_type
88 && diff_weights_md(0)->data_type == weights_md(0)->data_type
89 && diff_dst_md(0)->data_type == diff_src_md(0)->data_type
90 && platform::has_data_type_support(src_md(0)->data_type)
91 && platform::has_data_type_support(weights_md(0)->data_type)
92 && attr()->has_default_values() && set_default_formats()
93 && memory_desc_wrapper(diff_dst_md())
94 == memory_desc_wrapper(diff_src_md());
95 if (!ok) return status::unimplemented;
96
97 init_scratchpad();
98
99 return status::success;
100 }
101
102 int nthr_; // To not exceed the limit in execute used for set up.
103
104 private:
105 void init_scratchpad() {
106 auto scratchpad = this->scratchpad_registry().registrar();
107 dim_t scratchpad_size;
108 const memory_desc_wrapper src_d(src_md());
109 const memory_desc_wrapper weights_d(weights_md());
110 auto broadcast_strategy
111 = get_rhs_arg_broadcasting_strategy(*weights_md(), src_d);
112 // Assign `nthr_` here since the amount needed maybe reduced.
113 nthr_ = dnnl_get_max_threads();
114 // Scratchpad is needed to correctly reduce calculated diff_weights
115 // in cases where broadcast is used.
116 //
117 // example: if data tensor size is NxCxW and weight tensor is 1xCx1,
118 // diff_weight tensor would also be of size 1xCx1 and thus each value
119 // along C axis would equal: results summed up over N and W for given C.
120 //
121 // In current implementation reduction is 2 step:
122 // results are first copied to buffer and reduced, result is then
123 // stored in group buffer. Values in group buffer are then reduced
124 // to obtain final value.
125 if (broadcast_strategy == broadcasting_strategy_t::no_broadcast) {
126 return;
127 } else if (broadcast_strategy == broadcasting_strategy_t::scalar) {
128 int work_amount = static_cast<int>(src_d.nelems());
129 nthr_ = nstl::min(nthr_, work_amount);
130 scratchpad_size = prelu::get_scalar_scratchpad_offset(
131 nthr_, nthr_, src_d.nelems());
132 } else {
133 dim_t group_size, buf_size;
134 nthr_ = nstl::min(nthr_, static_cast<int>(weights_d.nelems()));
135 dim_t work_amount = src_d.nelems() / weights_d.nelems();
136 prelu::set_reduction_buffers(work_amount, group_size, buf_size);
137 scratchpad_size = nthr_ * (group_size + buf_size);
138 }
139 scratchpad.book(memory_tracking::names::key_prelu_reduction,
140 scratchpad_size, types::data_type_size(dnnl_f32));
141 }
142 };
143
144 ref_prelu_bwd_t(const pd_t *apd) : primitive_t(apd) {}
145
146 status_t execute(const exec_ctx_t &ctx) const override {
147 return execute_backward(ctx);
148 }
149
150private:
151 status_t execute_backward(const exec_ctx_t &ctx) const;
152 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
153
154 float ker(const byte *src, const byte *weights, const byte *diff_dst,
155 byte *diff_src, dim_t data_off, dim_t weight_off) const;
156 void calculate_scalar(const byte *src, const byte *weights,
157 byte *diff_weights, const byte *diff_dst, byte *diff_src,
158 float *scratchpad_buf) const;
159 void calculate_no_broadcast(const byte *src, const byte *weights,
160 byte *diff_weights, const byte *diff_dst, byte *diff_src,
161 float *scratchpad_buf) const;
162 void calculate_shared_axes(const byte *src, const byte *weights,
163 byte *diff_weights, const byte *diff_dst, byte *diff_src,
164 float *scratchpad_buf) const;
165};
166
167} // namespace cpu
168} // namespace impl
169} // namespace dnnl
170
171#endif
172
173// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
174