1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file reduce.cc
22 * \brief Reduction operators.
23 */
24#include <tvm/relay/attrs/reduce.h>
25#include <tvm/relay/expr.h>
26#include <tvm/relay/op.h>
27#include <tvm/topi/elemwise.h>
28#include <tvm/topi/reduction.h>
29
30#include <limits>
31#include <numeric>
32
33#include "../make_op.h"
34#include "../op_common.h"
35#include "../type_relations.h"
36
37namespace tvm {
38namespace relay {
39
40TVM_REGISTER_NODE_TYPE(ReduceAttrs);
41TVM_REGISTER_NODE_TYPE(ArgReduceAttrs);
42TVM_REGISTER_NODE_TYPE(VarianceAttrs);
43
44/*!
45 * \brief GetReduceAxes, get the new axis from indim and other arguments
46 * \param indim Number of dimensions of input data.
47 * \param axis The input axis vector.
48 * \param exclude Whether 'axis' input given is the excluded axis.
49 * \return r_axes The new reduced axes of the output.
50 */
51inline std::vector<int64_t> GetReduceAxes(const uint32_t indim, const Array<Integer>& inaxis,
52 bool exclude) {
53 if (!inaxis.defined() || inaxis.empty()) {
54 std::vector<int64_t> r_axes(indim);
55 std::iota(r_axes.begin(), r_axes.end(), 0);
56 return r_axes;
57 }
58
59 std::vector<int64_t> in_axes;
60 for (auto i : inaxis) {
61 int64_t axis = i->value;
62 if (axis < 0) {
63 axis = axis + indim;
64 }
65
66 // Check out of bounds error
67 ICHECK(axis >= 0) << "Axis out of bounds in reduce operator.";
68 ICHECK(axis < indim) << "Axis out of bounds in reduce operator.";
69 in_axes.push_back(axis);
70 }
71
72 ICHECK(in_axes[in_axes.size() - 1] < indim)
73 << "Reduction axis " << in_axes[in_axes.size() - 1] << " exceeds input dimensions " << indim;
74
75 std::sort(in_axes.begin(), in_axes.end());
76
77 if (!exclude) {
78 return in_axes;
79 }
80
81 auto r_size = indim - in_axes.size();
82 std::vector<int64_t> r_axes(r_size);
83 for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) {
84 if (j < in_axes.size() && in_axes[j] == i) {
85 ++j;
86 continue;
87 }
88 r_axes[k++] = i;
89 }
90 return r_axes;
91}
92
93// Get axis under exclude condition.
94Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
95 ICHECK(inaxis.defined()) << "Cannot set exclude when axis=None";
96 std::vector<bool> axis_flag(indim, true);
97 for (auto i : inaxis) {
98 int64_t axis = i->value;
99 if (axis < 0) {
100 axis = axis + static_cast<int64_t>(indim);
101 }
102 // Check out of bounds error
103 ICHECK_GE(axis, 0) << "Axis out of bounds in reduce operator.";
104 ICHECK_LT(axis, static_cast<int64_t>(indim)) << "Axis out of bounds in reduce operator.";
105 axis_flag[axis] = false;
106 }
107
108 Array<Integer> r_axes;
109
110 for (size_t i = 0; i < axis_flag.size(); ++i) {
111 if (axis_flag[i]) {
112 r_axes.push_back(static_cast<int>(i));
113 }
114 }
115 return r_axes;
116}
117
118// Return the modified layout for AlterOpLayout pass.
119template <typename T>
120InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
121 const Array<Layout>& new_in_layouts,
122 const Array<Layout>& old_in_layouts,
123 const Array<tvm::relay::Type>& old_in_types) {
124 const auto* attrs_ptr = attrs.as<T>();
125 ICHECK(attrs_ptr);
126 ObjectPtr<T> params = make_object<T>(*attrs_ptr);
127
128 // Get the reduce axes.
129 Array<Array<IndexExpr>> old_in_shapes;
130 for (auto old_in_t : old_in_types) {
131 ICHECK(old_in_t.as<TensorTypeNode>());
132 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
133 }
134 uint32_t indim = old_in_shapes[0].size();
135 auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);
136
137 Layout inferred_in = Layout::Undef();
138 Layout inferred_out = Layout::Undef();
139
140 // Infer [in_layout, out_layout, new_r_axes] from old_in_layout or new_in_layout
141 auto infer = [&](const Layout& layout) {
142 // 1) Collect the original axes
143 std::unordered_set<std::string> old_r_dims;
144 for (auto r_axis : r_axes) {
145 old_r_dims.emplace(old_in_layouts[0][r_axis].name());
146 }
147
148 // 2) Collect the new axes by walking new_layout.
149 tvm::Array<tvm::Integer> new_r_axes;
150 std::string inferred_in_string = "";
151 std::string inferred_out_string = "";
152 auto push_new_axis = [&](const std::string& layout_dim, int axis) {
153 if ((old_r_dims.count(layout_dim) && !params->exclude) ||
154 (!old_r_dims.count(layout_dim) && params->exclude)) {
155 new_r_axes.push_back(tvm::Integer(axis));
156 return true;
157 }
158 return false;
159 };
160 for (size_t axis_index = 0; axis_index < layout->axes.size(); ++axis_index) {
161 const auto& layout_axis = LayoutAxis::Get(layout->axes[axis_index]);
162 const std::string& layout_dim = layout_axis.name();
163 if (layout_axis.IsPrimal()) {
164 push_new_axis(layout_dim, axis_index);
165 inferred_in_string += layout_dim;
166 if (!old_r_dims.count(layout_dim) || params->keepdims) {
167 inferred_out_string += layout_dim;
168 }
169 } else {
170 // For example, if the original layout is NCHW, the new layout is NCHW8c, and the original
171 // reduce axes is [1], the new reduce axes become [1, 4].
172 auto primal_dim = layout_axis.ToPrimal().name();
173 auto packed_dim = std::to_string(layout.FactorOf(layout_axis)) + layout_dim;
174 inferred_in_string += packed_dim;
175 if (push_new_axis(primal_dim, axis_index)) {
176 if (params->exclude) {
177 // The primal axis is not reduced, so keep the input packed dim.
178 inferred_out_string += packed_dim;
179 } else if (params->keepdims) {
180 // If the primal axis is part of reduce axes in the original layout, the inner dim
181 // becomes 1 after reduction.
182 inferred_out_string += "1" + layout_dim;
183 }
184 } else {
185 inferred_out_string += packed_dim;
186 }
187 }
188 }
189
190 // 3) Set the new axis and layout.
191 return std::make_tuple(Layout(inferred_in_string), Layout(inferred_out_string), new_r_axes);
192 };
193
194 std::string new_layout_string;
195 Array<Integer> new_r_axes;
196 Array<Layout> new_input_layouts;
197
198 auto check_num_input_layouts = [](Array<Layout> in_layouts) {
199 // The second case is for variance op
200 ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2);
201 };
202
203 if (new_in_layouts.defined() && r_axes.size()) {
204 // Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
205 // modified layout axes.
206 check_num_input_layouts(new_in_layouts);
207 check_num_input_layouts(old_in_layouts);
208
209 // Get inferred_in and inferred_out from new_in_layout.
210 std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]);
211 params->axis = new_r_axes;
212 } else if (old_in_layouts.defined()) {
213 check_num_input_layouts(old_in_layouts);
214
215 // If the new layout is undefined, get inferred_in and inferred_out from old_in_layout.
216 if (old_in_layouts[0].defined()) {
217 std::tie(inferred_in, inferred_out, std::ignore) = infer(old_in_layouts[0]);
218 }
219 }
220
221 new_input_layouts.push_back(inferred_in);
222
223 if (old_in_layouts.size() == 2) {
224 new_input_layouts.push_back(inferred_in);
225 }
226
227 return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, Attrs(params));
228}
229
230template <typename F>
231Array<te::Tensor> ReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
232 const Type& out_type, F f) {
233 const ReduceAttrs* param = attrs.as<ReduceAttrs>();
234 ICHECK(param != nullptr);
235 if (inputs[0]->shape.size() == 0) {
236 return {topi::identity(inputs[0])};
237 }
238 auto axes = param->axis;
239 if (param->exclude) {
240 axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
241 if (axes.size() == 0) {
242 return {topi::identity(inputs[0])};
243 }
244 }
245
246 return {f(inputs[0], axes, param->keepdims, false)};
247}
248
249template <typename F>
250Array<te::Tensor> ArgReduceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
251 const Type& out_type, F f) {
252 const ArgReduceAttrs* param = attrs.as<ArgReduceAttrs>();
253 ICHECK(param != nullptr);
254 if (inputs[0]->shape.size() == 0) {
255 return {topi::identity(inputs[0])};
256 }
257 auto axes = param->axis;
258 if (param->exclude) {
259 axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis);
260 if (axes.size() == 0) {
261 return {topi::identity(inputs[0])};
262 }
263 }
264
265 return {f(inputs[0], axes, param->keepdims, false, param->select_last_index)};
266}
267
268/*!
269 * \brief ReduceShapeImpl get the outshape for the reduction operator
270 * \param in_shape Shape of input data.
271 * \param param Attrs details.
272 * \param reporter The reporter to report solution to.
273 * \return oshape Output shape inferred.
274 * \tparam AttrsType The attribute type.
275 */
276template <typename AttrsType>
277inline std::vector<IndexExpr> ReduceShapeImpl(const std::vector<IndexExpr>& in_shape,
278 const AttrsType* param,
279 const TypeReporter& reporter) {
280 uint32_t indim = in_shape.size();
281 auto r_axes = GetReduceAxes(indim, param->axis, param->exclude);
282 if (!r_axes.size()) {
283 return in_shape;
284 }
285
286 auto max_shape = tir::make_const(DataType::Int(64), 1);
287 bool is_dynamic_input = false;
288 for (int64_t axis : r_axes) {
289 if (in_shape[axis].as<IntImmNode>()) {
290 max_shape *= in_shape[axis];
291 } else {
292 is_dynamic_input = true;
293 break;
294 }
295 }
296
297 if (is_dynamic_input) {
298 ICHECK(reporter->Assert(
299 max_shape < tir::make_const(DataType::Int(64), std::numeric_limits<int32_t>::max())))
300 << "The maximum possible index of reduced shape cannot be more than int32 max.";
301 }
302
303 if (param->keepdims) {
304 std::vector<IndexExpr> oshape(in_shape);
305 for (unsigned i = 0, j = 0; i < indim; ++i) {
306 if (j >= r_axes.size() || !(r_axes[j] == i)) {
307 continue;
308 }
309 oshape[i] = 1;
310 ++j;
311 }
312 return oshape;
313 } else {
314 auto osize = indim - r_axes.size();
315 std::vector<IndexExpr> oshape(osize);
316 for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
317 if (j < r_axes.size() && (r_axes[j] == i)) {
318 ++j;
319 continue;
320 }
321 oshape[k++] = in_shape[i];
322 }
323 return oshape;
324 }
325}
326
327template <class T>
328bool GenericReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
329 const TypeReporter& reporter) {
330 ICHECK_EQ(types.size(), 2);
331 const auto* data = types[0].as<TensorTypeNode>();
332 if (data == nullptr) return false;
333 ICHECK(static_cast<int>(data->shape.size()) != 0);
334 std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end());
335
336 const T* param = attrs.as<T>();
337 ICHECK(param != nullptr);
338
339 // assign output type and shape
340 auto oshape = ReduceShapeImpl(in_shape, param, reporter);
341 reporter->Assign(types[1], TensorType(oshape, data->shape[0].dtype()));
342 return true;
343}
344/*!
345 * \brief ArgReduceRel Output type and shape relation evaluation function.
346 * \param num_inputs Number of input types in the args.
347 * \param attrs The additional attributes of the operator.
348 * \param reporter The reporter to report solution to.
349 * \return false if This relation cannot be resolved. true if this relation has been resolved.
350 */
351bool ArgReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
352 const TypeReporter& reporter) {
353 return GenericReduceRel<ReduceAttrs>(types, num_inputs, attrs, reporter);
354}
355
356/*!
357 * \brief ReduceRel Output type and shape relation evaluation function.
358 * \param num_inputs Number of input types in the args.
359 * \param attrs The additional attributes of the operator.
360 * \param reporter The reporter to report solution to.
361 * \return false if This relation cannot be resolved. true if this relation has been resolved.
362 */
363bool ReduceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
364 const TypeReporter& reporter) {
365 ICHECK_EQ(types.size(), 2);
366 const auto* data = types[0].as<TensorTypeNode>();
367 if (data == nullptr) return false;
368 std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end());
369
370 const ReduceAttrs* param = attrs.as<ReduceAttrs>();
371 ICHECK(param != nullptr);
372
373 // assign output type and shape
374 auto oshape = ReduceShapeImpl(in_shape, param, reporter);
375 reporter->Assign(types[1], TensorType(oshape, data->dtype));
376 return true;
377}
378
379Expr MakeReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, String op_name) {
380 auto attrs = make_object<ReduceAttrs>();
381 attrs->axis = std::move(axis);
382 attrs->keepdims = keepdims;
383 attrs->exclude = exclude;
384 return Call(Op::Get(op_name), {data}, Attrs(attrs), {});
385}
386
387Expr MakeOneElementReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude,
388 bool select_last_index, String op_name) {
389 auto attrs = make_object<ArgReduceAttrs>();
390 attrs->axis = std::move(axis);
391 attrs->keepdims = keepdims;
392 attrs->exclude = exclude;
393 attrs->select_last_index = select_last_index;
394 return Call(Op::Get(op_name), {data}, Attrs(attrs), {});
395}
396
397#define RELAY_REGISTER_REDUCE_OP(OpName) \
398 TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
399 .set_body_typed([](Expr data, Array<Integer> axis, bool keepdims, bool exclude) { \
400 return MakeReduce(data, axis, keepdims, exclude, OpName); \
401 }); \
402 RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.")
403
404#define RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP(OpName) \
405 TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
406 .set_body_typed([](Expr data, Array<Integer> axis, bool keepdims, bool exclude, \
407 bool select_last_index) { \
408 return MakeOneElementReduce(data, axis, keepdims, exclude, select_last_index, OpName); \
409 }); \
410 RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.")
411
412Array<te::Tensor> ArgMaxCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
413 const Type& out_type) {
414 return ArgReduceCompute(attrs, inputs, out_type, topi::argmax);
415}
416
417RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmax")
418 .describe(R"code(Creates an operation that finds the indices of the maximum
419values over a given axis.
420
421)code" TVM_ADD_FILELINE)
422 .set_attrs_type<ArgReduceAttrs>()
423 .set_support_level(4)
424 .add_type_rel("ArgReduce", GenericReduceRel<ArgReduceAttrs>)
425 .set_attr<FTVMCompute>("FTVMCompute", ArgMaxCompute)
426 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ArgReduceAttrs>)
427 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
428
429Array<te::Tensor> ArgMinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
430 const Type& out_type) {
431 return ArgReduceCompute(attrs, inputs, out_type, topi::argmin);
432}
433
434RELAY_REGISTER_ONE_ELEMENT_REDUCE_OP("argmin")
435 .describe(R"code(Creates an operation that finds the indices of the minimum
436values over a given axis.
437
438)code" TVM_ADD_FILELINE)
439 .set_attrs_type<ArgReduceAttrs>()
440 .set_support_level(4)
441 .add_type_rel("ArgReduce", GenericReduceRel<ArgReduceAttrs>)
442 .set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
443 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ArgReduceAttrs>)
444 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
445
446Array<te::Tensor> SumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
447 const Type& out_type) {
448 return ReduceCompute(attrs, inputs, out_type, topi::sum);
449}
450
451RELAY_REGISTER_REDUCE_OP("sum")
452 .describe(R"code(Computes the sum of array elements over given axes.
453
454Example::
455
456 data = [[[1,2],[2,3],[1,3]],
457 [[1,4],[4,3],[5,2]],
458 [[7,1],[7,2],[7,3]]]
459
460 sum(data, axis=1)
461 [[ 4. 8.]
462 [ 10. 9.]
463 [ 21. 6.]]
464
465 sum(data, axis=[1,2])
466 [ 12. 19. 27.]
467
468)code" TVM_ADD_FILELINE)
469 .set_attrs_type<ReduceAttrs>()
470 .set_support_level(4)
471 .add_type_rel("Reduce", ReduceRel)
472 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
473 .set_attr<FTVMCompute>("FTVMCompute", SumCompute)
474 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
475
476Array<te::Tensor> AllCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
477 const Type& out_type) {
478 return ReduceCompute(attrs, inputs, out_type, topi::all);
479}
480
481RELAY_REGISTER_REDUCE_OP("all")
482 .describe(R"code(Computes the logical AND of boolean array elements over given axes.
483
484Example::
485
486 data = [[[ True, True, True],
487 [ True, True, True],
488 [False, True, False]],
489 [[ True, False, False],
490 [ True, True, False],
491 [False, True, True]]]
492
493 all(data, axis=1)
494 [[False, True, False],
495 [False, False, False]]
496
497 all(data, axis=0)
498 [[ True, False, False],
499 [ True, True, False],
500 [False, True, False]]
501
502)code" TVM_ADD_FILELINE)
503 .set_attrs_type<ReduceAttrs>()
504 .set_support_level(4)
505 .add_type_rel("Reduce", ReduceRel)
506 .set_attr<FTVMCompute>("FTVMCompute", AllCompute)
507 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
508 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
509
510Array<te::Tensor> AnyCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
511 const Type& out_type) {
512 return ReduceCompute(attrs, inputs, out_type, topi::any);
513}
514
515RELAY_REGISTER_REDUCE_OP("any")
516 .describe(R"code(Computes the logical OR of boolean array elements over given axes.
517
518Example::
519
520 data = [[[ True, True, True],
521 [ True, True, True],
522 [False, True, False]],
523 [[ True, False, False],
524 [ True, True, False],
525 [False, True, True]]]
526
527 any(data, axis=1)
528 [[True, True, True],
529 [True, True, True]]
530
531 any(data, axis=0)
532 [[ True, True, True],
533 [ True, True, True],
534 [False, True, True]]
535
536)code" TVM_ADD_FILELINE)
537 .set_attrs_type<ReduceAttrs>()
538 .set_support_level(4)
539 .add_type_rel("Reduce", ReduceRel)
540 .set_attr<FTVMCompute>("FTVMCompute", AnyCompute)
541 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
542
543Array<te::Tensor> MaxCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
544 const Type& out_type) {
545 return ReduceCompute(attrs, inputs, out_type, topi::max);
546}
547
548RELAY_REGISTER_REDUCE_OP("max")
549 .describe(R"code(Computes the max of array elements over given axes.
550
551)code" TVM_ADD_FILELINE)
552 .set_attrs_type<ReduceAttrs>()
553 .set_support_level(4)
554 .add_type_rel("Reduce", ReduceRel)
555 .set_attr<FTVMCompute>("FTVMCompute", MaxCompute)
556 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
557 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
558
559Array<te::Tensor> MinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
560 const Type& out_type) {
561 return ReduceCompute(attrs, inputs, out_type, topi::min);
562}
563
564RELAY_REGISTER_REDUCE_OP("min")
565 .describe(R"code(Computes the min of array elements over given axes.
566
567)code" TVM_ADD_FILELINE)
568 .set_attrs_type<ReduceAttrs>()
569 .set_support_level(4)
570 .add_type_rel("Reduce", ReduceRel)
571 .set_attr<FTVMCompute>("FTVMCompute", MinCompute)
572 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
573 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
574
575Array<te::Tensor> ProdCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
576 const Type& out_type) {
577 return ReduceCompute(attrs, inputs, out_type, topi::prod);
578}
579
580TVM_REGISTER_GLOBAL("relay.op._make.prod").set_body_typed(Prod);
581
582RELAY_REGISTER_OP("prod")
583 .set_num_inputs(1)
584 .add_argument("data", "Tensor", "The input tensor.")
585 .describe(R"code(Computes the products of array elements over given axes.
586
587Example::
588
589 data = [[[1,2],[2,3],[1,3]],
590 [[1,4],[4,3],[5,2]],
591 [[7,1],[7,2],[7,3]]]
592
593 prod(data, axis=1)
594 [35562240]
595
596 prod(data, axis=[1,2])
597 [ 36 480 2058]
598
599)code" TVM_ADD_FILELINE)
600 .set_attrs_type<ReduceAttrs>()
601 .set_support_level(4)
602 .add_type_rel("Reduce", ReduceRel)
603 .set_attr<FTVMCompute>("FTVMCompute", ProdCompute)
604 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
605 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
606
607Array<te::Tensor> MeanCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
608 const Type& out_type) {
609 auto data = inputs[0];
610 IndexExpr count = tir::make_const(DataType::Int(64), 1);
611 const ReduceAttrs* param = attrs.as<ReduceAttrs>();
612 ICHECK(param != nullptr);
613 auto axes = param->axis;
614 for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, param->exclude)) {
615 count *= inputs[0]->shape[i];
616 }
617 // Check the datatype of input data. If it's fp16, we'll have trouble representing all
618 // indices and summation needed so we instead just cast to fp32.
619 bool recast_fp16 = false;
620 if (data->dtype.is_float16()) {
621 recast_fp16 = true;
622 data = topi::cast(data, DataType::Float(32));
623 }
624 count = cast(data->dtype, count);
625 auto res = ReduceCompute(attrs, {data}, out_type, topi::sum);
626 auto output = topi::divide(res[0], count);
627 // Set the output back to the appropriate fp16 type if needed.
628 if (recast_fp16) {
629 output = topi::cast(output, DataType::Float(16));
630 }
631 return {output};
632}
633
634RELAY_REGISTER_REDUCE_OP("mean")
635 .describe(R"code(Computes the mean of array elements over given axes.
636
637Example::
638
639 data = [[[1,2],[2,3],[1,3]],
640 [[1,4],[4,3],[5,2]],
641 [[7,1],[7,2],[7,3]]]
642
643 mean(data)
644 [3.22]
645
646 mean(data, axis=[1,2])
647 [ 2. 3.16666667 4.5]
648
649)code" TVM_ADD_FILELINE)
650 .set_attrs_type<ReduceAttrs>()
651 .set_support_level(4)
652 .add_type_rel("Reduce", ReduceRel)
653 .set_attr<FTVMCompute>("FTVMCompute", MeanCompute)
654 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
655 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
656
657bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
658 const TypeReporter& reporter) {
659 ICHECK_EQ(types.size(), 3);
660 const auto* data = types[0].as<TensorTypeNode>();
661 if (data == nullptr) return false;
662 ICHECK(static_cast<int>(data->shape.size()) != 0);
663 const auto* mean = types[1].as<TensorTypeNode>();
664 if (mean == nullptr) return false;
665
666 std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end());
667 std::vector<IndexExpr> mean_shape(mean->shape.begin(), mean->shape.end());
668 ICHECK_EQ(in_shape.size(), mean_shape.size());
669
670 const VarianceAttrs* param = attrs.as<VarianceAttrs>();
671 ICHECK(param != nullptr);
672
673 // assign output type and shape
674 auto oshape = ReduceShapeImpl(in_shape, param, reporter);
675 reporter->Assign(types[2], TensorType(oshape, data->dtype));
676 return true;
677}
678
679Array<te::Tensor> VarianceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
680 const Type& out_type) {
681 IndexExpr count = tir::make_const(DataType::Int(64), 1);
682 const VarianceAttrs* param = attrs.as<VarianceAttrs>();
683 ICHECK(param != nullptr);
684 auto axes = param->axis;
685 bool unbiased = param->unbiased;
686 auto data = inputs[0];
687 auto mean = inputs[1];
688 for (int64_t i : GetReduceAxes(data->shape.size(), param->axis, param->exclude)) {
689 count *= data->shape[i];
690 }
691 if (unbiased) {
692 count -= 1;
693 }
694 std::vector<Integer> expand_shape;
695 auto diff = topi::subtract(data, mean);
696 auto sq_diff = topi::multiply(diff, diff);
697 if (param->exclude) {
698 axes = GetExcludeAxes(sq_diff->shape.size(), param->axis);
699 ICHECK_NE(axes.size(), 0);
700 }
701 // If the input is fp16, we might have trouble representing the full sum of
702 // indices or values. We recast to fp32 to avoid this issue.
703 bool recast_fp16 = false;
704 if (data->dtype.is_float16()) {
705 recast_fp16 = true;
706 sq_diff = topi::cast(sq_diff, DataType::Float(32));
707 }
708 auto var = topi::divide(topi::sum(sq_diff, axes, param->keepdims, false), count);
709
710 // Recast back to fp16 if needed.
711 if (recast_fp16) {
712 var = topi::cast(var, DataType::Float(16));
713 }
714
715 return {var};
716}
717
718Expr MakeVariance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
719 bool unbiased = false) {
720 auto attrs = make_object<VarianceAttrs>();
721 attrs->axis = std::move(axis);
722 attrs->keepdims = keepdims;
723 attrs->exclude = exclude;
724 attrs->unbiased = unbiased;
725 static const Op& op = Op::Get("variance");
726 return Call(op, {data, mean}, Attrs(attrs), {});
727}
728
729TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body_typed(MakeVariance);
730
731RELAY_REGISTER_OP("variance")
732 .describe(R"code(Computes the variance of array elements over given axes.
733
734)code" TVM_ADD_FILELINE)
735 .set_attrs_type<VarianceAttrs>()
736 .set_support_level(4)
737 .set_num_inputs(2)
738 .add_argument("data", "Tensor", "The input tensor.")
739 .add_argument("mean", "Tensor", "The mean tensor.")
740 .add_type_rel("Variance", VarianceRel)
741 .set_attr<FTVMCompute>("FTVMCompute", VarianceCompute)
742 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<VarianceAttrs>)
743 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
744
745} // namespace relay
746} // namespace tvm
747