1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CONV_IR_BUILDER_HPP |
18 | #define GPU_JIT_CONV_IR_BUILDER_HPP |
19 | |
20 | #include <array> |
21 | |
22 | #include "common/convolution_pd.hpp" |
23 | #include "gpu/jit/ir/gemm_schedule.hpp" |
24 | #include "gpu/jit/ir/ir.hpp" |
25 | #include "gpu/jit/ir/ir_builder.hpp" |
26 | #include "gpu/jit/ir/tensor.hpp" |
27 | |
28 | #include "gpu/jit/conv/config.hpp" |
29 | #include "gpu/jit/conv/post_ops.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace gpu { |
34 | namespace jit { |
35 | |
36 | class conv_ir_builder_t : public ir_builder_t { |
37 | public: |
38 | conv_ir_builder_t( |
39 | const conv_config_t &cfg, const kernel_info_t &kernel_info) |
40 | : ir_builder_t(kernel_info), prb_(cfg.prb()), cfg_(cfg) { |
41 | build(); |
42 | } |
43 | |
44 | private: |
45 | void build() override; |
46 | void init_fwd(gemm_schedule_t &gemm_schedule, view_t &src_view, |
47 | view_t &wei_view, view_t &dst_view, expr_t &src_buf, |
48 | expr_t &wei_buf, expr_t &dst_buf); |
49 | void init_bwd_d(gemm_schedule_t &gemm_schedule, view_t &dst_view, |
50 | view_t &wei_view, view_t &src_view, expr_t &dst_buf, |
51 | expr_t &wei_buf, expr_t &src_buf); |
52 | void init_bwd_w(gemm_schedule_t &gemm_schedule, view_t &src_view, |
53 | view_t &dst_view, view_t &wei_view, view_t &bia_view, |
54 | expr_t &src_buf, expr_t &dst_buf, expr_t &wei_buf, expr_t &bia_buf, |
55 | expr_t &bia_reduction_condition); |
56 | |
57 | const conv_problem_t &prb_; |
58 | const conv_config_t &cfg_; |
59 | }; |
60 | |
61 | } // namespace jit |
62 | } // namespace gpu |
63 | } // namespace impl |
64 | } // namespace dnnl |
65 | |
66 | #endif |
67 | |