1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Dilate op constructions
22 * \file nn/dilate.h
23 */
24#ifndef TVM_TOPI_NN_DILATE_H_
25#define TVM_TOPI_NN_DILATE_H_
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/te/operation.h>
29#include <tvm/topi/tags.h>
30
31#include <string>
32
33namespace tvm {
34namespace topi {
35namespace nn {
36
37using namespace tvm::te;
38
39/*!
40 * \brief Create a new expression of the logical and of all
41 * conditions in the arguments.
42 *
43 * \param args The arguments to find the logical conjunction of
44 *
45 * \return The logical conjunction expression
46 */
47PrimExpr all(Array<PrimExpr> args) {
48 ICHECK_GT(args.size(), 0) << "all requires at least one argument";
49
50 PrimExpr ret = args[0];
51 for (size_t i = 1; i < args.size(); ++i) {
52 ret = ret && args[i];
53 }
54 return ret;
55}
56
57/*!
58 * \brief Dilate data with given dilation value (0 by default).
59 *
60 * \param x The input tensor, this can have any number of
61 * dimensions and any layout.
62 * \param strides Dilation stride for each dimension. Stride 1
63 * means no dilation.
64 * \param dilation_value Value used to dilate the input.
65 * \param name The name of the operation
66 * \param tag The tag to mark the operation
67 *
68 * \return The output tensor.
69 */
70inline Tensor dilate(const Tensor& x, Array<PrimExpr> strides, double dilation_value,
71 std::string name = "tensor", std::string tag = kInjective) {
72 auto n = x->shape.size();
73 ICHECK_EQ(n, strides.size()) << "strides size (" << strides.size()
74 << ") must match dimension of x (" << n << ")";
75
76 Array<PrimExpr> out_shape;
77 arith::Analyzer analyzer;
78 for (size_t i = 0; i < n; ++i) {
79 out_shape.push_back(
80 analyzer.Simplify((x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1)));
81 }
82
83 return tvm::te::compute(
84 out_shape,
85 [&](const Array<Var>& indices) {
86 Array<PrimExpr> not_zero;
87 Array<PrimExpr> index_tuple;
88 for (size_t i = 0; i < n; ++i) {
89 if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
90 index_tuple.push_back(indices[i]);
91 } else {
92 index_tuple.push_back(indexdiv(indices[i], strides[i]));
93 not_zero.push_back((indexmod(indices[i], strides[i])) == 0);
94 }
95 }
96 if (not_zero.size() > 0) {
97 auto all_not_zero = all(not_zero);
98 return tvm::if_then_else(all_not_zero, x(index_tuple),
99 make_const(x->dtype, dilation_value));
100 }
101 return x(index_tuple);
102 },
103 name, tag);
104}
105
106} // namespace nn
107} // namespace topi
108} // namespace tvm
109#endif // TVM_TOPI_NN_DILATE_H_
110