1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEN9_WINO_CONVOLUTION_HPP
18#define GPU_OCL_GEN9_WINO_CONVOLUTION_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "gpu/compute/compute.hpp"
25#include "gpu/gpu_convolution_pd.hpp"
26#include "gpu/gpu_eltwise_pd.hpp"
27#include "gpu/gpu_primitive.hpp"
28#include "gpu/gpu_resource.hpp"
29#include "gpu/ocl/ocl_stream.hpp"
30#include "gpu/ocl/ocl_utils.hpp"
31#include "gpu/primitive_conf.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace gpu {
36namespace ocl {
37
38struct gen9_wino_convolution_fwd_t : public gpu_primitive_t {
39 using gpu_primitive_t::gpu_primitive_t;
40 struct pd_t : public gpu_convolution_fwd_pd_t {
41 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
42 const convolution_fwd_pd_t *hint_fwd_pd)
43 : gpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
44
45 DECLARE_COMMON_PD_T("ocl:gen9:wino", gen9_wino_convolution_fwd_t);
46
47 status_t init(engine_t *engine) {
48 using namespace prop_kind;
49 using namespace data_type;
50 assert(engine->kind() == engine_kind::gpu);
51 auto *compute_engine
52 = utils::downcast<compute::compute_engine_t *>(engine);
53
54 auto src_data_t = this->desc()->src_desc.data_type;
55 auto dst_data_t = this->desc()->dst_desc.data_type;
56
57 const auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops;
58
59 bool ok = utils::one_of(this->desc()->prop_kind, forward_training,
60 forward_inference)
61 && (this->desc()->alg_kind == alg_kind::convolution_winograd
62 || this->desc()->alg_kind
63 == alg_kind::convolution_auto)
64 && utils::one_of(true,
65 expect_data_types(f32, f32, f32, f32, f32),
66 expect_data_types(f16, f16, f16, f16, f32))
67 && compute_engine->mayiuse(
68 compute::device_ext_t::intel_subgroups)
69 && IMPLICATION(src_data_t == f16,
70 true
71 && compute_engine->mayiuse(
72 compute::device_ext_t::khr_fp16)
73 && compute_engine->mayiuse(
74 compute::device_ext_t::
75 intel_subgroups_short))
76 && !has_zero_dim_memory()
77 && attr()->has_default_values(attr_skip_mask, dst_data_t)
78 && post_ops_with_binary_ok(attr(), dst_data_t);
79 if (!ok) return status::unimplemented;
80
81 CHECK(init_conf(compute_engine));
82
83 int sub_group_size = conf.wino_ic_block / 2; // LWX
84 if (!compute_engine->mayiuse_sub_group(sub_group_size))
85 return status::unimplemented;
86
87 init_scratchpad();
88
89 ok = set_default_formats_common(
90 conf.src_tag, conf.wei_tag, conf.dst_tag);
91 if (!ok) return status::unimplemented;
92
93 CHECK(attr_.set_default_formats(dst_md(0)));
94
95 return status::success;
96 }
97
98 status_t init_conf(compute::compute_engine_t *engine);
99 void init_scratchpad();
100 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
101
102 conv_conf_t conf;
103 };
104
105 status_t init(engine_t *engine) override {
106 bool is_fused = pd()->conf.is_fused;
107 bool is_nonfused_2x3 = pd()->conf.wino_m == 2 && !is_fused;
108
109 std::vector<const char *> kernel_names;
110 if (is_fused) {
111 kernel_names.push_back("gen9_wino_conv_fwd");
112 kernel_names.push_back("gen9_wino_wei_transform");
113 } else if (is_nonfused_2x3) {
114 kernel_names.push_back("gen9_wino_conv_fwd_2x3");
115 kernel_names.push_back("gen9_wino_wei_transform_2x3");
116 kernel_names.push_back("gen9_wino_src_transform_2x3");
117 kernel_names.push_back("gen9_wino_dst_transform_2x3");
118 } else {
119 assert(!"Invalid Winograd version chosen by init_conf");
120 return status::unimplemented;
121 }
122
123 compute::kernel_ctx_t kernel_ctx;
124 status_t status = pd()->init_kernel_ctx(kernel_ctx);
125 if (status != status::success) return status;
126
127 std::vector<compute::kernel_t> kernels;
128 CHECK(create_kernels(engine, &kernels, kernel_names, kernel_ctx));
129 kernel_ = kernels[0];
130 wei_trans_kernel_ = kernels[1];
131 if (!kernel_ || !wei_trans_kernel_) return status::runtime_error;
132 if (!is_fused) {
133 src_trans_kernel_ = kernels[2];
134 dst_trans_kernel_ = kernels[3];
135 if (!src_trans_kernel_ || !dst_trans_kernel_)
136 return status::runtime_error;
137 }
138
139 return status::success;
140 }
141
142 status_t execute(const exec_ctx_t &ctx) const override {
143 return execute_forward(ctx);
144 }
145
146private:
147 status_t execute_forward(const exec_ctx_t &ctx) const;
148 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
149 compute::kernel_t kernel_;
150 compute::kernel_t wei_trans_kernel_;
151 compute::kernel_t src_trans_kernel_;
152 compute::kernel_t dst_trans_kernel_;
153};
154
155} // namespace ocl
156} // namespace gpu
157} // namespace impl
158} // namespace dnnl
159
160#endif
161
162// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
163