1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_REF_CONCAT_HPP
18#define CPU_REF_CONCAT_HPP
19
20#include "common/engine.hpp"
21#include "common/primitive.hpp"
22#include "common/reorder.hpp"
23#include "common/reorder_pd.hpp"
24#include "common/stream.hpp"
25
26#include "cpu/cpu_concat_pd.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31
32struct ref_concat_t : public primitive_t {
33 struct pd_t : public cpu_concat_pd_t {
34 pd_t(const primitive_attr_t *attr, const memory_desc_t *dst_md, int n,
35 int concat_dim, const memory_desc_t *const *src_mds)
36 : cpu_concat_pd_t(attr, dst_md, n, concat_dim, src_mds)
37 , tent_dst_md_(types::zero_md()) {}
38 pd_t(const pd_t &rhs) = default;
39 ~pd_t() = default;
40
41 DECLARE_CONCAT_PD_T("ref:any", ref_concat_t);
42
43 status_t init(engine_t *engine) {
44 using sm = primitive_attr_t::skip_mask_t;
45 if (!attr()->has_default_values(sm::scales_runtime))
46 return status::unimplemented;
47 status_t status = cpu_concat_pd_t::init();
48 if (status != status::success) {
49 assert(dst_md_.format_kind != format_kind::undef);
50 status = memory_desc_init_by_strides(tent_dst_md_,
51 dst_md_.ndims, dst_md_.dims, dst_md_.data_type,
52 nullptr);
53 if (status != status::success) return status::unimplemented;
54
55 status = cpu_concat_pd_t::init(&tent_dst_md_);
56 if (status != status::success) return status::unimplemented;
57 }
58
59 const auto &sc = attr()->scales_;
60 reorder_pds_.resize(n_ + use_tent_dst());
61 for (int i = 0; i < n_; ++i) {
62 primitive_attr_t r_attr;
63 if (!sc.get(DNNL_ARG_MULTIPLE_SRC + i).has_default_values()) {
64 int mask = 0;
65 CHECK(sc.get(DNNL_ARG_MULTIPLE_SRC + i, &mask, nullptr));
66 if (mask != 0) return status::unimplemented;
67 r_attr.scales_.set(DNNL_ARG_SRC, mask);
68 }
69 CHECK(reorder_primitive_desc_create(reorder_pds_[i], engine,
70 src_md(i), src_image_md(i), &r_attr));
71 }
72 if (use_tent_dst()) {
73 assert(tent_dst_md_.format_kind != format_kind::undef);
74 assert(dst_md_.format_kind != format_kind::undef);
75 CHECK(reorder_primitive_desc_create(
76 reorder_pds_[n_], engine, &tent_dst_md_, &dst_md_));
77 }
78 init_scratchpad();
79 return status;
80 }
81
82 // if dst is forced and cannot be used directly.
83 bool use_tent_dst() const { return !types::is_zero_md(&tent_dst_md_); }
84
85 std::vector<std::shared_ptr<primitive_desc_t>> reorder_pds_;
86 memory_desc_t tent_dst_md_;
87
88 private:
89 void init_scratchpad() {
90 using namespace memory_tracking::names;
91 auto scratchpad = scratchpad_registry().registrar();
92 if (use_tent_dst()) {
93 const memory_desc_wrapper tent_dst_d(&tent_dst_md_);
94 scratchpad.book(memory_tracking::names::key_concat_tent_dst,
95 tent_dst_d.size(), 1, tent_dst_d.data_type_size());
96 }
97
98 for (size_t i = 0; i < reorder_pds_.size(); i++) {
99 scratchpad.book(key_nested_multiple + (int)i,
100 reorder_pds_[i]->scratchpad_registry());
101 }
102 }
103 };
104
105 ref_concat_t(const pd_t *apd) : primitive_t(apd) {}
106
107 status_t init(engine_t *engine) override {
108 const size_t n = pd()->reorder_pds_.size();
109 reorders_.resize(n);
110 for (size_t i = 0; i < n; ++i)
111 pd()->reorder_pds_[i]->create_primitive(reorders_[i], engine);
112 return status::success;
113 }
114
115 ~ref_concat_t() = default;
116
117 status_t execute(const exec_ctx_t &ctx) const override {
118 using namespace memory_tracking::names;
119 engine_t *engine = ctx.stream()->engine();
120 const auto n = pd()->n_inputs();
121
122 auto execute_reorder = [&](const std::shared_ptr<primitive_t> &reorder,
123 const memory_arg_t &src,
124 const memory_arg_t &dst,
125 const memory_arg_t *src_scales,
126 int r_num) {
127 exec_args_t r_args;
128 r_args[DNNL_ARG_SRC] = src;
129 r_args[DNNL_ARG_DST] = dst;
130 if (src_scales)
131 r_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = *src_scales;
132 exec_ctx_t r_ctx(ctx, std::move(r_args));
133
134 nested_scratchpad_t ns(ctx, key_nested_multiple + r_num, reorder);
135 r_ctx.set_scratchpad_grantor(ns.grantor());
136 reorder->execute(r_ctx);
137 };
138
139 if (pd()->use_tent_dst()) {
140 using namespace memory_tracking::names;
141 auto scratchpad = ctx.get_scratchpad_grantor();
142 auto tent_dst_storage
143 = scratchpad.get_memory_storage(key_concat_tent_dst);
144
145 for (int i = 0; i < n; ++i) {
146 memory_t tent_dst_i(engine, pd()->src_image_md(i),
147 tent_dst_storage->clone());
148 const auto &src_scales_arg = ctx.args().find(
149 DNNL_ARG_ATTR_SCALES | (DNNL_ARG_MULTIPLE_SRC + i));
150 const memory_arg_t *src_scales = nullptr;
151 if (src_scales_arg != ctx.args().end())
152 src_scales = &src_scales_arg->second;
153 execute_reorder(reorders_[i],
154 ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i),
155 {&tent_dst_i, false}, src_scales, i);
156 }
157
158 memory_t tent_dst(
159 engine, &pd()->tent_dst_md_, tent_dst_storage->clone());
160 execute_reorder(reorders_[n], {&tent_dst, true},
161 ctx.args().at(DNNL_ARG_DST), nullptr, n);
162 } else {
163 auto &dst_mem_storage = CTX_OUT_STORAGE(DNNL_ARG_DST);
164 for (int i = 0; i < n; ++i) {
165 memory_t tent_dst_i(
166 engine, pd()->src_image_md(i), dst_mem_storage.clone());
167 const auto &src_scales_arg = ctx.args().find(
168 DNNL_ARG_ATTR_SCALES | (DNNL_ARG_MULTIPLE_SRC + i));
169 const memory_arg_t *src_scales = nullptr;
170 if (src_scales_arg != ctx.args().end())
171 src_scales = &src_scales_arg->second;
172 execute_reorder(reorders_[i],
173 ctx.args().at(DNNL_ARG_MULTIPLE_SRC + i),
174 {&tent_dst_i, false}, src_scales, i);
175 }
176 }
177 return status::success;
178 }
179
180private:
181 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
182 std::vector<std::shared_ptr<primitive_t>> reorders_;
183};
184
185} // namespace cpu
186} // namespace impl
187} // namespace dnnl
188
189#endif
190