1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_MULTI_CONCAT_HPP |
18 | #define GPU_OCL_MULTI_CONCAT_HPP |
19 | |
20 | #include "common/concat.hpp" |
21 | #include "common/engine.hpp" |
22 | #include "common/primitive.hpp" |
23 | #include "common/primitive_desc.hpp" |
24 | #include "common/stream.hpp" |
25 | #include "gpu/gpu_concat_pd.hpp" |
26 | #include "gpu/gpu_primitive.hpp" |
27 | #include "gpu/ocl/ocl_utils.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace gpu { |
32 | namespace ocl { |
33 | |
34 | struct multi_concat_t : public gpu_primitive_t { |
35 | using gpu_primitive_t::gpu_primitive_t; |
36 | struct pd_t : public gpu_concat_pd_t { |
37 | using gpu_concat_pd_t::gpu_concat_pd_t; |
38 | |
39 | pd_t(const pd_t &rhs) = default; |
40 | ~pd_t() = default; |
41 | |
42 | DECLARE_CONCAT_PD_T("multi:any" , multi_concat_t); |
43 | |
44 | int max_batch_size() const { |
45 | if (n_inputs() > 64) return 64; |
46 | if (n_inputs() > 16) return 16; |
47 | return 0; |
48 | } |
49 | |
50 | status_t init(engine_t *engine) { |
51 | if (max_batch_size() == 0) return status::unimplemented; |
52 | |
53 | auto n_batches = utils::div_up(n_inputs(), max_batch_size()); |
54 | concat_pds_.resize(n_batches); |
55 | dst_chunk_mds_.resize(n_batches); |
56 | |
57 | dim_t concat_dim_offset = 0; |
58 | const auto ndims = dst_md()->ndims; |
59 | status_t status = status::success; |
60 | for (int i = 0; i < n_batches; ++i) { |
61 | const auto src_offset = max_batch_size() * i; |
62 | const auto remaining = n_inputs() - src_offset; |
63 | const auto batch_size = std::min(max_batch_size(), remaining); |
64 | dim_t batch_width = 0; |
65 | dims_t dims, offsets = {0}; |
66 | utils::array_copy(dims, dst_md()->dims, ndims); |
67 | for (int j = 0; j < batch_size; ++j) { |
68 | const auto &src = src_md(src_offset + j); |
69 | batch_width += src->dims[concat_dim_]; |
70 | } |
71 | dims[concat_dim_] = batch_width; |
72 | offsets[concat_dim_] = concat_dim_offset; |
73 | status = memory_desc_init_submemory( |
74 | dst_chunk_mds_[i], *dst_md(), dims, offsets); |
75 | if (status != status::success) { |
76 | concat_pds_.clear(); |
77 | dst_chunk_mds_.clear(); |
78 | return status; |
79 | } |
80 | status = concat_primitive_desc_create(concat_pds_[i], engine, |
81 | &dst_chunk_mds_[i], batch_size, concat_dim_, |
82 | src_md(src_offset), attr()); |
83 | if (status != status::success) { |
84 | concat_pds_.clear(); |
85 | dst_chunk_mds_.clear(); |
86 | return status; |
87 | } |
88 | concat_dim_offset += batch_width; |
89 | } |
90 | return status; |
91 | } |
92 | |
93 | std::vector<std::shared_ptr<primitive_desc_t>> concat_pds_; |
94 | std::vector<memory_desc_t> dst_chunk_mds_; |
95 | }; |
96 | |
97 | status_t init(engine_t *engine) override { |
98 | const auto &pds = pd()->concat_pds_; |
99 | const size_t n = pds.size(); |
100 | concats_.resize(n); |
101 | for (size_t i = 0; i < n; ++i) |
102 | CHECK(create_nested_primitive(concats_[i], pds[i], engine)); |
103 | return status::success; |
104 | } |
105 | |
106 | status_t execute(const exec_ctx_t &ctx) const override { |
107 | using namespace memory_tracking::names; |
108 | const auto n = pd()->n_inputs(); |
109 | const auto max_batch_size = pd()->max_batch_size(); |
110 | |
111 | auto execute_concat = [&](const std::shared_ptr<primitive_t> &concat, |
112 | int c_num, int n_inputs) { |
113 | exec_args_t r_args; |
114 | const auto arg_offset = DNNL_ARG_MULTIPLE_SRC; |
115 | for (int i = 0; i < n_inputs; ++i) |
116 | r_args[arg_offset + i] = ctx.args().at( |
117 | arg_offset + max_batch_size * c_num + i); |
118 | r_args[DNNL_ARG_DST] = ctx.args().at(DNNL_ARG_DST); |
119 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
120 | |
121 | nested_scratchpad_t ns(ctx, key_nested_multiple + c_num, concat); |
122 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
123 | return concat->execute(r_ctx); |
124 | }; |
125 | |
126 | const auto n_batches = utils::div_up(n, max_batch_size); |
127 | for (int i = 0; i < n_batches; ++i) { |
128 | const auto remaining = n - max_batch_size * i; |
129 | const auto batch_size = std::min(max_batch_size, remaining); |
130 | CHECK(execute_concat(concats_[i], i, batch_size)); |
131 | } |
132 | return status::success; |
133 | } |
134 | |
135 | private: |
136 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
137 | std::vector<std::shared_ptr<primitive_t>> concats_; |
138 | }; |
139 | |
140 | } // namespace ocl |
141 | } // namespace gpu |
142 | } // namespace impl |
143 | } // namespace dnnl |
144 | |
145 | #endif |
146 | |