1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CONV_GEN_CONVOLUTION_HPP
18#define GPU_JIT_CONV_GEN_CONVOLUTION_HPP
19
20#include <memory>
21
22#include "common/c_types_map.hpp"
23#include "gpu/compute/compute.hpp"
24#include "gpu/gpu_convolution_pd.hpp"
25#include "gpu/gpu_primitive.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30namespace jit {
31
32class gen_convolution_t;
33struct conv_pd_data_t;
34
35class gen_convolution_fwd_t : public gpu_primitive_t {
36public:
37 friend gen_convolution_t;
38
39 struct pd_t : public gpu_convolution_fwd_pd_t {
40 friend gen_convolution_t;
41
42 using gpu_convolution_fwd_pd_t::gpu_convolution_fwd_pd_t;
43
44 DECLARE_COMMON_PD_T("jit:ir", gen_convolution_fwd_t);
45
46 status_t init(engine_t *engine);
47
48 std::shared_ptr<conv_pd_data_t> data;
49 };
50
51 using gpu_primitive_t::gpu_primitive_t;
52
53 status_t init(engine_t *engine) override;
54 status_t execute(const exec_ctx_t &ctx) const override;
55
56private:
57 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
58
59 status_t init_res_storage(
60 engine_t *engine, gpu_resource_t *r) const override;
61
62 std::shared_ptr<gen_convolution_t> impl_;
63};
64
65class gen_convolution_bwd_data_t : public gpu_primitive_t {
66public:
67 friend gen_convolution_t;
68
69 struct pd_t : public gpu_convolution_bwd_data_pd_t {
70 friend gen_convolution_t;
71
72 using gpu_convolution_bwd_data_pd_t::gpu_convolution_bwd_data_pd_t;
73
74 DECLARE_COMMON_PD_T("jit:ir", gen_convolution_bwd_data_t);
75
76 status_t init(engine_t *engine);
77
78 std::shared_ptr<conv_pd_data_t> data;
79 };
80
81 using gpu_primitive_t::gpu_primitive_t;
82
83 status_t init(engine_t *engine) override;
84 status_t execute(const exec_ctx_t &ctx) const override;
85
86private:
87 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
88
89 status_t init_res_storage(
90 engine_t *engine, gpu_resource_t *r) const override;
91
92 std::shared_ptr<gen_convolution_t> impl_;
93};
94
95class gen_convolution_bwd_weights_t : public gpu_primitive_t {
96public:
97 friend gen_convolution_t;
98
99 struct pd_t : public gpu_convolution_bwd_weights_pd_t {
100 friend gen_convolution_t;
101
102 using gpu_convolution_bwd_weights_pd_t::
103 gpu_convolution_bwd_weights_pd_t;
104
105 DECLARE_COMMON_PD_T("jit:ir", gen_convolution_bwd_weights_t);
106
107 status_t init(engine_t *engine);
108
109 std::shared_ptr<conv_pd_data_t> data;
110 };
111
112 using gpu_primitive_t::gpu_primitive_t;
113
114 status_t init(engine_t *engine) override;
115 status_t execute(const exec_ctx_t &ctx) const override;
116
117private:
118 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
119
120 status_t init_res_storage(
121 engine_t *engine, gpu_resource_t *r) const override;
122
123 std::shared_ptr<gen_convolution_t> impl_;
124};
125
126} // namespace jit
127} // namespace gpu
128} // namespace impl
129} // namespace dnnl
130
131#endif
132