1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_REF_ELTWISE_HPP |
18 | #define CPU_REF_ELTWISE_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | #include "cpu/platform.hpp" |
28 | #include "cpu/primitive_attr_postops.hpp" |
29 | |
30 | #include "cpu/cpu_eltwise_pd.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | |
36 | template <impl::data_type_t data_type> |
37 | struct ref_eltwise_fwd_t : public primitive_t { |
38 | struct pd_t : public cpu_eltwise_fwd_pd_t { |
39 | using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; |
40 | |
41 | DECLARE_COMMON_PD_T("ref:any" , ref_eltwise_fwd_t); |
42 | |
43 | status_t init(engine_t *engine) { |
44 | using namespace utils; |
45 | using sm = primitive_attr_t::skip_mask_t; |
46 | |
47 | const memory_desc_wrapper src_d(src_md()); |
48 | const memory_desc_wrapper dst_d(dst_md()); |
49 | |
50 | bool ok = is_fwd() |
51 | && utils::everyone_is( |
52 | data_type, src_md()->data_type, dst_md()->data_type) |
53 | && platform::has_data_type_support(data_type) |
54 | && attr()->has_default_values(sm::post_ops) |
55 | && set_default_formats_common() && src_d == dst_d |
56 | && attr_.set_default_formats(dst_md(0)) == status::success; |
57 | if (!ok) return status::unimplemented; |
58 | |
59 | use_dense_ = src_d.is_dense(true) && dst_d.is_dense(true) |
60 | && IMPLICATION(!src_d.is_dense() || !dst_d.is_dense(), |
61 | is_zero_preserved()); |
62 | |
63 | use_nCspBc_padded_ = !use_dense_ |
64 | && src_d.blocking_desc().inner_nblks == 1 |
65 | && one_of(src_d.blocking_desc().inner_blks[0], 8, 16) |
66 | && src_d.blocking_desc().inner_idxs[0] == 1 |
67 | && src_d.only_padded_dim(1) && src_d.is_dense(true); |
68 | |
69 | const auto &po = attr()->post_ops_; |
70 | if (has_zero_dim_memory() || !po.has_default_values()) |
71 | use_dense_ = use_nCspBc_padded_ = false; |
72 | |
73 | return status::success; |
74 | } |
75 | |
76 | bool use_dense_, use_nCspBc_padded_; |
77 | }; |
78 | |
79 | ref_eltwise_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
80 | |
81 | status_t init(engine_t *engine) override { |
82 | ref_post_ops |
83 | = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_); |
84 | if (!ref_post_ops) return status::out_of_memory; |
85 | return status::success; |
86 | } |
87 | |
88 | using data_t = typename prec_traits<data_type>::type; |
89 | |
90 | status_t execute(const exec_ctx_t &ctx) const override { |
91 | if (pd()->use_dense_) |
92 | return execute_forward_dense(ctx); |
93 | else if (pd()->use_nCspBc_padded_) |
94 | return execute_forward_nCspBc_padded(ctx); |
95 | else |
96 | return execute_forward_generic(ctx); |
97 | } |
98 | |
99 | private: |
100 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
101 | status_t execute_forward_nCspBc_padded(const exec_ctx_t &ctx) const; |
102 | status_t execute_forward_dense(const exec_ctx_t &ctx) const; |
103 | status_t execute_forward_generic(const exec_ctx_t &ctx) const; |
104 | std::unique_ptr<ref_post_ops_t> ref_post_ops; |
105 | }; |
106 | |
107 | template <impl::data_type_t data_type> |
108 | struct ref_eltwise_bwd_t : public primitive_t { |
109 | struct pd_t : public cpu_eltwise_bwd_pd_t { |
110 | using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; |
111 | |
112 | DECLARE_COMMON_PD_T("ref:any" , ref_eltwise_bwd_t); |
113 | |
114 | status_t init(engine_t *engine) { |
115 | using namespace utils; |
116 | using namespace data_type; |
117 | |
118 | const memory_desc_wrapper diff_src_d(diff_src_md()); |
119 | const memory_desc_wrapper diff_dst_d(diff_dst_md()); |
120 | |
121 | bool ok = !is_fwd() |
122 | && utils::everyone_is(data_type, data_md()->data_type, |
123 | diff_src_md()->data_type, diff_dst_md()->data_type) |
124 | && platform::has_data_type_support(data_type) |
125 | && attr()->has_default_values() |
126 | && set_default_formats_common() && diff_dst_d == diff_src_d; |
127 | if (!ok) return status::unimplemented; |
128 | |
129 | use_dense_ = diff_dst_d.is_dense() |
130 | || (diff_dst_d.is_dense(true) && is_zero_preserved()); |
131 | |
132 | if (has_zero_dim_memory()) use_dense_ = false; |
133 | if (diff_dst_d != memory_desc_wrapper(data_md())) |
134 | use_dense_ = false; |
135 | |
136 | if (utils::one_of(data_type, bf16, f16)) init_scratchpad(); |
137 | |
138 | return status::success; |
139 | } |
140 | |
141 | bool use_dense_; |
142 | |
143 | private: |
144 | void init_scratchpad() { |
145 | const memory_desc_wrapper data_d(data_md()); |
146 | const memory_desc_wrapper diff_dst_d(diff_dst_md()); |
147 | using namespace memory_tracking::names; |
148 | auto scratchpad = scratchpad_registry().registrar(); |
149 | const auto diff_dst_size = diff_dst_d.nelems(true); |
150 | scratchpad.template book<float>( |
151 | key_eltwise_src, data_d.nelems(true)); |
152 | scratchpad.template book<float>( |
153 | key_eltwise_diff_dst, diff_dst_size); |
154 | } |
155 | }; |
156 | |
157 | ref_eltwise_bwd_t(const pd_t *apd) : primitive_t(apd) {} |
158 | typedef typename prec_traits<data_type>::type data_t; |
159 | |
160 | status_t execute(const exec_ctx_t &ctx) const override { |
161 | if (pd()->use_dense_) |
162 | return execute_backward_dense(ctx); |
163 | else |
164 | return execute_backward_generic(ctx); |
165 | } |
166 | |
167 | private: |
168 | status_t execute_backward_dense(const exec_ctx_t &ctx) const; |
169 | status_t execute_backward_generic(const exec_ctx_t &ctx) const; |
170 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
171 | }; |
172 | |
173 | } // namespace cpu |
174 | } // namespace impl |
175 | } // namespace dnnl |
176 | |
177 | #endif |
178 | |
179 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
180 | |