1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file pad.cc
22 * \brief Implementation of dynamic pad
23 */
24#include <tvm/relay/attrs/nn.h>
25#include <tvm/relay/op.h>
26#include <tvm/tir/data_layout.h>
27#include <tvm/tir/op.h>
28#include <tvm/topi/nn.h>
29
30#include <vector>
31
32#include "../../make_op.h"
33#include "../../op_common.h"
34
35namespace tvm {
36namespace relay {
37namespace dyn {
38
39// relay.dyn.nn.pad
40
41bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
42 const TypeReporter& reporter) {
43 // types = [data_type, pad_width_type, pad_value_type, ret_type]
44 ICHECK_EQ(types.size(), 4);
45 const auto* data = types[0].as<TensorTypeNode>();
46 if (data == nullptr) return false;
47
48 const auto* pad_width = types[1].as<TensorTypeNode>();
49 if (pad_width == nullptr) return false;
50
51 const auto* pad_value = types[2].as<TensorTypeNode>();
52 if (pad_value == nullptr) return false;
53
54 int data_rank = data->shape.size();
55 ICHECK(data_rank) << "Data shape must have static rank";
56
57 int pad_width_rank = pad_width->shape.size();
58 ICHECK_EQ(pad_width_rank, 2) << "Pad width must be 2D";
59
60 const PadAttrs* param = attrs.as<PadAttrs>();
61 ICHECK(param != nullptr);
62
63 std::vector<IndexExpr> oshape;
64 for (int i = 0; i < data_rank; i++) {
65 oshape.push_back(Any());
66 }
67
68 reporter->Assign(types[3], TensorType(oshape, data->dtype));
69 return true;
70}
71
72Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
73 const Type& out_type) {
74 const auto* param = attrs.as<PadAttrs>();
75 ICHECK(param);
76
77 auto data = inputs[0];
78 auto pad_width = inputs[1];
79
80 te::Tensor cast_pad_value = topi::cast(inputs[2], inputs[0]->dtype);
81 const PrimExpr& pad_value = cast_pad_value(Array<PrimExpr>());
82
83 Array<IndexExpr> pad_before;
84 Array<IndexExpr> pad_after;
85
86 for (int i = 0; i < pad_width->shape[0].as<IntImmNode>()->value; ++i) {
87 pad_before.push_back(pad_width[i][0]);
88 pad_after.push_back(pad_width[i][1]);
89 }
90
91 const auto* out_ttype = out_type.as<TensorTypeNode>();
92 ICHECK(out_ttype != nullptr);
93
94 return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad",
95 topi::kElementWise, param->pad_mode,
96 &out_type.as<TensorTypeNode>()->shape)};
97}
98
99// Handler to create a call to the padding op used by front-end FFI
100Expr MakePad(Expr data, Expr pad_width, Expr pad_value, String pad_mode) {
101 auto attrs = make_object<PadAttrs>();
102 attrs->pad_mode = std::move(pad_mode);
103 static const Op& op = Op::Get("dyn.nn.pad");
104 return Call(op, {data, pad_width, pad_value}, Attrs(attrs), {});
105}
106
107TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.pad").set_body_typed(MakePad);
108
109RELAY_REGISTER_OP("dyn.nn.pad")
110 .describe(R"code(Pad for n-D tensor.
111
112)code" TVM_ADD_FILELINE)
113 .set_attrs_type<PadAttrs>()
114 .set_num_inputs(3)
115 .add_argument("data", "Tensor", "Tensor that will be padded")
116 .add_argument("pad_width", "Tensor", "Tensor of how much to pad by")
117 .add_argument("pad_val", "double", "The value to fill the padded area with")
118 .set_support_level(2)
119 .add_type_rel("DynamicPad", PadRel)
120 .set_attr<TOpPattern>("TOpPattern", kInjective)
121 .set_attr<FTVMCompute>("FTVMCompute", PadCompute);
122
123} // namespace dyn
124} // namespace relay
125} // namespace tvm
126