1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_REF_CONCAT_HPP |
18 | #define CPU_REF_CONCAT_HPP |
19 | |
20 | #include "common/engine.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "common/reorder.hpp" |
23 | #include "common/reorder_pd.hpp" |
24 | #include "common/stream.hpp" |
25 | |
26 | #include "cpu/cpu_concat_pd.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | |
32 | struct ref_concat_t : public primitive_t { |
33 | struct pd_t : public cpu_concat_pd_t { |
34 | pd_t(const primitive_attr_t *attr, const memory_desc_t *dst_md, int n, |
35 | int concat_dim, const memory_desc_t *const *src_mds) |
36 | : cpu_concat_pd_t(attr, dst_md, n, concat_dim, src_mds) |
37 | , tent_dst_md_(types::zero_md()) {} |
38 | pd_t(const pd_t &rhs) = default; |
39 | ~pd_t() = default; |
40 | |
41 | DECLARE_CONCAT_PD_T("ref:any" , ref_concat_t); |
42 | |
43 | status_t init(engine_t *engine) { |
44 | using sm = primitive_attr_t::skip_mask_t; |
45 | if (!attr()->has_default_values(sm::scales_runtime)) |
46 | return status::unimplemented; |
47 | status_t status = cpu_concat_pd_t::init(); |
48 | if (status != status::success) { |
49 | assert(dst_md_.format_kind != format_kind::undef); |
50 | status = memory_desc_init_by_strides(tent_dst_md_, |
51 | dst_md_.ndims, dst_md_.dims, dst_md_.data_type, |
52 | nullptr); |
53 | if (status != status::success) return status::unimplemented; |
54 | |
55 | status = cpu_concat_pd_t::init(&tent_dst_md_); |
56 | if (status != status::success) return status::unimplemented; |
57 | } |
58 | |
59 | const auto &sc = attr()->scales_; |
60 | reorder_pds_.resize(n_ + use_tent_dst()); |
61 | for (int i = 0; i < n_; ++i) { |
62 | primitive_attr_t r_attr; |
63 | if (!sc.get(DNNL_ARG_MULTIPLE_SRC + i).has_default_values()) { |
64 | int mask = 0; |
65 | CHECK(sc.get(DNNL_ARG_MULTIPLE_SRC + i, &mask, nullptr)); |
66 | if (mask != 0) return status::unimplemented; |
67 | r_attr.scales_.set(DNNL_ARG_SRC, mask); |
68 | } |
69 | CHECK(reorder_primitive_desc_create(reorder_pds_[i], engine, |
70 | src_md(i), src_image_md(i), &r_attr)); |
71 | } |
72 | if (use_tent_dst()) { |
73 | assert(tent_dst_md_.format_kind != format_kind::undef); |
74 | assert(dst_md_.format_kind != format_kind::undef); |
75 | CHECK(reorder_primitive_desc_create( |
76 | reorder_pds_[n_], engine, &tent_dst_md_, &dst_md_)); |
77 | } |
78 | init_scratchpad(); |
79 | return status; |
80 | } |
81 | |
82 | // if dst is forced and cannot be used directly. |
83 | bool use_tent_dst() const { return !types::is_zero_md(&tent_dst_md_); } |
84 | |
85 | std::vector<std::shared_ptr<primitive_desc_t>> reorder_pds_; |
86 | memory_desc_t tent_dst_md_; |
87 | |
88 | private: |
89 | void init_scratchpad() { |
90 | using namespace memory_tracking::names; |
91 | auto scratchpad = scratchpad_registry().registrar(); |
92 | if (use_tent_dst()) { |
93 | const memory_desc_wrapper tent_dst_d(&tent_dst_md_); |
94 | scratchpad.book(memory_tracking::names::key_concat_tent_dst, |
95 | tent_dst_d.size(), 1, tent_dst_d.data_type_size()); |
96 | } |
97 | |
98 | for (size_t i = 0; i < reorder_pds_.size(); i++) { |
99 | scratchpad.book(key_nested_multiple + (int)i, |
100 | reorder_pds_[i]->scratchpad_registry()); |
101 | } |
102 | } |
103 | }; |
104 | |
105 | ref_concat_t(const pd_t *apd) : primitive_t(apd) {} |
106 | |
107 | status_t init(engine_t *engine) override { |
108 | const size_t n = pd()->reorder_pds_.size(); |
109 | reorders_.resize(n); |
110 | for (size_t i = 0; i < n; ++i) |
111 | pd()->reorder_pds_[i]->create_primitive(reorders_[i], engine); |
112 | return status::success; |
113 | } |
114 | |
115 | ~ref_concat_t() = default; |
116 | |
117 | status_t execute(const exec_ctx_t &ctx) const override { |
118 | using namespace memory_tracking::names; |
119 | engine_t *engine = ctx.stream()->engine(); |
120 | const auto n = pd()->n_inputs(); |
121 | |
122 | auto execute_reorder = [&](const std::shared_ptr<primitive_t> &reorder, |
123 | const memory_arg_t &src, |
124 | const memory_arg_t &dst, |
125 | const memory_arg_t *src_scales, |
126 | int r_num) { |
127 | exec_args_t r_args; |
128 | r_args[DNNL_ARG_SRC] = src; |
129 | r_args[DNNL_ARG_DST] = dst; |
130 | if (src_scales) |
131 | r_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = *src_scales; |
132 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
133 | |
134 | nested_scratchpad_t ns(ctx, key_nested_multiple + r_num, reorder); |
135 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
136 | reorder->execute(r_ctx); |
137 | }; |
138 | |
139 | if (pd()->use_tent_dst()) { |
140 | using namespace memory_tracking::names; |
141 | auto scratchpad = ctx.get_scratchpad_grantor(); |
142 | auto tent_dst_storage |
143 | = scratchpad.get_memory_storage(key_concat_tent_dst); |
144 | |
145 | for (int i = 0; i < n; ++i) { |
146 | memory_t tent_dst_i(engine, pd()->src_image_md(i), |
147 | tent_dst_storage->clone()); |
148 | const auto &src_scales_arg = ctx.args().find( |
149 | DNNL_ARG_ATTR_SCALES | (DNNL_ARG_MULTIPLE_SRC + i)); |
150 | const memory_arg_t *src_scales = nullptr; |
151 | if (src_scales_arg != ctx.args().end()) |
152 | src_scales = &src_scales_arg->second; |
153 | execute_reorder(reorders_[i], |
154 | ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i), |
155 | {&tent_dst_i, false}, src_scales, i); |
156 | } |
157 | |
158 | memory_t tent_dst( |
159 | engine, &pd()->tent_dst_md_, tent_dst_storage->clone()); |
160 | execute_reorder(reorders_[n], {&tent_dst, true}, |
161 | ctx.args().at(DNNL_ARG_DST), nullptr, n); |
162 | } else { |
163 | auto &dst_mem_storage = CTX_OUT_STORAGE(DNNL_ARG_DST); |
164 | for (int i = 0; i < n; ++i) { |
165 | memory_t tent_dst_i( |
166 | engine, pd()->src_image_md(i), dst_mem_storage.clone()); |
167 | const auto &src_scales_arg = ctx.args().find( |
168 | DNNL_ARG_ATTR_SCALES | (DNNL_ARG_MULTIPLE_SRC + i)); |
169 | const memory_arg_t *src_scales = nullptr; |
170 | if (src_scales_arg != ctx.args().end()) |
171 | src_scales = &src_scales_arg->second; |
172 | execute_reorder(reorders_[i], |
173 | ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i), |
174 | {&tent_dst_i, false}, src_scales, i); |
175 | } |
176 | } |
177 | return status::success; |
178 | } |
179 | |
180 | private: |
181 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
182 | std::vector<std::shared_ptr<primitive_t>> reorders_; |
183 | }; |
184 | |
185 | } // namespace cpu |
186 | } // namespace impl |
187 | } // namespace dnnl |
188 | |
189 | #endif |
190 | |