1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/op/nn/convolution_make.h
22 * \brief utilities for creating convolution ops
23 */
24#ifndef TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_
25#define TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_
26
27#include <tvm/relay/attrs/nn.h>
28#include <tvm/relay/op.h>
29
30#include <string>
31#include <utility>
32#include <vector>
33
34namespace tvm {
35namespace relay {
36
37template <typename T>
38inline Expr MakeConv(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
39 Array<IndexExpr> dilation, int groups, IndexExpr channels,
40 Array<IndexExpr> kernel_size, std::string data_layout,
41 std::string kernel_layout, std::string out_layout, DataType out_dtype,
42 std::string op_name) {
43 auto attrs = make_object<T>();
44 attrs->strides = std::move(strides);
45 attrs->padding = std::move(padding);
46 attrs->dilation = std::move(dilation);
47 attrs->groups = groups;
48 attrs->channels = std::move(channels);
49 attrs->kernel_size = std::move(kernel_size);
50 attrs->data_layout = std::move(data_layout);
51 attrs->kernel_layout = std::move(kernel_layout);
52 attrs->out_layout = std::move(out_layout);
53 attrs->out_dtype = std::move(out_dtype);
54 const Op& op = Op::Get(op_name);
55 return Call(op, {data, weight}, Attrs(attrs), {});
56}
57
58template <typename T>
59inline Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array<IndexExpr> strides,
60 Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
61 IndexExpr channels, Array<IndexExpr> kernel_size,
62 std::string data_layout, std::string kernel_layout,
63 std::string out_layout, DataType out_dtype, std::string op_name) {
64 auto attrs = make_object<T>();
65 attrs->tile_size = tile_size;
66 attrs->strides = std::move(strides);
67 attrs->padding = std::move(padding);
68 attrs->dilation = std::move(dilation);
69 attrs->groups = groups;
70 attrs->channels = std::move(channels);
71 attrs->kernel_size = std::move(kernel_size);
72 attrs->data_layout = std::move(data_layout);
73 attrs->kernel_layout = std::move(kernel_layout);
74 attrs->out_layout = std::move(out_layout);
75 attrs->out_dtype = std::move(out_dtype);
76 const Op& op = Op::Get(op_name);
77 return Call(op, {data, weight}, Attrs(attrs), {});
78}
79
80template <typename T>
81inline Expr MakeConvGemm(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
82 Array<IndexExpr> dilation, int groups, IndexExpr channels,
83 Array<IndexExpr> kernel_size, std::string data_layout,
84 std::string kernel_layout, std::string out_layout, DataType out_dtype,
85 std::string op_name) {
86 auto attrs = make_object<T>();
87 attrs->strides = std::move(strides);
88 attrs->padding = std::move(padding);
89 attrs->dilation = std::move(dilation);
90 attrs->groups = groups;
91 attrs->channels = std::move(channels);
92 attrs->kernel_size = std::move(kernel_size);
93 attrs->data_layout = std::move(data_layout);
94 attrs->kernel_layout = std::move(kernel_layout);
95 attrs->out_layout = std::move(out_layout);
96 attrs->out_dtype = std::move(out_dtype);
97 const Op& op = Op::Get(op_name);
98 return Call(op, {data, weight}, Attrs(attrs), {});
99}
100
101template <typename T>
102inline Expr MakeConvTranspose(Expr data, Expr weight, Array<IndexExpr> strides,
103 Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
104 IndexExpr channels, Array<IndexExpr> kernel_size,
105 std::string data_layout, std::string kernel_layout,
106 std::string out_layout, Array<IndexExpr> output_padding,
107 DataType out_dtype, std::string op_name) {
108 auto attrs = make_object<T>();
109 attrs->strides = std::move(strides);
110 attrs->padding = std::move(padding);
111 attrs->dilation = std::move(dilation);
112 attrs->groups = groups;
113 attrs->channels = std::move(channels);
114 attrs->kernel_size = std::move(kernel_size);
115 attrs->data_layout = std::move(data_layout);
116 attrs->kernel_layout = std::move(kernel_layout);
117 attrs->out_layout = std::move(out_layout);
118 attrs->output_padding = std::move(output_padding);
119 attrs->out_dtype = std::move(out_dtype);
120 const Op& op = Op::Get(op_name);
121 return Call(op, {data, weight}, Attrs(attrs), {});
122}
123
124template <typename T>
125inline Expr MakeDeformableConv(Expr data, Expr offset, Expr weight, Array<IndexExpr> strides,
126 Array<IndexExpr> padding, Array<IndexExpr> dilation,
127 int deformable_groups, int groups, int channels,
128 Array<IndexExpr> kernel_size, std::string data_layout,
129 std::string kernel_layout, std::string out_layout,
130 DataType out_dtype, std::string op_name) {
131 auto attrs = make_object<T>();
132 attrs->strides = strides;
133 attrs->padding = padding;
134 attrs->dilation = dilation;
135 attrs->deformable_groups = deformable_groups;
136 attrs->groups = groups;
137 attrs->channels = channels;
138 attrs->kernel_size = kernel_size;
139 attrs->data_layout = data_layout;
140 attrs->kernel_layout = kernel_layout;
141 attrs->out_layout = out_layout;
142 attrs->out_dtype = out_dtype;
143 const Op& op = Op::Get(op_name);
144 return Call(op, {data, offset, weight}, Attrs{attrs}, {});
145}
146
147} // namespace relay
148} // namespace tvm
149#endif // TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_
150