1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief NN op constructions
22 * \file topi/nn.h
23 */
24#ifndef TVM_TOPI_NN_H_
25#define TVM_TOPI_NN_H_
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/te/operation.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/op.h>
31#include <tvm/topi/detail/constant_utils.h>
32#include <tvm/topi/reduction.h>
33#include <tvm/topi/tags.h>
34#include <tvm/topi/transform.h>
35
36#include <algorithm>
37#include <string>
38
39namespace tvm {
40namespace topi {
41
42using namespace tvm::te;
43
44/*!
45 * \brief Creates an operation that performs a rectified linear unit
46 *
47 * \param t The input tensor
48 * \param threshold The relu threshold (default 0)
49 * \param name The name of the operation
50 * \param tag The tag to mark the operation
51 *
52 * \return A Tensor whose op member is the relu operation
53 */
54template <typename T>
55inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast<T>(0),
56 std::string name = "T_relu", std::string tag = kElementWise) {
57 return tvm::te::compute(
58 t->shape,
59 [&](const tvm::Array<tvm::tir::Var>& i) {
60 auto threshold_const = tvm::tir::make_const(t->dtype, threshold);
61 return tvm::max(t(i), threshold_const);
62 },
63 name, tag);
64}
65
66/*!
67 * \brief Creates an operation that performs a leaky rectified linear unit
68 *
69 * \param t The input tensor
70 * \param alpha The slope for the small gradient when t < 0
71 * \param name The name of the operation
72 * \param tag The tag to mark the operation
73 *
74 * \return A Tensor whose op member is the leaky relu operation
75 */
76inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1,
77 std::string name = "T_leaky_relu",
78 std::string tag = kElementWise) {
79 return tvm::te::compute(
80 t->shape,
81 [&](const tvm::Array<tvm::tir::Var>& i) {
82 auto value = t(i);
83 auto calpha = tvm::tir::make_const(value.dtype(), alpha);
84 return tvm::tir::Select(value > 0, value, value * calpha);
85 },
86 name, tag);
87}
88
89/*!
90 * \brief Creates an operation that performs a parametric rectified linear unit
91 *
92 * \param x The input data tensor
93 * \param slope The channel-wise slope tensor
94 * \param axis The axis where the channel data needs to be applied
95 * \param name The name of the operation
96 * \param tag The tag to mark the operation
97 *
98 * \return A Tensor whose op member is the parametric relu operation
99 */
100inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& slope,
101 const int axis = 1, std::string name = "T_prelu",
102 std::string tag = kBroadcast) {
103 ICHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. ";
104 ICHECK(topi::detail::GetConstInt(slope->shape[0]) == topi::detail::GetConstInt(x->shape[axis]))
105 << "Wrong slope shape received.";
106
107 return tvm::te::compute(
108 x->shape,
109 [&](const tvm::Array<tvm::tir::Var>& indices) {
110 auto xval = x(indices);
111 return tvm::tir::Select(xval > 0, xval, xval * slope(indices[axis]));
112 },
113 name, tag);
114}
115
116/*!
117 * \brief Creates an operation that performs padding
118 *
119 * \param t The input tensor
120 * \param pad_before An Array of Expr describing the padding before the
121 * respective iterator
122 * \param pad_after An Array of Expr describing the padding after the
123 * respective iterator
124 * \param pad_value The value to fill padding elements with
125 * \param pad_mode Padding type to use.
126 * "constant" pads with constant_value;
127 * "edge" pads using the edge values of the input array;
128 * "reflect" pads by reflecting values with respect to the edges.
129 * \param dyn_output_shape Output shape of the pad op, default nullptr.
130 * You only need to pass this in if the shape was evaluated dynamically.
131 * \param name The name of the operation
132 * \param tag The tag to mark the operation
133 *
134 * \return A Tensor whose op member is the padding operation
135 *
136 * \note
137 * The pad_after Array must either be empty or have the same length as
138 * pad_before
139 * When pad_after is empty, it takes the same values as pad_before (symmetric
140 * padding)
141 * The pad Array applies from the leading dimensions and skips missing
142 * trailing dimensions:
143 *
144 * pad(t(i, j, k), {1}, {0}) returns the equivalent operation for
145 * the following pseudocode:
146 * for i in [1, t.shape[0] + 2]:
147 * for i in [1, t.shape[0] + 2]:
148 * for i in [1, t.shape[0] + 2]:
149 * name(i,j,k) =
150 * (1 <= i <= t.shape[0] + 1) ?
151 * t(i-1, j, k) : 0;
152 *
153 *
154 */
155inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array<tvm::PrimExpr>& pad_before,
156 tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(),
157 PrimExpr pad_value = PrimExpr(), std::string name = "T_pad",
158 std::string tag = kElementWise, std::string pad_mode = "constant",
159 const Array<PrimExpr>* dyn_output_shape = nullptr) {
160 if (pad_after.size() < pad_before.size()) {
161 for (size_t i = pad_after.size(); i < pad_before.size(); ++i) {
162 pad_after.push_back(pad_before[i]);
163 }
164 }
165
166 arith::Analyzer analyzer;
167 ICHECK_GE(pad_before.size(), 1);
168 ICHECK_EQ(pad_before.size(), pad_after.size());
169 tvm::Array<tvm::PrimExpr> pad_before_int32;
170 tvm::Array<tvm::PrimExpr> pad_after_int32;
171
172 for (const auto& ele : pad_before) {
173 pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
174 }
175 for (const auto& ele : pad_after) {
176 pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
177 }
178
179 tvm::Array<tvm::PrimExpr> output_shape;
180 if (dyn_output_shape == nullptr) {
181 for (size_t i = 0; i < t->shape.size(); ++i) {
182 if (i >= pad_before.size()) {
183 output_shape.push_back(t->shape[i]);
184 } else {
185 output_shape.push_back(
186 analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i]));
187 }
188 }
189 } else {
190 for (size_t i = 0; i < dyn_output_shape->size(); i++) {
191 output_shape.push_back((*dyn_output_shape)[i]);
192 }
193 }
194
195 if (!pad_value.defined()) {
196 pad_value = tvm::tir::make_const(t->dtype, 0);
197 }
198
199 auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
200 tvm::Array<tvm::PrimExpr> indices;
201 tvm::Array<tvm::PrimExpr> sel;
202 tvm::Array<tvm::PrimExpr> pad_idx;
203 for (size_t i = 0; i < t->shape.size(); ++i) {
204 if (i >= pad_before_int32.size()) {
205 indices.push_back(ovars[i]);
206 continue;
207 }
208 if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) {
209 sel.push_back(ovars[i] >= pad_before_int32[i]);
210 indices.push_back(ovars[i] - pad_before_int32[i]);
211 } else {
212 indices.push_back(ovars[i]);
213 }
214 if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
215 sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
216 }
217 if (pad_mode == "edge") {
218 pad_idx.push_back(
219 tvm::if_then_else(ovars[i] < pad_before[i], 0,
220 tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
221 t->shape[i] - 1, ovars[i] - pad_before[i])));
222 } else if (pad_mode == "reflect") {
223 pad_idx.push_back(
224 tvm::if_then_else(ovars[i] < pad_before[i], pad_before[i] - ovars[i],
225 tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
226 t->shape[i] * 2 - ovars[i] + pad_before[i] - 2,
227 ovars[i] - pad_before[i])));
228 }
229 }
230 if (sel.size() != 0) {
231 if (pad_mode == "constant") {
232 return tvm::if_then_else(
233 foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
234 const_true(1), sel),
235 t(indices), pad_value);
236 } else if (pad_mode == "edge" || pad_mode == "reflect") {
237 return tvm::if_then_else(
238 foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); },
239 const_true(1), sel),
240 t(indices), t(pad_idx));
241 }
242 }
243 return t(indices);
244 };
245 return tvm::te::compute(output_shape, l, name, tag);
246}
247
248/*!
249 * \brief Creates an operation that performs a 2-D convolution with an
250 * NCHW-layout
251 *
252 * \param I The 4-D input tensor
253 * \param W The 4-D weight tensor
254 * \param pad_h A static constant padding amount applied to the height of the
255 * image, before and after (symmetric padding)
256 * \param pad_w A static constant padding amount applied to the width of the
257 * image, before and after (symmetric padding)
258 * \param stride_h A static constant striding amount applied to the height of
259 * the image, before and after (symmetric padding)
260 * \param stride_w A static constant strindingamount applied to the width of
261 * the image, before and after (symmetric padding)
262 * \param name The name of the operation
263 * \param tag The tag to mark the operation
264 *
265 * \return A Tensor whose op member is the 2-D convolution operation (NCHW
266 * layout)
267 */
268inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W,
269 int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1,
270 std::string name = "T_conv2d_nchw",
271 std::string tag = kConv2dNCHW) {
272 ICHECK_EQ(4, I->shape.size());
273 ICHECK_EQ(4, W->shape.size());
274 auto pH = I->shape[2];
275 auto pW = I->shape[3];
276 tvm::Array<tvm::PrimExpr> output_shape{
277 I->shape[0], // B
278 W->shape[0], // O
279 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
280 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
281 };
282 auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
283 auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
284 auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
285 auto T =
286 (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
287 auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
288 return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw});
289 };
290 return tvm::te::compute(output_shape, l, name, tag);
291}
292
293/*!
294 * \brief Creates an operation for 2-D convolution layer with an HWCN-layout
295 *
296 * \param I The 4-D input tensor
297 * \param W The 4-D weight tensor
298 * \param pad_h A static constant padding amount applied to the height of the
299 * image, before and after (symmetric padding)
300 * \param pad_w A static constant padding amount applied to the width of the
301 * image, before and after (symmetric padding)
302 * \param stride_h A static constant striding amount applied to the height of
303 * the image, before and after (symmetric padding)
304 * \param stride_w A static constant strindingamount applied to the width of
305 * the image, before and after (symmetric padding)
306 * \param name The name of the operation
307 * \param tag The tag to mark the operation
308 *
309 * \return A Tensor whose op member is the 2-D convolution operation
310 * (HWCN layout)
311 */
312inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, const tvm::te::Tensor& W,
313 int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1,
314 std::string name = "T_conv2d_hwcn",
315 std::string tag = kConv2dHWCN) {
316 ICHECK_EQ(4, I->shape.size());
317 ICHECK_EQ(4, W->shape.size());
318 auto pH = I->shape[2];
319 auto pW = I->shape[3];
320 tvm::Array<tvm::PrimExpr> output_shape{
321 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
322 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W
323 I->shape[2], // B
324 W->shape[3] // O
325 };
326 auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
327 auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
328 auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
329 auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w});
330 auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
331 return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw});
332 };
333 return tvm::te::compute(output_shape, l, name, tag);
334}
335
336/*!
337 * \brief Creates an operation that performs a 2-D depthwise convolution with
338 * an NCHW-layout
339 *
340 * \param I The 4-D input tensor
341 * \param W The 4-D weight tensor
342 * \param pad_h A static constant padding amount applied to the height of the
343 * image, before and after (symmetric padding)
344 * \param pad_w A static constant padding amount applied to the width of the
345 * image, before and after (symmetric padding)
346 * \param stride_h A static constant striding amount applied to the height of
347 * the image, before and after (symmetric padding)
348 * \param stride_w A static constant strindingamount applied to the width of
349 * the image, before and after (symmetric padding)
350 * \param name The name of the operation
351 * \param tag The tag to mark the operation
352 *
353 * \return A Tensor whose op member is the 2-D depthwise convolution operation
354 * (NCHW layout)
355 */
356inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W,
357 int pad_h = 0, int pad_w = 0, int stride_h = 1,
358 int stride_w = 1,
359 std::string name = "T_depthwise_conv2d_nchw",
360 std::string tag = kDepthwiseConv2dNCHW) {
361 ICHECK_EQ(4, I->shape.size());
362 ICHECK_EQ(4, W->shape.size());
363 auto pH = I->shape[2];
364 auto pW = I->shape[3];
365 auto pCM = W->shape[1]; // channel_multiplier
366 tvm::Array<tvm::PrimExpr> output_shape{
367 I->shape[0], // B
368 W->shape[1], // O
369 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H
370 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W
371 };
372 auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i");
373 auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh");
374 auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw");
375 auto T =
376 (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
377 auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) {
378 return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) *
379 W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw),
380 {i, kh, kw});
381 };
382 return tvm::te::compute(output_shape, l, name, tag);
383}
384
385inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, const tvm::te::Tensor& W,
386 int pad_h = 0, int pad_w = 0, int stride_h = 1,
387 int stride_w = 1,
388 std::string name = "T_depthwise_conv2d_nhwc",
389 std::string tag = kDepthwiseConv2dNHWC) {
390 ICHECK_EQ(4, I->shape.size());
391 ICHECK_EQ(4, W->shape.size());
392 auto pH = I->shape[1];
393 auto pW = I->shape[2];
394 auto pCM = W->shape[1]; // channel_multiplier
395 tvm::Array<tvm::PrimExpr> output_shape{
396 I->shape[0], // B
397 indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H
398 indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W
399 W->shape[3], // O
400 };
401 auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i");
402 auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh");
403 auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw");
404 auto T =
405 (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)});
406 auto l = [&](tvm::tir::Var b, tvm::tir::Var h, tvm::tir::Var w, tvm::tir::Var o) {
407 return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) *
408 W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)),
409 {kh, kw, i});
410 };
411 return tvm::te::compute(output_shape, l, name, tag);
412}
413
414/*!
415 * \brief Creates an operation that performs a 2-D group convolution with
416 * an NGCHW-layout
417 *
418 * \param I The 5-D input tensor
419 * \param W The 5-D weight tensor
420 * \param pad_h A static constant padding amount applied to the height of the
421 * image, before and after (symmetric padding)
422 * \param pad_w A static constant padding amount applied to the width of the
423 * image, before and after (symmetric padding)
424 * \param stride_h A static constant striding amount applied to the height of
425 * the image, before and after (symmetric padding)
426 * \param stride_w A static constant strindingamount applied to the width of
427 * the image, before and after (symmetric padding)
428 * \param name The name of the operation
429 * \param tag The tag to mark the operation
430 *
431 * \return A Tensor whose op member is the 2-D groupconvolution operation
432 * (NCHW layout)
433 */
434inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W,
435 int pad_h = 0, int pad_w = 0, int stride_h = 1,
436 int stride_w = 1,
437 std::string name = "T_group_conv2d_ngchw",
438 std::string tag = kGroupConv2d) {
439 ICHECK_EQ(5, I->shape.size());
440 ICHECK_EQ(5, W->shape.size());
441 auto pH = I->shape[2];
442 auto pW = I->shape[3];
443 tvm::Array<tvm::PrimExpr> output_shape{
444 I->shape[0], // B
445 I->shape[1], // G
446 W->shape[2], // O
447 indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H
448 indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W
449 };
450 auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[2]}, "i");
451 auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kh");
452 auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[4]}, "kw");
453
454 auto T = (pad_h == 0 && pad_w == 0)
455 ? I
456 : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w});
457 auto l = [&](tvm::Array<tvm::tir::Var> args) {
458 tvm::tir::Var b = args[0];
459 tvm::tir::Var g = args[1];
460 tvm::tir::Var o = args[2];
461 tvm::tir::Var h = args[3];
462 tvm::tir::Var w = args[4];
463 return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw),
464 {i, kh, kw});
465 };
466 return tvm::te::compute(output_shape, l, name, tag);
467}
468
469/*!
470 * \brief Divide spatial dimensions of the input into a grid of blocks.
471 *
472 * \param data The input tensor.
473 * \param block_shape The size of the spatial block.
474 * \param pad_before The zero-padding size before each spatial dimension.
475 * \param pad_after The zero-padding size after each spatial dimension.
476 * \param pad_value The value used for padding.
477 * \param name The name of the operation.
478 * \param tag The tag to mark the operation.
479 *
480 * \return A Tensor whose op member is the space_to_batch_nd operation
481 */
482inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data,
483 const tvm::Array<Integer>& block_shape,
484 const tvm::Array<tvm::PrimExpr>& pad_before,
485 const tvm::Array<tvm::PrimExpr>& pad_after,
486 PrimExpr pad_value = PrimExpr(),
487 std::string name = "space_to_batch_nd",
488 std::string tag = kInjective) {
489 tvm::te::Tensor padded_t;
490 CHECK_EQ(pad_before.size(), pad_after.size());
491 CHECK_EQ(block_shape.size(), pad_before.size())
492 << "Paddings must be provided for each spatial dimension";
493 tvm::Array<tvm::PrimExpr> pad_before_int32;
494 tvm::Array<tvm::PrimExpr> pad_after_int32;
495
496 // pad size for batch dimension is 0
497 pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
498 pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
499 // insert pad sizes given for spatial dimensions
500 for (const auto& ele : pad_before) {
501 pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
502 }
503 for (const auto& ele : pad_after) {
504 pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
505 }
506
507 // pad the input with paddings provided
508 if (!pad_value.defined()) {
509 pad_value = tvm::tir::make_const(data->dtype, 0);
510 }
511 padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value);
512
513 auto input_shape = data->shape;
514 auto padded_shape = padded_t->shape;
515
516 // infer shapes
517 tvm::Array<PrimExpr> r_shape;
518 tvm::Array<Integer> axis;
519 tvm::Array<PrimExpr> o_shape;
520
521 size_t num_block_dims = block_shape.size();
522 int batch = static_cast<int>(GetConstInt(input_shape[0]));
523 tvm::PrimExpr block_shape_prod(1);
524 r_shape.push_back(batch);
525
526 for (size_t i = 1; i <= num_block_dims; i++) {
527 int padded_input = static_cast<int>(GetConstInt(padded_shape[i]));
528 int block_size = static_cast<int>(GetConstInt(block_shape[i - 1]));
529 CHECK_EQ((padded_input % block_size), 0)
530 << "(" << i
531 << ")th "
532 "Input dimension after padding ("
533 << padded_input << ")"
534 << " must be divisible by its block size (" << block_size << ")";
535
536 r_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
537 r_shape.push_back(block_shape[i - 1]);
538 block_shape_prod *= block_shape[i - 1];
539 axis.push_back(Integer(r_shape.size() - 1)); // index of block_shape[i - 1]
540 }
541
542 size_t n = axis.size();
543 axis.push_back(0); // batch is at index 0
544 // index of (padded_shape[i] / block_shape[i - 1]) in r_shape
545 for (size_t i = 0; i < n; i++) {
546 axis.push_back(static_cast<int>(GetConstInt(axis[i] - 1)));
547 }
548 o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod);
549 for (size_t i = 1; i <= num_block_dims; i++) {
550 o_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
551 }
552 // append remaining shape
553 for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) {
554 r_shape.push_back(input_shape[i]);
555 axis.push_back(Integer(r_shape.size() - 1)); // index of remaining shape in r_shape
556 o_shape.push_back(input_shape[i]);
557 }
558
559 tvm::te::Tensor output = reshape(padded_t, r_shape);
560 output = transpose(output, axis);
561 output = reshape(output, o_shape);
562
563 return output;
564}
565
566/*!
567 * \brief Reshape the batch dimension into spatial dimensions.
568 *
569 * \param data The input tensor.
570 * \param block_shape The size of the spatial block.
571 * \param crop_begin_list The begin crop size for each spatial dimension.
572 * \param crop_end_list The end crop size for each spatial dimension.
573 * \param name The name of the operation.
574 * \param tag The tag to mark the operation.
575 *
576 * \return A Tensor whose op member is the batch_to_space_nd operation
577 */
578inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
579 const tvm::Array<Integer>& block_shape,
580 const tvm::Array<tvm::PrimExpr>& crop_begin_list,
581 const tvm::Array<tvm::PrimExpr>& crop_end_list,
582 std::string name = "batch_to_space_nd",
583 std::string tag = kInjective) {
584 // Construct shapes for reshape and transpose operation
585 Array<PrimExpr> in_shape = data->shape;
586 Array<PrimExpr> r_shape;
587 Array<Integer> axis;
588 size_t num_block_dims = block_shape.size();
589 size_t num_input_dims = in_shape.size();
590 tvm::PrimExpr block_shape_prod(1);
591 int batch = static_cast<int>(GetConstInt(in_shape[0]));
592
593 for (size_t i = 0; i < num_block_dims; i++) {
594 r_shape.push_back(block_shape[i]);
595 block_shape_prod *= block_shape[i];
596 }
597 axis.push_back(Integer(r_shape.size())); // axis of (batch / block_shape_prod)
598 r_shape.push_back(batch / block_shape_prod);
599
600 for (size_t i = 1; i < num_input_dims; i++) {
601 axis.push_back(Integer(r_shape.size())); // axis of in_shape[i]
602 if (axis.size() < (num_block_dims + num_input_dims)) {
603 axis.push_back(Integer(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i]
604 }
605 r_shape.push_back(in_shape[i]);
606 }
607
608 Array<PrimExpr> r_p_shape;
609 r_p_shape.push_back(batch / block_shape_prod);
610 for (size_t i = 1; i <= num_block_dims; i++) {
611 r_p_shape.push_back(in_shape[i] * block_shape[i - 1]);
612 }
613 for (size_t i = num_block_dims + 1; i < num_input_dims; i++) {
614 r_p_shape.push_back(in_shape[i]);
615 }
616
617 tvm::te::Tensor out;
618 out = reshape(data, r_shape);
619 out = transpose(out, axis);
620 out = reshape(out, r_p_shape);
621
622 // Crop the start and end of dimensions of out
623 Array<Integer> begin_idx, end_idx, strides;
624 for (size_t i = 0; i < r_p_shape.size(); ++i) {
625 strides.push_back(Integer(1));
626 if (i > 0 && i <= num_block_dims) {
627 // prepare begin and end index for spatial dimensions
628 int begin_i = static_cast<int>(GetConstInt(crop_begin_list[i - 1]));
629 int end_i = static_cast<int>(GetConstInt(crop_end_list[i - 1]));
630 int out_i = static_cast<int>(GetConstInt(r_p_shape[i]));
631 CHECK_GT(out_i, (begin_i + end_i))
632 << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than"
633 << " output size" << out_i << " vs " << (begin_i + end_i);
634 begin_idx.push_back(begin_i);
635 end_idx.push_back(out_i - end_i);
636 } else {
637 // ignore the batch and remaining dimension
638 begin_idx.push_back(Integer(0));
639 end_idx.push_back(static_cast<int>(GetConstInt(r_p_shape[i])));
640 }
641 }
642
643 out = strided_slice(out, begin_idx, end_idx, strides);
644 return out;
645}
646
647/*!
648 * \brief Negative log likelihood loss.
649 *
650 * \param predictions The prediction tensor.
651 * \param targets The target tensor.
652 * \param weights A manual rescaling weight given to each class.
653 * \param reduction The reduction method to apply to the output.
654 * \param ignore_index The target value to ignore.
655 * \param name The name of the operation.
656 * \param tag The tag to mark the operation.
657 *
658 * \return The negative log likelihood loss of the predictions and targets.
659 */
660inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights,
661 std::string reduction = "mean", int ignore_index = -100,
662 const std::string name = "nll_loss", const std::string tag = kBroadcast) {
663 auto T = tvm::te::compute(
664 targets->shape,
665 [&](const tvm::Array<tvm::tir::Var>& target_indices) {
666 auto c = targets(target_indices);
667 tvm::Array<tvm::PrimExpr> pred_indices;
668 pred_indices.push_back(target_indices[0]); // batch index
669 pred_indices.push_back(c); // class index
670 for (size_t i = 1; i < target_indices.size(); i++) {
671 pred_indices.push_back(target_indices[i]); // indices for multidimensional loss
672 }
673 return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c),
674 tvm::tir::make_const(predictions->dtype, 0));
675 },
676 name, tag);
677 if (reduction == "mean") {
678 auto W = tvm::te::compute(
679 targets->shape,
680 [&](const tvm::Array<tvm::tir::Var>& target_indices) {
681 auto c = targets(target_indices);
682 return tvm::tir::Select(c != ignore_index, weights(c),
683 tvm::tir::make_const(predictions->dtype, 0));
684 },
685 name, tag);
686 return topi::divide(topi::sum(T, {}), topi::sum(W, {}));
687 } else if (reduction == "sum") {
688 return topi::sum(T, {});
689 } else { // reduction == "none"
690 return T;
691 }
692}
693} // namespace topi
694} // namespace tvm
695#endif // TVM_TOPI_NN_H_
696