1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_REF_CONCAT_HPP |
18 | #define GPU_OCL_REF_CONCAT_HPP |
19 | |
20 | #include "common/engine.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "common/reorder.hpp" |
23 | #include "common/reorder_pd.hpp" |
24 | #include "common/stream.hpp" |
25 | #include "gpu/gpu_concat_pd.hpp" |
26 | #include "gpu/gpu_primitive.hpp" |
27 | #include "gpu/ocl/ocl_utils.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace gpu { |
32 | namespace ocl { |
33 | |
34 | struct ref_concat_t : public gpu_primitive_t { |
35 | using gpu_primitive_t::gpu_primitive_t; |
36 | struct pd_t : public gpu_concat_pd_t { |
37 | pd_t(const primitive_attr_t *attr, const memory_desc_t *dst_md, int n, |
38 | int concat_dim, const memory_desc_t *const *src_mds) |
39 | : gpu_concat_pd_t(attr, dst_md, n, concat_dim, src_mds) |
40 | , tent_dst_md_(types::zero_md()) {} |
41 | |
42 | pd_t(const pd_t &rhs) = default; |
43 | ~pd_t() = default; |
44 | |
45 | DECLARE_CONCAT_PD_T("ref:any" , ref_concat_t); |
46 | |
47 | status_t init(engine_t *engine) { |
48 | using sm = primitive_attr_t::skip_mask_t; |
49 | if (!attr()->has_default_values(sm::scales_runtime)) |
50 | return status::unimplemented; |
51 | status_t status = gpu_concat_pd_t::init(); |
52 | if (status != status::success) { |
53 | assert(dst_md_.format_kind != format_kind::undef); |
54 | status = memory_desc_init_by_strides(tent_dst_md_, |
55 | dst_md_.ndims, dst_md_.dims, dst_md_.data_type, |
56 | nullptr); |
57 | if (status != status::success) return status::unimplemented; |
58 | |
59 | status = gpu_concat_pd_t::init(&tent_dst_md_); |
60 | if (status != status::success) return status::unimplemented; |
61 | } |
62 | |
63 | const auto &sc = attr()->scales_; |
64 | reorder_pds_.resize(n_ + use_tent_dst()); |
65 | for (int i = 0; i < n_; ++i) { |
66 | primitive_attr_t r_attr; |
67 | int mask = 0; |
68 | bool is_set = false; |
69 | CHECK(sc.get(DNNL_ARG_MULTIPLE_SRC + i, &mask, &is_set)); |
70 | if (is_set) { |
71 | if (mask != 0) return status::unimplemented; |
72 | r_attr.scales_.set(DNNL_ARG_SRC, mask); |
73 | } |
74 | CHECK(reorder_primitive_desc_create(reorder_pds_[i], engine, |
75 | src_md(i), src_image_md(i), &r_attr)); |
76 | } |
77 | |
78 | if (use_tent_dst()) { |
79 | assert(tent_dst_md_.format_kind != format_kind::undef); |
80 | assert(dst_md_.format_kind != format_kind::undef); |
81 | CHECK(reorder_primitive_desc_create( |
82 | reorder_pds_[n_], engine, &tent_dst_md_, &dst_md_)); |
83 | } |
84 | init_scratchpad(); |
85 | return status; |
86 | } |
87 | |
88 | // if dst is forced and cannot be used directly. |
89 | bool use_tent_dst() const { return !types::is_zero_md(&tent_dst_md_); } |
90 | |
91 | std::vector<std::shared_ptr<primitive_desc_t>> reorder_pds_; |
92 | memory_desc_t tent_dst_md_; |
93 | |
94 | private: |
95 | void init_scratchpad() { |
96 | auto scratchpad = scratchpad_registry().registrar(); |
97 | |
98 | if (use_tent_dst()) { |
99 | const memory_desc_wrapper wtent_dst_md(tent_dst_md_); |
100 | scratchpad.book(memory_tracking::names::key_concat_tent_dst, |
101 | wtent_dst_md.size(), 1, OCL_BUFFER_ALIGNMENT); |
102 | } |
103 | |
104 | for (size_t i = 0; i < reorder_pds_.size(); i++) { |
105 | scratchpad.book( |
106 | memory_tracking::names::key_nested_multiple + (int)i, |
107 | reorder_pds_[i]->scratchpad_registry()); |
108 | } |
109 | } |
110 | }; |
111 | |
112 | status_t init(engine_t *engine) override { |
113 | const size_t n = pd()->reorder_pds_.size(); |
114 | reorders_.resize(n); |
115 | for (size_t i = 0; i < n; ++i) { |
116 | CHECK(create_nested_primitive( |
117 | reorders_[i], pd()->reorder_pds_[i], engine)); |
118 | } |
119 | return status::success; |
120 | } |
121 | |
122 | status_t execute(const exec_ctx_t &ctx) const override { |
123 | using namespace memory_tracking::names; |
124 | engine_t *engine = ctx.stream()->engine(); |
125 | const auto n = pd()->n_inputs(); |
126 | |
127 | auto execute_reorder = [&](const std::shared_ptr<primitive_t> &reorder, |
128 | const memory_arg_t &src, |
129 | const memory_arg_t &dst, |
130 | const memory_arg_t *src_scales, |
131 | int r_num) { |
132 | exec_args_t r_args; |
133 | r_args[DNNL_ARG_SRC] = src; |
134 | r_args[DNNL_ARG_DST] = dst; |
135 | if (src_scales) |
136 | r_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = *src_scales; |
137 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
138 | |
139 | nested_scratchpad_t ns(ctx, key_nested_multiple + r_num, reorder); |
140 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
141 | return reorder->execute(r_ctx); |
142 | }; |
143 | |
144 | if (pd()->use_tent_dst()) { |
145 | auto scratchpad = ctx.get_scratchpad_grantor().get_memory_storage( |
146 | memory_tracking::names::key_concat_tent_dst); |
147 | |
148 | memory_t tent_dst( |
149 | engine, &pd()->tent_dst_md_, std::move(scratchpad)); |
150 | |
151 | for (int i = 0; i < n; ++i) { |
152 | const auto &src_scales_arg = ctx.args().find( |
153 | DNNL_ARG_ATTR_SCALES | (DNNL_ARG_MULTIPLE_SRC + i)); |
154 | |
155 | const memory_arg_t *src_scales = nullptr; |
156 | if (src_scales_arg != ctx.args().end()) |
157 | src_scales = &src_scales_arg->second; |
158 | CHECK(execute_reorder(reorders_[i], |
159 | ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i), |
160 | {&tent_dst, false}, src_scales, i)); |
161 | } |
162 | |
163 | CHECK(execute_reorder(reorders_[n], {&tent_dst, true}, |
164 | ctx.args().at(DNNL_ARG_DST), nullptr, n)); |
165 | } else { |
166 | for (int i = 0; i < n; ++i) { |
167 | const auto &src_scales_arg = ctx.args().find( |
168 | DNNL_ARG_ATTR_SCALES | (DNNL_ARG_MULTIPLE_SRC + i)); |
169 | |
170 | const memory_arg_t *src_scales = nullptr; |
171 | if (src_scales_arg != ctx.args().end()) { |
172 | src_scales = &src_scales_arg->second; |
173 | } |
174 | CHECK(execute_reorder(reorders_[i], |
175 | ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i), |
176 | ctx.args().at(DNNL_ARG_DST), src_scales, i)); |
177 | } |
178 | } |
179 | |
180 | return status::success; |
181 | } |
182 | |
183 | private: |
184 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
185 | std::vector<std::shared_ptr<primitive_t>> reorders_; |
186 | }; |
187 | |
188 | } // namespace ocl |
189 | } // namespace gpu |
190 | } // namespace impl |
191 | } // namespace dnnl |
192 | |
193 | #endif |
194 | |