1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file reduce.cc |
22 | * \brief Reduction operators. |
23 | */ |
24 | #include <tvm/relay/attrs/reduce.h> |
25 | #include <tvm/relay/expr.h> |
26 | #include <tvm/relay/op.h> |
27 | #include <tvm/topi/elemwise.h> |
28 | #include <tvm/topi/reduction.h> |
29 | |
30 | #include <limits> |
31 | #include <numeric> |
32 | |
33 | #include "../make_op.h" |
34 | #include "../op_common.h" |
35 | #include "../type_relations.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | TVM_REGISTER_NODE_TYPE(ReduceAttrs); |
41 | TVM_REGISTER_NODE_TYPE(ArgReduceAttrs); |
42 | TVM_REGISTER_NODE_TYPE(VarianceAttrs); |
43 | |
44 | /*! |
45 | * \brief GetReduceAxes, get the new axis from indim and other arguments |
46 | * \param indim Number of dimensions of input data. |
47 | * \param axis The input axis vector. |
48 | * \param exclude Whether 'axis' input given is the excluded axis. |
49 | * \return r_axes The new reduced axes of the output. |
50 | */ |
51 | inline std::vector<int64_t> GetReduceAxes(const uint32_t indim, const Array<Integer>& inaxis, |
52 | bool exclude) { |
53 | if (!inaxis.defined() || inaxis.empty()) { |
54 | std::vector<int64_t> r_axes(indim); |
55 | std::iota(r_axes.begin(), r_axes.end(), 0); |
56 | return r_axes; |
57 | } |
58 | |
59 | std::vector<int64_t> in_axes; |
60 | for (auto i : inaxis) { |
61 | int64_t axis = i->value; |
62 | if (axis < 0) { |
63 | axis = axis + indim; |
64 | } |
65 | |
66 | // Check out of bounds error |
67 | ICHECK(axis >= 0) << "Axis out of bounds in reduce operator." ; |
68 | ICHECK(axis < indim) << "Axis out of bounds in reduce operator." ; |
69 | in_axes.push_back(axis); |
70 | } |
71 | |
72 | ICHECK(in_axes[in_axes.size() - 1] < indim) |
73 | << "Reduction axis " << in_axes[in_axes.size() - 1] << " exceeds input dimensions " << indim; |
74 | |
75 | std::sort(in_axes.begin(), in_axes.end()); |
76 | |
77 | if (!exclude) { |
78 | return in_axes; |
79 | } |
80 | |
81 | auto r_size = indim - in_axes.size(); |
82 | std::vector<int64_t> r_axes(r_size); |
83 | for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) { |
84 | if (j < in_axes.size() && in_axes[j] == i) { |
85 | ++j; |
86 | continue; |
87 | } |
88 | r_axes[k++] = i; |
89 | } |
90 | return r_axes; |
91 | } |
92 | |
93 | // Get axis under exclude condition. |
94 | Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) { |
95 | ICHECK(inaxis.defined()) << "Cannot set exclude when axis=None" ; |
96 | std::vector<bool> axis_flag(indim, true); |
97 | for (auto i : inaxis) { |
98 | int64_t axis = i->value; |
99 | if (axis < 0) { |
100 | axis = axis + static_cast<int64_t>(indim); |
101 | } |
102 | // Check out of bounds error |
103 | ICHECK_GE(axis, 0) << "Axis out of bounds in reduce operator." ; |
104 | ICHECK_LT(axis, static_cast<int64_t>(indim)) << "Axis out of bounds in reduce operator." ; |
105 | axis_flag[axis] = false; |
106 | } |
107 | |
108 | Array<Integer> r_axes; |
109 | |
110 | for (size_t i = 0; i < axis_flag.size(); ++i) { |
111 | if (axis_flag[i]) { |
112 | r_axes.push_back(static_cast<int>(i)); |
113 | } |
114 | } |
115 | return r_axes; |
116 | } |
117 | |
118 | // Return the modified layout for AlterOpLayout pass. |
119 | template <typename T> |
120 | InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, |
121 | const Array<Layout>& new_in_layouts, |
122 | const Array<Layout>& old_in_layouts, |
123 | const Array<tvm::relay::Type>& old_in_types) { |
124 | const auto* attrs_ptr = attrs.as<T>(); |
125 | ICHECK(attrs_ptr); |
126 | ObjectPtr<T> params = make_object<T>(*attrs_ptr); |
127 | |
128 | // Get the reduce axes. |
129 | Array<Array<IndexExpr>> old_in_shapes; |
130 | for (auto old_in_t : old_in_types) { |
131 | ICHECK(old_in_t.as<TensorTypeNode>()); |
132 | old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape); |
133 | } |
134 | uint32_t indim = old_in_shapes[0].size(); |
135 | auto r_axes = GetReduceAxes(indim, params->axis, params->exclude); |
136 | |
137 | Layout inferred_in = Layout::Undef(); |
138 | Layout inferred_out = Layout::Undef(); |
139 | |
140 | // Infer [in_layout, out_layout, new_r_axes] from old_in_layout or new_in_layout |
141 | auto infer = [&](const Layout& layout) { |
142 | // 1) Collect the original axes |
143 | std::unordered_set<std::string> old_r_dims; |
144 | for (auto r_axis : r_axes) { |
145 | old_r_dims.emplace(old_in_layouts[0][r_axis].name()); |
146 | } |
147 | |
148 | // 2) Collect the new axes by walking new_layout. |
149 | tvm::Array<tvm::Integer> new_r_axes; |
150 | std::string inferred_in_string = "" ; |
151 | std::string inferred_out_string = "" ; |
152 | auto push_new_axis = [&](const std::string& layout_dim, int axis) { |
153 | if ((old_r_dims.count(layout_dim) && !params->exclude) || |
154 | (!old_r_dims.count(layout_dim) && params->exclude)) { |
155 | new_r_axes.push_back(tvm::Integer(axis)); |
156 | return true; |
157 | } |
158 | return false; |
159 | }; |
160 | for (size_t axis_index = 0; axis_index < layout->axes.size(); ++axis_index) { |
161 | const auto& layout_axis = LayoutAxis::Get(layout->axes[axis_index]); |
162 | const std::string& layout_dim = layout_axis.name(); |
163 | if (layout_axis.IsPrimal()) { |
164 | push_new_axis(layout_dim, axis_index); |
165 | inferred_in_string += layout_dim; |
166 | if (!old_r_dims.count(layout_dim) || params->keepdims) { |
167 | inferred_out_string += layout_dim; |
168 | } |
169 | } else { |
170 | // For example, if the original layout is NCHW, the new layout is NCHW8c, and the original |
171 | // reduce axes is [1], the new reduce axes become [1, 4]. |
172 | auto primal_dim = layout_axis.ToPrimal().name(); |
173 | auto packed_dim = std::to_string(layout.FactorOf(layout_axis)) + layout_dim; |
174 | inferred_in_string += packed_dim; |
175 | if (push_new_axis(primal_dim, axis_index)) { |
176 | if (params->exclude) { |
177 | // The primal axis is not reduced, so keep the input packed dim. |
178 | inferred_out_string += packed_dim; |
179 | } else if (params->keepdims) { |
180 | // If the primal axis is part of reduce axes in the original layout, the inner dim |
181 | // becomes 1 after reduction. |
182 | inferred_out_string += "1" + layout_dim; |
183 | } |
184 | } else { |
185 | inferred_out_string += packed_dim; |
186 | } |
187 | } |
188 | } |
189 | |
190 | // 3) Set the new axis and layout. |
191 | return std::make_tuple(Layout(inferred_in_string), Layout(inferred_out_string), new_r_axes); |
192 | }; |
193 | |
194 | std::string new_layout_string; |
195 | Array<Integer> new_r_axes; |
196 | Array<Layout> new_input_layouts; |
197 | |
198 | auto check_num_input_layouts = [](Array<Layout> in_layouts) { |
199 | // The second case is for variance op |
200 | ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2); |
201 | }; |
202 | |
203 | if (new_in_layouts.defined() && r_axes.size()) { |
204 | // Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the |
205 | // modified layout axes. |
206 | check_num_input_layouts(new_in_layouts); |
207 | check_num_input_layouts(old_in_layouts); |
208 | |
209 | // Get inferred_in and inferred_out from new_in_layout. |
210 | std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]); |
211 | params->axis = new_r_axes; |
212 | } else if (old_in_layouts.defined()) { |
213 | check_num_input_layouts(old_in_layouts); |
214 | |
215 | // If the new layout is undefined, get inferred_in and inferred_out from old_in_layout. |
216 | if (old_in_layouts[0].defined()) { |
217 | std::tie(inferred_in, inferred_out, std::ignore) = infer(old_in_layouts[0]); |
218 | } |
219 | } |
220 | |
221 | new_input_layouts.push_back(inferred_in); |
222 | |
223 | if (old_in_layouts.size() == 2) { |
224 | new_input_layouts.push_back(inferred_in); |
225 | } |
226 | |
227 | return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, Attrs(params)); |
228 | } |
229 | |
230 | template <typename F> |
231 | Array<te::Tensor> ReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
232 | const Type& out_type, F f) { |
233 | const ReduceAttrs* param = attrs.as<ReduceAttrs>(); |
234 | ICHECK(param != nullptr); |
235 | if (inputs[0]->shape.size() == 0) { |
236 | return {topi::identity(inputs[0])}; |
237 | } |
238 | auto axes = param->axis; |
239 | if (param->exclude) { |
240 | axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); |
241 | if (axes.size() == 0) { |
242 | return {topi::identity(inputs[0])}; |
243 | } |
244 | } |
245 | |
246 | return {f(inputs[0], axes, param->keepdims, false)}; |
247 | } |
248 | |
249 | template <typename F> |
250 | Array<te::Tensor> ArgReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
251 | const Type& out_type, F f) { |
252 | const ArgReduceAttrs* param = attrs.as<ArgReduceAttrs>(); |
253 | ICHECK(param != nullptr); |
254 | if (inputs[0]->shape.size() == 0) { |
255 | return {topi::identity(inputs[0])}; |
256 | } |
257 | auto axes = param->axis; |
258 | if (param->exclude) { |
259 | axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); |
260 | if (axes.size() == 0) { |
261 | return {topi::identity(inputs[0])}; |
262 | } |
263 | } |
264 | |
265 | return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)}; |
266 | } |
267 | |
268 | /*! |
269 | * \brief ReduceShapeImpl get the outshape for the reduction operator |
270 | * \param in_shape Shape of input data. |
271 | * \param param Attrs details. |
272 | * \param reporter The reporter to report solution to. |
273 | * \return oshape Output shape inferred. |
274 | * \tparam AttrsType The attribute type. |
275 | */ |
276 | template <typename AttrsType> |
277 | inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr>& in_shape, |
278 | const AttrsType* param, |
279 | const TypeReporter& reporter) { |
280 | uint32_t indim = in_shape.size(); |
281 | auto r_axes = GetReduceAxes(indim, param->axis, param->exclude); |
282 | if (!r_axes.size()) { |
283 | return in_shape; |
284 | } |
285 | |
286 | auto max_shape = tir::make_const(DataType::Int(64), 1); |
287 | bool is_dynamic_input = false; |
288 | for (int64_t axis : r_axes) { |
289 | if (in_shape[axis].as<IntImmNode>()) { |
290 | max_shape *= in_shape[axis]; |
291 | } else { |
292 | is_dynamic_input = true; |
293 | break; |
294 | } |
295 | } |
296 | |
297 | if (is_dynamic_input) { |
298 | ICHECK(reporter->Assert( |
299 | max_shape < tir::make_const(DataType::Int(64), std::numeric_limits<int32_t>::max()))) |
300 | << "The maximum possible index of reduced shape cannot be more than int32 max." ; |
301 | } |
302 | |
303 | if (param->keepdims) { |
304 | std::vector<IndexExpr> oshape(in_shape); |
305 | for (unsigned i = 0, j = 0; i < indim; ++i) { |
306 | if (j >= r_axes.size() || !(r_axes[j] == i)) { |
307 | continue; |
308 | } |
309 | oshape[i] = 1; |
310 | ++j; |
311 | } |
312 | return oshape; |
313 | } else { |
314 | auto osize = indim - r_axes.size(); |
315 | std::vector<IndexExpr> oshape(osize); |
316 | for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) { |
317 | if (j < r_axes.size() && (r_axes[j] == i)) { |
318 | ++j; |
319 | continue; |
320 | } |
321 | oshape[k++] = in_shape[i]; |
322 | } |
323 | return oshape; |
324 | } |
325 | } |
326 | |
327 | template <class T> |
328 | bool GenericReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
329 | const TypeReporter& reporter) { |
330 | ICHECK_EQ(types.size(), 2); |
331 | const auto* data = types[0].as<TensorTypeNode>(); |
332 | if (data == nullptr) return false; |
333 | ICHECK(static_cast<int>(data->shape.size()) != 0); |
334 | std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end()); |
335 | |
336 | const T* param = attrs.as<T>(); |
337 | ICHECK(param != nullptr); |
338 | |
339 | // assign output type and shape |
340 | auto oshape = ReduceShapeImpl(in_shape, param, reporter); |
341 | reporter->Assign(types[1], TensorType(oshape, data->shape[0].dtype())); |
342 | return true; |
343 | } |
344 | /*! |
345 | * \brief ArgReduceRel Output type and shape relation evaluation function. |
346 | * \param num_inputs Number of input types in the args. |
347 | * \param attrs The additional attributes of the operator. |
348 | * \param reporter The reporter to report solution to. |
349 | * \return false if This relation cannot be resolved. true if this relation has been resolved. |
350 | */ |
351 | bool ArgReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
352 | const TypeReporter& reporter) { |
353 | return GenericReduceRel<ReduceAttrs>(types, num_inputs, attrs, reporter); |
354 | } |
355 | |
356 | /*! |
357 | * \brief ReduceRel Output type and shape relation evaluation function. |
358 | * \param num_inputs Number of input types in the args. |
359 | * \param attrs The additional attributes of the operator. |
360 | * \param reporter The reporter to report solution to. |
361 | * \return false if This relation cannot be resolved. true if this relation has been resolved. |
362 | */ |
363 | bool ReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
364 | const TypeReporter& reporter) { |
365 | ICHECK_EQ(types.size(), 2); |
366 | const auto* data = types[0].as<TensorTypeNode>(); |
367 | if (data == nullptr) return false; |
368 | std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end()); |
369 | |
370 | const ReduceAttrs* param = attrs.as<ReduceAttrs>(); |
371 | ICHECK(param != nullptr); |
372 | |
373 | // assign output type and shape |
374 | auto oshape = ReduceShapeImpl(in_shape, param, reporter); |
375 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
376 | return true; |
377 | } |
378 | |
379 | Expr MakeReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, String op_name) { |
380 | auto attrs = make_object<ReduceAttrs>(); |
381 | attrs->axis = std::move(axis); |
382 | attrs->keepdims = keepdims; |
383 | attrs->exclude = exclude; |
384 | return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); |
385 | } |
386 | |
387 | Expr MakeOneElementReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, |
388 | bool select_last_index, String op_name) { |
389 | auto attrs = make_object<ArgReduceAttrs>(); |
390 | attrs->axis = std::move(axis); |
391 | attrs->keepdims = keepdims; |
392 | attrs->exclude = exclude; |
393 | attrs->select_last_index = select_last_index; |
394 | return Call(Op::Get(op_name), {data}, Attrs(attrs), {}); |
395 | } |
396 | |
397 | #define RELAY_REGISTER_REDUCE_OP(OpName) \ |
398 | TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ |
399 | .set_body_typed([](Expr data, Array<Integer> axis, bool keepdims, bool exclude) { \ |
400 | return MakeReduce(data, axis, keepdims, exclude, OpName); \ |
401 | }); \ |
402 | RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") |
403 | |
404 | #define RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP(OpName) \ |
405 | TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ |
406 | .set_body_typed([](Expr data, Array<Integer> axis, bool keepdims, bool exclude, \ |
407 | bool select_last_index) { \ |
408 | return MakeOneElementReduce(data, axis, keepdims, exclude, select_last_index, OpName); \ |
409 | }); \ |
410 | RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") |
411 | |
412 | Array<te::Tensor> ArgMaxCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
413 | const Type& out_type) { |
414 | return ArgReduceCompute(attrs, inputs, out_type, topi::argmax); |
415 | } |
416 | |
417 | RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmax" ) |
418 | .describe(R"code(Creates an operation that finds the indices of the maximum |
419 | values over a given axis. |
420 | |
421 | )code" TVM_ADD_FILELINE) |
422 | .set_attrs_type<ArgReduceAttrs>() |
423 | .set_support_level(4) |
424 | .add_type_rel("ArgReduce" , GenericReduceRel<ArgReduceAttrs>) |
425 | .set_attr<FTVMCompute>("FTVMCompute" , ArgMaxCompute) |
426 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ReduceInferCorrectLayout<ArgReduceAttrs>) |
427 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
428 | |
429 | Array<te::Tensor> ArgMinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
430 | const Type& out_type) { |
431 | return ArgReduceCompute(attrs, inputs, out_type, topi::argmin); |
432 | } |
433 | |
434 | RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmin" ) |
435 | .describe(R"code(Creates an operation that finds the indices of the minimum |
436 | values over a given axis. |
437 | |
438 | )code" TVM_ADD_FILELINE) |
439 | .set_attrs_type<ArgReduceAttrs>() |
440 | .set_support_level(4) |
441 | .add_type_rel("ArgReduce" , GenericReduceRel<ArgReduceAttrs>) |
442 | .set_attr<FTVMCompute>("FTVMCompute" , ArgMinCompute) |
443 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ReduceInferCorrectLayout<ArgReduceAttrs>) |
444 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
445 | |
446 | Array<te::Tensor> SumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
447 | const Type& out_type) { |
448 | return ReduceCompute(attrs, inputs, out_type, topi::sum); |
449 | } |
450 | |
451 | RELAY_REGISTER_REDUCE_OP("sum" ) |
452 | .describe(R"code(Computes the sum of array elements over given axes. |
453 | |
454 | Example:: |
455 | |
456 | data = [[[1,2],[2,3],[1,3]], |
457 | [[1,4],[4,3],[5,2]], |
458 | [[7,1],[7,2],[7,3]]] |
459 | |
460 | sum(data, axis=1) |
461 | [[ 4. 8.] |
462 | [ 10. 9.] |
463 | [ 21. 6.]] |
464 | |
465 | sum(data, axis=[1,2]) |
466 | [ 12. 19. 27.] |
467 | |
468 | )code" TVM_ADD_FILELINE) |
469 | .set_attrs_type<ReduceAttrs>() |
470 | .set_support_level(4) |
471 | .add_type_rel("Reduce" , ReduceRel) |
472 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>) |
473 | .set_attr<FTVMCompute>("FTVMCompute" , SumCompute) |
474 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
475 | |
476 | Array<te::Tensor> AllCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
477 | const Type& out_type) { |
478 | return ReduceCompute(attrs, inputs, out_type, topi::all); |
479 | } |
480 | |
481 | RELAY_REGISTER_REDUCE_OP("all" ) |
482 | .describe(R"code(Computes the logical AND of boolean array elements over given axes. |
483 | |
484 | Example:: |
485 | |
486 | data = [[[ True, True, True], |
487 | [ True, True, True], |
488 | [False, True, False]], |
489 | [[ True, False, False], |
490 | [ True, True, False], |
491 | [False, True, True]]] |
492 | |
493 | all(data, axis=1) |
494 | [[False, True, False], |
495 | [False, False, False]] |
496 | |
497 | all(data, axis=0) |
498 | [[ True, False, False], |
499 | [ True, True, False], |
500 | [False, True, False]] |
501 | |
502 | )code" TVM_ADD_FILELINE) |
503 | .set_attrs_type<ReduceAttrs>() |
504 | .set_support_level(4) |
505 | .add_type_rel("Reduce" , ReduceRel) |
506 | .set_attr<FTVMCompute>("FTVMCompute" , AllCompute) |
507 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>) |
508 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
509 | |
510 | Array<te::Tensor> AnyCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
511 | const Type& out_type) { |
512 | return ReduceCompute(attrs, inputs, out_type, topi::any); |
513 | } |
514 | |
515 | RELAY_REGISTER_REDUCE_OP("any" ) |
516 | .describe(R"code(Computes the logical OR of boolean array elements over given axes. |
517 | |
518 | Example:: |
519 | |
520 | data = [[[ True, True, True], |
521 | [ True, True, True], |
522 | [False, True, False]], |
523 | [[ True, False, False], |
524 | [ True, True, False], |
525 | [False, True, True]]] |
526 | |
527 | any(data, axis=1) |
528 | [[True, True, True], |
529 | [True, True, True]] |
530 | |
531 | any(data, axis=0) |
532 | [[ True, True, True], |
533 | [ True, True, True], |
534 | [False, True, True]] |
535 | |
536 | )code" TVM_ADD_FILELINE) |
537 | .set_attrs_type<ReduceAttrs>() |
538 | .set_support_level(4) |
539 | .add_type_rel("Reduce" , ReduceRel) |
540 | .set_attr<FTVMCompute>("FTVMCompute" , AnyCompute) |
541 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
542 | |
543 | Array<te::Tensor> MaxCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
544 | const Type& out_type) { |
545 | return ReduceCompute(attrs, inputs, out_type, topi::max); |
546 | } |
547 | |
548 | RELAY_REGISTER_REDUCE_OP("max" ) |
549 | .describe(R"code(Computes the max of array elements over given axes. |
550 | |
551 | )code" TVM_ADD_FILELINE) |
552 | .set_attrs_type<ReduceAttrs>() |
553 | .set_support_level(4) |
554 | .add_type_rel("Reduce" , ReduceRel) |
555 | .set_attr<FTVMCompute>("FTVMCompute" , MaxCompute) |
556 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>) |
557 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
558 | |
559 | Array<te::Tensor> MinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
560 | const Type& out_type) { |
561 | return ReduceCompute(attrs, inputs, out_type, topi::min); |
562 | } |
563 | |
564 | RELAY_REGISTER_REDUCE_OP("min" ) |
565 | .describe(R"code(Computes the min of array elements over given axes. |
566 | |
567 | )code" TVM_ADD_FILELINE) |
568 | .set_attrs_type<ReduceAttrs>() |
569 | .set_support_level(4) |
570 | .add_type_rel("Reduce" , ReduceRel) |
571 | .set_attr<FTVMCompute>("FTVMCompute" , MinCompute) |
572 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>) |
573 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
574 | |
575 | Array<te::Tensor> ProdCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
576 | const Type& out_type) { |
577 | return ReduceCompute(attrs, inputs, out_type, topi::prod); |
578 | } |
579 | |
580 | TVM_REGISTER_GLOBAL("relay.op._make.prod" ).set_body_typed(Prod); |
581 | |
582 | RELAY_REGISTER_OP("prod" ) |
583 | .set_num_inputs(1) |
584 | .add_argument("data" , "Tensor" , "The input tensor." ) |
585 | .describe(R"code(Computes the products of array elements over given axes. |
586 | |
587 | Example:: |
588 | |
589 | data = [[[1,2],[2,3],[1,3]], |
590 | [[1,4],[4,3],[5,2]], |
591 | [[7,1],[7,2],[7,3]]] |
592 | |
593 | prod(data, axis=1) |
594 | [35562240] |
595 | |
596 | prod(data, axis=[1,2]) |
597 | [ 36 480 2058] |
598 | |
599 | )code" TVM_ADD_FILELINE) |
600 | .set_attrs_type<ReduceAttrs>() |
601 | .set_support_level(4) |
602 | .add_type_rel("Reduce" , ReduceRel) |
603 | .set_attr<FTVMCompute>("FTVMCompute" , ProdCompute) |
604 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>) |
605 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
606 | |
607 | Array<te::Tensor> MeanCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
608 | const Type& out_type) { |
609 | auto data = inputs[0]; |
610 | IndexExpr count = tir::make_const(DataType::Int(64), 1); |
611 | const ReduceAttrs* param = attrs.as<ReduceAttrs>(); |
612 | ICHECK(param != nullptr); |
613 | auto axes = param->axis; |
614 | for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, param->exclude)) { |
615 | count *= inputs[0]->shape[i]; |
616 | } |
617 | // Check the datatype of input data. If it's fp16, we'll have trouble representing all |
618 | // indices and summation needed so we instead just cast to fp32. |
619 | bool recast_fp16 = false; |
620 | if (data->dtype.is_float16()) { |
621 | recast_fp16 = true; |
622 | data = topi::cast(data, DataType::Float(32)); |
623 | } |
624 | count = cast(data->dtype, count); |
625 | auto res = ReduceCompute(attrs, {data}, out_type, topi::sum); |
626 | auto output = topi::divide(res[0], count); |
627 | // Set the output back to the appropriate fp16 type if needed. |
628 | if (recast_fp16) { |
629 | output = topi::cast(output, DataType::Float(16)); |
630 | } |
631 | return {output}; |
632 | } |
633 | |
634 | RELAY_REGISTER_REDUCE_OP("mean" ) |
635 | .describe(R"code(Computes the mean of array elements over given axes. |
636 | |
637 | Example:: |
638 | |
639 | data = [[[1,2],[2,3],[1,3]], |
640 | [[1,4],[4,3],[5,2]], |
641 | [[7,1],[7,2],[7,3]]] |
642 | |
643 | mean(data) |
644 | [3.22] |
645 | |
646 | mean(data, axis=[1,2]) |
647 | [ 2. 3.16666667 4.5] |
648 | |
649 | )code" TVM_ADD_FILELINE) |
650 | .set_attrs_type<ReduceAttrs>() |
651 | .set_support_level(4) |
652 | .add_type_rel("Reduce" , ReduceRel) |
653 | .set_attr<FTVMCompute>("FTVMCompute" , MeanCompute) |
654 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>) |
655 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
656 | |
657 | bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
658 | const TypeReporter& reporter) { |
659 | ICHECK_EQ(types.size(), 3); |
660 | const auto* data = types[0].as<TensorTypeNode>(); |
661 | if (data == nullptr) return false; |
662 | ICHECK(static_cast<int>(data->shape.size()) != 0); |
663 | const auto* mean = types[1].as<TensorTypeNode>(); |
664 | if (mean == nullptr) return false; |
665 | |
666 | std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end()); |
667 | std::vector<IndexExpr> mean_shape(mean->shape.begin(), mean->shape.end()); |
668 | ICHECK_EQ(in_shape.size(), mean_shape.size()); |
669 | |
670 | const VarianceAttrs* param = attrs.as<VarianceAttrs>(); |
671 | ICHECK(param != nullptr); |
672 | |
673 | // assign output type and shape |
674 | auto oshape = ReduceShapeImpl(in_shape, param, reporter); |
675 | reporter->Assign(types[2], TensorType(oshape, data->dtype)); |
676 | return true; |
677 | } |
678 | |
679 | Array<te::Tensor> VarianceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
680 | const Type& out_type) { |
681 | IndexExpr count = tir::make_const(DataType::Int(64), 1); |
682 | const VarianceAttrs* param = attrs.as<VarianceAttrs>(); |
683 | ICHECK(param != nullptr); |
684 | auto axes = param->axis; |
685 | bool unbiased = param->unbiased; |
686 | auto data = inputs[0]; |
687 | auto mean = inputs[1]; |
688 | for (int64_t i : GetReduceAxes(data->shape.size(), param->axis, param->exclude)) { |
689 | count *= data->shape[i]; |
690 | } |
691 | if (unbiased) { |
692 | count -= 1; |
693 | } |
694 | std::vector<Integer> expand_shape; |
695 | auto diff = topi::subtract(data, mean); |
696 | auto sq_diff = topi::multiply(diff, diff); |
697 | if (param->exclude) { |
698 | axes = GetExcludeAxes(sq_diff->shape.size(), param->axis); |
699 | ICHECK_NE(axes.size(), 0); |
700 | } |
701 | // If the input is fp16, we might have trouble representing the full sum of |
702 | // indices or values. We recast to fp32 to avoid this issue. |
703 | bool recast_fp16 = false; |
704 | if (data->dtype.is_float16()) { |
705 | recast_fp16 = true; |
706 | sq_diff = topi::cast(sq_diff, DataType::Float(32)); |
707 | } |
708 | auto var = topi::divide(topi::sum(sq_diff, axes, param->keepdims, false), count); |
709 | |
710 | // Recast back to fp16 if needed. |
711 | if (recast_fp16) { |
712 | var = topi::cast(var, DataType::Float(16)); |
713 | } |
714 | |
715 | return {var}; |
716 | } |
717 | |
718 | Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude, |
719 | bool unbiased = false) { |
720 | auto attrs = make_object<VarianceAttrs>(); |
721 | attrs->axis = std::move(axis); |
722 | attrs->keepdims = keepdims; |
723 | attrs->exclude = exclude; |
724 | attrs->unbiased = unbiased; |
725 | static const Op& op = Op::Get("variance" ); |
726 | return Call(op, {data, mean}, Attrs(attrs), {}); |
727 | } |
728 | |
729 | TVM_REGISTER_GLOBAL("relay.op._make._variance" ).set_body_typed(MakeVariance); |
730 | |
731 | RELAY_REGISTER_OP("variance" ) |
732 | .describe(R"code(Computes the variance of array elements over given axes. |
733 | |
734 | )code" TVM_ADD_FILELINE) |
735 | .set_attrs_type<VarianceAttrs>() |
736 | .set_support_level(4) |
737 | .set_num_inputs(2) |
738 | .add_argument("data" , "Tensor" , "The input tensor." ) |
739 | .add_argument("mean" , "Tensor" , "The mean tensor." ) |
740 | .add_type_rel("Variance" , VarianceRel) |
741 | .set_attr<FTVMCompute>("FTVMCompute" , VarianceCompute) |
742 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ReduceInferCorrectLayout<VarianceAttrs>) |
743 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
744 | |
745 | } // namespace relay |
746 | } // namespace tvm |
747 | |