1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/convolution_inner_product.hpp" |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/convolution_pd.hpp" |
20 | #include "common/reorder.hpp" |
21 | |
22 | using namespace dnnl::impl::memory_tracking; |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace ocl { |
28 | |
29 | static int adjust_dims(dims_t &dims, const memory_desc_t *dst, int ndims) { |
30 | utils::array_copy(&dims[0], &dst->dims[0], dst->ndims); |
31 | int max_dims = nstl::max(3, nstl::max(ndims, dst->ndims)); |
32 | utils::array_set(&dims[dst->ndims], 1, max_dims - dst->ndims); |
33 | return max_dims; |
34 | } |
35 | |
36 | status_t convolution_inner_product_fwd_t::pd_t::init_conf(engine_t *engine) { |
37 | const inner_product_desc_t &ipd = *desc(); |
38 | |
39 | const auto *src_md = invariant_src_md(); |
40 | const auto *wei_md = invariant_wei_md(); |
41 | const auto *dst_md = invariant_dst_md(); |
42 | |
43 | convolution_desc_t cd; |
44 | memory_desc_t conv_src_md, conv_wei_md, conv_dst_md, ip_dst_md; |
45 | |
46 | conf.ndims = src_md->ndims; |
47 | conf.attr_info = attr_info_t::create(attr()); |
48 | |
49 | dims_t dims; |
50 | |
51 | int max_dims = adjust_dims(dims, dst_md, conf.ndims); |
52 | |
53 | memory_desc_init_by_tag( |
54 | conv_dst_md, max_dims, dims, dst_md->data_type, format_tag::any); |
55 | |
56 | auto init_md = [&](memory_desc_t &out_md, const memory_desc_t *in_md) { |
57 | max_dims = adjust_dims(dims, in_md, conf.ndims); |
58 | if (in_md->format_kind == format_kind::any) { |
59 | memory_desc_init_by_tag( |
60 | out_md, max_dims, dims, in_md->data_type, format_tag::any); |
61 | } else { |
62 | out_md = *in_md; |
63 | out_md.ndims = max_dims; |
64 | |
65 | utils::array_copy(&out_md.dims[0], &dims[0], max_dims); |
66 | } |
67 | }; |
68 | |
69 | init_md(conv_src_md, src_md); |
70 | init_md(conv_wei_md, wei_md); |
71 | |
72 | dim_t strides[] = {1, 1, 1}; |
73 | dim_t padding[] = {0, 0, 0}; |
74 | dim_t padding_r[] = {0, 0, 0}; |
75 | |
76 | alg_kind_t alg = alg_kind::convolution_direct; |
77 | CHECK(conv_desc_init(&cd, ipd.prop_kind, alg, &conv_src_md, &conv_wei_md, |
78 | invariant_bia_md(), &conv_dst_md, &strides[0], nullptr, &padding[0], |
79 | &padding_r[0])); |
80 | |
81 | primitive_attr_t conv_attr(*attr()); |
82 | if (!conv_attr.is_initialized()) return status::out_of_memory; |
83 | |
84 | primitive_desc_iterator_t it(engine, (op_desc_t *)&cd, &conv_attr, nullptr); |
85 | if (!it.is_initialized()) return status::out_of_memory; |
86 | cpd_ = *(++it); |
87 | if (!cpd_) return status::unimplemented; |
88 | std::string impl_name(cpd_->name()); |
89 | if (impl_name.find("ref" ) != std::string::npos) |
90 | return status::unimplemented; |
91 | |
92 | auto src_conv = *cpd_->src_md(); |
93 | auto wei_conv = *cpd_->weights_md(); |
94 | auto dst_conv = *cpd_->dst_md(); |
95 | |
96 | memory_desc_init_by_tag(ip_dst_md, conv_dst_md.ndims, conv_dst_md.dims, |
97 | dst_md->data_type, |
98 | utils::pick(conv_dst_md.ndims - 2, format_tag::nc, format_tag::ncw, |
99 | format_tag::nchw, format_tag::ncdhw)); |
100 | |
101 | if (dst_conv != ip_dst_md |
102 | && dst_conv.format_desc.blocking.inner_nblks > 0) { |
103 | conf.reorder_dst = true; |
104 | primitive_attr_t r_attr(default_attr()); |
105 | if (!r_attr.is_initialized()) return status::out_of_memory; |
106 | CHECK(reorder_primitive_desc_create( |
107 | rpd_dst_, engine, &dst_conv, &ip_dst_md, &r_attr)); |
108 | |
109 | if (conf.attr_info.with_sum) { |
110 | primitive_attr_t r_attr(default_attr()); |
111 | if (!r_attr.is_initialized()) return status::out_of_memory; |
112 | CHECK(reorder_primitive_desc_create( |
113 | rpd_postop_, engine, &ip_dst_md, &dst_conv, &r_attr)); |
114 | } |
115 | } |
116 | |
117 | if (src_md_.format_kind == format_kind::any) { |
118 | memory_desc_init_by_blocking_desc( |
119 | src_md_, src_conv.format_desc.blocking); |
120 | } |
121 | if (weights_md_.format_kind == format_kind::any) { |
122 | memory_desc_init_by_blocking_desc( |
123 | weights_md_, wei_conv.format_desc.blocking); |
124 | } |
125 | |
126 | memory_desc_wrapper src_d(src_md_); |
127 | memory_desc_wrapper dst_d(dst_md_); |
128 | if (conv_src_md.format_desc.blocking.inner_nblks < 2 |
129 | && conv_wei_md.format_desc.blocking.inner_nblks < 2 |
130 | && src_d.size() + dst_d.size() >= 20000000) |
131 | return status::unimplemented; |
132 | |
133 | return status::success; |
134 | } |
135 | |
136 | status_t convolution_inner_product_fwd_t::pd_t::init_scratchpad() { |
137 | auto scratchpad = scratchpad_registry().registrar(); |
138 | if (conf.reorder_dst) { |
139 | memory_desc_wrapper md_d(*cpd_->dst_md()); |
140 | scratchpad.book(memory_tracking::names::key_iprod_dst_reorder, |
141 | md_d.size(), 1, OCL_BUFFER_ALIGNMENT); |
142 | scratchpad.book(memory_tracking::names::key_nested_multiple + 1, |
143 | rpd_dst_->scratchpad_registry()); |
144 | if (conf.attr_info.with_sum) |
145 | scratchpad.book(memory_tracking::names::key_nested_multiple + 2, |
146 | rpd_postop_->scratchpad_registry()); |
147 | } |
148 | scratchpad.book(memory_tracking::names::key_nested_multiple, |
149 | cpd_->scratchpad_registry()); |
150 | return status::success; |
151 | } |
152 | |
153 | status_t convolution_inner_product_fwd_t::execute_forward( |
154 | const exec_ctx_t &ctx) const { |
155 | |
156 | const auto &conf = pd()->conf; |
157 | |
158 | auto src = ctx.input(DNNL_ARG_SRC); |
159 | auto wei = ctx.input(DNNL_ARG_WEIGHTS); |
160 | auto bia = ctx.input(DNNL_ARG_BIAS); |
161 | auto dst = ctx.output(DNNL_ARG_DST); |
162 | |
163 | std::unique_ptr<memory_t> wspace_dst; |
164 | auto exec_reorder = [&](memory_t *in, memory_t *out, |
165 | const std::shared_ptr<primitive_t> &prim, |
166 | int r_num) -> status_t { |
167 | exec_args_t r_args; |
168 | r_args[DNNL_ARG_FROM] = memory_arg_t {in, true}; |
169 | r_args[DNNL_ARG_TO] = memory_arg_t {out, false}; |
170 | exec_ctx_t r_ctx(ctx, std::move(r_args)); |
171 | nested_scratchpad_t ns( |
172 | ctx, memory_tracking::names::key_nested_multiple + r_num, prim); |
173 | r_ctx.set_scratchpad_grantor(ns.grantor()); |
174 | return prim->execute(r_ctx); |
175 | }; |
176 | if (conf.reorder_dst) { |
177 | auto scratchpad = ctx.get_scratchpad_grantor().get_memory_storage( |
178 | memory_tracking::names::key_iprod_dst_reorder); |
179 | CHECK(safe_ptr_assign(wspace_dst, |
180 | new memory_t(ctx.stream()->engine(), pd()->cpd_->dst_md(), |
181 | std::move(scratchpad)))); |
182 | } |
183 | |
184 | if (pd()->conf.attr_info.with_sum && conf.reorder_dst) { |
185 | CHECK(exec_reorder(dst, wspace_dst.get(), postop_reorder_, 2)); |
186 | } |
187 | |
188 | exec_args_t c_args; |
189 | c_args[DNNL_ARG_SRC] = memory_arg_t {src, true}; |
190 | c_args[DNNL_ARG_WEIGHTS] = memory_arg_t {wei, true}; |
191 | c_args[DNNL_ARG_BIAS] = memory_arg_t {bia, true}; |
192 | c_args[DNNL_ARG_DST] |
193 | = memory_arg_t {conf.reorder_dst ? wspace_dst.get() : dst, false}; |
194 | |
195 | const auto &args = ctx.args(); |
196 | for (int idx = 0; idx < pd()->attr()->post_ops_.len(); ++idx) { |
197 | if (pd()->attr()->post_ops_.entry_[idx].is_binary()) { |
198 | c_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1] |
199 | = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
200 | | DNNL_ARG_SRC_1); |
201 | } |
202 | } |
203 | |
204 | exec_ctx_t c_ctx(ctx, std::move(c_args)); |
205 | nested_scratchpad_t ns( |
206 | ctx, memory_tracking::names::key_nested_multiple, conv_); |
207 | c_ctx.set_scratchpad_grantor(ns.grantor()); |
208 | CHECK(conv_->execute(c_ctx)); |
209 | |
210 | if (conf.reorder_dst) { |
211 | CHECK(exec_reorder(wspace_dst.get(), dst, dst_reorder_, 1)); |
212 | } |
213 | |
214 | return status::success; |
215 | } |
216 | |
217 | } // namespace ocl |
218 | } // namespace gpu |
219 | } // namespace impl |
220 | } // namespace dnnl |
221 | |