1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/convolution_inner_product.hpp"
18#include "common/c_types_map.hpp"
19#include "common/convolution_pd.hpp"
20#include "common/reorder.hpp"
21
22using namespace dnnl::impl::memory_tracking;
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace ocl {
28
29static int adjust_dims(dims_t &dims, const memory_desc_t *dst, int ndims) {
30 utils::array_copy(&dims[0], &dst->dims[0], dst->ndims);
31 int max_dims = nstl::max(3, nstl::max(ndims, dst->ndims));
32 utils::array_set(&dims[dst->ndims], 1, max_dims - dst->ndims);
33 return max_dims;
34}
35
36status_t convolution_inner_product_fwd_t::pd_t::init_conf(engine_t *engine) {
37 const inner_product_desc_t &ipd = *desc();
38
39 const auto *src_md = invariant_src_md();
40 const auto *wei_md = invariant_wei_md();
41 const auto *dst_md = invariant_dst_md();
42
43 convolution_desc_t cd;
44 memory_desc_t conv_src_md, conv_wei_md, conv_dst_md, ip_dst_md;
45
46 conf.ndims = src_md->ndims;
47 conf.attr_info = attr_info_t::create(attr());
48
49 dims_t dims;
50
51 int max_dims = adjust_dims(dims, dst_md, conf.ndims);
52
53 memory_desc_init_by_tag(
54 conv_dst_md, max_dims, dims, dst_md->data_type, format_tag::any);
55
56 auto init_md = [&](memory_desc_t &out_md, const memory_desc_t *in_md) {
57 max_dims = adjust_dims(dims, in_md, conf.ndims);
58 if (in_md->format_kind == format_kind::any) {
59 memory_desc_init_by_tag(
60 out_md, max_dims, dims, in_md->data_type, format_tag::any);
61 } else {
62 out_md = *in_md;
63 out_md.ndims = max_dims;
64
65 utils::array_copy(&out_md.dims[0], &dims[0], max_dims);
66 }
67 };
68
69 init_md(conv_src_md, src_md);
70 init_md(conv_wei_md, wei_md);
71
72 dim_t strides[] = {1, 1, 1};
73 dim_t padding[] = {0, 0, 0};
74 dim_t padding_r[] = {0, 0, 0};
75
76 alg_kind_t alg = alg_kind::convolution_direct;
77 CHECK(conv_desc_init(&cd, ipd.prop_kind, alg, &conv_src_md, &conv_wei_md,
78 invariant_bia_md(), &conv_dst_md, &strides[0], nullptr, &padding[0],
79 &padding_r[0]));
80
81 primitive_attr_t conv_attr(*attr());
82 if (!conv_attr.is_initialized()) return status::out_of_memory;
83
84 primitive_desc_iterator_t it(engine, (op_desc_t *)&cd, &conv_attr, nullptr);
85 if (!it.is_initialized()) return status::out_of_memory;
86 cpd_ = *(++it);
87 if (!cpd_) return status::unimplemented;
88 std::string impl_name(cpd_->name());
89 if (impl_name.find("ref") != std::string::npos)
90 return status::unimplemented;
91
92 auto src_conv = *cpd_->src_md();
93 auto wei_conv = *cpd_->weights_md();
94 auto dst_conv = *cpd_->dst_md();
95
96 memory_desc_init_by_tag(ip_dst_md, conv_dst_md.ndims, conv_dst_md.dims,
97 dst_md->data_type,
98 utils::pick(conv_dst_md.ndims - 2, format_tag::nc, format_tag::ncw,
99 format_tag::nchw, format_tag::ncdhw));
100
101 if (dst_conv != ip_dst_md
102 && dst_conv.format_desc.blocking.inner_nblks > 0) {
103 conf.reorder_dst = true;
104 primitive_attr_t r_attr(default_attr());
105 if (!r_attr.is_initialized()) return status::out_of_memory;
106 CHECK(reorder_primitive_desc_create(
107 rpd_dst_, engine, &dst_conv, &ip_dst_md, &r_attr));
108
109 if (conf.attr_info.with_sum) {
110 primitive_attr_t r_attr(default_attr());
111 if (!r_attr.is_initialized()) return status::out_of_memory;
112 CHECK(reorder_primitive_desc_create(
113 rpd_postop_, engine, &ip_dst_md, &dst_conv, &r_attr));
114 }
115 }
116
117 if (src_md_.format_kind == format_kind::any) {
118 memory_desc_init_by_blocking_desc(
119 src_md_, src_conv.format_desc.blocking);
120 }
121 if (weights_md_.format_kind == format_kind::any) {
122 memory_desc_init_by_blocking_desc(
123 weights_md_, wei_conv.format_desc.blocking);
124 }
125
126 memory_desc_wrapper src_d(src_md_);
127 memory_desc_wrapper dst_d(dst_md_);
128 if (conv_src_md.format_desc.blocking.inner_nblks < 2
129 && conv_wei_md.format_desc.blocking.inner_nblks < 2
130 && src_d.size() + dst_d.size() >= 20000000)
131 return status::unimplemented;
132
133 return status::success;
134}
135
136status_t convolution_inner_product_fwd_t::pd_t::init_scratchpad() {
137 auto scratchpad = scratchpad_registry().registrar();
138 if (conf.reorder_dst) {
139 memory_desc_wrapper md_d(*cpd_->dst_md());
140 scratchpad.book(memory_tracking::names::key_iprod_dst_reorder,
141 md_d.size(), 1, OCL_BUFFER_ALIGNMENT);
142 scratchpad.book(memory_tracking::names::key_nested_multiple + 1,
143 rpd_dst_->scratchpad_registry());
144 if (conf.attr_info.with_sum)
145 scratchpad.book(memory_tracking::names::key_nested_multiple + 2,
146 rpd_postop_->scratchpad_registry());
147 }
148 scratchpad.book(memory_tracking::names::key_nested_multiple,
149 cpd_->scratchpad_registry());
150 return status::success;
151}
152
153status_t convolution_inner_product_fwd_t::execute_forward(
154 const exec_ctx_t &ctx) const {
155
156 const auto &conf = pd()->conf;
157
158 auto src = ctx.input(DNNL_ARG_SRC);
159 auto wei = ctx.input(DNNL_ARG_WEIGHTS);
160 auto bia = ctx.input(DNNL_ARG_BIAS);
161 auto dst = ctx.output(DNNL_ARG_DST);
162
163 std::unique_ptr<memory_t> wspace_dst;
164 auto exec_reorder = [&](memory_t *in, memory_t *out,
165 const std::shared_ptr<primitive_t> &prim,
166 int r_num) -> status_t {
167 exec_args_t r_args;
168 r_args[DNNL_ARG_FROM] = memory_arg_t {in, true};
169 r_args[DNNL_ARG_TO] = memory_arg_t {out, false};
170 exec_ctx_t r_ctx(ctx, std::move(r_args));
171 nested_scratchpad_t ns(
172 ctx, memory_tracking::names::key_nested_multiple + r_num, prim);
173 r_ctx.set_scratchpad_grantor(ns.grantor());
174 return prim->execute(r_ctx);
175 };
176 if (conf.reorder_dst) {
177 auto scratchpad = ctx.get_scratchpad_grantor().get_memory_storage(
178 memory_tracking::names::key_iprod_dst_reorder);
179 CHECK(safe_ptr_assign(wspace_dst,
180 new memory_t(ctx.stream()->engine(), pd()->cpd_->dst_md(),
181 std::move(scratchpad))));
182 }
183
184 if (pd()->conf.attr_info.with_sum && conf.reorder_dst) {
185 CHECK(exec_reorder(dst, wspace_dst.get(), postop_reorder_, 2));
186 }
187
188 exec_args_t c_args;
189 c_args[DNNL_ARG_SRC] = memory_arg_t {src, true};
190 c_args[DNNL_ARG_WEIGHTS] = memory_arg_t {wei, true};
191 c_args[DNNL_ARG_BIAS] = memory_arg_t {bia, true};
192 c_args[DNNL_ARG_DST]
193 = memory_arg_t {conf.reorder_dst ? wspace_dst.get() : dst, false};
194
195 const auto &args = ctx.args();
196 for (int idx = 0; idx < pd()->attr()->post_ops_.len(); ++idx) {
197 if (pd()->attr()->post_ops_.entry_[idx].is_binary()) {
198 c_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1]
199 = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
200 | DNNL_ARG_SRC_1);
201 }
202 }
203
204 exec_ctx_t c_ctx(ctx, std::move(c_args));
205 nested_scratchpad_t ns(
206 ctx, memory_tracking::names::key_nested_multiple, conv_);
207 c_ctx.set_scratchpad_grantor(ns.grantor());
208 CHECK(conv_->execute(c_ctx));
209
210 if (conf.reorder_dst) {
211 CHECK(exec_reorder(wspace_dst.get(), dst, dst_reorder_, 1));
212 }
213
214 return status::success;
215}
216
217} // namespace ocl
218} // namespace gpu
219} // namespace impl
220} // namespace dnnl
221