1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REF_ELTWISE_HPP
18#define CPU_REF_ELTWISE_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/platform.hpp"
28#include "cpu/primitive_attr_postops.hpp"
29
30#include "cpu/cpu_eltwise_pd.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35
36template <impl::data_type_t data_type>
37struct ref_eltwise_fwd_t : public primitive_t {
38 struct pd_t : public cpu_eltwise_fwd_pd_t {
39 using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t;
40
41 DECLARE_COMMON_PD_T("ref:any", ref_eltwise_fwd_t);
42
43 status_t init(engine_t *engine) {
44 using namespace utils;
45 using sm = primitive_attr_t::skip_mask_t;
46
47 const memory_desc_wrapper src_d(src_md());
48 const memory_desc_wrapper dst_d(dst_md());
49
50 bool ok = is_fwd()
51 && utils::everyone_is(
52 data_type, src_md()->data_type, dst_md()->data_type)
53 && platform::has_data_type_support(data_type)
54 && attr()->has_default_values(sm::post_ops)
55 && set_default_formats_common() && src_d == dst_d
56 && attr_.set_default_formats(dst_md(0)) == status::success;
57 if (!ok) return status::unimplemented;
58
59 use_dense_ = src_d.is_dense(true) && dst_d.is_dense(true)
60 && IMPLICATION(!src_d.is_dense() || !dst_d.is_dense(),
61 is_zero_preserved());
62
63 use_nCspBc_padded_ = !use_dense_
64 && src_d.blocking_desc().inner_nblks == 1
65 && one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
66 && src_d.blocking_desc().inner_idxs[0] == 1
67 && src_d.only_padded_dim(1) && src_d.is_dense(true);
68
69 const auto &po = attr()->post_ops_;
70 if (has_zero_dim_memory() || !po.has_default_values())
71 use_dense_ = use_nCspBc_padded_ = false;
72
73 return status::success;
74 }
75
76 bool use_dense_, use_nCspBc_padded_;
77 };
78
79 ref_eltwise_fwd_t(const pd_t *apd) : primitive_t(apd) {}
80
81 status_t init(engine_t *engine) override {
82 ref_post_ops
83 = utils::make_unique<ref_post_ops_t>(pd()->attr()->post_ops_);
84 if (!ref_post_ops) return status::out_of_memory;
85 return status::success;
86 }
87
88 using data_t = typename prec_traits<data_type>::type;
89
90 status_t execute(const exec_ctx_t &ctx) const override {
91 if (pd()->use_dense_)
92 return execute_forward_dense(ctx);
93 else if (pd()->use_nCspBc_padded_)
94 return execute_forward_nCspBc_padded(ctx);
95 else
96 return execute_forward_generic(ctx);
97 }
98
99private:
100 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
101 status_t execute_forward_nCspBc_padded(const exec_ctx_t &ctx) const;
102 status_t execute_forward_dense(const exec_ctx_t &ctx) const;
103 status_t execute_forward_generic(const exec_ctx_t &ctx) const;
104 std::unique_ptr<ref_post_ops_t> ref_post_ops;
105};
106
107template <impl::data_type_t data_type>
108struct ref_eltwise_bwd_t : public primitive_t {
109 struct pd_t : public cpu_eltwise_bwd_pd_t {
110 using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t;
111
112 DECLARE_COMMON_PD_T("ref:any", ref_eltwise_bwd_t);
113
114 status_t init(engine_t *engine) {
115 using namespace utils;
116 using namespace data_type;
117
118 const memory_desc_wrapper diff_src_d(diff_src_md());
119 const memory_desc_wrapper diff_dst_d(diff_dst_md());
120
121 bool ok = !is_fwd()
122 && utils::everyone_is(data_type, data_md()->data_type,
123 diff_src_md()->data_type, diff_dst_md()->data_type)
124 && platform::has_data_type_support(data_type)
125 && attr()->has_default_values()
126 && set_default_formats_common() && diff_dst_d == diff_src_d;
127 if (!ok) return status::unimplemented;
128
129 use_dense_ = diff_dst_d.is_dense()
130 || (diff_dst_d.is_dense(true) && is_zero_preserved());
131
132 if (has_zero_dim_memory()) use_dense_ = false;
133 if (diff_dst_d != memory_desc_wrapper(data_md()))
134 use_dense_ = false;
135
136 if (utils::one_of(data_type, bf16, f16)) init_scratchpad();
137
138 return status::success;
139 }
140
141 bool use_dense_;
142
143 private:
144 void init_scratchpad() {
145 const memory_desc_wrapper data_d(data_md());
146 const memory_desc_wrapper diff_dst_d(diff_dst_md());
147 using namespace memory_tracking::names;
148 auto scratchpad = scratchpad_registry().registrar();
149 const auto diff_dst_size = diff_dst_d.nelems(true);
150 scratchpad.template book<float>(
151 key_eltwise_src, data_d.nelems(true));
152 scratchpad.template book<float>(
153 key_eltwise_diff_dst, diff_dst_size);
154 }
155 };
156
157 ref_eltwise_bwd_t(const pd_t *apd) : primitive_t(apd) {}
158 typedef typename prec_traits<data_type>::type data_t;
159
160 status_t execute(const exec_ctx_t &ctx) const override {
161 if (pd()->use_dense_)
162 return execute_backward_dense(ctx);
163 else
164 return execute_backward_generic(ctx);
165 }
166
167private:
168 status_t execute_backward_dense(const exec_ctx_t &ctx) const;
169 status_t execute_backward_generic(const exec_ctx_t &ctx) const;
170 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
171};
172
173} // namespace cpu
174} // namespace impl
175} // namespace dnnl
176
177#endif
178
179// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
180