1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_CONCAT_HPP
18#define GPU_OCL_REF_CONCAT_HPP
19
20#include "common/engine.hpp"
21#include "common/primitive.hpp"
22#include "common/reorder.hpp"
23#include "common/reorder_pd.hpp"
24#include "common/stream.hpp"
25#include "gpu/gpu_concat_pd.hpp"
26#include "gpu/gpu_primitive.hpp"
27#include "gpu/ocl/ocl_utils.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace gpu {
32namespace ocl {
33
34struct ref_concat_t : public gpu_primitive_t {
35 using gpu_primitive_t::gpu_primitive_t;
36 struct pd_t : public gpu_concat_pd_t {
37 pd_t(const primitive_attr_t *attr, const memory_desc_t *dst_md, int n,
38 int concat_dim, const memory_desc_t *const *src_mds)
39 : gpu_concat_pd_t(attr, dst_md, n, concat_dim, src_mds)
40 , tent_dst_md_(types::zero_md()) {}
41
42 pd_t(const pd_t &rhs) = default;
43 ~pd_t() = default;
44
45 DECLARE_CONCAT_PD_T("ref:any", ref_concat_t);
46
47 status_t init(engine_t *engine) {
48 using sm = primitive_attr_t::skip_mask_t;
49 if (!attr()->has_default_values(sm::scales_runtime))
50 return status::unimplemented;
51 status_t status = gpu_concat_pd_t::init();
52 if (status != status::success) {
53 assert(dst_md_.format_kind != format_kind::undef);
54 status = memory_desc_init_by_strides(tent_dst_md_,
55 dst_md_.ndims, dst_md_.dims, dst_md_.data_type,
56 nullptr);
57 if (status != status::success) return status::unimplemented;
58
59 status = gpu_concat_pd_t::init(&tent_dst_md_);
60 if (status != status::success) return status::unimplemented;
61 }
62
63 const auto &sc = attr()->scales_;
64 reorder_pds_.resize(n_ + use_tent_dst());
65 for (int i = 0; i < n_; ++i) {
66 primitive_attr_t r_attr;
67 int mask = 0;
68 bool is_set = false;
69 CHECK(sc.get(DNNL_ARG_MULTIPLE_SRC + i, &mask, &is_set));
70 if (is_set) {
71 if (mask != 0) return status::unimplemented;
72 r_attr.scales_.set(DNNL_ARG_SRC, mask);
73 }
74 CHECK(reorder_primitive_desc_create(reorder_pds_[i], engine,
75 src_md(i), src_image_md(i), &r_attr));
76 }
77
78 if (use_tent_dst()) {
79 assert(tent_dst_md_.format_kind != format_kind::undef);
80 assert(dst_md_.format_kind != format_kind::undef);
81 CHECK(reorder_primitive_desc_create(
82 reorder_pds_[n_], engine, &tent_dst_md_, &dst_md_));
83 }
84 init_scratchpad();
85 return status;
86 }
87
88 // if dst is forced and cannot be used directly.
89 bool use_tent_dst() const { return !types::is_zero_md(&tent_dst_md_); }
90
91 std::vector<std::shared_ptr<primitive_desc_t>> reorder_pds_;
92 memory_desc_t tent_dst_md_;
93
94 private:
95 void init_scratchpad() {
96 auto scratchpad = scratchpad_registry().registrar();
97
98 if (use_tent_dst()) {
99 const memory_desc_wrapper wtent_dst_md(tent_dst_md_);
100 scratchpad.book(memory_tracking::names::key_concat_tent_dst,
101 wtent_dst_md.size(), 1, OCL_BUFFER_ALIGNMENT);
102 }
103
104 for (size_t i = 0; i < reorder_pds_.size(); i++) {
105 scratchpad.book(
106 memory_tracking::names::key_nested_multiple + (int)i,
107 reorder_pds_[i]->scratchpad_registry());
108 }
109 }
110 };
111
112 status_t init(engine_t *engine) override {
113 const size_t n = pd()->reorder_pds_.size();
114 reorders_.resize(n);
115 for (size_t i = 0; i < n; ++i) {
116 CHECK(create_nested_primitive(
117 reorders_[i], pd()->reorder_pds_[i], engine));
118 }
119 return status::success;
120 }
121
122 status_t execute(const exec_ctx_t &ctx) const override {
123 using namespace memory_tracking::names;
124 engine_t *engine = ctx.stream()->engine();
125 const auto n = pd()->n_inputs();
126
127 auto execute_reorder = [&](const std::shared_ptr<primitive_t> &reorder,
128 const memory_arg_t &src,
129 const memory_arg_t &dst,
130 const memory_arg_t *src_scales,
131 int r_num) {
132 exec_args_t r_args;
133 r_args[DNNL_ARG_SRC] = src;
134 r_args[DNNL_ARG_DST] = dst;
135 if (src_scales)
136 r_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = *src_scales;
137 exec_ctx_t r_ctx(ctx, std::move(r_args));
138
139 nested_scratchpad_t ns(ctx, key_nested_multiple + r_num, reorder);
140 r_ctx.set_scratchpad_grantor(ns.grantor());
141 return reorder->execute(r_ctx);
142 };
143
144 if (pd()->use_tent_dst()) {
145 auto scratchpad = ctx.get_scratchpad_grantor().get_memory_storage(
146 memory_tracking::names::key_concat_tent_dst);
147
148 memory_t tent_dst(
149 engine, &pd()->tent_dst_md_, std::move(scratchpad));
150
151 for (int i = 0; i < n; ++i) {
152 const auto &src_scales_arg = ctx.args().find(
153 DNNL_ARG_ATTR_SCALES | (DNNL_ARG_MULTIPLE_SRC + i));
154
155 const memory_arg_t *src_scales = nullptr;
156 if (src_scales_arg != ctx.args().end())
157 src_scales = &src_scales_arg->second;
158 CHECK(execute_reorder(reorders_[i],
159 ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i),
160 {&tent_dst, false}, src_scales, i));
161 }
162
163 CHECK(execute_reorder(reorders_[n], {&tent_dst, true},
164 ctx.args().at(DNNL_ARG_DST), nullptr, n));
165 } else {
166 for (int i = 0; i < n; ++i) {
167 const auto &src_scales_arg = ctx.args().find(
168 DNNL_ARG_ATTR_SCALES | (DNNL_ARG_MULTIPLE_SRC + i));
169
170 const memory_arg_t *src_scales = nullptr;
171 if (src_scales_arg != ctx.args().end()) {
172 src_scales = &src_scales_arg->second;
173 }
174 CHECK(execute_reorder(reorders_[i],
175 ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i),
176 ctx.args().at(DNNL_ARG_DST), src_scales, i));
177 }
178 }
179
180 return status::success;
181 }
182
183private:
184 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
185 std::vector<std::shared_ptr<primitive_t>> reorders_;
186};
187
188} // namespace ocl
189} // namespace gpu
190} // namespace impl
191} // namespace dnnl
192
193#endif
194