1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3* Copyright 2022 Arm Ltd. and affiliates
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#ifndef CPU_REF_FUSED_CONVOLUTION_HPP
19#define CPU_REF_FUSED_CONVOLUTION_HPP
20
21#include "common/primitive.hpp"
22#include "common/primitive_desc_iterator.hpp"
23#include "common/reorder.hpp"
24#include "common/stream.hpp"
25
26#include "cpu/cpu_convolution_pd.hpp"
27#include "cpu/dw_convolution_utils.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32
33struct ref_fused_convolution_fwd_t : public primitive_t {
34
35 struct arg_cache_t {
36 struct arg_info_t {
37 int op_arg;
38 bool is_ctx_arg;
39 bool is_const;
40 union {
41 size_t offset;
42 int ctx_arg;
43 };
44 memory_desc_t md;
45 };
46
47 void append_ctx_arg(int op_arg, int ctx_arg) {
48 arg_info_t arg_info;
49 arg_info.op_arg = op_arg;
50 arg_info.is_ctx_arg = true;
51 arg_info.is_const = false; // unused
52 arg_info.ctx_arg = ctx_arg;
53 arg_info.md = glob_zero_md;
54 info_.push_back(arg_info);
55 }
56
57 void append_inout_arg(int arg, size_t offset, const memory_desc_t *md,
58 bool is_const) {
59 arg_info_t arg_info;
60 arg_info.op_arg = arg;
61 arg_info.is_ctx_arg = false;
62 arg_info.is_const = is_const;
63 arg_info.offset = offset;
64 arg_info.md = *md;
65 info_.push_back(arg_info);
66 }
67
68 void append_ctx_arg(int arg) { append_ctx_arg(arg, arg); }
69
70 const std::vector<arg_info_t> &info() const { return info_; }
71
72 private:
73 std::vector<arg_info_t> info_;
74 };
75
76 struct pd_t : public cpu_convolution_fwd_pd_t {
77 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
78 const typename pd_t::base_class *hint_fwd_pd)
79 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {
80 name_ = "ref_fused_convolution:any";
81 }
82
83 pd_t(const pd_t &other) = default;
84
85 DECLARE_COMMON_PD_T(name_.c_str(), ref_fused_convolution_fwd_t);
86
87 virtual status_t init(engine_t *engine) {
88 bool ok = true && is_fwd()
89 && (attr()->post_ops_.find(primitive_kind::sum) == -1);
90
91 if (!ok) return status::unimplemented;
92
93 CHECK(init_ops(engine));
94 init_name();
95 return status::success;
96 }
97
98 const memory_desc_t *src_md(int index = 0) const override {
99 return op_pds_.front()->src_md(index);
100 }
101
102 const memory_desc_t *dst_md(int index = 0) const override {
103 return op_pds_.back()->dst_md(index);
104 }
105
106 const memory_desc_t *weights_md(int index = 0) const override {
107 return op_pds_.front()->weights_md(index); // for now
108 }
109
110 const memory_desc_t *arg_md(int index = 0) const override {
111 switch (index) { // for now
112 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS:
113 return op_pds_.back()->weights_md(0);
114 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS:
115 return op_pds_.back()->weights_md(1);
116 default: return convolution_fwd_pd_t::arg_md(index);
117 }
118 }
119
120 arg_usage_t arg_usage(int arg) const override {
121 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
122 return arg_usage_t::input;
123
124 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)
125 && attr_post_op_dw_inputs() > 1)
126 return arg_usage_t::input;
127
128 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_SRC))
129 return arg_usage_t::input;
130
131 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
132 return arg_usage_t::input;
133
134 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST))
135 return arg_usage_t::input;
136
137 return convolution_fwd_pd_t::arg_usage(arg);
138 }
139
140 size_t user_scratchpad_size_;
141 std::vector<std::shared_ptr<primitive_desc_t>> op_pds_;
142 std::vector<arg_cache_t> args_;
143
144 private:
145 std::string name_;
146 const unsigned int max_fusions_ = 1;
147
148 status_t append_op(std::shared_ptr<primitive_desc_t> &op_pd,
149 size_t &sp_begin, size_t &sp_end, engine_t *engine) {
150 auto from_md = op_pds_.back()->dst_md();
151 auto to_md = op_pd->src_md();
152
153 if (*from_md != *to_md) {
154 //TODO: Find a test-case for this
155 std::shared_ptr<primitive_desc_t> pd;
156 CHECK(reorder_primitive_desc_create(
157 pd, engine, from_md, to_md));
158 op_pds_.emplace_back(std::move(pd));
159
160 arg_cache_t arg_cache;
161 arg_cache.append_inout_arg(
162 DNNL_ARG_FROM, sp_begin, from_md, true);
163 arg_cache.append_inout_arg(DNNL_ARG_TO, sp_end, to_md, false);
164 args_.push_back(arg_cache);
165
166 // Increment scratchpad offsets
167 sp_begin = sp_end;
168 sp_end += memory_desc_wrapper(to_md).size();
169
170 user_scratchpad_size_ = nstl::max<size_t>(user_scratchpad_size_,
171 op_pds_.back()->scratchpad_size(
172 attr()->scratchpad_mode_));
173 }
174
175 op_pds_.emplace_back(std::move(op_pd));
176 user_scratchpad_size_ = nstl::max<size_t>(user_scratchpad_size_,
177 op_pds_.back()->scratchpad_size(attr()->scratchpad_mode_));
178 return status::success;
179 }
180
181 status_t init_ops(engine_t *engine) {
182 using namespace data_type;
183 primitive_attr_t root_attr(*attr());
184 if (!root_attr.is_initialized()) return status::out_of_memory;
185 auto po_op_iter
186 = attr()->post_ops_.find(primitive_kind::convolution);
187 if (po_op_iter == -1) return status::unimplemented;
188
189 primitive_attr_t attr_1x1(*attr());
190 // erase dw_conv post-op scales
191 for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) {
192 auto &scale
193 = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | arg);
194 if (!scale.has_default_values())
195 attr_1x1.scales_.reset(DNNL_ARG_ATTR_POST_OP_DW | arg);
196 }
197 // erase post-ops after fusion as they will be handled separately
198 auto &e = attr_1x1.post_ops_.entry_;
199 e.erase(e.begin() + po_op_iter, e.end());
200
201 primitive_desc_iterator_t it(engine, op_desc(), &attr_1x1, nullptr);
202 if (!it.is_initialized()) return status::out_of_memory;
203 std::shared_ptr<primitive_desc_t> root_pd = *(++it);
204 if (!root_pd) return status::unimplemented;
205 op_pds_.emplace_back(root_pd);
206 // Scratchpad offsets. Simulate offset computation so that offset
207 // computation can be avoided during execution.
208 size_t inout_sp_offset_begin = 0;
209 size_t inout_sp_offset_end = 0;
210 user_scratchpad_size_
211 = root_pd->scratchpad_size(attr()->scratchpad_mode_);
212
213 // Create arg cache for the root pd
214 arg_cache_t arg_cache;
215 arg_cache.append_ctx_arg(DNNL_ARG_SRC);
216 arg_cache.append_ctx_arg(DNNL_ARG_WEIGHTS);
217 for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST})
218 if (!attr_1x1.scales_.get(arg).has_default_values())
219 arg_cache.append_ctx_arg(DNNL_ARG_ATTR_SCALES | arg);
220 if (desc()->bias_desc.data_type != data_type::undef)
221 arg_cache.append_ctx_arg(DNNL_ARG_BIAS);
222 arg_cache.append_inout_arg(DNNL_ARG_DST, inout_sp_offset_end,
223 root_pd->dst_md(), false);
224 for (int idx = 0; idx < attr_1x1.post_ops_.len(); ++idx) {
225 if (attr_1x1.post_ops_.contain(primitive_kind::binary, idx))
226 arg_cache.append_ctx_arg(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
227 | DNNL_ARG_SRC_1);
228 }
229 args_.push_back(arg_cache);
230
231 // Increment scratchpad offsets
232 inout_sp_offset_begin = inout_sp_offset_end;
233 inout_sp_offset_end
234 += memory_desc_wrapper(root_pd->dst_md()).size();
235
236 const auto &po = attr()->post_ops_;
237 const auto &end = po.len();
238
239 unsigned int fusion_ops = 0;
240 // Loop through the post-ops until we reach the end
241 // (if we have more than one op to fuse later)
242 while (po_op_iter < end) {
243 if (fusion_ops++ > max_fusions_) return status::unimplemented;
244
245 const auto &prev_op_pd = op_pds_.back();
246
247 if (po.entry_[po_op_iter].kind != primitive_kind::convolution)
248 return status::unimplemented;
249
250 if (prev_op_pd->kind() != primitive_kind::convolution)
251 return status::unimplemented;
252
253 auto conv_pd = reinterpret_cast<convolution_pd_t *>(
254 prev_op_pd.get());
255 bool ok = true && is_fwd()
256 && utils::everyone_is(
257 1, conv_pd->KD(), conv_pd->KH(), conv_pd->KW());
258 if (!ok) return status::unimplemented;
259
260 convolution_desc_t cd_dw;
261 primitive_attr_t attr_dw;
262 CHECK(get_depthwise_conv_desc(cd_dw, *(conv_pd->dst_md()),
263 root_attr, attr_dw, po_op_iter));
264 primitive_desc_iterator_t it(
265 engine, (op_desc_t *)&cd_dw, &attr_dw, nullptr);
266 if (!it.is_initialized()) return status::out_of_memory;
267
268 std::shared_ptr<primitive_desc_t> append_conv_pd = *(++it);
269 if (!append_conv_pd) return status::unimplemented;
270
271 CHECK(append_op(append_conv_pd, inout_sp_offset_begin,
272 inout_sp_offset_end, engine));
273
274 const auto &op = op_pds_.back();
275 arg_cache_t arg_cache;
276 arg_cache.append_inout_arg(DNNL_ARG_SRC, inout_sp_offset_begin,
277 op->src_md(), true);
278 arg_cache.append_ctx_arg(DNNL_ARG_DST);
279 arg_cache.append_ctx_arg(DNNL_ARG_WEIGHTS,
280 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
281 for (auto arg : {DNNL_ARG_WEIGHTS, DNNL_ARG_DST})
282 if (!attr_dw.scales_.get(arg).has_default_values())
283 arg_cache.append_ctx_arg(DNNL_ARG_ATTR_SCALES | arg,
284 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_SCALES
285 | arg);
286 // dw_conv src_scale = 1x1_conv dst_scale
287 if (!attr_1x1.scales_.get(DNNL_ARG_DST).has_default_values())
288 arg_cache.append_ctx_arg(
289 DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC,
290 DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
291 if (op->weights_md(1)->data_type != data_type::undef)
292 arg_cache.append_ctx_arg(DNNL_ARG_BIAS,
293 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
294 for (int idx = 0; idx < attr_dw.post_ops_.len(); ++idx) {
295 if (attr_dw.post_ops_.contain(primitive_kind::binary, idx))
296 arg_cache.append_ctx_arg(
297 (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
298 | DNNL_ARG_SRC_1),
299 (DNNL_ARG_ATTR_MULTIPLE_POST_OP(
300 idx + po_op_iter + 1)
301 | DNNL_ARG_SRC_1));
302 }
303
304 args_.push_back(arg_cache);
305
306 while (++po_op_iter < end) {
307 if (utils::one_of(po.entry_[po_op_iter].kind,
308 primitive_kind::convolution))
309 break;
310 }
311 }
312
313 assert(!op_pds_.empty());
314
315 CHECK(init_scratchpad_memory(inout_sp_offset_end));
316
317 return status::success;
318 }
319
320 status_t init_scratchpad_memory(size_t inout_buffer_size) {
321
322 auto scratchpad = scratchpad_registry().registrar();
323
324 scratchpad.book(memory_tracking::names::key_fusion_inout_buffer,
325 inout_buffer_size, 1, 16);
326 scratchpad.book(
327 memory_tracking::names::key_fusion_forward_scratchpad,
328 user_scratchpad_size_, 1, 16);
329 return status::success;
330 }
331
332 void init_name() {
333 for (const auto &op_pd : op_pds_) {
334 name_.append(":");
335 name_.append(op_pd->name());
336 }
337 return;
338 }
339 };
340
341 ref_fused_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
342
343 status_t init(engine_t *engine) override {
344 const auto &op_pds = pd()->op_pds_;
345 for (auto &op_pd : op_pds) {
346 std::shared_ptr<primitive_t> p;
347 op_pd->create_primitive(p, engine);
348 primitives_.emplace_back(p);
349 }
350 return status::success;
351 }
352
353#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL
354 status_t create_resource(
355 engine_t *engine, resource_mapper_t &mapper) const override {
356 for (auto &p : primitives_) {
357 CHECK(p->create_resource(engine, mapper));
358 }
359 return status::success;
360 }
361#endif
362
363 status_t execute(const exec_ctx_t &ctx) const override {
364 engine_t *engine = ctx.stream()->engine();
365 const auto scratchpad = ctx.get_scratchpad_grantor();
366
367 const auto inout_buffer = scratchpad.get_memory_storage(
368 memory_tracking::names::key_fusion_inout_buffer);
369
370 const auto &ctx_args = ctx.args();
371 const auto op_count = primitives_.size();
372 std::vector<std::unique_ptr<memory_t>> inout_memory;
373
374 for (size_t i = 0; i < op_count; ++i) {
375 const auto &op = primitives_[i];
376 const auto &arg_cache = pd()->args_[i];
377
378 exec_args_t exec_args;
379
380 for (const auto &arg_info : arg_cache.info()) {
381 if (arg_info.is_ctx_arg) {
382 exec_args[arg_info.op_arg] = ctx_args.at(arg_info.ctx_arg);
383 } else {
384 inout_memory.emplace_back(new memory_t(engine, &arg_info.md,
385 inout_buffer->get_sub_storage(arg_info.offset,
386 memory_desc_wrapper(arg_info.md).size())));
387 exec_args[arg_info.op_arg].mem = inout_memory.back().get();
388 exec_args[arg_info.op_arg].is_const = arg_info.is_const;
389 }
390 }
391
392 exec_ctx_t op_ctx(ctx, std::move(exec_args));
393
394 nested_scratchpad_t ns(ctx,
395 memory_tracking::names::key_fusion_forward_scratchpad, op);
396 op_ctx.set_scratchpad_grantor(ns.grantor());
397 CHECK(op->execute(op_ctx));
398 }
399
400 return status::success;
401 }
402
403private:
404 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
405 std::vector<std::shared_ptr<primitive_t>> primitives_;
406};
407
408} // namespace cpu
409} // namespace impl
410} // namespace dnnl
411
412#endif
413
414// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
415