1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/op/nn/convolution.h |
22 | * \brief Properties def of convlution operator for sharing. |
23 | */ |
24 | #ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_ |
25 | #define TVM_RELAY_OP_NN_CONVOLUTION_H_ |
26 | |
27 | #include <tvm/auto_scheduler/compute_dag.h> |
28 | #include <tvm/runtime/logging.h> |
29 | #include <tvm/tir/analysis.h> |
30 | |
31 | #include <string> |
32 | #include <utility> |
33 | #include <vector> |
34 | |
35 | #include "../op_common.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
41 | const TypeReporter& reporter); |
42 | |
43 | bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
44 | const TypeReporter& reporter); |
45 | |
46 | template <typename AttrType> |
47 | bool Conv2DWinogradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
48 | const TypeReporter& reporter) { |
49 | ICHECK_EQ(types.size(), 3); |
50 | const auto* data = types[0].as<TensorTypeNode>(); |
51 | if (data == nullptr) return false; |
52 | static const Layout kNCHW("NCHW" ); |
53 | static const Layout kOIHW("OIHW" ); |
54 | |
55 | const AttrType* param = attrs.as<AttrType>(); |
56 | ICHECK(param != nullptr); |
57 | const Layout in_layout(param->data_layout); |
58 | const Layout kernel_layout(param->kernel_layout); |
59 | |
60 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); |
61 | ICHECK(trans_in_layout.defined()) |
62 | << "Conv only support input layouts that are convertible from NCHW." |
63 | << " But got " << in_layout; |
64 | |
65 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); |
66 | ICHECK(trans_kernel_layout.defined()) |
67 | << "Conv only support kernel layouts that are convertible from OIHW." |
68 | << " But got " << kernel_layout; |
69 | |
70 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
71 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); |
72 | ICHECK(trans_out_layout.defined()) |
73 | << "Conv only support output layouts that are convertible from NCHW." |
74 | << " But got " << out_layout; |
75 | |
76 | Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape); |
77 | |
78 | IndexExpr channels, dilated_ksize_y, dilated_ksize_x; |
79 | |
80 | ICHECK(param->kernel_size.defined() && param->channels.defined()) |
81 | << "The kernel size and channels of a Conv must be set or inferred by previous pass" ; |
82 | |
83 | ICHECK_EQ(param->kernel_size.size(), 2); |
84 | ICHECK_EQ(param->dilation.size(), 2); |
85 | |
86 | channels = param->channels; |
87 | dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
88 | dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; |
89 | |
90 | // NOTE: Do not check weight shape here! |
91 | // Different backend requires different layout to compute |
92 | // the batch gemm stage in winograd efficiently, but we want to |
93 | // make this op work for all backends. |
94 | // So we accept all weight shapes, and assume the TOPI developers |
95 | // can handle this correctly in alter_op_layout. |
96 | |
97 | // dilation |
98 | Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); |
99 | |
100 | IndexExpr pad_h, pad_w; |
101 | GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); |
102 | if (!dshape_nchw[2].as<tir::AnyNode>()) { |
103 | oshape.Set(2, (dshape_nchw[2] + pad_h - dilated_ksize_y) / param->strides[0] + 1); |
104 | } else { |
105 | oshape.Set(2, dshape_nchw[2]); |
106 | } |
107 | if (!dshape_nchw[3].as<tir::AnyNode>()) { |
108 | oshape.Set(3, (dshape_nchw[3] + pad_w - dilated_ksize_x) / param->strides[1] + 1); |
109 | } else { |
110 | oshape.Set(3, dshape_nchw[3]); |
111 | } |
112 | |
113 | DataType out_dtype = param->out_dtype; |
114 | if (out_dtype.bits() == 0) { |
115 | out_dtype = data->dtype; |
116 | } |
117 | oshape = trans_out_layout.BackwardShape(oshape); |
118 | // assign output type |
119 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
120 | return true; |
121 | } |
122 | |
123 | template <typename T> |
124 | InferCorrectLayoutOutput ConvInferCorrectLayout(const Attrs& attrs, |
125 | const Array<Layout>& new_in_layouts, |
126 | const Array<Layout>& old_in_layouts, |
127 | const Array<tvm::relay::Type>& old_in_types) { |
128 | const T* params = attrs.as<T>(); |
129 | // We always make other operators to fit the layouts of convolution layers |
130 | // So this inference ignores all inputs |
131 | return InferCorrectLayoutOutput( |
132 | {params->data_layout, params->kernel_layout}, |
133 | {params->out_layout == "" ? params->data_layout : params->out_layout}, attrs); |
134 | } |
135 | |
136 | } // namespace relay |
137 | } // namespace tvm |
138 | #endif // TVM_RELAY_OP_NN_CONVOLUTION_H_ |
139 | |