1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/op/nn/convolution.h
22 * \brief Properties def of convlution operator for sharing.
23 */
24#ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_
25#define TVM_RELAY_OP_NN_CONVOLUTION_H_
26
27#include <tvm/auto_scheduler/compute_dag.h>
28#include <tvm/runtime/logging.h>
29#include <tvm/tir/analysis.h>
30
31#include <string>
32#include <utility>
33#include <vector>
34
35#include "../op_common.h"
36
37namespace tvm {
38namespace relay {
39
40bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
41 const TypeReporter& reporter);
42
43bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
44 const TypeReporter& reporter);
45
46template <typename AttrType>
47bool Conv2DWinogradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
48 const TypeReporter& reporter) {
49 ICHECK_EQ(types.size(), 3);
50 const auto* data = types[0].as<TensorTypeNode>();
51 if (data == nullptr) return false;
52 static const Layout kNCHW("NCHW");
53 static const Layout kOIHW("OIHW");
54
55 const AttrType* param = attrs.as<AttrType>();
56 ICHECK(param != nullptr);
57 const Layout in_layout(param->data_layout);
58 const Layout kernel_layout(param->kernel_layout);
59
60 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
61 ICHECK(trans_in_layout.defined())
62 << "Conv only support input layouts that are convertible from NCHW."
63 << " But got " << in_layout;
64
65 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
66 ICHECK(trans_kernel_layout.defined())
67 << "Conv only support kernel layouts that are convertible from OIHW."
68 << " But got " << kernel_layout;
69
70 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
71 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
72 ICHECK(trans_out_layout.defined())
73 << "Conv only support output layouts that are convertible from NCHW."
74 << " But got " << out_layout;
75
76 Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
77
78 IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
79
80 ICHECK(param->kernel_size.defined() && param->channels.defined())
81 << "The kernel size and channels of a Conv must be set or inferred by previous pass";
82
83 ICHECK_EQ(param->kernel_size.size(), 2);
84 ICHECK_EQ(param->dilation.size(), 2);
85
86 channels = param->channels;
87 dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
88 dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
89
90 // NOTE: Do not check weight shape here!
91 // Different backend requires different layout to compute
92 // the batch gemm stage in winograd efficiently, but we want to
93 // make this op work for all backends.
94 // So we accept all weight shapes, and assume the TOPI developers
95 // can handle this correctly in alter_op_layout.
96
97 // dilation
98 Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
99
100 IndexExpr pad_h, pad_w;
101 GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
102 if (!dshape_nchw[2].as<tir::AnyNode>()) {
103 oshape.Set(2, (dshape_nchw[2] + pad_h - dilated_ksize_y) / param->strides[0] + 1);
104 } else {
105 oshape.Set(2, dshape_nchw[2]);
106 }
107 if (!dshape_nchw[3].as<tir::AnyNode>()) {
108 oshape.Set(3, (dshape_nchw[3] + pad_w - dilated_ksize_x) / param->strides[1] + 1);
109 } else {
110 oshape.Set(3, dshape_nchw[3]);
111 }
112
113 DataType out_dtype = param->out_dtype;
114 if (out_dtype.bits() == 0) {
115 out_dtype = data->dtype;
116 }
117 oshape = trans_out_layout.BackwardShape(oshape);
118 // assign output type
119 reporter->Assign(types[2], TensorType(oshape, out_dtype));
120 return true;
121}
122
123template <typename T>
124InferCorrectLayoutOutput ConvInferCorrectLayout(const Attrs& attrs,
125 const Array<Layout>& new_in_layouts,
126 const Array<Layout>& old_in_layouts,
127 const Array<tvm::relay::Type>& old_in_types) {
128 const T* params = attrs.as<T>();
129 // We always make other operators to fit the layouts of convolution layers
130 // So this inference ignores all inputs
131 return InferCorrectLayoutOutput(
132 {params->data_layout, params->kernel_layout},
133 {params->out_layout == "" ? params->data_layout : params->out_layout}, attrs);
134}
135
136} // namespace relay
137} // namespace tvm
138#endif // TVM_RELAY_OP_NN_CONVOLUTION_H_
139