1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/op/call/call.cc
22 * \brief Operators for calling lowered functions.
23 */
24
25#include "./call.h"
26
27#include <tvm/relay/attrs/call.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/op.h>
30#include <tvm/relay/op_attr_types.h>
31
32#include "../../transforms/infer_layout_utils.h"
33
34namespace tvm {
35namespace relay {
36
37TVM_REGISTER_NODE_TYPE(CallLoweredAttrs);
38
39// call_lowered
40bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
41 const TypeReporter& reporter) {
42 // Types = [func, call_args, ret_type]
43 if (types.size() != 3u) {
44 return false;
45 }
46 const auto* func_type = types[0].as<FuncTypeNode>();
47 if (!func_type) {
48 return false;
49 }
50
51 const auto* tuple_type_node = types[1].as<TupleTypeNode>();
52 if (!tuple_type_node) {
53 return false;
54 }
55
56 // Constraint to ensure function arguments are the same type as the inputs to the function (modulo
57 // the Tuple wrapper)
58 reporter->Assign(GetRef<TupleType>(tuple_type_node), TupleType(func_type->arg_types, {}));
59 // Constraint to ensure the output of call_lowered is the same as the function's return type
60 reporter->Assign(types[2], func_type->ret_type);
61 return true;
62}
63
64const Op& CallLoweredOp() { return Op::Get("call_lowered"); }
65
66Call CallLowered(GlobalVar lowered_func, Array<Expr> args, CallLoweredAttrs call_lowered_attrs,
67 Span span) {
68 auto attrs = make_object<CallLoweredAttrs>(std::move(call_lowered_attrs));
69 return Call(CallLoweredOp(), {std::move(lowered_func), Tuple(std::move(args))},
70 Attrs(std::move(attrs)), /*type_args=*/{}, std::move(span));
71}
72
73TVM_REGISTER_GLOBAL("relay.op.call_lowered")
74 .set_body_typed([](Expr lowered_func, Array<Expr> args, Attrs attrs, Span span) {
75 const auto* lowered_func_node = lowered_func.as<GlobalVarNode>();
76 ICHECK(lowered_func_node) << "Function to call should be GlobalVarNode, but got:" << std::endl
77 << PrettyPrint(lowered_func);
78 const auto* call_lowered_attrs = attrs.as<CallLoweredAttrs>();
79 ICHECK(call_lowered_attrs) << "Expected attributes to be CallLoweredAttrs, but got "
80 << attrs->GetTypeKey();
81 return CallLowered(GetRef<GlobalVar>(lowered_func_node), std::move(args), *call_lowered_attrs,
82 std::move(span));
83 });
84
85RELAY_REGISTER_OP("call_lowered")
86 .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
87 .set_num_inputs(2)
88 .set_attrs_type<CallLoweredAttrs>()
89 .add_argument("func", "Function", "The lowered function to call.")
90 .add_argument("call_args", "Tuple", "The input tensors.")
91 .add_type_rel("CallLoweredRel", CallLoweredRel)
92 .set_support_level(10)
93 .set_attr<TOpPattern>("TOpPattern", kOpaque)
94 .set_attr<TOpIsStateful>("TOpIsStateful", false)
95 .set_attr<TNonComputational>("TNonComputational", true)
96 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
97
98CallLoweredProps GetCallLoweredProps(const CallNode* call_node) {
99 if (call_node->op == CallLoweredOp()) {
100 ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 arguments.";
101 const auto* function_node = call_node->args[0].as<GlobalVarNode>();
102 ICHECK(function_node) << "Expected first arg to call_lowered to be a GlobalVar. ";
103
104 const auto* tuple_args = call_node->args[1].as<TupleNode>();
105 ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple of input arguments.";
106
107 ICHECK(call_node->attrs.defined()) << "Expecting call_lowered to have attributes.";
108 const auto* call_lowered_attrs = call_node->attrs.as<CallLoweredAttrs>();
109 ICHECK(call_lowered_attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found "
110 << call_node->attrs->GetTypeKey();
111 // If the call_node has type_args then they are for the polymorphic 'call_lowered' operator
112 // itself which expects the function type and argument type as parameters.
113 return {GetRef<GlobalVar>(function_node), tuple_args->fields, *call_lowered_attrs};
114 }
115 return {};
116}
117
118Call GetAnyCall(const CallNode* call_node) {
119 CallLoweredProps props = GetCallLoweredProps(call_node);
120 if (props.lowered_func.defined()) {
121 auto call_lowered_attrs = make_object<CallLoweredAttrs>(props.attrs);
122 return Call(std::move(props.lowered_func), std::move(props.arguments),
123 Attrs(std::move(call_lowered_attrs)),
124 /*type_args=*/{}, call_node->span);
125 } else {
126 return GetRef<Call>(call_node);
127 }
128}
129
130bool IsReshapeOnly(const CallLoweredProps& props) {
131 if (props.attrs.metadata.count("relay_attrs")) {
132 auto dict_attrs = Downcast<DictAttrs>(props.attrs.metadata["relay_attrs"]);
133 return dict_attrs.HasNonzeroAttr(attr::kReshapeOnly);
134 }
135 return false;
136}
137
138} // namespace relay
139} // namespace tvm
140