1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/op/tensor/transform.h |
22 | * \brief Transform op attributes that can be shared among Relay and its dialects. |
23 | */ |
24 | #ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_ |
25 | #define TVM_RELAY_OP_TENSOR_TRANSFORM_H_ |
26 | |
27 | #include <tvm/relay/attrs/transform.h> |
28 | #include <tvm/relay/error.h> |
29 | #include <tvm/relay/op_attr_types.h> |
30 | |
31 | #include <algorithm> |
32 | #include <limits> |
33 | #include <string> |
34 | #include <unordered_set> |
35 | #include <utility> |
36 | #include <vector> |
37 | |
38 | #include "../../transforms/infer_layout_utils.h" |
39 | #include "../make_op.h" |
40 | |
41 | namespace tvm { |
42 | namespace relay { |
43 | |
44 | template <typename AttrType> |
45 | bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
46 | const TypeReporter& reporter) { |
47 | // types: [data, result] |
48 | ICHECK_EQ(types.size(), 2) << "the arity of concatenate is 2, not " << types.size(); |
49 | /* If we receive a tuple we can continue, if we receive |
50 | * anything but an incomplete type we should signal an |
51 | * error. |
52 | */ |
53 | const auto* tensor_tuple = types[0].as<TupleTypeNode>(); |
54 | if (tensor_tuple == nullptr) { |
55 | reporter->GetDiagCtx().EmitFatal( |
56 | Diagnostic::Error(reporter->GetSpan()) |
57 | << "concatenate requires a tuple of tensors as the first argument, found " |
58 | << PrettyPrint(types[0])); |
59 | return false; |
60 | } else if (types[0].as<IncompleteTypeNode>() != nullptr) { |
61 | return false; |
62 | } |
63 | |
64 | const auto* param = attrs.as<AttrType>(); |
65 | if (param == nullptr) { |
66 | reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) |
67 | << "the call attributes are not defined" ); |
68 | return false; |
69 | } |
70 | |
71 | if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) { |
72 | return false; |
73 | } |
74 | const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]); |
75 | // Sanity check: ndim and dtype. |
76 | const int ndim = static_cast<int>(first->shape.size()); |
77 | const DataType dtype = first->dtype; |
78 | |
79 | // Sanity check: axis |
80 | int axis = param->axis; |
81 | if (!(-ndim <= axis && axis < ndim)) { |
82 | throw CompileError(ErrorBuilder() << "concatenate only accepts `axis` in [-ndim, ndim)" |
83 | << ", but got axis = " << axis << ", and ndim = " << ndim); |
84 | } |
85 | axis = axis < 0 ? ndim + axis : axis; |
86 | |
87 | for (const Type& ele : tensor_tuple->fields) { |
88 | if (ele.as<IncompleteTypeNode>()) { |
89 | return false; |
90 | } |
91 | |
92 | const auto& e = Downcast<TensorType>(ele); |
93 | |
94 | int e_ndim = static_cast<int>(e->shape.size()); |
95 | const DataType& e_dtype = e->dtype; |
96 | if (e_ndim != ndim) { |
97 | throw Error("relay.concatenate requires all tensors have the same ndim" ); |
98 | } |
99 | if (e_dtype != dtype) { |
100 | throw Error("relay.concatenate requires all tensors have the same dtype" ); |
101 | } |
102 | } |
103 | |
104 | // Calculate shape |
105 | std::vector<IndexExpr> oshape(ndim); |
106 | const size_t data_length = tensor_tuple->fields.size(); |
107 | |
108 | // Accumulate the concat axis output dim or decide if this is dynamic concat |
109 | bool is_dynamic_concat = false; |
110 | std::vector<TensorType> input_tensors; |
111 | IndexExpr concat_output_dim = first->shape[axis]; |
112 | for (size_t i = 0; i < data_length; ++i) { |
113 | const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]); |
114 | input_tensors.push_back(e); |
115 | if (e->shape[axis].as<AnyNode>()) { |
116 | is_dynamic_concat = true; |
117 | concat_output_dim = Any(); |
118 | } else if (i > 0 && !is_dynamic_concat) { |
119 | // accumulate axis dimension |
120 | concat_output_dim += e->shape[axis]; |
121 | } |
122 | } |
123 | |
124 | oshape[axis] = concat_output_dim; |
125 | |
126 | for (int i = 0; i < ndim; ++i) { |
127 | if (i == axis) { |
128 | // The concat axis is already handled above. |
129 | // The rest of the body sets the output shape for non-concat axes |
130 | continue; |
131 | } |
132 | std::vector<IndexExpr> non_any; |
133 | for (size_t j = 0; j < data_length; ++j) { |
134 | const auto& e = input_tensors[j]; |
135 | if (!e->shape[i].as<AnyNode>()) { |
136 | non_any.push_back(e->shape[i]); |
137 | } |
138 | } |
139 | size_t non_any_size = non_any.size(); |
140 | for (size_t k = 1; k < non_any_size; k++) { |
141 | if (reporter->AssertEQ(non_any[0], non_any[k])) continue; |
142 | throw Error( |
143 | "relay.concatenate requires all tensors have the same shape " |
144 | "on non-concatenating axes" ); |
145 | } |
146 | |
147 | if (non_any_size == data_length) { |
148 | // All static case |
149 | oshape[i] = non_any[0]; |
150 | } else if (non_any_size > 0 && is_dynamic_concat) { |
151 | // For non-concat axes, we want to enforce static shape constraint. |
152 | // However, if the concat axis is static, the output shape would become static while |
153 | // the input could be partially static/dynamic. To prevent runtime segfaults due to the lack |
154 | // of runtime input shape checking for such cases, static shape constraint is only enforced |
155 | // when the output concat axis is dynamic. |
156 | // |
157 | // Examples (both concat on the first axis): |
158 | // * [(?, 3), (?, ?)] -> (?, 3) |
159 | // * [(1, 3), (1, ?)] -> (2, ?) |
160 | oshape[i] = non_any[0]; |
161 | } else { |
162 | oshape[i] = Any(); |
163 | } |
164 | } |
165 | |
166 | auto rtype = TensorType(oshape, dtype); |
167 | reporter->Assign(types[1], rtype); |
168 | return true; |
169 | } |
170 | |
171 | static inline InferCorrectLayoutOutput ConcatenateLayout( |
172 | const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
173 | const Array<tvm::relay::Type>& old_in_types) { |
174 | const auto* attrs_ptr = attrs.as<ConcatenateAttrs>(); |
175 | ICHECK(attrs_ptr); |
176 | ObjectPtr<ConcatenateAttrs> param = make_object<ConcatenateAttrs>(*attrs_ptr); |
177 | |
178 | Array<Array<IndexExpr>> old_in_shapes; |
179 | ICHECK_EQ(old_in_types.size(), 1); |
180 | for (auto old_in_tuple_t : old_in_types) { |
181 | ICHECK(old_in_tuple_t.as<TupleTypeNode>()); |
182 | for (auto old_in_t : old_in_tuple_t.as<TupleTypeNode>()->fields) { |
183 | old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape); |
184 | } |
185 | } |
186 | |
187 | size_t axis = |
188 | param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis); |
189 | |
190 | Layout ret; |
191 | bool is_new_layout_selected = false; |
192 | if (new_in_layouts.defined()) { // this function is called after some operators are alternated. |
193 | // If all the new input layouts are same, the new in layout gets selected. For axis, the new |
194 | // axis in the new layout is identified. The param->axis is then modified on the fly to conform |
195 | // to the new input layout. |
196 | const auto& concate_dim = old_in_layouts[0][axis]; |
197 | bool all_input_layouts_same = true; |
198 | for (auto new_layout : new_in_layouts) { |
199 | if (!new_layout.Equals(new_in_layouts[0])) { |
200 | all_input_layouts_same = false; |
201 | } |
202 | } |
203 | if (all_input_layouts_same) { |
204 | auto new_index = new_in_layouts[0].IndexOf(concate_dim); |
205 | ret = new_in_layouts[0]; |
206 | param->axis = new_index; |
207 | is_new_layout_selected = true; |
208 | } |
209 | } |
210 | |
211 | if (!is_new_layout_selected) { |
212 | // this function is called on the original correct relay ir |
213 | for (size_t i = 0; i < old_in_layouts.size(); ++i) { |
214 | if (old_in_layouts[i].defined()) { |
215 | ret = old_in_layouts[i]; |
216 | break; |
217 | } |
218 | } |
219 | |
220 | if (ret.ndim() <= axis || !ret[axis].IsPrimal()) { |
221 | return InferCorrectLayoutOutput({Layout::Undef()}, {Layout::Undef()}, attrs); |
222 | } |
223 | } |
224 | |
225 | return InferCorrectLayoutOutput(Array<Layout>(old_in_layouts.size(), ret), {ret}, Attrs(param)); |
226 | } |
227 | |
228 | /*! |
229 | * \brief Infer output shape for reshape. |
230 | * |
231 | * \param data_shape The input data shape. |
232 | * \param attrs The attributes. |
233 | * \param reverse Whether to reverse the indices. |
234 | * \return Output shape. |
235 | */ |
236 | Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const Attrs& attrs, |
237 | bool reverse); |
238 | |
239 | } // namespace relay |
240 | } // namespace tvm |
241 | #endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_ |
242 | |