1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief NN op constructions |
22 | * \file topi/nn.h |
23 | */ |
24 | #ifndef TVM_TOPI_NN_H_ |
25 | #define TVM_TOPI_NN_H_ |
26 | |
27 | #include <tvm/arith/analyzer.h> |
28 | #include <tvm/te/operation.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/op.h> |
31 | #include <tvm/topi/detail/constant_utils.h> |
32 | #include <tvm/topi/reduction.h> |
33 | #include <tvm/topi/tags.h> |
34 | #include <tvm/topi/transform.h> |
35 | |
36 | #include <algorithm> |
37 | #include <string> |
38 | |
39 | namespace tvm { |
40 | namespace topi { |
41 | |
42 | using namespace tvm::te; |
43 | |
44 | /*! |
45 | * \brief Creates an operation that performs a rectified linear unit |
46 | * |
47 | * \param t The input tensor |
48 | * \param threshold The relu threshold (default 0) |
49 | * \param name The name of the operation |
50 | * \param tag The tag to mark the operation |
51 | * |
52 | * \return A Tensor whose op member is the relu operation |
53 | */ |
54 | template <typename T> |
55 | inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast<T>(0), |
56 | std::string name = "T_relu" , std::string tag = kElementWise) { |
57 | return tvm::te::compute( |
58 | t->shape, |
59 | [&](const tvm::Array<tvm::tir::Var>& i) { |
60 | auto threshold_const = tvm::tir::make_const(t->dtype, threshold); |
61 | return tvm::max(t(i), threshold_const); |
62 | }, |
63 | name, tag); |
64 | } |
65 | |
66 | /*! |
67 | * \brief Creates an operation that performs a leaky rectified linear unit |
68 | * |
69 | * \param t The input tensor |
70 | * \param alpha The slope for the small gradient when t < 0 |
71 | * \param name The name of the operation |
72 | * \param tag The tag to mark the operation |
73 | * |
74 | * \return A Tensor whose op member is the leaky relu operation |
75 | */ |
76 | inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, |
77 | std::string name = "T_leaky_relu" , |
78 | std::string tag = kElementWise) { |
79 | return tvm::te::compute( |
80 | t->shape, |
81 | [&](const tvm::Array<tvm::tir::Var>& i) { |
82 | auto value = t(i); |
83 | auto calpha = tvm::tir::make_const(value.dtype(), alpha); |
84 | return tvm::tir::Select(value > 0, value, value * calpha); |
85 | }, |
86 | name, tag); |
87 | } |
88 | |
89 | /*! |
90 | * \brief Creates an operation that performs a parametric rectified linear unit |
91 | * |
92 | * \param x The input data tensor |
93 | * \param slope The channel-wise slope tensor |
94 | * \param axis The axis where the channel data needs to be applied |
95 | * \param name The name of the operation |
96 | * \param tag The tag to mark the operation |
97 | * |
98 | * \return A Tensor whose op member is the parametric relu operation |
99 | */ |
100 | inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& slope, |
101 | const int axis = 1, std::string name = "T_prelu" , |
102 | std::string tag = kBroadcast) { |
103 | ICHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. " ; |
104 | ICHECK(topi::detail::GetConstInt(slope->shape[0]) == topi::detail::GetConstInt(x->shape[axis])) |
105 | << "Wrong slope shape received." ; |
106 | |
107 | return tvm::te::compute( |
108 | x->shape, |
109 | [&](const tvm::Array<tvm::tir::Var>& indices) { |
110 | auto xval = x(indices); |
111 | return tvm::tir::Select(xval > 0, xval, xval * slope(indices[axis])); |
112 | }, |
113 | name, tag); |
114 | } |
115 | |
116 | /*! |
117 | * \brief Creates an operation that performs padding |
118 | * |
119 | * \param t The input tensor |
120 | * \param pad_before An Array of Expr describing the padding before the |
121 | * respective iterator |
122 | * \param pad_after An Array of Expr describing the padding after the |
123 | * respective iterator |
124 | * \param pad_value The value to fill padding elements with |
125 | * \param pad_mode Padding type to use. |
126 | * "constant" pads with constant_value; |
127 | * "edge" pads using the edge values of the input array; |
128 | * "reflect" pads by reflecting values with respect to the edges. |
129 | * \param dyn_output_shape Output shape of the pad op, default nullptr. |
130 | * You only need to pass this in if the shape was evaluated dynamically. |
131 | * \param name The name of the operation |
132 | * \param tag The tag to mark the operation |
133 | * |
134 | * \return A Tensor whose op member is the padding operation |
135 | * |
136 | * \note |
137 | * The pad_after Array must either be empty or have the same length as |
138 | * pad_before |
139 | * When pad_after is empty, it takes the same values as pad_before (symmetric |
140 | * padding) |
141 | * The pad Array applies from the leading dimensions and skips missing |
142 | * trailing dimensions: |
143 | * |
144 | * pad(t(i, j, k), {1}, {0}) returns the equivalent operation for |
145 | * the following pseudocode: |
146 | * for i in [1, t.shape[0] + 2]: |
147 | * for i in [1, t.shape[0] + 2]: |
148 | * for i in [1, t.shape[0] + 2]: |
149 | * name(i,j,k) = |
150 | * (1 <= i <= t.shape[0] + 1) ? |
151 | * t(i-1, j, k) : 0; |
152 | * |
153 | * |
154 | */ |
155 | inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array<tvm::PrimExpr>& pad_before, |
156 | tvm::Array<tvm::PrimExpr> pad_after = tvm::Array<tvm::PrimExpr>(), |
157 | PrimExpr pad_value = PrimExpr(), std::string name = "T_pad" , |
158 | std::string tag = kElementWise, std::string pad_mode = "constant" , |
159 | const Array<PrimExpr>* dyn_output_shape = nullptr) { |
160 | if (pad_after.size() < pad_before.size()) { |
161 | for (size_t i = pad_after.size(); i < pad_before.size(); ++i) { |
162 | pad_after.push_back(pad_before[i]); |
163 | } |
164 | } |
165 | |
166 | arith::Analyzer analyzer; |
167 | ICHECK_GE(pad_before.size(), 1); |
168 | ICHECK_EQ(pad_before.size(), pad_after.size()); |
169 | tvm::Array<tvm::PrimExpr> pad_before_int32; |
170 | tvm::Array<tvm::PrimExpr> pad_after_int32; |
171 | |
172 | for (const auto& ele : pad_before) { |
173 | pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); |
174 | } |
175 | for (const auto& ele : pad_after) { |
176 | pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); |
177 | } |
178 | |
179 | tvm::Array<tvm::PrimExpr> output_shape; |
180 | if (dyn_output_shape == nullptr) { |
181 | for (size_t i = 0; i < t->shape.size(); ++i) { |
182 | if (i >= pad_before.size()) { |
183 | output_shape.push_back(t->shape[i]); |
184 | } else { |
185 | output_shape.push_back( |
186 | analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); |
187 | } |
188 | } |
189 | } else { |
190 | for (size_t i = 0; i < dyn_output_shape->size(); i++) { |
191 | output_shape.push_back((*dyn_output_shape)[i]); |
192 | } |
193 | } |
194 | |
195 | if (!pad_value.defined()) { |
196 | pad_value = tvm::tir::make_const(t->dtype, 0); |
197 | } |
198 | |
199 | auto l = [&](tvm::Array<tvm::tir::Var> ovars) { |
200 | tvm::Array<tvm::PrimExpr> indices; |
201 | tvm::Array<tvm::PrimExpr> sel; |
202 | tvm::Array<tvm::PrimExpr> pad_idx; |
203 | for (size_t i = 0; i < t->shape.size(); ++i) { |
204 | if (i >= pad_before_int32.size()) { |
205 | indices.push_back(ovars[i]); |
206 | continue; |
207 | } |
208 | if (!topi::detail::EqualCheck(pad_before_int32[i], 0)) { |
209 | sel.push_back(ovars[i] >= pad_before_int32[i]); |
210 | indices.push_back(ovars[i] - pad_before_int32[i]); |
211 | } else { |
212 | indices.push_back(ovars[i]); |
213 | } |
214 | if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) { |
215 | sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); |
216 | } |
217 | if (pad_mode == "edge" ) { |
218 | pad_idx.push_back( |
219 | tvm::if_then_else(ovars[i] < pad_before[i], 0, |
220 | tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], |
221 | t->shape[i] - 1, ovars[i] - pad_before[i]))); |
222 | } else if (pad_mode == "reflect" ) { |
223 | pad_idx.push_back( |
224 | tvm::if_then_else(ovars[i] < pad_before[i], pad_before[i] - ovars[i], |
225 | tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], |
226 | t->shape[i] * 2 - ovars[i] + pad_before[i] - 2, |
227 | ovars[i] - pad_before[i]))); |
228 | } |
229 | } |
230 | if (sel.size() != 0) { |
231 | if (pad_mode == "constant" ) { |
232 | return tvm::if_then_else( |
233 | foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); }, |
234 | const_true(1), sel), |
235 | t(indices), pad_value); |
236 | } else if (pad_mode == "edge" || pad_mode == "reflect" ) { |
237 | return tvm::if_then_else( |
238 | foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); }, |
239 | const_true(1), sel), |
240 | t(indices), t(pad_idx)); |
241 | } |
242 | } |
243 | return t(indices); |
244 | }; |
245 | return tvm::te::compute(output_shape, l, name, tag); |
246 | } |
247 | |
248 | /*! |
249 | * \brief Creates an operation that performs a 2-D convolution with an |
250 | * NCHW-layout |
251 | * |
252 | * \param I The 4-D input tensor |
253 | * \param W The 4-D weight tensor |
254 | * \param pad_h A static constant padding amount applied to the height of the |
255 | * image, before and after (symmetric padding) |
256 | * \param pad_w A static constant padding amount applied to the width of the |
257 | * image, before and after (symmetric padding) |
258 | * \param stride_h A static constant striding amount applied to the height of |
259 | * the image, before and after (symmetric padding) |
260 | * \param stride_w A static constant strindingamount applied to the width of |
261 | * the image, before and after (symmetric padding) |
262 | * \param name The name of the operation |
263 | * \param tag The tag to mark the operation |
264 | * |
265 | * \return A Tensor whose op member is the 2-D convolution operation (NCHW |
266 | * layout) |
267 | */ |
268 | inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, |
269 | int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1, |
270 | std::string name = "T_conv2d_nchw" , |
271 | std::string tag = kConv2dNCHW) { |
272 | ICHECK_EQ(4, I->shape.size()); |
273 | ICHECK_EQ(4, W->shape.size()); |
274 | auto pH = I->shape[2]; |
275 | auto pW = I->shape[3]; |
276 | tvm::Array<tvm::PrimExpr> output_shape{ |
277 | I->shape[0], // B |
278 | W->shape[0], // O |
279 | indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H |
280 | indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W |
281 | }; |
282 | auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i" ); |
283 | auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh" ); |
284 | auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw" ); |
285 | auto T = |
286 | (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); |
287 | auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { |
288 | return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw}); |
289 | }; |
290 | return tvm::te::compute(output_shape, l, name, tag); |
291 | } |
292 | |
293 | /*! |
294 | * \brief Creates an operation for 2-D convolution layer with an HWCN-layout |
295 | * |
296 | * \param I The 4-D input tensor |
297 | * \param W The 4-D weight tensor |
298 | * \param pad_h A static constant padding amount applied to the height of the |
299 | * image, before and after (symmetric padding) |
300 | * \param pad_w A static constant padding amount applied to the width of the |
301 | * image, before and after (symmetric padding) |
302 | * \param stride_h A static constant striding amount applied to the height of |
303 | * the image, before and after (symmetric padding) |
304 | * \param stride_w A static constant strindingamount applied to the width of |
305 | * the image, before and after (symmetric padding) |
306 | * \param name The name of the operation |
307 | * \param tag The tag to mark the operation |
308 | * |
309 | * \return A Tensor whose op member is the 2-D convolution operation |
310 | * (HWCN layout) |
311 | */ |
312 | inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, const tvm::te::Tensor& W, |
313 | int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1, |
314 | std::string name = "T_conv2d_hwcn" , |
315 | std::string tag = kConv2dHWCN) { |
316 | ICHECK_EQ(4, I->shape.size()); |
317 | ICHECK_EQ(4, W->shape.size()); |
318 | auto pH = I->shape[2]; |
319 | auto pW = I->shape[3]; |
320 | tvm::Array<tvm::PrimExpr> output_shape{ |
321 | indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H |
322 | indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W |
323 | I->shape[2], // B |
324 | W->shape[3] // O |
325 | }; |
326 | auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i" ); |
327 | auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh" ); |
328 | auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw" ); |
329 | auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w}); |
330 | auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { |
331 | return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw}); |
332 | }; |
333 | return tvm::te::compute(output_shape, l, name, tag); |
334 | } |
335 | |
336 | /*! |
337 | * \brief Creates an operation that performs a 2-D depthwise convolution with |
338 | * an NCHW-layout |
339 | * |
340 | * \param I The 4-D input tensor |
341 | * \param W The 4-D weight tensor |
342 | * \param pad_h A static constant padding amount applied to the height of the |
343 | * image, before and after (symmetric padding) |
344 | * \param pad_w A static constant padding amount applied to the width of the |
345 | * image, before and after (symmetric padding) |
346 | * \param stride_h A static constant striding amount applied to the height of |
347 | * the image, before and after (symmetric padding) |
348 | * \param stride_w A static constant strindingamount applied to the width of |
349 | * the image, before and after (symmetric padding) |
350 | * \param name The name of the operation |
351 | * \param tag The tag to mark the operation |
352 | * |
353 | * \return A Tensor whose op member is the 2-D depthwise convolution operation |
354 | * (NCHW layout) |
355 | */ |
356 | inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, |
357 | int pad_h = 0, int pad_w = 0, int stride_h = 1, |
358 | int stride_w = 1, |
359 | std::string name = "T_depthwise_conv2d_nchw" , |
360 | std::string tag = kDepthwiseConv2dNCHW) { |
361 | ICHECK_EQ(4, I->shape.size()); |
362 | ICHECK_EQ(4, W->shape.size()); |
363 | auto pH = I->shape[2]; |
364 | auto pW = I->shape[3]; |
365 | auto pCM = W->shape[1]; // channel_multiplier |
366 | tvm::Array<tvm::PrimExpr> output_shape{ |
367 | I->shape[0], // B |
368 | W->shape[1], // O |
369 | indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H |
370 | indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W |
371 | }; |
372 | auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i" ); |
373 | auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh" ); |
374 | auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw" ); |
375 | auto T = |
376 | (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); |
377 | auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { |
378 | return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) * |
379 | W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw), |
380 | {i, kh, kw}); |
381 | }; |
382 | return tvm::te::compute(output_shape, l, name, tag); |
383 | } |
384 | |
385 | inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, const tvm::te::Tensor& W, |
386 | int pad_h = 0, int pad_w = 0, int stride_h = 1, |
387 | int stride_w = 1, |
388 | std::string name = "T_depthwise_conv2d_nhwc" , |
389 | std::string tag = kDepthwiseConv2dNHWC) { |
390 | ICHECK_EQ(4, I->shape.size()); |
391 | ICHECK_EQ(4, W->shape.size()); |
392 | auto pH = I->shape[1]; |
393 | auto pW = I->shape[2]; |
394 | auto pCM = W->shape[1]; // channel_multiplier |
395 | tvm::Array<tvm::PrimExpr> output_shape{ |
396 | I->shape[0], // B |
397 | indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H |
398 | indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W |
399 | W->shape[3], // O |
400 | }; |
401 | auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i" ); |
402 | auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh" ); |
403 | auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw" ); |
404 | auto T = |
405 | (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)}); |
406 | auto l = [&](tvm::tir::Var b, tvm::tir::Var h, tvm::tir::Var w, tvm::tir::Var o) { |
407 | return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) * |
408 | W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)), |
409 | {kh, kw, i}); |
410 | }; |
411 | return tvm::te::compute(output_shape, l, name, tag); |
412 | } |
413 | |
414 | /*! |
415 | * \brief Creates an operation that performs a 2-D group convolution with |
416 | * an NGCHW-layout |
417 | * |
418 | * \param I The 5-D input tensor |
419 | * \param W The 5-D weight tensor |
420 | * \param pad_h A static constant padding amount applied to the height of the |
421 | * image, before and after (symmetric padding) |
422 | * \param pad_w A static constant padding amount applied to the width of the |
423 | * image, before and after (symmetric padding) |
424 | * \param stride_h A static constant striding amount applied to the height of |
425 | * the image, before and after (symmetric padding) |
426 | * \param stride_w A static constant strindingamount applied to the width of |
427 | * the image, before and after (symmetric padding) |
428 | * \param name The name of the operation |
429 | * \param tag The tag to mark the operation |
430 | * |
431 | * \return A Tensor whose op member is the 2-D groupconvolution operation |
432 | * (NCHW layout) |
433 | */ |
434 | inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, |
435 | int pad_h = 0, int pad_w = 0, int stride_h = 1, |
436 | int stride_w = 1, |
437 | std::string name = "T_group_conv2d_ngchw" , |
438 | std::string tag = kGroupConv2d) { |
439 | ICHECK_EQ(5, I->shape.size()); |
440 | ICHECK_EQ(5, W->shape.size()); |
441 | auto pH = I->shape[2]; |
442 | auto pW = I->shape[3]; |
443 | tvm::Array<tvm::PrimExpr> output_shape{ |
444 | I->shape[0], // B |
445 | I->shape[1], // G |
446 | W->shape[2], // O |
447 | indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H |
448 | indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W |
449 | }; |
450 | auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[2]}, "i" ); |
451 | auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kh" ); |
452 | auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[4]}, "kw" ); |
453 | |
454 | auto T = (pad_h == 0 && pad_w == 0) |
455 | ? I |
456 | : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); |
457 | auto l = [&](tvm::Array<tvm::tir::Var> args) { |
458 | tvm::tir::Var b = args[0]; |
459 | tvm::tir::Var g = args[1]; |
460 | tvm::tir::Var o = args[2]; |
461 | tvm::tir::Var h = args[3]; |
462 | tvm::tir::Var w = args[4]; |
463 | return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw), |
464 | {i, kh, kw}); |
465 | }; |
466 | return tvm::te::compute(output_shape, l, name, tag); |
467 | } |
468 | |
469 | /*! |
470 | * \brief Divide spatial dimensions of the input into a grid of blocks. |
471 | * |
472 | * \param data The input tensor. |
473 | * \param block_shape The size of the spatial block. |
474 | * \param pad_before The zero-padding size before each spatial dimension. |
475 | * \param pad_after The zero-padding size after each spatial dimension. |
476 | * \param pad_value The value used for padding. |
477 | * \param name The name of the operation. |
478 | * \param tag The tag to mark the operation. |
479 | * |
480 | * \return A Tensor whose op member is the space_to_batch_nd operation |
481 | */ |
482 | inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, |
483 | const tvm::Array<Integer>& block_shape, |
484 | const tvm::Array<tvm::PrimExpr>& pad_before, |
485 | const tvm::Array<tvm::PrimExpr>& pad_after, |
486 | PrimExpr pad_value = PrimExpr(), |
487 | std::string name = "space_to_batch_nd" , |
488 | std::string tag = kInjective) { |
489 | tvm::te::Tensor padded_t; |
490 | CHECK_EQ(pad_before.size(), pad_after.size()); |
491 | CHECK_EQ(block_shape.size(), pad_before.size()) |
492 | << "Paddings must be provided for each spatial dimension" ; |
493 | tvm::Array<tvm::PrimExpr> pad_before_int32; |
494 | tvm::Array<tvm::PrimExpr> pad_after_int32; |
495 | |
496 | // pad size for batch dimension is 0 |
497 | pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0)); |
498 | pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0)); |
499 | // insert pad sizes given for spatial dimensions |
500 | for (const auto& ele : pad_before) { |
501 | pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); |
502 | } |
503 | for (const auto& ele : pad_after) { |
504 | pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); |
505 | } |
506 | |
507 | // pad the input with paddings provided |
508 | if (!pad_value.defined()) { |
509 | pad_value = tvm::tir::make_const(data->dtype, 0); |
510 | } |
511 | padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value); |
512 | |
513 | auto input_shape = data->shape; |
514 | auto padded_shape = padded_t->shape; |
515 | |
516 | // infer shapes |
517 | tvm::Array<PrimExpr> r_shape; |
518 | tvm::Array<Integer> axis; |
519 | tvm::Array<PrimExpr> o_shape; |
520 | |
521 | size_t num_block_dims = block_shape.size(); |
522 | int batch = static_cast<int>(GetConstInt(input_shape[0])); |
523 | tvm::PrimExpr block_shape_prod(1); |
524 | r_shape.push_back(batch); |
525 | |
526 | for (size_t i = 1; i <= num_block_dims; i++) { |
527 | int padded_input = static_cast<int>(GetConstInt(padded_shape[i])); |
528 | int block_size = static_cast<int>(GetConstInt(block_shape[i - 1])); |
529 | CHECK_EQ((padded_input % block_size), 0) |
530 | << "(" << i |
531 | << ")th " |
532 | "Input dimension after padding (" |
533 | << padded_input << ")" |
534 | << " must be divisible by its block size (" << block_size << ")" ; |
535 | |
536 | r_shape.push_back(div(padded_shape[i], block_shape[i - 1])); |
537 | r_shape.push_back(block_shape[i - 1]); |
538 | block_shape_prod *= block_shape[i - 1]; |
539 | axis.push_back(Integer(r_shape.size() - 1)); // index of block_shape[i - 1] |
540 | } |
541 | |
542 | size_t n = axis.size(); |
543 | axis.push_back(0); // batch is at index 0 |
544 | // index of (padded_shape[i] / block_shape[i - 1]) in r_shape |
545 | for (size_t i = 0; i < n; i++) { |
546 | axis.push_back(static_cast<int>(GetConstInt(axis[i] - 1))); |
547 | } |
548 | o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod); |
549 | for (size_t i = 1; i <= num_block_dims; i++) { |
550 | o_shape.push_back(div(padded_shape[i], block_shape[i - 1])); |
551 | } |
552 | // append remaining shape |
553 | for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) { |
554 | r_shape.push_back(input_shape[i]); |
555 | axis.push_back(Integer(r_shape.size() - 1)); // index of remaining shape in r_shape |
556 | o_shape.push_back(input_shape[i]); |
557 | } |
558 | |
559 | tvm::te::Tensor output = reshape(padded_t, r_shape); |
560 | output = transpose(output, axis); |
561 | output = reshape(output, o_shape); |
562 | |
563 | return output; |
564 | } |
565 | |
566 | /*! |
567 | * \brief Reshape the batch dimension into spatial dimensions. |
568 | * |
569 | * \param data The input tensor. |
570 | * \param block_shape The size of the spatial block. |
571 | * \param crop_begin_list The begin crop size for each spatial dimension. |
572 | * \param crop_end_list The end crop size for each spatial dimension. |
573 | * \param name The name of the operation. |
574 | * \param tag The tag to mark the operation. |
575 | * |
576 | * \return A Tensor whose op member is the batch_to_space_nd operation |
577 | */ |
578 | inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, |
579 | const tvm::Array<Integer>& block_shape, |
580 | const tvm::Array<tvm::PrimExpr>& crop_begin_list, |
581 | const tvm::Array<tvm::PrimExpr>& crop_end_list, |
582 | std::string name = "batch_to_space_nd" , |
583 | std::string tag = kInjective) { |
584 | // Construct shapes for reshape and transpose operation |
585 | Array<PrimExpr> in_shape = data->shape; |
586 | Array<PrimExpr> r_shape; |
587 | Array<Integer> axis; |
588 | size_t num_block_dims = block_shape.size(); |
589 | size_t num_input_dims = in_shape.size(); |
590 | tvm::PrimExpr block_shape_prod(1); |
591 | int batch = static_cast<int>(GetConstInt(in_shape[0])); |
592 | |
593 | for (size_t i = 0; i < num_block_dims; i++) { |
594 | r_shape.push_back(block_shape[i]); |
595 | block_shape_prod *= block_shape[i]; |
596 | } |
597 | axis.push_back(Integer(r_shape.size())); // axis of (batch / block_shape_prod) |
598 | r_shape.push_back(batch / block_shape_prod); |
599 | |
600 | for (size_t i = 1; i < num_input_dims; i++) { |
601 | axis.push_back(Integer(r_shape.size())); // axis of in_shape[i] |
602 | if (axis.size() < (num_block_dims + num_input_dims)) { |
603 | axis.push_back(Integer(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i] |
604 | } |
605 | r_shape.push_back(in_shape[i]); |
606 | } |
607 | |
608 | Array<PrimExpr> r_p_shape; |
609 | r_p_shape.push_back(batch / block_shape_prod); |
610 | for (size_t i = 1; i <= num_block_dims; i++) { |
611 | r_p_shape.push_back(in_shape[i] * block_shape[i - 1]); |
612 | } |
613 | for (size_t i = num_block_dims + 1; i < num_input_dims; i++) { |
614 | r_p_shape.push_back(in_shape[i]); |
615 | } |
616 | |
617 | tvm::te::Tensor out; |
618 | out = reshape(data, r_shape); |
619 | out = transpose(out, axis); |
620 | out = reshape(out, r_p_shape); |
621 | |
622 | // Crop the start and end of dimensions of out |
623 | Array<Integer> begin_idx, end_idx, strides; |
624 | for (size_t i = 0; i < r_p_shape.size(); ++i) { |
625 | strides.push_back(Integer(1)); |
626 | if (i > 0 && i <= num_block_dims) { |
627 | // prepare begin and end index for spatial dimensions |
628 | int begin_i = static_cast<int>(GetConstInt(crop_begin_list[i - 1])); |
629 | int end_i = static_cast<int>(GetConstInt(crop_end_list[i - 1])); |
630 | int out_i = static_cast<int>(GetConstInt(r_p_shape[i])); |
631 | CHECK_GT(out_i, (begin_i + end_i)) |
632 | << "Incorrect crop sizes for (" << i << ")th dim, can not crop more than" |
633 | << " output size" << out_i << " vs " << (begin_i + end_i); |
634 | begin_idx.push_back(begin_i); |
635 | end_idx.push_back(out_i - end_i); |
636 | } else { |
637 | // ignore the batch and remaining dimension |
638 | begin_idx.push_back(Integer(0)); |
639 | end_idx.push_back(static_cast<int>(GetConstInt(r_p_shape[i]))); |
640 | } |
641 | } |
642 | |
643 | out = strided_slice(out, begin_idx, end_idx, strides); |
644 | return out; |
645 | } |
646 | |
647 | /*! |
648 | * \brief Negative log likelihood loss. |
649 | * |
650 | * \param predictions The prediction tensor. |
651 | * \param targets The target tensor. |
652 | * \param weights A manual rescaling weight given to each class. |
653 | * \param reduction The reduction method to apply to the output. |
654 | * \param ignore_index The target value to ignore. |
655 | * \param name The name of the operation. |
656 | * \param tag The tag to mark the operation. |
657 | * |
658 | * \return The negative log likelihood loss of the predictions and targets. |
659 | */ |
660 | inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const Tensor& weights, |
661 | std::string reduction = "mean" , int ignore_index = -100, |
662 | const std::string name = "nll_loss" , const std::string tag = kBroadcast) { |
663 | auto T = tvm::te::compute( |
664 | targets->shape, |
665 | [&](const tvm::Array<tvm::tir::Var>& target_indices) { |
666 | auto c = targets(target_indices); |
667 | tvm::Array<tvm::PrimExpr> pred_indices; |
668 | pred_indices.push_back(target_indices[0]); // batch index |
669 | pred_indices.push_back(c); // class index |
670 | for (size_t i = 1; i < target_indices.size(); i++) { |
671 | pred_indices.push_back(target_indices[i]); // indices for multidimensional loss |
672 | } |
673 | return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c), |
674 | tvm::tir::make_const(predictions->dtype, 0)); |
675 | }, |
676 | name, tag); |
677 | if (reduction == "mean" ) { |
678 | auto W = tvm::te::compute( |
679 | targets->shape, |
680 | [&](const tvm::Array<tvm::tir::Var>& target_indices) { |
681 | auto c = targets(target_indices); |
682 | return tvm::tir::Select(c != ignore_index, weights(c), |
683 | tvm::tir::make_const(predictions->dtype, 0)); |
684 | }, |
685 | name, tag); |
686 | return topi::divide(topi::sum(T, {}), topi::sum(W, {})); |
687 | } else if (reduction == "sum" ) { |
688 | return topi::sum(T, {}); |
689 | } else { // reduction == "none" |
690 | return T; |
691 | } |
692 | } |
693 | } // namespace topi |
694 | } // namespace tvm |
695 | #endif // TVM_TOPI_NN_H_ |
696 | |