1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/op/nn/convolution_make.h |
22 | * \brief utilities for creating convolution ops |
23 | */ |
24 | #ifndef TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_ |
25 | #define TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_ |
26 | |
27 | #include <tvm/relay/attrs/nn.h> |
28 | #include <tvm/relay/op.h> |
29 | |
30 | #include <string> |
31 | #include <utility> |
32 | #include <vector> |
33 | |
34 | namespace tvm { |
35 | namespace relay { |
36 | |
37 | template <typename T> |
38 | inline Expr MakeConv(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
39 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
40 | Array<IndexExpr> kernel_size, std::string data_layout, |
41 | std::string kernel_layout, std::string out_layout, DataType out_dtype, |
42 | std::string op_name) { |
43 | auto attrs = make_object<T>(); |
44 | attrs->strides = std::move(strides); |
45 | attrs->padding = std::move(padding); |
46 | attrs->dilation = std::move(dilation); |
47 | attrs->groups = groups; |
48 | attrs->channels = std::move(channels); |
49 | attrs->kernel_size = std::move(kernel_size); |
50 | attrs->data_layout = std::move(data_layout); |
51 | attrs->kernel_layout = std::move(kernel_layout); |
52 | attrs->out_layout = std::move(out_layout); |
53 | attrs->out_dtype = std::move(out_dtype); |
54 | const Op& op = Op::Get(op_name); |
55 | return Call(op, {data, weight}, Attrs(attrs), {}); |
56 | } |
57 | |
58 | template <typename T> |
59 | inline Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array<IndexExpr> strides, |
60 | Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups, |
61 | IndexExpr channels, Array<IndexExpr> kernel_size, |
62 | std::string data_layout, std::string kernel_layout, |
63 | std::string out_layout, DataType out_dtype, std::string op_name) { |
64 | auto attrs = make_object<T>(); |
65 | attrs->tile_size = tile_size; |
66 | attrs->strides = std::move(strides); |
67 | attrs->padding = std::move(padding); |
68 | attrs->dilation = std::move(dilation); |
69 | attrs->groups = groups; |
70 | attrs->channels = std::move(channels); |
71 | attrs->kernel_size = std::move(kernel_size); |
72 | attrs->data_layout = std::move(data_layout); |
73 | attrs->kernel_layout = std::move(kernel_layout); |
74 | attrs->out_layout = std::move(out_layout); |
75 | attrs->out_dtype = std::move(out_dtype); |
76 | const Op& op = Op::Get(op_name); |
77 | return Call(op, {data, weight}, Attrs(attrs), {}); |
78 | } |
79 | |
80 | template <typename T> |
81 | inline Expr MakeConvGemm(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
82 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
83 | Array<IndexExpr> kernel_size, std::string data_layout, |
84 | std::string kernel_layout, std::string out_layout, DataType out_dtype, |
85 | std::string op_name) { |
86 | auto attrs = make_object<T>(); |
87 | attrs->strides = std::move(strides); |
88 | attrs->padding = std::move(padding); |
89 | attrs->dilation = std::move(dilation); |
90 | attrs->groups = groups; |
91 | attrs->channels = std::move(channels); |
92 | attrs->kernel_size = std::move(kernel_size); |
93 | attrs->data_layout = std::move(data_layout); |
94 | attrs->kernel_layout = std::move(kernel_layout); |
95 | attrs->out_layout = std::move(out_layout); |
96 | attrs->out_dtype = std::move(out_dtype); |
97 | const Op& op = Op::Get(op_name); |
98 | return Call(op, {data, weight}, Attrs(attrs), {}); |
99 | } |
100 | |
101 | template <typename T> |
102 | inline Expr MakeConvTranspose(Expr data, Expr weight, Array<IndexExpr> strides, |
103 | Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups, |
104 | IndexExpr channels, Array<IndexExpr> kernel_size, |
105 | std::string data_layout, std::string kernel_layout, |
106 | std::string out_layout, Array<IndexExpr> output_padding, |
107 | DataType out_dtype, std::string op_name) { |
108 | auto attrs = make_object<T>(); |
109 | attrs->strides = std::move(strides); |
110 | attrs->padding = std::move(padding); |
111 | attrs->dilation = std::move(dilation); |
112 | attrs->groups = groups; |
113 | attrs->channels = std::move(channels); |
114 | attrs->kernel_size = std::move(kernel_size); |
115 | attrs->data_layout = std::move(data_layout); |
116 | attrs->kernel_layout = std::move(kernel_layout); |
117 | attrs->out_layout = std::move(out_layout); |
118 | attrs->output_padding = std::move(output_padding); |
119 | attrs->out_dtype = std::move(out_dtype); |
120 | const Op& op = Op::Get(op_name); |
121 | return Call(op, {data, weight}, Attrs(attrs), {}); |
122 | } |
123 | |
124 | template <typename T> |
125 | inline Expr MakeDeformableConv(Expr data, Expr offset, Expr weight, Array<IndexExpr> strides, |
126 | Array<IndexExpr> padding, Array<IndexExpr> dilation, |
127 | int deformable_groups, int groups, int channels, |
128 | Array<IndexExpr> kernel_size, std::string data_layout, |
129 | std::string kernel_layout, std::string out_layout, |
130 | DataType out_dtype, std::string op_name) { |
131 | auto attrs = make_object<T>(); |
132 | attrs->strides = strides; |
133 | attrs->padding = padding; |
134 | attrs->dilation = dilation; |
135 | attrs->deformable_groups = deformable_groups; |
136 | attrs->groups = groups; |
137 | attrs->channels = channels; |
138 | attrs->kernel_size = kernel_size; |
139 | attrs->data_layout = data_layout; |
140 | attrs->kernel_layout = kernel_layout; |
141 | attrs->out_layout = out_layout; |
142 | attrs->out_dtype = out_dtype; |
143 | const Op& op = Op::Get(op_name); |
144 | return Call(op, {data, offset, weight}, Attrs{attrs}, {}); |
145 | } |
146 | |
147 | } // namespace relay |
148 | } // namespace tvm |
149 | #endif // TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_ |
150 | |