1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * Copyright 2022 Arm Ltd. and affiliates |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #ifndef CPU_REF_FUSED_CONVOLUTION_HPP |
19 | #define CPU_REF_FUSED_CONVOLUTION_HPP |
20 | |
21 | #include "common/primitive.hpp" |
22 | #include "common/primitive_desc_iterator.hpp" |
23 | #include "common/reorder.hpp" |
24 | #include "common/stream.hpp" |
25 | |
26 | #include "cpu/cpu_convolution_pd.hpp" |
27 | #include "cpu/dw_convolution_utils.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | |
33 | struct ref_fused_convolution_fwd_t : public primitive_t { |
34 | |
35 | struct arg_cache_t { |
36 | struct arg_info_t { |
37 | int op_arg; |
38 | bool is_ctx_arg; |
39 | bool is_const; |
40 | union { |
41 | size_t offset; |
42 | int ctx_arg; |
43 | }; |
44 | memory_desc_t md; |
45 | }; |
46 | |
47 | void append_ctx_arg(int op_arg, int ctx_arg) { |
48 | arg_info_t arg_info; |
49 | arg_info.op_arg = op_arg; |
50 | arg_info.is_ctx_arg = true; |
51 | arg_info.is_const = false; // unused |
52 | arg_info.ctx_arg = ctx_arg; |
53 | arg_info.md = glob_zero_md; |
54 | info_.push_back(arg_info); |
55 | } |
56 | |
57 | void append_inout_arg(int arg, size_t offset, const memory_desc_t *md, |
58 | bool is_const) { |
59 | arg_info_t arg_info; |
60 | arg_info.op_arg = arg; |
61 | arg_info.is_ctx_arg = false; |
62 | arg_info.is_const = is_const; |
63 | arg_info.offset = offset; |
64 | arg_info.md = *md; |
65 | info_.push_back(arg_info); |
66 | } |
67 | |
68 | void append_ctx_arg(int arg) { append_ctx_arg(arg, arg); } |
69 | |
70 | const std::vector<arg_info_t> &info() const { return info_; } |
71 | |
72 | private: |
73 | std::vector<arg_info_t> info_; |
74 | }; |
75 | |
76 | struct pd_t : public cpu_convolution_fwd_pd_t { |
77 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
78 | const typename pd_t::base_class *hint_fwd_pd) |
79 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) { |
80 | name_ = "ref_fused_convolution:any" ; |
81 | } |
82 | |
83 | pd_t(const pd_t &other) = default; |
84 | |
85 | DECLARE_COMMON_PD_T(name_.c_str(), ref_fused_convolution_fwd_t); |
86 | |
87 | virtual status_t init(engine_t *engine) { |
88 | bool ok = true && is_fwd() |
89 | && (attr()->post_ops_.find(primitive_kind::sum) == -1); |
90 | |
91 | if (!ok) return status::unimplemented; |
92 | |
93 | CHECK(init_ops(engine)); |
94 | init_name(); |
95 | return status::success; |
96 | } |
97 | |
98 | const memory_desc_t *src_md(int index = 0) const override { |
99 | return op_pds_.front()->src_md(index); |
100 | } |
101 | |
102 | const memory_desc_t *dst_md(int index = 0) const override { |
103 | return op_pds_.back()->dst_md(index); |
104 | } |
105 | |
106 | const memory_desc_t *weights_md(int index = 0) const override { |
107 | return op_pds_.front()->weights_md(index); // for now |
108 | } |
109 | |
110 | const memory_desc_t *arg_md(int index = 0) const override { |
111 | switch (index) { // for now |
112 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: |
113 | return op_pds_.back()->weights_md(0); |
114 | case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: |
115 | return op_pds_.back()->weights_md(1); |
116 | default: return convolution_fwd_pd_t::arg_md(index); |
117 | } |
118 | } |
119 | |
120 | arg_usage_t arg_usage(int arg) const override { |
121 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) |
122 | return arg_usage_t::input; |
123 | |
124 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) |
125 | && attr_post_op_dw_inputs() > 1) |
126 | return arg_usage_t::input; |
127 | |
128 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_SRC)) |
129 | return arg_usage_t::input; |
130 | |
131 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) |
132 | return arg_usage_t::input; |
133 | |
134 | if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST)) |
135 | return arg_usage_t::input; |
136 | |
137 | return convolution_fwd_pd_t::arg_usage(arg); |
138 | } |
139 | |
140 | size_t user_scratchpad_size_; |
141 | std::vector<std::shared_ptr<primitive_desc_t>> op_pds_; |
142 | std::vector<arg_cache_t> args_; |
143 | |
144 | private: |
145 | std::string name_; |
146 | const unsigned int max_fusions_ = 1; |
147 | |
148 | status_t append_op(std::shared_ptr<primitive_desc_t> &op_pd, |
149 | size_t &sp_begin, size_t &sp_end, engine_t *engine) { |
150 | auto from_md = op_pds_.back()->dst_md(); |
151 | auto to_md = op_pd->src_md(); |
152 | |
153 | if (*from_md != *to_md) { |
154 | //TODO: Find a test-case for this |
155 | std::shared_ptr<primitive_desc_t> pd; |
156 | CHECK(reorder_primitive_desc_create( |
157 | pd, engine, from_md, to_md)); |
158 | op_pds_.emplace_back(std::move(pd)); |
159 | |
160 | arg_cache_t arg_cache; |
161 | arg_cache.append_inout_arg( |
162 | DNNL_ARG_FROM, sp_begin, from_md, true); |
163 | arg_cache.append_inout_arg(DNNL_ARG_TO, sp_end, to_md, false); |
164 | args_.push_back(arg_cache); |
165 | |
166 | // Increment scratchpad offsets |
167 | sp_begin = sp_end; |
168 | sp_end += memory_desc_wrapper(to_md).size(); |
169 | |
170 | user_scratchpad_size_ = nstl::max<size_t>(user_scratchpad_size_, |
171 | op_pds_.back()->scratchpad_size( |
172 | attr()->scratchpad_mode_)); |
173 | } |
174 | |
175 | op_pds_.emplace_back(std::move(op_pd)); |
176 | user_scratchpad_size_ = nstl::max<size_t>(user_scratchpad_size_, |
177 | op_pds_.back()->scratchpad_size(attr()->scratchpad_mode_)); |
178 | return status::success; |
179 | } |
180 | |
181 | status_t init_ops(engine_t *engine) { |
182 | using namespace data_type; |
183 | primitive_attr_t root_attr(*attr()); |
184 | if (!root_attr.is_initialized()) return status::out_of_memory; |
185 | auto po_op_iter |
186 | = attr()->post_ops_.find(primitive_kind::convolution); |
187 | if (po_op_iter == -1) return status::unimplemented; |
188 | |
189 | primitive_attr_t attr_1x1(*attr()); |
190 | // erase dw_conv post-op scales |
191 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { |
192 | auto &scale |
193 | = attr_1x1.scales_.get(DNNL_ARG_ATTR_POST_OP_DW | arg); |
194 | if (!scale.has_default_values()) |
195 | attr_1x1.scales_.reset(DNNL_ARG_ATTR_POST_OP_DW | arg); |
196 | } |
197 | // erase post-ops after fusion as they will be handled separately |
198 | auto &e = attr_1x1.post_ops_.entry_; |
199 | e.erase(e.begin() + po_op_iter, e.end()); |
200 | |
201 | primitive_desc_iterator_t it(engine, op_desc(), &attr_1x1, nullptr); |
202 | if (!it.is_initialized()) return status::out_of_memory; |
203 | std::shared_ptr<primitive_desc_t> root_pd = *(++it); |
204 | if (!root_pd) return status::unimplemented; |
205 | op_pds_.emplace_back(root_pd); |
206 | // Scratchpad offsets. Simulate offset computation so that offset |
207 | // computation can be avoided during execution. |
208 | size_t inout_sp_offset_begin = 0; |
209 | size_t inout_sp_offset_end = 0; |
210 | user_scratchpad_size_ |
211 | = root_pd->scratchpad_size(attr()->scratchpad_mode_); |
212 | |
213 | // Create arg cache for the root pd |
214 | arg_cache_t arg_cache; |
215 | arg_cache.append_ctx_arg(DNNL_ARG_SRC); |
216 | arg_cache.append_ctx_arg(DNNL_ARG_WEIGHTS); |
217 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) |
218 | if (!attr_1x1.scales_.get(arg).has_default_values()) |
219 | arg_cache.append_ctx_arg(DNNL_ARG_ATTR_SCALES | arg); |
220 | if (desc()->bias_desc.data_type != data_type::undef) |
221 | arg_cache.append_ctx_arg(DNNL_ARG_BIAS); |
222 | arg_cache.append_inout_arg(DNNL_ARG_DST, inout_sp_offset_end, |
223 | root_pd->dst_md(), false); |
224 | for (int idx = 0; idx < attr_1x1.post_ops_.len(); ++idx) { |
225 | if (attr_1x1.post_ops_.contain(primitive_kind::binary, idx)) |
226 | arg_cache.append_ctx_arg(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
227 | | DNNL_ARG_SRC_1); |
228 | } |
229 | args_.push_back(arg_cache); |
230 | |
231 | // Increment scratchpad offsets |
232 | inout_sp_offset_begin = inout_sp_offset_end; |
233 | inout_sp_offset_end |
234 | += memory_desc_wrapper(root_pd->dst_md()).size(); |
235 | |
236 | const auto &po = attr()->post_ops_; |
237 | const auto &end = po.len(); |
238 | |
239 | unsigned int fusion_ops = 0; |
240 | // Loop through the post-ops until we reach the end |
241 | // (if we have more than one op to fuse later) |
242 | while (po_op_iter < end) { |
243 | if (fusion_ops++ > max_fusions_) return status::unimplemented; |
244 | |
245 | const auto &prev_op_pd = op_pds_.back(); |
246 | |
247 | if (po.entry_[po_op_iter].kind != primitive_kind::convolution) |
248 | return status::unimplemented; |
249 | |
250 | if (prev_op_pd->kind() != primitive_kind::convolution) |
251 | return status::unimplemented; |
252 | |
253 | auto conv_pd = reinterpret_cast<convolution_pd_t *>( |
254 | prev_op_pd.get()); |
255 | bool ok = true && is_fwd() |
256 | && utils::everyone_is( |
257 | 1, conv_pd->KD(), conv_pd->KH(), conv_pd->KW()); |
258 | if (!ok) return status::unimplemented; |
259 | |
260 | convolution_desc_t cd_dw; |
261 | primitive_attr_t attr_dw; |
262 | CHECK(get_depthwise_conv_desc(cd_dw, *(conv_pd->dst_md()), |
263 | root_attr, attr_dw, po_op_iter)); |
264 | primitive_desc_iterator_t it( |
265 | engine, (op_desc_t *)&cd_dw, &attr_dw, nullptr); |
266 | if (!it.is_initialized()) return status::out_of_memory; |
267 | |
268 | std::shared_ptr<primitive_desc_t> append_conv_pd = *(++it); |
269 | if (!append_conv_pd) return status::unimplemented; |
270 | |
271 | CHECK(append_op(append_conv_pd, inout_sp_offset_begin, |
272 | inout_sp_offset_end, engine)); |
273 | |
274 | const auto &op = op_pds_.back(); |
275 | arg_cache_t arg_cache; |
276 | arg_cache.append_inout_arg(DNNL_ARG_SRC, inout_sp_offset_begin, |
277 | op->src_md(), true); |
278 | arg_cache.append_ctx_arg(DNNL_ARG_DST); |
279 | arg_cache.append_ctx_arg(DNNL_ARG_WEIGHTS, |
280 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
281 | for (auto arg : {DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) |
282 | if (!attr_dw.scales_.get(arg).has_default_values()) |
283 | arg_cache.append_ctx_arg(DNNL_ARG_ATTR_SCALES | arg, |
284 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_ATTR_SCALES |
285 | | arg); |
286 | // dw_conv src_scale = 1x1_conv dst_scale |
287 | if (!attr_1x1.scales_.get(DNNL_ARG_DST).has_default_values()) |
288 | arg_cache.append_ctx_arg( |
289 | DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, |
290 | DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
291 | if (op->weights_md(1)->data_type != data_type::undef) |
292 | arg_cache.append_ctx_arg(DNNL_ARG_BIAS, |
293 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); |
294 | for (int idx = 0; idx < attr_dw.post_ops_.len(); ++idx) { |
295 | if (attr_dw.post_ops_.contain(primitive_kind::binary, idx)) |
296 | arg_cache.append_ctx_arg( |
297 | (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
298 | | DNNL_ARG_SRC_1), |
299 | (DNNL_ARG_ATTR_MULTIPLE_POST_OP( |
300 | idx + po_op_iter + 1) |
301 | | DNNL_ARG_SRC_1)); |
302 | } |
303 | |
304 | args_.push_back(arg_cache); |
305 | |
306 | while (++po_op_iter < end) { |
307 | if (utils::one_of(po.entry_[po_op_iter].kind, |
308 | primitive_kind::convolution)) |
309 | break; |
310 | } |
311 | } |
312 | |
313 | assert(!op_pds_.empty()); |
314 | |
315 | CHECK(init_scratchpad_memory(inout_sp_offset_end)); |
316 | |
317 | return status::success; |
318 | } |
319 | |
320 | status_t init_scratchpad_memory(size_t inout_buffer_size) { |
321 | |
322 | auto scratchpad = scratchpad_registry().registrar(); |
323 | |
324 | scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, |
325 | inout_buffer_size, 1, 16); |
326 | scratchpad.book( |
327 | memory_tracking::names::key_fusion_forward_scratchpad, |
328 | user_scratchpad_size_, 1, 16); |
329 | return status::success; |
330 | } |
331 | |
332 | void init_name() { |
333 | for (const auto &op_pd : op_pds_) { |
334 | name_.append(":" ); |
335 | name_.append(op_pd->name()); |
336 | } |
337 | return; |
338 | } |
339 | }; |
340 | |
341 | ref_fused_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
342 | |
343 | status_t init(engine_t *engine) override { |
344 | const auto &op_pds = pd()->op_pds_; |
345 | for (auto &op_pd : op_pds) { |
346 | std::shared_ptr<primitive_t> p; |
347 | op_pd->create_primitive(p, engine); |
348 | primitives_.emplace_back(p); |
349 | } |
350 | return status::success; |
351 | } |
352 | |
353 | #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL |
354 | status_t create_resource( |
355 | engine_t *engine, resource_mapper_t &mapper) const override { |
356 | for (auto &p : primitives_) { |
357 | CHECK(p->create_resource(engine, mapper)); |
358 | } |
359 | return status::success; |
360 | } |
361 | #endif |
362 | |
363 | status_t execute(const exec_ctx_t &ctx) const override { |
364 | engine_t *engine = ctx.stream()->engine(); |
365 | const auto scratchpad = ctx.get_scratchpad_grantor(); |
366 | |
367 | const auto inout_buffer = scratchpad.get_memory_storage( |
368 | memory_tracking::names::key_fusion_inout_buffer); |
369 | |
370 | const auto &ctx_args = ctx.args(); |
371 | const auto op_count = primitives_.size(); |
372 | std::vector<std::unique_ptr<memory_t>> inout_memory; |
373 | |
374 | for (size_t i = 0; i < op_count; ++i) { |
375 | const auto &op = primitives_[i]; |
376 | const auto &arg_cache = pd()->args_[i]; |
377 | |
378 | exec_args_t exec_args; |
379 | |
380 | for (const auto &arg_info : arg_cache.info()) { |
381 | if (arg_info.is_ctx_arg) { |
382 | exec_args[arg_info.op_arg] = ctx_args.at(arg_info.ctx_arg); |
383 | } else { |
384 | inout_memory.emplace_back(new memory_t(engine, &arg_info.md, |
385 | inout_buffer->get_sub_storage(arg_info.offset, |
386 | memory_desc_wrapper(arg_info.md).size()))); |
387 | exec_args[arg_info.op_arg].mem = inout_memory.back().get(); |
388 | exec_args[arg_info.op_arg].is_const = arg_info.is_const; |
389 | } |
390 | } |
391 | |
392 | exec_ctx_t op_ctx(ctx, std::move(exec_args)); |
393 | |
394 | nested_scratchpad_t ns(ctx, |
395 | memory_tracking::names::key_fusion_forward_scratchpad, op); |
396 | op_ctx.set_scratchpad_grantor(ns.grantor()); |
397 | CHECK(op->execute(op_ctx)); |
398 | } |
399 | |
400 | return status::success; |
401 | } |
402 | |
403 | private: |
404 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
405 | std::vector<std::shared_ptr<primitive_t>> primitives_; |
406 | }; |
407 | |
408 | } // namespace cpu |
409 | } // namespace impl |
410 | } // namespace dnnl |
411 | |
412 | #endif |
413 | |
414 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
415 | |