1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_GEMM_POST_OPS_INNER_PRODUCT_HPP |
18 | #define GPU_OCL_GEMM_POST_OPS_INNER_PRODUCT_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/gemm_types.hpp" |
24 | #include "common/gemm_utils.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/primitive_desc_iterator.hpp" |
27 | #include "gpu/compute/compute.hpp" |
28 | #include "gpu/gemm/gpu_gemm.hpp" |
29 | #include "gpu/gpu_inner_product_pd.hpp" |
30 | #include "gpu/gpu_primitive.hpp" |
31 | #include "gpu/gpu_resource.hpp" |
32 | #include "gpu/ocl/ocl_utils.hpp" |
33 | #include "gpu/primitive_conf.hpp" |
34 | |
35 | namespace dnnl { |
36 | namespace impl { |
37 | namespace gpu { |
38 | namespace ocl { |
39 | |
40 | struct gemm_post_ops_inner_product_fwd_t : public gpu_primitive_t { |
41 | using gpu_primitive_t::gpu_primitive_t; |
42 | struct pd_t : public gpu_inner_product_fwd_pd_t { |
43 | pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr, |
44 | const inner_product_fwd_pd_t *hint_fwd_pd) |
45 | : gpu_inner_product_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
46 | |
47 | pd_t(const pd_t &rhs) = default; |
48 | |
49 | DECLARE_COMMON_PD_T( |
50 | "ocl:gemm_post_ops_fwd" , gemm_post_ops_inner_product_fwd_t); |
51 | |
52 | status_t init(engine_t *engine) { |
53 | using namespace status; |
54 | using namespace utils; |
55 | using namespace data_type; |
56 | using namespace primitive_kind; |
57 | assert(engine->kind() == engine_kind::gpu); |
58 | |
59 | const primitive_attr_t::skip_mask_t attr_skip_mask |
60 | = primitive_attr_t::skip_mask_t::oscale_runtime |
61 | | primitive_attr_t::skip_mask_t::post_ops; |
62 | |
63 | bool ok = is_fwd() && set_default_params() == success |
64 | && dense_consistency_check(src_md(), weights_md(), dst_md()) |
65 | && dense_gemm_consistency_check( |
66 | src_md(), weights_md(), dst_md()) |
67 | && attr()->has_default_values(attr_skip_mask) |
68 | && post_ops_with_binary_ok(attr(), dst_md()->data_type) |
69 | && attr_.set_default_formats(dst_md(0)) == status::success |
70 | && IMPLICATION(!attr()->output_scales_.has_default_values(), |
71 | one_of(attr()->output_scales_.mask_, 0, 1 << 1)); |
72 | if (!ok) return unimplemented; |
73 | |
74 | attr_info_ = attr_info_t::create(attr()); |
75 | |
76 | // XXX: Empty attributes increase chances of creating a gemm |
77 | // primitive. Ideally gemm should be created multiple times with |
78 | // different attr combinations, but this mechanism might be tricky. |
79 | // Current implementation computes attr - related things in the post |
80 | // process kernel. |
81 | primitive_attr_t gemm_attr; |
82 | is_int8_ = weights_md()->data_type == s8; |
83 | |
84 | memory_desc_t a_md, b_md, c_md; |
85 | init_2d_desc(&a_md, src_md()); |
86 | init_2d_desc(&b_md, weights_md(), true); |
87 | init_2d_desc(&c_md, dst_md()); |
88 | c_md.data_type = desc()->accum_data_type; |
89 | bool gemm_ok = status::success |
90 | == create_gemm_pd(gemm_pd_, engine, &a_md, &b_md, &c_md, |
91 | &glob_zero_md, desc()->accum_data_type, &gemm_attr, |
92 | true); |
93 | if (!gemm_ok) return status::unimplemented; |
94 | |
95 | status_t scratchpad_status = init_ip_scratchpad_md(); |
96 | if (scratchpad_status != success) return scratchpad_status; |
97 | init_scratchpad(); |
98 | |
99 | return success; |
100 | } |
101 | |
102 | bool with_post_process() const { |
103 | return use_scratchpad() || with_bias() || attr_info_.with_oscales |
104 | || attr_info_.with_eltwise || attr_info_.with_binary |
105 | || attr_info_.with_sum; |
106 | } |
107 | bool use_scratchpad() const { return use_temp_dst(); } |
108 | |
109 | bool use_temp_dst() const { |
110 | using namespace data_type; |
111 | return (is_int8_ && !utils::one_of(dst_md()->data_type, s32, f32)) |
112 | || attr_info_.with_sum |
113 | || desc()->accum_data_type != dst_md()->data_type; |
114 | } |
115 | const memory_desc_t *ip_scratchpad_md() const { |
116 | return &ip_scratchpad_md_; |
117 | } |
118 | |
119 | status_t init_ip_scratchpad_md() { |
120 | if (use_scratchpad()) { |
121 | ip_scratchpad_md_.data_type = desc()->accum_data_type; |
122 | ip_scratchpad_md_.ndims = 1; |
123 | ip_scratchpad_md_.dims[0] = 0; |
124 | |
125 | if (use_temp_dst()) { |
126 | const size_t temp_dst_size = MB() * OC(); |
127 | ip_scratchpad_md_.dims[0] += temp_dst_size; |
128 | } |
129 | return memory_desc_init_by_tag( |
130 | ip_scratchpad_md_, format_tag::x); |
131 | } |
132 | |
133 | return status::success; |
134 | } |
135 | |
136 | std::shared_ptr<primitive_desc_t> gemm_pd_; |
137 | |
138 | memory_desc_t ip_scratchpad_md_; |
139 | bool is_int8_ = false; |
140 | attr_info_t attr_info_ = {}; |
141 | |
142 | private: |
143 | void init_scratchpad() { |
144 | auto scratchpad = scratchpad_registry().registrar(); |
145 | |
146 | if (use_scratchpad()) { |
147 | memory_desc_wrapper scratchpad_mdw(ip_scratchpad_md()); |
148 | size_t sz = scratchpad_mdw.size(); |
149 | scratchpad.book( |
150 | memory_tracking::names::key_iprod_int_dat_in_acc_dt, sz, |
151 | 1, OCL_BUFFER_ALIGNMENT); |
152 | } |
153 | |
154 | scratchpad.book(memory_tracking::names::key_nested, |
155 | gemm_pd_->scratchpad_registry()); |
156 | } |
157 | }; |
158 | |
159 | status_t init(engine_t *engine) override { |
160 | CHECK(create_nested_primitive(gemm_, pd()->gemm_pd_, engine)); |
161 | |
162 | const size_t mb = pd()->MB(); |
163 | const size_t oc = pd()->OC(); |
164 | |
165 | // Prepare post process kernel |
166 | if (pd()->with_post_process()) { |
167 | compute::kernel_ctx_t kernel_ctx; |
168 | |
169 | kernel_ctx.define_int("MB" , mb); |
170 | kernel_ctx.define_int("OC" , oc); |
171 | bool int8 = pd()->is_int8_; |
172 | kernel_ctx.set_data_type( |
173 | int8 ? data_type::f32 : pd()->dst_md()->data_type); |
174 | //here SRC is output tensor of gemm call |
175 | def_data_type(kernel_ctx, pd()->desc()->accum_data_type, "SRC" ); |
176 | def_data_type(kernel_ctx, |
177 | int8 ? data_type::f32 : pd()->desc()->accum_data_type, |
178 | "ACC" ); |
179 | def_data_type(kernel_ctx, |
180 | pd()->with_bias() |
181 | ? pd()->weights_md(1)->data_type |
182 | : int8 ? data_type::f32 : pd()->dst_md()->data_type, |
183 | "BIAS" ); |
184 | def_data_type(kernel_ctx, pd()->desc()->accum_data_type, "SPAD" ); |
185 | def_data_type(kernel_ctx, pd()->dst_md()->data_type, "DST" ); |
186 | |
187 | kernel_ctx.define_int("USE_TEMP_DST" , pd()->use_temp_dst()); |
188 | |
189 | kernel_ctx.define_int("WITH_BIAS" , pd()->with_bias()); |
190 | |
191 | def_attr_info( |
192 | kernel_ctx, pd()->attr_info_, pd()->attr()->post_ops_); |
193 | |
194 | create_kernel(engine, &post_process_kernel_, |
195 | "gemm_post_ops_inner_product" , kernel_ctx); |
196 | if (!post_process_kernel_) return status::runtime_error; |
197 | } |
198 | |
199 | return status::success; |
200 | } |
201 | |
202 | status_t execute(const exec_ctx_t &ctx) const override { |
203 | return execute_forward(ctx); |
204 | } |
205 | |
206 | private: |
207 | status_t execute_forward(const exec_ctx_t &ctx) const; |
208 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
209 | |
210 | std::shared_ptr<primitive_t> gemm_; |
211 | compute::kernel_t post_process_kernel_; |
212 | }; |
213 | |
214 | } // namespace ocl |
215 | } // namespace gpu |
216 | } // namespace impl |
217 | } // namespace dnnl |
218 | |
219 | #endif |
220 | |