1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/op/vm/vm.cc |
22 | * \brief Dialect operators for Relay VM. |
23 | */ |
24 | |
25 | #include "vm.h" |
26 | |
27 | #include <tvm/relay/attrs/memory.h> |
28 | #include <tvm/relay/attrs/vm.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/op.h> |
31 | #include <tvm/relay/op_attr_types.h> |
32 | #include <tvm/runtime/data_type.h> |
33 | #include <tvm/topi/elemwise.h> |
34 | |
35 | #include <utility> |
36 | |
37 | #include "../../transforms/infer_layout_utils.h" |
38 | #include "../op_common.h" |
39 | #include "../type_relations.h" |
40 | |
41 | namespace tvm { |
42 | namespace relay { |
43 | |
44 | // shape_of |
45 | // register ShapeOfAttrs here to make sure it has been registered when vm.shape_of uses it |
46 | TVM_REGISTER_NODE_TYPE(ShapeOfAttrs); |
47 | |
48 | // vm.shape_func |
49 | TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs); |
50 | |
51 | RELAY_REGISTER_OP("vm.shape_of" ) |
52 | .describe(R"code(Get the shape of an input tensor. |
53 | )code" TVM_ADD_FILELINE) |
54 | .set_num_inputs(1) |
55 | .add_argument("tensor" , "Tensor" , "The input tensor" ) |
56 | .add_type_rel("ShapeOf" , ShapeOfRel) |
57 | .set_attrs_type_key("relay.attrs.ShapeOfAttrs" ) |
58 | .set_support_level(10) |
59 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
60 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
61 | .set_attr<TNonComputational>("TNonComputational" , true) |
62 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout); |
63 | |
64 | Expr ShapeOf(Expr expr) { |
65 | auto attrs = make_object<ShapeOfAttrs>(); |
66 | attrs->dtype = DataType::Int(64); |
67 | static const Op& op = Op::Get("vm.shape_of" ); |
68 | return Call(op, {std::move(expr)}, Attrs(std::move(attrs)), {}); |
69 | } |
70 | |
71 | TVM_REGISTER_GLOBAL("relay.op.vm.shape_of" ).set_body_typed(ShapeOf); |
72 | |
73 | // vm.invoke_tvm_op |
74 | bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
75 | const TypeReporter& reporter) { |
76 | ICHECK_EQ(types.size(), 4u); |
77 | auto func_type = types[0].as<FuncTypeNode>(); |
78 | ICHECK(func_type != nullptr) << "input must be operator with known type" ; |
79 | auto input_type = types[1].as<TupleTypeNode>(); |
80 | auto output_type = types[2].as<TupleTypeNode>(); |
81 | ICHECK(input_type != nullptr) |
82 | << "internal invariant violated: invoke_tvm_op inputs must be a tuple" ; |
83 | ICHECK(output_type != nullptr) |
84 | << "internal invariant violated: invoke_tvm_op outputs must be a tuple" ; |
85 | Type ex_output; |
86 | if (func_type->ret_type.as<TensorTypeNode>()) { |
87 | ex_output = TupleType({func_type->ret_type}); |
88 | } else { |
89 | ICHECK(func_type->ret_type.as<TupleTypeNode>()) |
90 | << "expecting function result to be tuple type. Types:" << std::endl |
91 | << PrettyPrint(types); |
92 | ex_output = func_type->ret_type; |
93 | } |
94 | auto ex_input = TupleType(func_type->arg_types); |
95 | reporter->Assign(ex_input, GetRef<Type>(input_type)); |
96 | reporter->Assign(ex_output, GetRef<Type>(output_type)); |
97 | reporter->Assign(types[3], TupleType::Empty()); |
98 | return true; |
99 | } |
100 | |
101 | Expr InvokeTVMOp(Expr func, Expr inputs, Expr outputs, DictAttrs attrs) { |
102 | static const Op& op = Op::Get("vm.invoke_tvm_op" ); |
103 | return Call(op, {std::move(func), std::move(inputs), std::move(outputs)}, std::move(attrs)); |
104 | } |
105 | |
106 | TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op" ) |
107 | .set_body_typed([](Expr func, Expr inputs, Expr outputs, DictAttrs attrs) { |
108 | return InvokeTVMOp(std::move(func), std::move(inputs), std::move(outputs), std::move(attrs)); |
109 | }); |
110 | |
111 | RELAY_REGISTER_OP("vm.invoke_tvm_op" ) |
112 | .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE) |
113 | .set_num_inputs(3) |
114 | .add_argument("op" , "Function" , "The operation to call" ) |
115 | .add_argument("ins" , "Tuple" , "The input tensors." ) |
116 | .add_argument("outs" , "Tuple" , "The output tensors." ) |
117 | .add_type_rel("InvokeTVMOp" , InvokeTVMOpRel) |
118 | .set_attrs_type_key("DictAttrs" ) |
119 | .set_support_level(10) |
120 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
121 | .set_attr<TOpIsStateful>("TOpIsStateful" , true) |
122 | .set_attr<TNonComputational>("TNonComputational" , true) |
123 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout); |
124 | |
125 | // vm.reshape |
126 | TVM_REGISTER_NODE_TYPE(ReshapeTensorAttrs); |
127 | |
128 | bool ReshapeTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
129 | const TypeReporter& reporter) { |
130 | ICHECK_EQ(types.size(), 3u); |
131 | auto reshape_attrs = attrs.as<ReshapeTensorAttrs>(); |
132 | ICHECK(reshape_attrs); |
133 | auto tt = types[0].as<TensorTypeNode>(); |
134 | ICHECK(tt) << "input must be tensor type" ; |
135 | reporter->Assign(types[2], TensorType(reshape_attrs->newshape, tt->dtype)); |
136 | return true; |
137 | } |
138 | |
139 | RELAY_REGISTER_OP("vm.reshape_tensor" ) |
140 | .describe(R"code(Use VM reshape_tensor instruction to reshape the tensor. |
141 | )code" TVM_ADD_FILELINE) |
142 | .set_num_inputs(2) |
143 | .add_argument("data" , "Tensor" , "The input tensor" ) |
144 | .add_argument("shape" , "Tensor" , "The output shape tensor" ) |
145 | .add_type_rel("ReshapeTensor" , ReshapeTensorRel) |
146 | .set_attrs_type_key("relay.attrs.ReshapeTensorAttrs" ) |
147 | .set_support_level(10) |
148 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
149 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
150 | .set_attr<TNonComputational>("TNonComputational" , true) |
151 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout); |
152 | |
153 | Expr ReshapeTensor(Expr data, Expr shape, Array<PrimExpr> newshape) { |
154 | static const Op& op = Op::Get("vm.reshape_tensor" ); |
155 | auto attrs = make_object<ReshapeTensorAttrs>(); |
156 | attrs->newshape = std::move(newshape); |
157 | return Call(op, {std::move(data), std::move(shape)}, Attrs(std::move(attrs)), {}); |
158 | } |
159 | |
160 | TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor" ).set_body_typed(ReshapeTensor); |
161 | |
162 | } // namespace relay |
163 | } // namespace tvm |
164 | |