1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_INNER_PRODUCT_HPP
18#define GPU_OCL_REF_INNER_PRODUCT_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "gpu/compute/compute.hpp"
25#include "gpu/gpu_inner_product_pd.hpp"
26#include "gpu/gpu_primitive.hpp"
27#include "gpu/gpu_resource.hpp"
28#include "gpu/ocl/ocl_stream.hpp"
29#include "gpu/ocl/ocl_utils.hpp"
30#include "gpu/primitive_conf.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace ocl {
36
37struct ref_inner_product_fwd_t : public gpu_primitive_t {
38 using gpu_primitive_t::gpu_primitive_t;
39 struct pd_t : public gpu_inner_product_fwd_pd_t {
40 pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr,
41 const inner_product_fwd_pd_t *hint_fwd_pd)
42 : gpu_inner_product_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
43
44 DECLARE_COMMON_PD_T("ocl:ref:any", ref_inner_product_fwd_t);
45
46 status_t init(engine_t *engine) {
47 using namespace data_type;
48 using namespace prop_kind;
49 using namespace data_type;
50 assert(engine->kind() == engine_kind::gpu);
51 auto *compute_engine
52 = utils::downcast<compute::compute_engine_t *>(engine);
53
54 const auto attr_skip_mask
55 = primitive_attr_t::skip_mask_t::scales_runtime
56 | primitive_attr_t::skip_mask_t::post_ops;
57
58 bool ok = true
59 && utils::one_of(desc()->prop_kind, forward_training,
60 forward_inference)
61 && set_default_params() == status::success
62 && utils::one_of(true,
63 expect_data_types(
64 u8, s8, data_type::undef, s8, s32),
65 expect_data_types(
66 u8, s8, data_type::undef, u8, s32),
67 expect_data_types(
68 u8, s8, data_type::undef, s32, s32),
69 expect_data_types(
70 s8, s8, data_type::undef, s8, s32),
71 expect_data_types(
72 s8, s8, data_type::undef, u8, s32),
73 expect_data_types(
74 s8, s8, data_type::undef, s32, s32),
75 expect_data_types(
76 bf16, bf16, data_type::undef, bf16, f32),
77 expect_data_types(
78 bf16, bf16, data_type::undef, f32, f32),
79 expect_data_types(f32, f32, f32, f32, f32),
80 expect_data_types(f16, f16, f16, f16, f32))
81 && IMPLICATION(with_bias(),
82 utils::one_of(desc()->bias_desc.data_type, u8, s8,
83 bf16, f16, f32))
84 && attr()->has_default_values(attr_skip_mask)
85 && post_ops_with_binary_ok(
86 attr(), desc()->dst_desc.data_type)
87 && attr_.set_default_formats(dst_md(0)) == status::success
88 && IMPLICATION(!attr()->scales_.has_default_values(),
89 utils::one_of(src_md_.data_type, s8, u8)
90 && arg_scales_ok())
91 && IMPLICATION(desc()->src_desc.data_type == f16,
92 compute_engine->mayiuse(
93 compute::device_ext_t::khr_fp16));
94 if (!ok) return status::unimplemented;
95
96 return init_conf(engine);
97 }
98
99 status_t init_conf(engine_t *engine);
100 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
101
102 inner_product_conf_t conf;
103 offsets_t off;
104
105 private:
106 bool arg_scales_ok() const {
107 std::vector<int> supported_args
108 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
109 if (!attr()->scales_.has_default_values(supported_args))
110 return false;
111 for (int arg : supported_args) {
112 auto &scales = attr()->scales_.get(arg);
113 if (scales.has_default_values()) continue;
114 int mask = scales.mask_;
115 if (arg == DNNL_ARG_WEIGHTS) {
116 if (!utils::one_of(mask, 0, 1 << 0)) return false;
117 } else {
118 if (mask != 0) return false;
119 }
120 }
121 return true;
122 }
123 };
124
125 status_t init(engine_t *engine) override {
126 compute::kernel_ctx_t kernel_ctx;
127 status_t status = pd()->init_kernel_ctx(kernel_ctx);
128 CHECK(status);
129
130 create_kernel(engine, &kernel_, "ref_inner_product_fwd", kernel_ctx);
131 if (!kernel_) return status::runtime_error;
132
133 return status::success;
134 }
135
136 status_t execute(const exec_ctx_t &ctx) const override {
137 return execute_forward(ctx);
138 }
139
140private:
141 status_t execute_forward(const exec_ctx_t &ctx) const;
142 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
143 compute::kernel_t kernel_;
144};
145
146struct ref_inner_product_bwd_data_t : public gpu_primitive_t {
147 using gpu_primitive_t::gpu_primitive_t;
148 struct pd_t : public gpu_inner_product_bwd_data_pd_t {
149 pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr,
150 const inner_product_fwd_pd_t *hint_fwd_pd)
151 : gpu_inner_product_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
152
153 DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t);
154
155 status_t init(engine_t *engine) {
156 using namespace data_type;
157 using namespace prop_kind;
158 assert(engine->kind() == engine_kind::gpu);
159
160 bool ok = true
161 && utils::one_of(
162 this->desc()->prop_kind, backward, backward_data)
163 && this->set_default_params() == status::success
164 && utils::one_of(true,
165 expect_data_types(
166 bf16, bf16, data_type::undef, bf16, f32),
167 expect_data_types(
168 f32, bf16, data_type::undef, bf16, f32),
169 expect_data_types(
170 f32, f32, data_type::undef, f32, f32))
171 && attr()->has_default_values();
172 if (!ok) return status::unimplemented;
173
174 return init_conf(engine);
175 }
176
177 status_t init_conf(engine_t *engine);
178 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
179
180 inner_product_conf_t conf;
181 offsets_t off;
182 };
183
184 status_t init(engine_t *engine) override {
185 compute::kernel_ctx_t kernel_ctx;
186 status_t status = pd()->init_kernel_ctx(kernel_ctx);
187 CHECK(status);
188
189 create_kernel(
190 engine, &kernel_, "ref_inner_product_bwd_data", kernel_ctx);
191 if (!kernel_) return status::runtime_error;
192
193 return status::success;
194 }
195
196 status_t execute(const exec_ctx_t &ctx) const override {
197 return execute_backward_data(ctx);
198 }
199
200private:
201 status_t execute_backward_data(const exec_ctx_t &ctx) const;
202 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
203 compute::kernel_t kernel_;
204};
205
206struct ref_inner_product_bwd_weights_t : public gpu_primitive_t {
207 using gpu_primitive_t::gpu_primitive_t;
208 struct pd_t : public gpu_inner_product_bwd_weights_pd_t {
209 pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr,
210 const inner_product_fwd_pd_t *hint_fwd_pd)
211 : gpu_inner_product_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
212
213 DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t);
214
215 status_t init(engine_t *engine) {
216 using namespace data_type;
217 using namespace prop_kind;
218 assert(engine->kind() == engine_kind::gpu);
219 bool ok = true
220 && utils::one_of(
221 this->desc()->prop_kind, backward, backward_weights)
222 && this->set_default_params() == status::success
223 && utils::one_of(true,
224 expect_data_types(bf16, bf16, bf16, bf16, f32),
225 expect_data_types(bf16, f32, f32, bf16, f32),
226 expect_data_types(f32, f32, f32, f32, f32))
227 && attr()->has_default_values();
228 if (!ok) return status::unimplemented;
229
230 return init_conf(engine);
231 }
232
233 status_t init_conf(engine_t *engine);
234 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
235
236 inner_product_conf_t conf;
237 offsets_t off;
238 };
239
240 status_t init(engine_t *engine) override {
241 compute::kernel_ctx_t kernel_ctx;
242 status_t status = pd()->init_kernel_ctx(kernel_ctx);
243 CHECK(status);
244
245 create_kernel(
246 engine, &kernel_, "ref_inner_product_bwd_weights", kernel_ctx);
247 if (!kernel_) return status::runtime_error;
248
249 return status::success;
250 }
251
252 status_t execute(const exec_ctx_t &ctx) const override {
253 return execute_backward_weights(ctx);
254 }
255
256private:
257 status_t execute_backward_weights(const exec_ctx_t &ctx) const;
258 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
259 compute::kernel_t kernel_;
260};
261
262} // namespace ocl
263} // namespace gpu
264} // namespace impl
265} // namespace dnnl
266
267#endif
268
269// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
270