1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file pad.cc |
22 | * \brief Implementation of dynamic pad |
23 | */ |
24 | #include <tvm/relay/attrs/nn.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/tir/data_layout.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/topi/nn.h> |
29 | |
30 | #include <vector> |
31 | |
32 | #include "../../make_op.h" |
33 | #include "../../op_common.h" |
34 | |
35 | namespace tvm { |
36 | namespace relay { |
37 | namespace dyn { |
38 | |
39 | // relay.dyn.nn.pad |
40 | |
41 | bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
42 | const TypeReporter& reporter) { |
43 | // types = [data_type, pad_width_type, pad_value_type, ret_type] |
44 | ICHECK_EQ(types.size(), 4); |
45 | const auto* data = types[0].as<TensorTypeNode>(); |
46 | if (data == nullptr) return false; |
47 | |
48 | const auto* pad_width = types[1].as<TensorTypeNode>(); |
49 | if (pad_width == nullptr) return false; |
50 | |
51 | const auto* pad_value = types[2].as<TensorTypeNode>(); |
52 | if (pad_value == nullptr) return false; |
53 | |
54 | int data_rank = data->shape.size(); |
55 | ICHECK(data_rank) << "Data shape must have static rank" ; |
56 | |
57 | int pad_width_rank = pad_width->shape.size(); |
58 | ICHECK_EQ(pad_width_rank, 2) << "Pad width must be 2D" ; |
59 | |
60 | const PadAttrs* param = attrs.as<PadAttrs>(); |
61 | ICHECK(param != nullptr); |
62 | |
63 | std::vector<IndexExpr> oshape; |
64 | for (int i = 0; i < data_rank; i++) { |
65 | oshape.push_back(Any()); |
66 | } |
67 | |
68 | reporter->Assign(types[3], TensorType(oshape, data->dtype)); |
69 | return true; |
70 | } |
71 | |
72 | Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
73 | const Type& out_type) { |
74 | const auto* param = attrs.as<PadAttrs>(); |
75 | ICHECK(param); |
76 | |
77 | auto data = inputs[0]; |
78 | auto pad_width = inputs[1]; |
79 | |
80 | te::Tensor cast_pad_value = topi::cast(inputs[2], inputs[0]->dtype); |
81 | const PrimExpr& pad_value = cast_pad_value(Array<PrimExpr>()); |
82 | |
83 | Array<IndexExpr> pad_before; |
84 | Array<IndexExpr> pad_after; |
85 | |
86 | for (int i = 0; i < pad_width->shape[0].as<IntImmNode>()->value; ++i) { |
87 | pad_before.push_back(pad_width[i][0]); |
88 | pad_after.push_back(pad_width[i][1]); |
89 | } |
90 | |
91 | const auto* out_ttype = out_type.as<TensorTypeNode>(); |
92 | ICHECK(out_ttype != nullptr); |
93 | |
94 | return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad" , |
95 | topi::kElementWise, param->pad_mode, |
96 | &out_type.as<TensorTypeNode>()->shape)}; |
97 | } |
98 | |
99 | // Handler to create a call to the padding op used by front-end FFI |
100 | Expr MakePad(Expr data, Expr pad_width, Expr pad_value, String pad_mode) { |
101 | auto attrs = make_object<PadAttrs>(); |
102 | attrs->pad_mode = std::move(pad_mode); |
103 | static const Op& op = Op::Get("dyn.nn.pad" ); |
104 | return Call(op, {data, pad_width, pad_value}, Attrs(attrs), {}); |
105 | } |
106 | |
107 | TVM_REGISTER_GLOBAL("relay.op.dyn.nn._make.pad" ).set_body_typed(MakePad); |
108 | |
109 | RELAY_REGISTER_OP("dyn.nn.pad" ) |
110 | .describe(R"code(Pad for n-D tensor. |
111 | |
112 | )code" TVM_ADD_FILELINE) |
113 | .set_attrs_type<PadAttrs>() |
114 | .set_num_inputs(3) |
115 | .add_argument("data" , "Tensor" , "Tensor that will be padded" ) |
116 | .add_argument("pad_width" , "Tensor" , "Tensor of how much to pad by" ) |
117 | .add_argument("pad_val" , "double" , "The value to fill the padded area with" ) |
118 | .set_support_level(2) |
119 | .add_type_rel("DynamicPad" , PadRel) |
120 | .set_attr<TOpPattern>("TOpPattern" , kInjective) |
121 | .set_attr<FTVMCompute>("FTVMCompute" , PadCompute); |
122 | |
123 | } // namespace dyn |
124 | } // namespace relay |
125 | } // namespace tvm |
126 | |