1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_REF_POOLING_HPP |
18 | #define GPU_OCL_REF_POOLING_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "gpu/compute/compute.hpp" |
23 | #include "gpu/gpu_pooling_pd.hpp" |
24 | #include "gpu/gpu_primitive.hpp" |
25 | #include "gpu/gpu_resource.hpp" |
26 | #include "gpu/ocl/ocl_stream.hpp" |
27 | #include "gpu/ocl/ocl_utils.hpp" |
28 | #include "gpu/primitive_conf.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace gpu { |
33 | namespace ocl { |
34 | |
35 | struct ref_pooling_fwd_t : public gpu_primitive_t { |
36 | using gpu_primitive_t::gpu_primitive_t; |
37 | struct pd_t : public gpu_pooling_fwd_pd_t { |
38 | pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr, |
39 | const pooling_fwd_pd_t *hint_fwd_pd) |
40 | : gpu_pooling_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
41 | |
42 | DECLARE_COMMON_PD_T("ocl:ref" , ref_pooling_fwd_t); |
43 | |
44 | status_t init(engine_t *engine) { |
45 | using namespace data_type; |
46 | using namespace prop_kind; |
47 | using namespace alg_kind; |
48 | auto src_data_t = src_md()->data_type; |
49 | auto dst_data_t = dst_md()->data_type; |
50 | auto acc_data_t = desc()->accum_data_type; |
51 | |
52 | const auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops; |
53 | |
54 | bool ok = set_default_params() == status::success |
55 | && utils::one_of(desc()->prop_kind, forward_training, |
56 | forward_inference) |
57 | && utils::one_of(desc()->alg_kind, pooling_max, |
58 | pooling_avg_include_padding, |
59 | pooling_avg_exclude_padding) |
60 | && IMPLICATION(utils::one_of(src_data_t, f16, s8, u8, s32), |
61 | desc()->prop_kind == forward_inference) |
62 | && IMPLICATION(src_data_t != dst_data_t, |
63 | desc()->prop_kind == forward_inference) |
64 | && IMPLICATION(src_data_t == bf16, src_data_t == dst_data_t) |
65 | && IMPLICATION(utils::one_of(src_data_t, s8, u8), |
66 | utils::one_of(dst_data_t, s8, u8, f16, f32)) |
67 | && IMPLICATION(src_data_t == f16, |
68 | utils::one_of(dst_data_t, s8, u8, f16)) |
69 | && IMPLICATION(src_data_t == f32, |
70 | utils::one_of(dst_data_t, s8, u8, f32)) |
71 | && IMPLICATION(utils::one_of(f32, src_data_t, dst_data_t), |
72 | acc_data_t == f32) |
73 | && IMPLICATION(utils::one_of(src_data_t, s8, u8) |
74 | && dst_data_t != f32, |
75 | acc_data_t == s32) |
76 | && attr()->has_default_values(attr_skip_mask) |
77 | && post_ops_with_binary_ok(attr(), dst_md()->data_type, 5) |
78 | && attr_.set_default_formats(dst_md(0)) == status::success; |
79 | if (!ok) return status::unimplemented; |
80 | |
81 | bool is_training = desc_.prop_kind == forward_training; |
82 | if (desc()->alg_kind == pooling_max && is_training) |
83 | init_default_ws(s32); |
84 | |
85 | return init_conf(engine); |
86 | } |
87 | |
88 | status_t init_conf(engine_t *engine); |
89 | status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const; |
90 | |
91 | pool_conf_t conf; |
92 | offsets_t off; |
93 | }; |
94 | |
95 | status_t init(engine_t *engine) override { |
96 | compute::kernel_ctx_t kernel_ctx; |
97 | status_t status = pd()->init_kernel_ctx(kernel_ctx); |
98 | CHECK(status); |
99 | |
100 | create_kernel(engine, &kernel_, "ref_pooling_fwd" , kernel_ctx); |
101 | if (!kernel_) return status::runtime_error; |
102 | |
103 | return status::success; |
104 | } |
105 | |
106 | status_t execute(const exec_ctx_t &ctx) const override { |
107 | return execute_forward(ctx); |
108 | } |
109 | |
110 | private: |
111 | status_t execute_forward(const exec_ctx_t &ctx) const; |
112 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
113 | compute::kernel_t kernel_; |
114 | }; |
115 | |
116 | struct ref_pooling_bwd_t : public gpu_primitive_t { |
117 | using gpu_primitive_t::gpu_primitive_t; |
118 | struct pd_t : public gpu_pooling_bwd_pd_t { |
119 | pd_t(const pooling_desc_t *adesc, const primitive_attr_t *attr, |
120 | const pooling_fwd_pd_t *hint_fwd_pd) |
121 | : gpu_pooling_bwd_pd_t(adesc, attr, hint_fwd_pd) {} |
122 | |
123 | DECLARE_COMMON_PD_T("ocl:ref:any" , ref_pooling_bwd_t); |
124 | |
125 | status_t init(engine_t *engine) { |
126 | using namespace prop_kind; |
127 | using namespace alg_kind; |
128 | |
129 | bool ok = set_default_params() == status::success |
130 | && utils::one_of(desc()->prop_kind, backward_data) |
131 | && utils::one_of(desc()->alg_kind, pooling_max, |
132 | pooling_avg_include_padding, |
133 | pooling_avg_exclude_padding) |
134 | && (utils::everyone_is(data_type::f32, |
135 | diff_dst_md()->data_type, |
136 | diff_src_md()->data_type) |
137 | || utils::everyone_is(data_type::bf16, |
138 | diff_dst_md()->data_type, |
139 | diff_src_md()->data_type)) |
140 | && attr()->has_default_values(); |
141 | if (!ok) return status::unimplemented; |
142 | |
143 | if (desc()->alg_kind == pooling_max) { |
144 | init_default_ws(data_type::s32); |
145 | if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; |
146 | } |
147 | |
148 | return init_conf(engine); |
149 | } |
150 | |
151 | status_t init_conf(engine_t *engine); |
152 | status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const; |
153 | |
154 | pool_conf_t conf; |
155 | offsets_t off; |
156 | }; |
157 | |
158 | status_t init(engine_t *engine) override { |
159 | compute::kernel_ctx_t kernel_ctx; |
160 | status_t status = pd()->init_kernel_ctx(kernel_ctx); |
161 | CHECK(status); |
162 | |
163 | create_kernel(engine, &kernel_, "ref_pooling_bwd" , kernel_ctx); |
164 | if (!kernel_) return status::runtime_error; |
165 | |
166 | return status::success; |
167 | } |
168 | |
169 | status_t execute(const exec_ctx_t &ctx) const override { |
170 | return execute_backward(ctx); |
171 | } |
172 | |
173 | private: |
174 | status_t execute_backward(const exec_ctx_t &ctx) const; |
175 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
176 | compute::kernel_t kernel_; |
177 | }; |
178 | |
179 | } // namespace ocl |
180 | } // namespace gpu |
181 | } // namespace impl |
182 | } // namespace dnnl |
183 | |
184 | #endif |
185 | |
186 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
187 | |