1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEMM_POST_OPS_INNER_PRODUCT_HPP
18#define GPU_OCL_GEMM_POST_OPS_INNER_PRODUCT_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/gemm_types.hpp"
24#include "common/gemm_utils.hpp"
25#include "common/primitive.hpp"
26#include "common/primitive_desc_iterator.hpp"
27#include "gpu/compute/compute.hpp"
28#include "gpu/gemm/gpu_gemm.hpp"
29#include "gpu/gpu_inner_product_pd.hpp"
30#include "gpu/gpu_primitive.hpp"
31#include "gpu/gpu_resource.hpp"
32#include "gpu/ocl/ocl_utils.hpp"
33#include "gpu/primitive_conf.hpp"
34
35namespace dnnl {
36namespace impl {
37namespace gpu {
38namespace ocl {
39
40struct gemm_post_ops_inner_product_fwd_t : public gpu_primitive_t {
41 using gpu_primitive_t::gpu_primitive_t;
42 struct pd_t : public gpu_inner_product_fwd_pd_t {
43 pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr,
44 const inner_product_fwd_pd_t *hint_fwd_pd)
45 : gpu_inner_product_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
46
47 pd_t(const pd_t &rhs) = default;
48
49 DECLARE_COMMON_PD_T(
50 "ocl:gemm_post_ops_fwd", gemm_post_ops_inner_product_fwd_t);
51
52 status_t init(engine_t *engine) {
53 using namespace status;
54 using namespace utils;
55 using namespace data_type;
56 using namespace primitive_kind;
57 assert(engine->kind() == engine_kind::gpu);
58
59 const primitive_attr_t::skip_mask_t attr_skip_mask
60 = primitive_attr_t::skip_mask_t::oscale_runtime
61 | primitive_attr_t::skip_mask_t::post_ops;
62
63 bool ok = is_fwd() && set_default_params() == success
64 && dense_consistency_check(src_md(), weights_md(), dst_md())
65 && dense_gemm_consistency_check(
66 src_md(), weights_md(), dst_md())
67 && attr()->has_default_values(attr_skip_mask)
68 && post_ops_with_binary_ok(attr(), dst_md()->data_type)
69 && attr_.set_default_formats(dst_md(0)) == status::success
70 && IMPLICATION(!attr()->output_scales_.has_default_values(),
71 one_of(attr()->output_scales_.mask_, 0, 1 << 1));
72 if (!ok) return unimplemented;
73
74 attr_info_ = attr_info_t::create(attr());
75
76 // XXX: Empty attributes increase chances of creating a gemm
77 // primitive. Ideally gemm should be created multiple times with
78 // different attr combinations, but this mechanism might be tricky.
79 // Current implementation computes attr - related things in the post
80 // process kernel.
81 primitive_attr_t gemm_attr;
82 is_int8_ = weights_md()->data_type == s8;
83
84 memory_desc_t a_md, b_md, c_md;
85 init_2d_desc(&a_md, src_md());
86 init_2d_desc(&b_md, weights_md(), true);
87 init_2d_desc(&c_md, dst_md());
88 c_md.data_type = desc()->accum_data_type;
89 bool gemm_ok = status::success
90 == create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md,
91 &glob_zero_md, desc()->accum_data_type, &gemm_attr,
92 true);
93 if (!gemm_ok) return status::unimplemented;
94
95 status_t scratchpad_status = init_ip_scratchpad_md();
96 if (scratchpad_status != success) return scratchpad_status;
97 init_scratchpad();
98
99 return success;
100 }
101
102 bool with_post_process() const {
103 return use_scratchpad() || with_bias() || attr_info_.with_oscales
104 || attr_info_.with_eltwise || attr_info_.with_binary
105 || attr_info_.with_sum;
106 }
107 bool use_scratchpad() const { return use_temp_dst(); }
108
109 bool use_temp_dst() const {
110 using namespace data_type;
111 return (is_int8_ && !utils::one_of(dst_md()->data_type, s32, f32))
112 || attr_info_.with_sum
113 || desc()->accum_data_type != dst_md()->data_type;
114 }
115 const memory_desc_t *ip_scratchpad_md() const {
116 return &ip_scratchpad_md_;
117 }
118
119 status_t init_ip_scratchpad_md() {
120 if (use_scratchpad()) {
121 ip_scratchpad_md_.data_type = desc()->accum_data_type;
122 ip_scratchpad_md_.ndims = 1;
123 ip_scratchpad_md_.dims[0] = 0;
124
125 if (use_temp_dst()) {
126 const size_t temp_dst_size = MB() * OC();
127 ip_scratchpad_md_.dims[0] += temp_dst_size;
128 }
129 return memory_desc_init_by_tag(
130 ip_scratchpad_md_, format_tag::x);
131 }
132
133 return status::success;
134 }
135
136 std::shared_ptr<primitive_desc_t> gemm_pd_;
137
138 memory_desc_t ip_scratchpad_md_;
139 bool is_int8_ = false;
140 attr_info_t attr_info_ = {};
141
142 private:
143 void init_scratchpad() {
144 auto scratchpad = scratchpad_registry().registrar();
145
146 if (use_scratchpad()) {
147 memory_desc_wrapper scratchpad_mdw(ip_scratchpad_md());
148 size_t sz = scratchpad_mdw.size();
149 scratchpad.book(
150 memory_tracking::names::key_iprod_int_dat_in_acc_dt, sz,
151 1, OCL_BUFFER_ALIGNMENT);
152 }
153
154 scratchpad.book(memory_tracking::names::key_nested,
155 gemm_pd_->scratchpad_registry());
156 }
157 };
158
159 status_t init(engine_t *engine) override {
160 CHECK(create_nested_primitive(gemm_, pd()->gemm_pd_, engine));
161
162 const size_t mb = pd()->MB();
163 const size_t oc = pd()->OC();
164
165 // Prepare post process kernel
166 if (pd()->with_post_process()) {
167 compute::kernel_ctx_t kernel_ctx;
168
169 kernel_ctx.define_int("MB", mb);
170 kernel_ctx.define_int("OC", oc);
171 bool int8 = pd()->is_int8_;
172 kernel_ctx.set_data_type(
173 int8 ? data_type::f32 : pd()->dst_md()->data_type);
174 //here SRC is output tensor of gemm call
175 def_data_type(kernel_ctx, pd()->desc()->accum_data_type, "SRC");
176 def_data_type(kernel_ctx,
177 int8 ? data_type::f32 : pd()->desc()->accum_data_type,
178 "ACC");
179 def_data_type(kernel_ctx,
180 pd()->with_bias()
181 ? pd()->weights_md(1)->data_type
182 : int8 ? data_type::f32 : pd()->dst_md()->data_type,
183 "BIAS");
184 def_data_type(kernel_ctx, pd()->desc()->accum_data_type, "SPAD");
185 def_data_type(kernel_ctx, pd()->dst_md()->data_type, "DST");
186
187 kernel_ctx.define_int("USE_TEMP_DST", pd()->use_temp_dst());
188
189 kernel_ctx.define_int("WITH_BIAS", pd()->with_bias());
190
191 def_attr_info(
192 kernel_ctx, pd()->attr_info_, pd()->attr()->post_ops_);
193
194 create_kernel(engine, &post_process_kernel_,
195 "gemm_post_ops_inner_product", kernel_ctx);
196 if (!post_process_kernel_) return status::runtime_error;
197 }
198
199 return status::success;
200 }
201
202 status_t execute(const exec_ctx_t &ctx) const override {
203 return execute_forward(ctx);
204 }
205
206private:
207 status_t execute_forward(const exec_ctx_t &ctx) const;
208 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
209
210 std::shared_ptr<primitive_t> gemm_;
211 compute::kernel_t post_process_kernel_;
212};
213
214} // namespace ocl
215} // namespace gpu
216} // namespace impl
217} // namespace dnnl
218
219#endif
220