1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/op/call/call.h
22 * \brief Operators for calling lowered functions.
23 */
24#ifndef TVM_RELAY_OP_CALL_CALL_H_
25#define TVM_RELAY_OP_CALL_CALL_H_
26
27#include <tvm/relay/attrs/call.h>
28#include <tvm/relay/expr.h>
29
30#include <utility>
31
32namespace tvm {
33namespace relay {
34
35/*!
36 * \brief Returns the Relay call_lowered op. Use this helper to avoid extraneous calls to
37 * Registry::Get.
38 */
39const Op& CallLoweredOp();
40
41/*!
42 * \brief Helper to construct a Relay call with the "call_lowered" op.
43 *
44 * The callee must:
45 * - Be a global bound to a PrimFunc or an externally defined functions.
46 * - Accept only tensor arguments and return tensor results.
47 * - Arguments and results correspond to the flattened form (see FlattenTupleType) of the
48 * Relay Function type.
49 * - Return results by output pointer, ie use DPS.
50 * The arguments remain in Relay form (ie not flattened).
51 * The result remains in Relay form (ie returned from the call and not flattened).
52 *
53 * \param lowered_func Lowered function to call with call_lowered.
54 * \param args Arguments to be passed to the function.
55 * \param call_lowered_attrs Function attributes.
56 * \param span TVM span for propagating debugging info.
57 * \return
58 */
59Call CallLowered(GlobalVar lowered_func, Array<Expr> args, CallLoweredAttrs call_lowered_attrs,
60 Span span);
61
62/*!
63 * \brief Lowered function and the arguments to call it with.
64 */
65struct CallLoweredProps {
66 /*! \brief Global variable pointing to the lowered function. */
67 GlobalVar lowered_func;
68 /*! \brief Array of the arguments to call lowered_func with. */
69 Array<Expr> arguments;
70 /*! \brief Attributes from the call_lowered op. */
71 CallLoweredAttrs attrs;
72};
73
74/*!
75 * \brief Helper to extract the lowered function and its arguments from a Call("call_lowered", ...).
76 * Returns the null/empty \p CallLoweredProps if \p call_node is not in that form.
77 */
78CallLoweredProps GetCallLoweredProps(const CallNode* call_node);
79
80/*!
81 * \brief Returns \p call_node in 'standard' Relay form. Ie if \p call_node is a call_lowered
82 * then returns it in un-lowered form, otherwise returns \p call_node directly.
83 *
84 * Useful for passes which can act uniformly on calls irrespective of their form.
85 */
86Call GetAnyCall(const CallNode* call_node);
87
88/*!
89 * \brief Returns true if lowered call described by \p props is to a reshape primitive.
90 */
91bool IsReshapeOnly(const CallLoweredProps& props);
92
93} // namespace relay
94} // namespace tvm
95
96#endif // TVM_RELAY_OP_CALL_CALL_H_
97