1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/op/call/call.cc |
22 | * \brief Operators for calling lowered functions. |
23 | */ |
24 | |
25 | #include "./call.h" |
26 | |
27 | #include <tvm/relay/attrs/call.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/op.h> |
30 | #include <tvm/relay/op_attr_types.h> |
31 | |
32 | #include "../../transforms/infer_layout_utils.h" |
33 | |
34 | namespace tvm { |
35 | namespace relay { |
36 | |
37 | TVM_REGISTER_NODE_TYPE(CallLoweredAttrs); |
38 | |
39 | // call_lowered |
40 | bool CallLoweredRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
41 | const TypeReporter& reporter) { |
42 | // Types = [func, call_args, ret_type] |
43 | if (types.size() != 3u) { |
44 | return false; |
45 | } |
46 | const auto* func_type = types[0].as<FuncTypeNode>(); |
47 | if (!func_type) { |
48 | return false; |
49 | } |
50 | |
51 | const auto* tuple_type_node = types[1].as<TupleTypeNode>(); |
52 | if (!tuple_type_node) { |
53 | return false; |
54 | } |
55 | |
56 | // Constraint to ensure function arguments are the same type as the inputs to the function (modulo |
57 | // the Tuple wrapper) |
58 | reporter->Assign(GetRef<TupleType>(tuple_type_node), TupleType(func_type->arg_types, {})); |
59 | // Constraint to ensure the output of call_lowered is the same as the function's return type |
60 | reporter->Assign(types[2], func_type->ret_type); |
61 | return true; |
62 | } |
63 | |
64 | const Op& CallLoweredOp() { return Op::Get("call_lowered" ); } |
65 | |
66 | Call CallLowered(GlobalVar lowered_func, Array<Expr> args, CallLoweredAttrs call_lowered_attrs, |
67 | Span span) { |
68 | auto attrs = make_object<CallLoweredAttrs>(std::move(call_lowered_attrs)); |
69 | return Call(CallLoweredOp(), {std::move(lowered_func), Tuple(std::move(args))}, |
70 | Attrs(std::move(attrs)), /*type_args=*/{}, std::move(span)); |
71 | } |
72 | |
73 | TVM_REGISTER_GLOBAL("relay.op.call_lowered" ) |
74 | .set_body_typed([](Expr lowered_func, Array<Expr> args, Attrs attrs, Span span) { |
75 | const auto* lowered_func_node = lowered_func.as<GlobalVarNode>(); |
76 | ICHECK(lowered_func_node) << "Function to call should be GlobalVarNode, but got:" << std::endl |
77 | << PrettyPrint(lowered_func); |
78 | const auto* call_lowered_attrs = attrs.as<CallLoweredAttrs>(); |
79 | ICHECK(call_lowered_attrs) << "Expected attributes to be CallLoweredAttrs, but got " |
80 | << attrs->GetTypeKey(); |
81 | return CallLowered(GetRef<GlobalVar>(lowered_func_node), std::move(args), *call_lowered_attrs, |
82 | std::move(span)); |
83 | }); |
84 | |
85 | RELAY_REGISTER_OP("call_lowered" ) |
86 | .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE) |
87 | .set_num_inputs(2) |
88 | .set_attrs_type<CallLoweredAttrs>() |
89 | .add_argument("func" , "Function" , "The lowered function to call." ) |
90 | .add_argument("call_args" , "Tuple" , "The input tensors." ) |
91 | .add_type_rel("CallLoweredRel" , CallLoweredRel) |
92 | .set_support_level(10) |
93 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
94 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
95 | .set_attr<TNonComputational>("TNonComputational" , true) |
96 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout); |
97 | |
98 | CallLoweredProps GetCallLoweredProps(const CallNode* call_node) { |
99 | if (call_node->op == CallLoweredOp()) { |
100 | ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 arguments." ; |
101 | const auto* function_node = call_node->args[0].as<GlobalVarNode>(); |
102 | ICHECK(function_node) << "Expected first arg to call_lowered to be a GlobalVar. " ; |
103 | |
104 | const auto* tuple_args = call_node->args[1].as<TupleNode>(); |
105 | ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple of input arguments." ; |
106 | |
107 | ICHECK(call_node->attrs.defined()) << "Expecting call_lowered to have attributes." ; |
108 | const auto* call_lowered_attrs = call_node->attrs.as<CallLoweredAttrs>(); |
109 | ICHECK(call_lowered_attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found " |
110 | << call_node->attrs->GetTypeKey(); |
111 | // If the call_node has type_args then they are for the polymorphic 'call_lowered' operator |
112 | // itself which expects the function type and argument type as parameters. |
113 | return {GetRef<GlobalVar>(function_node), tuple_args->fields, *call_lowered_attrs}; |
114 | } |
115 | return {}; |
116 | } |
117 | |
118 | Call GetAnyCall(const CallNode* call_node) { |
119 | CallLoweredProps props = GetCallLoweredProps(call_node); |
120 | if (props.lowered_func.defined()) { |
121 | auto call_lowered_attrs = make_object<CallLoweredAttrs>(props.attrs); |
122 | return Call(std::move(props.lowered_func), std::move(props.arguments), |
123 | Attrs(std::move(call_lowered_attrs)), |
124 | /*type_args=*/{}, call_node->span); |
125 | } else { |
126 | return GetRef<Call>(call_node); |
127 | } |
128 | } |
129 | |
130 | bool IsReshapeOnly(const CallLoweredProps& props) { |
131 | if (props.attrs.metadata.count("relay_attrs" )) { |
132 | auto dict_attrs = Downcast<DictAttrs>(props.attrs.metadata["relay_attrs" ]); |
133 | return dict_attrs.HasNonzeroAttr(attr::kReshapeOnly); |
134 | } |
135 | return false; |
136 | } |
137 | |
138 | } // namespace relay |
139 | } // namespace tvm |
140 | |