1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/op/vm/vm.cc
22 * \brief Dialect operators for Relay VM.
23 */
24
25#include "vm.h"
26
27#include <tvm/relay/attrs/memory.h>
28#include <tvm/relay/attrs/vm.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/op.h>
31#include <tvm/relay/op_attr_types.h>
32#include <tvm/runtime/data_type.h>
33#include <tvm/topi/elemwise.h>
34
35#include <utility>
36
37#include "../../transforms/infer_layout_utils.h"
38#include "../op_common.h"
39#include "../type_relations.h"
40
41namespace tvm {
42namespace relay {
43
44// shape_of
45// register ShapeOfAttrs here to make sure it has been registered when vm.shape_of uses it
46TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);
47
48// vm.shape_func
49TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);
50
51RELAY_REGISTER_OP("vm.shape_of")
52 .describe(R"code(Get the shape of an input tensor.
53)code" TVM_ADD_FILELINE)
54 .set_num_inputs(1)
55 .add_argument("tensor", "Tensor", "The input tensor")
56 .add_type_rel("ShapeOf", ShapeOfRel)
57 .set_attrs_type_key("relay.attrs.ShapeOfAttrs")
58 .set_support_level(10)
59 .set_attr<TOpPattern>("TOpPattern", kOpaque)
60 .set_attr<TOpIsStateful>("TOpIsStateful", false)
61 .set_attr<TNonComputational>("TNonComputational", true)
62 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
63
64Expr ShapeOf(Expr expr) {
65 auto attrs = make_object<ShapeOfAttrs>();
66 attrs->dtype = DataType::Int(64);
67 static const Op& op = Op::Get("vm.shape_of");
68 return Call(op, {std::move(expr)}, Attrs(std::move(attrs)), {});
69}
70
71TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed(ShapeOf);
72
73// vm.invoke_tvm_op
74bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
75 const TypeReporter& reporter) {
76 ICHECK_EQ(types.size(), 4u);
77 auto func_type = types[0].as<FuncTypeNode>();
78 ICHECK(func_type != nullptr) << "input must be operator with known type";
79 auto input_type = types[1].as<TupleTypeNode>();
80 auto output_type = types[2].as<TupleTypeNode>();
81 ICHECK(input_type != nullptr)
82 << "internal invariant violated: invoke_tvm_op inputs must be a tuple";
83 ICHECK(output_type != nullptr)
84 << "internal invariant violated: invoke_tvm_op outputs must be a tuple";
85 Type ex_output;
86 if (func_type->ret_type.as<TensorTypeNode>()) {
87 ex_output = TupleType({func_type->ret_type});
88 } else {
89 ICHECK(func_type->ret_type.as<TupleTypeNode>())
90 << "expecting function result to be tuple type. Types:" << std::endl
91 << PrettyPrint(types);
92 ex_output = func_type->ret_type;
93 }
94 auto ex_input = TupleType(func_type->arg_types);
95 reporter->Assign(ex_input, GetRef<Type>(input_type));
96 reporter->Assign(ex_output, GetRef<Type>(output_type));
97 reporter->Assign(types[3], TupleType::Empty());
98 return true;
99}
100
101Expr InvokeTVMOp(Expr func, Expr inputs, Expr outputs, DictAttrs attrs) {
102 static const Op& op = Op::Get("vm.invoke_tvm_op");
103 return Call(op, {std::move(func), std::move(inputs), std::move(outputs)}, std::move(attrs));
104}
105
106TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op")
107 .set_body_typed([](Expr func, Expr inputs, Expr outputs, DictAttrs attrs) {
108 return InvokeTVMOp(std::move(func), std::move(inputs), std::move(outputs), std::move(attrs));
109 });
110
111RELAY_REGISTER_OP("vm.invoke_tvm_op")
112 .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
113 .set_num_inputs(3)
114 .add_argument("op", "Function", "The operation to call")
115 .add_argument("ins", "Tuple", "The input tensors.")
116 .add_argument("outs", "Tuple", "The output tensors.")
117 .add_type_rel("InvokeTVMOp", InvokeTVMOpRel)
118 .set_attrs_type_key("DictAttrs")
119 .set_support_level(10)
120 .set_attr<TOpPattern>("TOpPattern", kOpaque)
121 .set_attr<TOpIsStateful>("TOpIsStateful", true)
122 .set_attr<TNonComputational>("TNonComputational", true)
123 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
124
125// vm.reshape
126TVM_REGISTER_NODE_TYPE(ReshapeTensorAttrs);
127
128bool ReshapeTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
129 const TypeReporter& reporter) {
130 ICHECK_EQ(types.size(), 3u);
131 auto reshape_attrs = attrs.as<ReshapeTensorAttrs>();
132 ICHECK(reshape_attrs);
133 auto tt = types[0].as<TensorTypeNode>();
134 ICHECK(tt) << "input must be tensor type";
135 reporter->Assign(types[2], TensorType(reshape_attrs->newshape, tt->dtype));
136 return true;
137}
138
139RELAY_REGISTER_OP("vm.reshape_tensor")
140 .describe(R"code(Use VM reshape_tensor instruction to reshape the tensor.
141)code" TVM_ADD_FILELINE)
142 .set_num_inputs(2)
143 .add_argument("data", "Tensor", "The input tensor")
144 .add_argument("shape", "Tensor", "The output shape tensor")
145 .add_type_rel("ReshapeTensor", ReshapeTensorRel)
146 .set_attrs_type_key("relay.attrs.ReshapeTensorAttrs")
147 .set_support_level(10)
148 .set_attr<TOpPattern>("TOpPattern", kOpaque)
149 .set_attr<TOpIsStateful>("TOpIsStateful", false)
150 .set_attr<TNonComputational>("TNonComputational", true)
151 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
152
153Expr ReshapeTensor(Expr data, Expr shape, Array<PrimExpr> newshape) {
154 static const Op& op = Op::Get("vm.reshape_tensor");
155 auto attrs = make_object<ReshapeTensorAttrs>();
156 attrs->newshape = std::move(newshape);
157 return Call(op, {std::move(data), std::move(shape)}, Attrs(std::move(attrs)), {});
158}
159
160TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor").set_body_typed(ReshapeTensor);
161
162} // namespace relay
163} // namespace tvm
164