1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CONV_GEN_CONVOLUTION_HPP |
18 | #define GPU_JIT_CONV_GEN_CONVOLUTION_HPP |
19 | |
20 | #include <memory> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "gpu/compute/compute.hpp" |
24 | #include "gpu/gpu_convolution_pd.hpp" |
25 | #include "gpu/gpu_primitive.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace gpu { |
30 | namespace jit { |
31 | |
32 | class gen_convolution_t; |
33 | struct conv_pd_data_t; |
34 | |
35 | class gen_convolution_fwd_t : public gpu_primitive_t { |
36 | public: |
37 | friend gen_convolution_t; |
38 | |
39 | struct pd_t : public gpu_convolution_fwd_pd_t { |
40 | friend gen_convolution_t; |
41 | |
42 | using gpu_convolution_fwd_pd_t::gpu_convolution_fwd_pd_t; |
43 | |
44 | DECLARE_COMMON_PD_T("jit:ir" , gen_convolution_fwd_t); |
45 | |
46 | status_t init(engine_t *engine); |
47 | |
48 | std::shared_ptr<conv_pd_data_t> data; |
49 | }; |
50 | |
51 | using gpu_primitive_t::gpu_primitive_t; |
52 | |
53 | status_t init(engine_t *engine) override; |
54 | status_t execute(const exec_ctx_t &ctx) const override; |
55 | |
56 | private: |
57 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
58 | |
59 | status_t init_res_storage( |
60 | engine_t *engine, gpu_resource_t *r) const override; |
61 | |
62 | std::shared_ptr<gen_convolution_t> impl_; |
63 | }; |
64 | |
65 | class gen_convolution_bwd_data_t : public gpu_primitive_t { |
66 | public: |
67 | friend gen_convolution_t; |
68 | |
69 | struct pd_t : public gpu_convolution_bwd_data_pd_t { |
70 | friend gen_convolution_t; |
71 | |
72 | using gpu_convolution_bwd_data_pd_t::gpu_convolution_bwd_data_pd_t; |
73 | |
74 | DECLARE_COMMON_PD_T("jit:ir" , gen_convolution_bwd_data_t); |
75 | |
76 | status_t init(engine_t *engine); |
77 | |
78 | std::shared_ptr<conv_pd_data_t> data; |
79 | }; |
80 | |
81 | using gpu_primitive_t::gpu_primitive_t; |
82 | |
83 | status_t init(engine_t *engine) override; |
84 | status_t execute(const exec_ctx_t &ctx) const override; |
85 | |
86 | private: |
87 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
88 | |
89 | status_t init_res_storage( |
90 | engine_t *engine, gpu_resource_t *r) const override; |
91 | |
92 | std::shared_ptr<gen_convolution_t> impl_; |
93 | }; |
94 | |
95 | class gen_convolution_bwd_weights_t : public gpu_primitive_t { |
96 | public: |
97 | friend gen_convolution_t; |
98 | |
99 | struct pd_t : public gpu_convolution_bwd_weights_pd_t { |
100 | friend gen_convolution_t; |
101 | |
102 | using gpu_convolution_bwd_weights_pd_t:: |
103 | gpu_convolution_bwd_weights_pd_t; |
104 | |
105 | DECLARE_COMMON_PD_T("jit:ir" , gen_convolution_bwd_weights_t); |
106 | |
107 | status_t init(engine_t *engine); |
108 | |
109 | std::shared_ptr<conv_pd_data_t> data; |
110 | }; |
111 | |
112 | using gpu_primitive_t::gpu_primitive_t; |
113 | |
114 | status_t init(engine_t *engine) override; |
115 | status_t execute(const exec_ctx_t &ctx) const override; |
116 | |
117 | private: |
118 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
119 | |
120 | status_t init_res_storage( |
121 | engine_t *engine, gpu_resource_t *r) const override; |
122 | |
123 | std::shared_ptr<gen_convolution_t> impl_; |
124 | }; |
125 | |
126 | } // namespace jit |
127 | } // namespace gpu |
128 | } // namespace impl |
129 | } // namespace dnnl |
130 | |
131 | #endif |
132 | |