1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/op/tensor/transform.h
22 * \brief Transform op attributes that can be shared among Relay and its dialects.
23 */
24#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_
25#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_
26
27#include <tvm/relay/attrs/transform.h>
28#include <tvm/relay/error.h>
29#include <tvm/relay/op_attr_types.h>
30
31#include <algorithm>
32#include <limits>
33#include <string>
34#include <unordered_set>
35#include <utility>
36#include <vector>
37
38#include "../../transforms/infer_layout_utils.h"
39#include "../make_op.h"
40
41namespace tvm {
42namespace relay {
43
44template <typename AttrType>
45bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
46 const TypeReporter& reporter) {
47 // types: [data, result]
48 ICHECK_EQ(types.size(), 2) << "the arity of concatenate is 2, not " << types.size();
49 /* If we receive a tuple we can continue, if we receive
50 * anything but an incomplete type we should signal an
51 * error.
52 */
53 const auto* tensor_tuple = types[0].as<TupleTypeNode>();
54 if (tensor_tuple == nullptr) {
55 reporter->GetDiagCtx().EmitFatal(
56 Diagnostic::Error(reporter->GetSpan())
57 << "concatenate requires a tuple of tensors as the first argument, found "
58 << PrettyPrint(types[0]));
59 return false;
60 } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
61 return false;
62 }
63
64 const auto* param = attrs.as<AttrType>();
65 if (param == nullptr) {
66 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
67 << "the call attributes are not defined");
68 return false;
69 }
70
71 if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) {
72 return false;
73 }
74 const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
75 // Sanity check: ndim and dtype.
76 const int ndim = static_cast<int>(first->shape.size());
77 const DataType dtype = first->dtype;
78
79 // Sanity check: axis
80 int axis = param->axis;
81 if (!(-ndim <= axis && axis < ndim)) {
82 throw CompileError(ErrorBuilder() << "concatenate only accepts `axis` in [-ndim, ndim)"
83 << ", but got axis = " << axis << ", and ndim = " << ndim);
84 }
85 axis = axis < 0 ? ndim + axis : axis;
86
87 for (const Type& ele : tensor_tuple->fields) {
88 if (ele.as<IncompleteTypeNode>()) {
89 return false;
90 }
91
92 const auto& e = Downcast<TensorType>(ele);
93
94 int e_ndim = static_cast<int>(e->shape.size());
95 const DataType& e_dtype = e->dtype;
96 if (e_ndim != ndim) {
97 throw Error("relay.concatenate requires all tensors have the same ndim");
98 }
99 if (e_dtype != dtype) {
100 throw Error("relay.concatenate requires all tensors have the same dtype");
101 }
102 }
103
104 // Calculate shape
105 std::vector<IndexExpr> oshape(ndim);
106 const size_t data_length = tensor_tuple->fields.size();
107
108 // Accumulate the concat axis output dim or decide if this is dynamic concat
109 bool is_dynamic_concat = false;
110 std::vector<TensorType> input_tensors;
111 IndexExpr concat_output_dim = first->shape[axis];
112 for (size_t i = 0; i < data_length; ++i) {
113 const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
114 input_tensors.push_back(e);
115 if (e->shape[axis].as<AnyNode>()) {
116 is_dynamic_concat = true;
117 concat_output_dim = Any();
118 } else if (i > 0 && !is_dynamic_concat) {
119 // accumulate axis dimension
120 concat_output_dim += e->shape[axis];
121 }
122 }
123
124 oshape[axis] = concat_output_dim;
125
126 for (int i = 0; i < ndim; ++i) {
127 if (i == axis) {
128 // The concat axis is already handled above.
129 // The rest of the body sets the output shape for non-concat axes
130 continue;
131 }
132 std::vector<IndexExpr> non_any;
133 for (size_t j = 0; j < data_length; ++j) {
134 const auto& e = input_tensors[j];
135 if (!e->shape[i].as<AnyNode>()) {
136 non_any.push_back(e->shape[i]);
137 }
138 }
139 size_t non_any_size = non_any.size();
140 for (size_t k = 1; k < non_any_size; k++) {
141 if (reporter->AssertEQ(non_any[0], non_any[k])) continue;
142 throw Error(
143 "relay.concatenate requires all tensors have the same shape "
144 "on non-concatenating axes");
145 }
146
147 if (non_any_size == data_length) {
148 // All static case
149 oshape[i] = non_any[0];
150 } else if (non_any_size > 0 && is_dynamic_concat) {
151 // For non-concat axes, we want to enforce static shape constraint.
152 // However, if the concat axis is static, the output shape would become static while
153 // the input could be partially static/dynamic. To prevent runtime segfaults due to the lack
154 // of runtime input shape checking for such cases, static shape constraint is only enforced
155 // when the output concat axis is dynamic.
156 //
157 // Examples (both concat on the first axis):
158 // * [(?, 3), (?, ?)] -> (?, 3)
159 // * [(1, 3), (1, ?)] -> (2, ?)
160 oshape[i] = non_any[0];
161 } else {
162 oshape[i] = Any();
163 }
164 }
165
166 auto rtype = TensorType(oshape, dtype);
167 reporter->Assign(types[1], rtype);
168 return true;
169}
170
171static inline InferCorrectLayoutOutput ConcatenateLayout(
172 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
173 const Array<tvm::relay::Type>& old_in_types) {
174 const auto* attrs_ptr = attrs.as<ConcatenateAttrs>();
175 ICHECK(attrs_ptr);
176 ObjectPtr<ConcatenateAttrs> param = make_object<ConcatenateAttrs>(*attrs_ptr);
177
178 Array<Array<IndexExpr>> old_in_shapes;
179 ICHECK_EQ(old_in_types.size(), 1);
180 for (auto old_in_tuple_t : old_in_types) {
181 ICHECK(old_in_tuple_t.as<TupleTypeNode>());
182 for (auto old_in_t : old_in_tuple_t.as<TupleTypeNode>()->fields) {
183 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
184 }
185 }
186
187 size_t axis =
188 param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);
189
190 Layout ret;
191 bool is_new_layout_selected = false;
192 if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
193 // If all the new input layouts are same, the new in layout gets selected. For axis, the new
194 // axis in the new layout is identified. The param->axis is then modified on the fly to conform
195 // to the new input layout.
196 const auto& concate_dim = old_in_layouts[0][axis];
197 bool all_input_layouts_same = true;
198 for (auto new_layout : new_in_layouts) {
199 if (!new_layout.Equals(new_in_layouts[0])) {
200 all_input_layouts_same = false;
201 }
202 }
203 if (all_input_layouts_same) {
204 auto new_index = new_in_layouts[0].IndexOf(concate_dim);
205 ret = new_in_layouts[0];
206 param->axis = new_index;
207 is_new_layout_selected = true;
208 }
209 }
210
211 if (!is_new_layout_selected) {
212 // this function is called on the original correct relay ir
213 for (size_t i = 0; i < old_in_layouts.size(); ++i) {
214 if (old_in_layouts[i].defined()) {
215 ret = old_in_layouts[i];
216 break;
217 }
218 }
219
220 if (ret.ndim() <= axis || !ret[axis].IsPrimal()) {
221 return InferCorrectLayoutOutput({Layout::Undef()}, {Layout::Undef()}, attrs);
222 }
223 }
224
225 return InferCorrectLayoutOutput(Array<Layout>(old_in_layouts.size(), ret), {ret}, Attrs(param));
226}
227
228/*!
229 * \brief Infer output shape for reshape.
230 *
231 * \param data_shape The input data shape.
232 * \param attrs The attributes.
233 * \param reverse Whether to reverse the indices.
234 * \return Output shape.
235 */
236Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const Attrs& attrs,
237 bool reverse);
238
239} // namespace relay
240} // namespace tvm
241#endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_
242