1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_MULTI_CONCAT_HPP
18#define GPU_OCL_MULTI_CONCAT_HPP
19
20#include "common/concat.hpp"
21#include "common/engine.hpp"
22#include "common/primitive.hpp"
23#include "common/primitive_desc.hpp"
24#include "common/stream.hpp"
25#include "gpu/gpu_concat_pd.hpp"
26#include "gpu/gpu_primitive.hpp"
27#include "gpu/ocl/ocl_utils.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace gpu {
32namespace ocl {
33
34struct multi_concat_t : public gpu_primitive_t {
35 using gpu_primitive_t::gpu_primitive_t;
36 struct pd_t : public gpu_concat_pd_t {
37 using gpu_concat_pd_t::gpu_concat_pd_t;
38
39 pd_t(const pd_t &rhs) = default;
40 ~pd_t() = default;
41
42 DECLARE_CONCAT_PD_T("multi:any", multi_concat_t);
43
44 int max_batch_size() const {
45 if (n_inputs() > 64) return 64;
46 if (n_inputs() > 16) return 16;
47 return 0;
48 }
49
50 status_t init(engine_t *engine) {
51 if (max_batch_size() == 0) return status::unimplemented;
52
53 auto n_batches = utils::div_up(n_inputs(), max_batch_size());
54 concat_pds_.resize(n_batches);
55 dst_chunk_mds_.resize(n_batches);
56
57 dim_t concat_dim_offset = 0;
58 const auto ndims = dst_md()->ndims;
59 status_t status = status::success;
60 for (int i = 0; i < n_batches; ++i) {
61 const auto src_offset = max_batch_size() * i;
62 const auto remaining = n_inputs() - src_offset;
63 const auto batch_size = std::min(max_batch_size(), remaining);
64 dim_t batch_width = 0;
65 dims_t dims, offsets = {0};
66 utils::array_copy(dims, dst_md()->dims, ndims);
67 for (int j = 0; j < batch_size; ++j) {
68 const auto &src = src_md(src_offset + j);
69 batch_width += src->dims[concat_dim_];
70 }
71 dims[concat_dim_] = batch_width;
72 offsets[concat_dim_] = concat_dim_offset;
73 status = memory_desc_init_submemory(
74 dst_chunk_mds_[i], *dst_md(), dims, offsets);
75 if (status != status::success) {
76 concat_pds_.clear();
77 dst_chunk_mds_.clear();
78 return status;
79 }
80 status = concat_primitive_desc_create(concat_pds_[i], engine,
81 &dst_chunk_mds_[i], batch_size, concat_dim_,
82 src_md(src_offset), attr());
83 if (status != status::success) {
84 concat_pds_.clear();
85 dst_chunk_mds_.clear();
86 return status;
87 }
88 concat_dim_offset += batch_width;
89 }
90 return status;
91 }
92
93 std::vector<std::shared_ptr<primitive_desc_t>> concat_pds_;
94 std::vector<memory_desc_t> dst_chunk_mds_;
95 };
96
97 status_t init(engine_t *engine) override {
98 const auto &pds = pd()->concat_pds_;
99 const size_t n = pds.size();
100 concats_.resize(n);
101 for (size_t i = 0; i < n; ++i)
102 CHECK(create_nested_primitive(concats_[i], pds[i], engine));
103 return status::success;
104 }
105
106 status_t execute(const exec_ctx_t &ctx) const override {
107 using namespace memory_tracking::names;
108 const auto n = pd()->n_inputs();
109 const auto max_batch_size = pd()->max_batch_size();
110
111 auto execute_concat = [&](const std::shared_ptr<primitive_t> &concat,
112 int c_num, int n_inputs) {
113 exec_args_t r_args;
114 const auto arg_offset = DNNL_ARG_MULTIPLE_SRC;
115 for (int i = 0; i < n_inputs; ++i)
116 r_args[arg_offset + i] = ctx.args().at(
117 arg_offset + max_batch_size * c_num + i);
118 r_args[DNNL_ARG_DST] = ctx.args().at(DNNL_ARG_DST);
119 exec_ctx_t r_ctx(ctx, std::move(r_args));
120
121 nested_scratchpad_t ns(ctx, key_nested_multiple + c_num, concat);
122 r_ctx.set_scratchpad_grantor(ns.grantor());
123 return concat->execute(r_ctx);
124 };
125
126 const auto n_batches = utils::div_up(n, max_batch_size);
127 for (int i = 0; i < n_batches; ++i) {
128 const auto remaining = n - max_batch_size * i;
129 const auto batch_size = std::min(max_batch_size, remaining);
130 CHECK(execute_concat(concats_[i], i, batch_size));
131 }
132 return status::success;
133 }
134
135private:
136 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
137 std::vector<std::shared_ptr<primitive_t>> concats_;
138};
139
140} // namespace ocl
141} // namespace gpu
142} // namespace impl
143} // namespace dnnl
144
145#endif
146