1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file topi/reduction.h |
22 | * \brief Reduction op constructors |
23 | */ |
24 | #ifndef TVM_TOPI_REDUCTION_H_ |
25 | #define TVM_TOPI_REDUCTION_H_ |
26 | |
27 | #include <tvm/te/operation.h> |
28 | #include <tvm/topi/broadcast.h> |
29 | #include <tvm/topi/detail/constant_utils.h> |
30 | #include <tvm/topi/detail/ravel_unravel.h> |
31 | #include <tvm/topi/elemwise.h> |
32 | #include <tvm/topi/tags.h> |
33 | #include <tvm/topi/transform.h> |
34 | |
35 | #include <algorithm> |
36 | #include <iterator> |
37 | #include <string> |
38 | #include <vector> |
39 | |
40 | namespace tvm { |
41 | namespace topi { |
42 | |
43 | using namespace tvm::te; |
44 | |
45 | /*! \brief The operation to use for CommReduce */ |
46 | using FReduce = std::function<PrimExpr(PrimExpr source, const Array<IterVar>& axis, |
47 | Array<PrimExpr> init, Span span)>; |
48 | |
49 | /*! \brief The operation to use for CommReduceIdx */ |
50 | using FCommReduce = std::function<Array<PrimExpr>(Array<PrimExpr> exprs, const Array<IterVar>& axis, |
51 | PrimExpr* condition)>; |
52 | |
53 | /*! |
54 | * \brief Convert a reduction axis which could be empty or have negative |
55 | * elements into a real axis with valid dimension indices. |
56 | * |
57 | * \param ndim Number of dimensions in the target. |
58 | * \param axis The axis parameter. |
59 | * |
60 | * \return A non-empty sorted array of valid dimension indices, with no duplicates. |
61 | * If the input axis is empty, the result will be an axis including all dimensions. |
62 | * If any input element is negative, it will be treated as an offset from the |
63 | * last dimension (same as python indexing rules). |
64 | */ |
65 | inline std::vector<int> GetRealAxis(int ndim, const Array<Integer>& axis) { |
66 | std::vector<int> real_axis; |
67 | if (!axis.defined() || axis.size() == 0) { |
68 | for (int i = 0; i < ndim; ++i) { |
69 | real_axis.push_back(i); |
70 | } |
71 | } else { |
72 | // Use a set so duplicates are removed and the dims are sorted |
73 | for (auto elem : axis) { |
74 | int64_t val = elem->value; |
75 | if (val < 0) { |
76 | val += ndim; |
77 | } |
78 | ICHECK_LE(val, ndim) << " exceeds the maximum dimension " << ndim; |
79 | ICHECK_GE(val, 0); |
80 | real_axis.push_back(static_cast<int>(val)); |
81 | } |
82 | std::sort(real_axis.begin(), real_axis.end()); |
83 | real_axis.resize(std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin()); |
84 | } |
85 | return real_axis; |
86 | } |
87 | |
88 | /*! \brief Enumerate the axes for a reduce op */ |
89 | inline Array<IterVar> MakeReduceAxes(const std::vector<int>& real_axis, const Tensor& data) { |
90 | Array<IterVar> reduce_axes; |
91 | for (auto i : real_axis) { |
92 | std::string name = "k" + std::to_string(i); |
93 | reduce_axes.push_back(tvm::te::reduce_axis(Range(0, data->shape[i]), name)); |
94 | } |
95 | return reduce_axes; |
96 | } |
97 | |
98 | /*! \brief Calculate the target shape for a reduce op */ |
99 | inline Array<PrimExpr> MakeReduceTargetShape(const std::vector<int>& real_axis, const Tensor& data, |
100 | bool keepdims, bool atleast1d) { |
101 | auto ndim = data->shape.size(); |
102 | Array<PrimExpr> target_shape; |
103 | if (keepdims) { |
104 | for (size_t i = 0; i < ndim; ++i) { |
105 | if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { |
106 | // real_axis contains i |
107 | target_shape.push_back(1); |
108 | } else { |
109 | target_shape.push_back(data->shape[i]); |
110 | } |
111 | } |
112 | } else { |
113 | for (size_t i = 0; i < ndim; ++i) { |
114 | if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) { |
115 | // real_axis does not contain i |
116 | target_shape.push_back(data->shape[i]); |
117 | } |
118 | } |
119 | } |
120 | if (target_shape.size() == 0 && atleast1d) { |
121 | target_shape.push_back(1); |
122 | } |
123 | return target_shape; |
124 | } |
125 | |
126 | /*! |
127 | * \brief Create a reduction operation. |
128 | * |
129 | * \param data The input tensor. |
130 | * \param func The reduction function eg. tvm::sum |
131 | * \param target_shape The output Tensor shape. |
132 | * \param reduce_axes The real axes along which the reduction is performed. |
133 | * \param squeeze_axes The real axes to squeeze. Unsqueezed, reduced axes will |
134 | * have shape 1 in the output tensor. |
135 | * \param span The location of this reducer in the source. |
136 | * |
137 | * \return The result tensor. |
138 | */ |
139 | inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array<PrimExpr>& target_shape, |
140 | const std::vector<int>& reduce_axes, |
141 | const std::vector<int>& squeeze_axes, Span span = Span()) { |
142 | auto r_axes = MakeReduceAxes(reduce_axes, data); |
143 | auto compute = [&](const Array<Var>& indices) { |
144 | Array<PrimExpr> eval_range; |
145 | Array<Var> eval_indices; |
146 | int arg_counter = 0; |
147 | int red_counter = 0; |
148 | |
149 | for (size_t i = 0; i < data->shape.size(); ++i) { |
150 | bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != squeeze_axes.end(); |
151 | if (std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end()) { |
152 | // real_axis contains i |
153 | eval_range.push_back(r_axes[red_counter]); |
154 | eval_indices.push_back(r_axes[red_counter]->var); |
155 | red_counter++; |
156 | arg_counter += !squeeze_i; |
157 | continue; |
158 | } |
159 | eval_range.push_back(indices[arg_counter]); |
160 | arg_counter++; |
161 | } |
162 | |
163 | return func(data(eval_range), r_axes, {}, span); |
164 | }; |
165 | |
166 | return tvm::te::compute(target_shape, compute, data->op->name + "_red" , kCommReduce); |
167 | } |
168 | |
169 | /*! |
170 | * \brief Create a reduction operation. |
171 | * |
172 | * \param data The input tensor. |
173 | * \param axis The axes along which the reduction is performed. |
174 | * \param func The reduction function eg. tvm::sum |
175 | * \param keepdims If this is set to true, the axes which are reduced are |
176 | * left in the result as dimensions with size one. This enables the result |
177 | * to broadcast correctly against the input array. |
178 | * \param atleast1d Whether the output need to be atleast1d. |
179 | * |
180 | * \return The result tensor. |
181 | */ |
182 | inline Tensor CommReduce(const Tensor& data, const Array<Integer>& axis, FReduce func, |
183 | bool keepdims, bool atleast1d) { |
184 | auto ndim = data->shape.size(); |
185 | ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor" ; |
186 | auto real_axis = GetRealAxis(static_cast<int>(ndim), axis); |
187 | auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); |
188 | return DoCommReduce(data, func, target_shape, real_axis, |
189 | keepdims ? std::vector<int>() : real_axis); |
190 | } |
191 | |
192 | /*! |
193 | * \brief Create an index reduction operation. |
194 | * |
195 | * \param data The input tensor. |
196 | * \param axis The axes along which the reduction is performed. |
197 | * \param func The reduction function |
198 | * \param keepdims If this is set to true, the axes which are reduced are |
199 | * left in the result as dimensions with size one. This enables the result |
200 | * to broadcast correctly against the input array. |
201 | * \param atleast1d Whether the output need to be atleast1d. |
202 | * |
203 | * \return The result tensor. |
204 | */ |
205 | inline Tensor CommReduceIdx(const Tensor& data, const Array<Integer>& axis, FCommReduce func, |
206 | bool keepdims, bool atleast1d) { |
207 | auto ndim = data->shape.size(); |
208 | ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor" ; |
209 | auto real_axis = GetRealAxis(static_cast<int>(ndim), axis); |
210 | auto reduce_axes = MakeReduceAxes(real_axis, data); |
211 | auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); |
212 | |
213 | auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, |
214 | &data](const Array<Var>& indices) { |
215 | Array<PrimExpr> eval_range; |
216 | Array<PrimExpr> eval_indices; |
217 | int arg_counter = 0; |
218 | int red_counter = 0; |
219 | |
220 | for (size_t i = 0; i < ndim; ++i) { |
221 | if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) { |
222 | // real_axis contains i |
223 | eval_range.push_back(reduce_axes[red_counter]); |
224 | eval_indices.push_back(reduce_axes[red_counter]->var); |
225 | red_counter++; |
226 | } else { |
227 | if (!keepdims) { |
228 | eval_range.push_back(indices[arg_counter]); |
229 | arg_counter++; |
230 | } else { |
231 | eval_range.push_back(indices[i]); |
232 | } |
233 | } |
234 | } |
235 | |
236 | Array<PrimExpr> ravel_shape; |
237 | for (auto i : real_axis) { |
238 | ravel_shape.push_back(data->shape[i]); |
239 | } |
240 | auto idx = detail::RavelIndex(eval_indices, ravel_shape); |
241 | return func({idx, data(eval_range)}, reduce_axes, nullptr); |
242 | }; |
243 | |
244 | auto temp_idx_val = |
245 | tvm::te::compute(target_shape, compute, data->op->name + "_red_temp" , kCommReduceIdx); |
246 | auto temp_idx = temp_idx_val[0]; |
247 | auto temp_val = temp_idx_val[1]; |
248 | return tvm::te::compute( |
249 | target_shape, [&temp_idx](const Array<Var>& indices) { return temp_idx(indices); }, |
250 | data->op->name + "_red" , kCommReduceIdx); |
251 | } |
252 | |
253 | /*! \brief A combiner function for a reduction */ |
254 | using FCombine = std::function<Array<PrimExpr>(Array<Var> lhs, Array<Var> rhs)>; |
255 | |
256 | /*! \brief An initializer function for a reduction */ |
257 | using FIdentity = std::function<Array<PrimExpr>(std::vector<DataType> types)>; |
258 | |
259 | /*! |
260 | * \brief Create a commutative reducer for a reduction |
261 | * |
262 | * \param fcombine A function to combine exprs |
263 | * \param fidentity A function to initialize elements |
264 | * \param name The name of the operation |
265 | * |
266 | * \return A reducer function which creates a reduce expression over an axis. |
267 | */ |
268 | inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, |
269 | std::string name = "reduce" ) { |
270 | return [fcombine, fidentity, name](Array<PrimExpr> exprs, const Array<IterVar>& axis, |
271 | PrimExpr* condition) { |
272 | Array<Var> lhs, rhs; |
273 | std::vector<DataType> dtypes; |
274 | |
275 | for (size_t i = 0; i < exprs.size(); ++i) { |
276 | auto dtype = exprs[i].dtype(); |
277 | dtypes.push_back(dtype); |
278 | lhs.push_back(var(name + "_lhs_" + std::to_string(i), dtype)); |
279 | rhs.push_back(var(name + "_rhs_" + std::to_string(i), dtype)); |
280 | } |
281 | |
282 | auto result = fcombine(lhs, rhs); |
283 | auto id_elem = fidentity(dtypes); |
284 | auto cond = condition != nullptr ? *condition : tir::const_true(); |
285 | |
286 | auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem); |
287 | Array<PrimExpr> outputs; |
288 | for (size_t i = 0; i < exprs.size(); ++i) { |
289 | outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast<int>(i), {})); |
290 | } |
291 | return outputs; |
292 | }; |
293 | } |
294 | |
295 | /*! \brief Wrap tvm::min to ensure we get the correct overload */ |
296 | inline PrimExpr MinOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}, |
297 | Span span = Span()) { |
298 | return tvm::min(source, axis, init, span); |
299 | } |
300 | |
301 | /*! \brief Wrap tvm::max to ensure we get the correct overload */ |
302 | inline PrimExpr MaxOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}, |
303 | Span span = Span()) { |
304 | return tvm::max(source, axis, init, span); // NOLINT(*) |
305 | } |
306 | |
307 | /*! \brief Wrap tvm::prod to ensure we get the correct overload */ |
308 | inline PrimExpr ProdOp(PrimExpr source, Array<IterVar> axis, Array<PrimExpr> init = {}, |
309 | Span span = Span()) { |
310 | return tvm::prod(source, axis, init, span); // NOLINT(*) |
311 | } |
312 | |
313 | /*! |
314 | * \brief Creates an operation that sums array elements over a given axis |
315 | * |
316 | * \param data The input tensor |
317 | * \param axis The axis to sum over. If axis is empty, the operation will |
318 | * sum over all elements of the array. |
319 | * \param keepdims If this is set to true, the axes which are reduced are |
320 | * left in the result as dimensions with size one. This enables the result |
321 | * to broadcast correctly against the input array. |
322 | * \param atleast1d Whether the output need to be atleast1d. |
323 | * |
324 | * \return A Tensor whose op member is the sum operation |
325 | */ |
326 | inline Tensor sum(const Tensor& data, const Array<Integer>& axis, bool keepdims = false, |
327 | bool atleast1d = false) { |
328 | return CommReduce(data, axis, tvm::sum, keepdims, atleast1d); |
329 | } |
330 | |
331 | inline Tensor collapse_sum(const Tensor& data, Array<PrimExpr> target_shape) { |
332 | ICHECK_GE(data->shape.size(), target_shape.size()); |
333 | auto ishape = detail::GetConstIntValues(data->shape, "ishape" ); |
334 | auto oshape = detail::GetConstIntValues(target_shape, "oshape" ); |
335 | |
336 | std::vector<int> reduce_axes; |
337 | std::vector<int> squeeze_axes; |
338 | for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) { |
339 | if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) { |
340 | --o_ax; |
341 | continue; |
342 | } |
343 | reduce_axes.push_back(i_ax); |
344 | if (o_ax < 0) { // squeeze o_ax if was added during expansion |
345 | squeeze_axes.push_back(i_ax); |
346 | } else if (oshape[o_ax] == 1) { |
347 | --o_ax; |
348 | } |
349 | } |
350 | |
351 | if (reduce_axes.size() == 0) return topi::identity(data, "tensor" , kCommReduce); |
352 | |
353 | std::reverse(reduce_axes.begin(), reduce_axes.end()); |
354 | std::reverse(squeeze_axes.begin(), squeeze_axes.end()); |
355 | return DoCommReduce(data, tvm::sum, target_shape, reduce_axes, squeeze_axes); |
356 | } |
357 | |
358 | /*! |
359 | * \brief Creates an operation that computes the logical AND of elements |
360 | * over a given axis |
361 | * |
362 | * \param data The input boolean tensor |
363 | * \param axis The axes to reduce. If axis is empty, the operation will |
364 | * perform logical AND over all elements of the array. |
365 | * \param keepdims If this is set to true, the axes which are reduced are |
366 | * left in the result as dimensions with size one. This enables the result |
367 | * to broadcast correctly against the input array. |
368 | * \param atleast1d Whether the output need to be atleast1d. |
369 | * |
370 | * \return A Tensor whose op member is the all operation |
371 | */ |
372 | inline Tensor all(const Tensor& data, const Array<Integer>& axis, bool keepdims = false, |
373 | bool atleast1d = false) { |
374 | return CommReduce(data, axis, tvm::all, keepdims, atleast1d); |
375 | } |
376 | |
377 | /*! |
378 | * \brief Creates an operation that computes the logical OR of elements |
379 | * over a given axis |
380 | * |
381 | * \param data The input boolean tensor |
382 | * \param axis The axes to reduce. If axis is empty, the operation will |
383 | * perform logical OR over all elements of the array. |
384 | * \param keepdims If this is set to true, the axes which are reduced are |
385 | * left in the result as dimensions with size one. This enables the result |
386 | * to broadcast correctly against the input array. |
387 | * \param atleast1d Whether the output need to be atleast1d. |
388 | * |
389 | * \return A Tensor whose op member is the all operation |
390 | */ |
391 | inline Tensor any(const Tensor& data, const Array<Integer>& axis, bool keepdims = false, |
392 | bool atleast1d = false) { |
393 | return CommReduce(data, axis, tvm::any, keepdims, atleast1d); |
394 | } |
395 | |
396 | /*! |
397 | * \brief Creates an operation that finds the minimum of elements over |
398 | * a given axis. |
399 | * |
400 | * \param data The input tensor |
401 | * \param axis The axis to find the minimum over. If axis is empty, the |
402 | * operation will find the minimum over all elements of the array. |
403 | * \param keepdims If this is set to true, the axes which are reduced are |
404 | * left in the result as dimensions with size one. This enables the result |
405 | * to broadcast correctly against the input array. |
406 | * \param atleast1d Whether the output need to be atleast1d. |
407 | * |
408 | * \return A Tensor whose op member is the min operation |
409 | */ |
410 | inline Tensor min(const Tensor& data, const Array<Integer>& axis, bool keepdims = false, |
411 | bool atleast1d = false) { |
412 | return CommReduce(data, axis, MinOp, keepdims, atleast1d); |
413 | } |
414 | |
415 | /*! |
416 | * \brief Creates an operation that finds the maximum of elements over |
417 | * a given axis. |
418 | * |
419 | * \param data The input tensor |
420 | * \param axis The axis to find the maximum over. If axis is empty, the |
421 | * operation will find the maximum over all elements of the array. |
422 | * \param keepdims If this is set to true, the axes which are reduced are |
423 | * left in the result as dimensions with size one. This enables the result |
424 | * to broadcast correctly against the input array. |
425 | * \param atleast1d Whether the output need to be atleast1d. |
426 | * |
427 | * \return A Tensor whose op member is the max operation |
428 | */ |
429 | inline Tensor max(const Tensor& data, const Array<Integer>& axis, bool keepdims = false, |
430 | bool atleast1d = false) { |
431 | return CommReduce(data, axis, MaxOp, keepdims, atleast1d); |
432 | } |
433 | |
434 | inline FCommReduce MakeArgminReducer(bool select_last_index = false) { |
435 | // Create a Commutative Reducer with a comparison operation, and method to get the initial value. |
436 | auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) { |
437 | Array<PrimExpr> result; |
438 | |
439 | // Casting to avoid operator ambiguity |
440 | PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]); |
441 | PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]); |
442 | PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]); |
443 | PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]); |
444 | |
445 | // These variables compare the actual values of the array |
446 | auto is_smaller = lhs_val < rhs_val; |
447 | auto is_same = lhs_val == rhs_val; |
448 | |
449 | // This checks if the indices are correct for the reduction. E.g. for select_last_index |
450 | // it gives precedence for later indices of the same element and precedence for sooner |
451 | // indices if not select_last_index; |
452 | PrimExpr proper_index; |
453 | if (select_last_index) { |
454 | proper_index = lhs_idx > rhs_idx; |
455 | } else { |
456 | proper_index = lhs_idx < rhs_idx; |
457 | } |
458 | |
459 | PrimExpr update_index = is_smaller || (is_same && proper_index); |
460 | result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx |
461 | result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val |
462 | return result; |
463 | }; |
464 | auto fidentity = [&](std::vector<DataType> types) { |
465 | Array<PrimExpr> result; |
466 | result.push_back(tvm::tir::make_const(types[0], -1)); // idx |
467 | result.push_back(tvm::max_value(types[1])); // val |
468 | return result; |
469 | }; |
470 | return MakeCommReducer(fcombine, fidentity, "argmin" ); |
471 | } |
472 | |
473 | /*! |
474 | * \brief Creates an operation that finds the indices of the minimum |
475 | * values over a given axis. |
476 | * |
477 | * \param data The input tensor |
478 | * \param axis The axis along which the argmin is performed. If axis is empty, |
479 | * the operation will find the minimum index over all elements of the array. |
480 | * \param keepdims If this is set to true, the axes which are reduced are |
481 | * left in the result as dimensions with size one. This enables the result |
482 | * to broadcast correctly against the input array. |
483 | * \param atleast1d Whether the output need to be atleast1d. |
484 | * \param select_last_index Whether to select the last index if the minimum element |
485 | * appears multiple times, else select the first index. |
486 | * |
487 | * \return A Tensor whose op member is the argmin operation |
488 | */ |
489 | inline Tensor argmin(const Tensor& data, const Array<Integer>& axis, bool keepdims = false, |
490 | bool atleast1d = false, bool select_last_index = false) { |
491 | auto reducer = MakeArgminReducer(select_last_index); |
492 | return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); |
493 | } |
494 | |
495 | inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { |
496 | // Create a Commutative Reducer with a comparison operation, and method to get the initial value. |
497 | auto fcombine = [=](Array<Var> lhs, Array<Var> rhs) { |
498 | Array<PrimExpr> result; |
499 | |
500 | // Casting to avoid operator ambiguity |
501 | PrimExpr lhs_idx = static_cast<PrimExpr>(lhs[0]); |
502 | PrimExpr rhs_idx = static_cast<PrimExpr>(rhs[0]); |
503 | PrimExpr lhs_val = static_cast<PrimExpr>(lhs[1]); |
504 | PrimExpr rhs_val = static_cast<PrimExpr>(rhs[1]); |
505 | |
506 | // These variables compare the actual values of the array |
507 | auto is_bigger = lhs_val > rhs_val; |
508 | auto is_same = lhs_val == rhs_val; |
509 | |
510 | // This checks if the indices are correct for the reduction. E.g. for select_last_index |
511 | // it gives precedence for later indices of the same element and precedence for sooner |
512 | // indices if not select_last_index; |
513 | PrimExpr proper_index; |
514 | if (select_last_index) { |
515 | proper_index = lhs_idx > rhs_idx; |
516 | } else { |
517 | proper_index = lhs_idx < rhs_idx; |
518 | } |
519 | |
520 | PrimExpr update_index = is_bigger || (is_same && proper_index); |
521 | result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx |
522 | result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val |
523 | return result; |
524 | }; |
525 | auto fidentity = [&](std::vector<DataType> types) { |
526 | Array<PrimExpr> result; |
527 | result.push_back(tvm::tir::make_const(types[0], -1)); // idx |
528 | result.push_back(tvm::min_value(types[1])); // val |
529 | return result; |
530 | }; |
531 | return MakeCommReducer(fcombine, fidentity, "argmax" ); |
532 | } |
533 | |
534 | /*! |
535 | * \brief Creates an operation that finds the indices of the maximum |
536 | * values over a given axis. |
537 | * |
538 | * \param data The input tensor |
539 | * \param axis The axis along which the argmax is performed. If axis is empty, |
540 | * the operation will find the maximum index over all elements of the array. |
541 | * \param keepdims If this is set to true, the axes which are reduced are |
542 | * left in the result as dimensions with size one. This enables the result |
543 | * to broadcast correctly against the input array. |
544 | * \param atleast1d Whether the output need to be atleast1d. |
545 | * \param select_last_index Whether to select the last index if the maximum element |
546 | * appears multiple times, else select the first index. |
547 | * \return A Tensor whose op member is the argmax operation |
548 | */ |
549 | inline Tensor argmax(const Tensor& data, const Array<Integer>& axis, bool keepdims = false, |
550 | bool atleast1d = false, bool select_last_index = false) { |
551 | auto reducer = MakeArgmaxReducer(select_last_index); |
552 | return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); |
553 | } |
554 | |
555 | /*! |
556 | * \brief Creates product operation over given axis. |
557 | * |
558 | * \param data The input tensor |
559 | * \param axis The axis to do product over. If axis is empty, the |
560 | * operation will do the product over all elements of the array. |
561 | * \param keepdims If this is set to true, the axes which are reduced are |
562 | * left in the result as dimensions with size one. This enables the result |
563 | * to broadcast correctly against the input array. |
564 | * \param atleast1d Whether the output need to be atleast1d. |
565 | * |
566 | * \return A Tensor whose op member is the prod operation |
567 | */ |
568 | inline Tensor prod(const Tensor& data, const Array<Integer>& axis, bool keepdims = false, |
569 | bool atleast1d = false) { |
570 | return CommReduce(data, axis, ProdOp, keepdims, atleast1d); |
571 | } |
572 | |
573 | /*! |
574 | * \brief Create communitive reducer summing over tuples |
575 | */ |
576 | inline FCommReduce MakeTupleSumReducer() { |
577 | auto fcombine = [](Array<Var> lhs, Array<Var> rhs) { |
578 | Array<PrimExpr> result; |
579 | ICHECK_EQ(lhs.size(), rhs.size()); |
580 | result.reserve(lhs.size()); |
581 | for (size_t i = 0; i < lhs.size(); ++i) { |
582 | result.push_back(lhs[i] + rhs[i]); |
583 | } |
584 | return result; |
585 | }; |
586 | auto fidentity = [](std::vector<DataType> types) { |
587 | Array<PrimExpr> result; |
588 | for (size_t i = 0; i < types.size(); ++i) { |
589 | result.push_back(tvm::tir::make_const(types[i], 0)); |
590 | } |
591 | return result; |
592 | }; |
593 | return MakeCommReducer(fcombine, fidentity, "tuple_sum" ); |
594 | } |
595 | |
596 | } // namespace topi |
597 | } // namespace tvm |
598 | #endif // TVM_TOPI_REDUCTION_H_ |
599 | |