1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/interpreter.h
22 * \brief An interpreter for Relay.
23 *
24 * This file implements a simple reference interpreter for Relay programs.
25 * Given a Relay module, and a Relay expression it produces a value.
26 *
27 * The interpreter's values are a naive representation of the values that
28 * can be produced by a Relay program and are exposed via TVM's object
29 * protocol to Python for introspection and debugging.
30 *
31 * The interpreter's intent is to serve as a reference semantics for the Relay IR,
32 * as well as for debugging and testing.
33 */
34#ifndef TVM_RELAY_INTERPRETER_H_
35#define TVM_RELAY_INTERPRETER_H_
36
37#include <tvm/ir/module.h>
38#include <tvm/relay/expr.h>
39#include <tvm/runtime/container/closure.h>
40#include <tvm/runtime/object.h>
41#include <tvm/target/target.h>
42
43#include <unordered_set>
44
45namespace tvm {
46namespace relay {
47
48/*! \brief The container type of Closures used by the interpreter. */
49class InterpreterClosureObj : public runtime::ClosureObj {
50 public:
51 /*! \brief The set of free variables in the closure.
52 *
53 * These are the captured variables which are required for
54 * evaluation when we call the closure.
55 */
56 tvm::Map<Var, ObjectRef> env;
57 /*! \brief The function which implements the closure.
58 *
59 * \note May reference the variables contained in the env.
60 */
61 Function func;
62
63 InterpreterClosureObj() {}
64
65 void VisitAttrs(tvm::AttrVisitor* v) {
66 v->Visit("env", &env);
67 v->Visit("func", &func);
68 }
69
70 static constexpr const char* _type_key = "interpreter.Closure";
71 TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::ClosureObj);
72};
73
74class InterpreterClosure : public runtime::Closure {
75 public:
76 TVM_DLL InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func);
77 TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::Closure, InterpreterClosureObj);
78};
79
80/*! \brief The container type of RecClosure. */
81class RecClosureObj : public Object {
82 public:
83 /*! \brief The closure. */
84 InterpreterClosure clos;
85 /*! \brief variable the closure bind to. */
86 Var bind;
87
88 RecClosureObj() {}
89
90 void VisitAttrs(tvm::AttrVisitor* v) {
91 v->Visit("clos", &clos);
92 v->Visit("bind", &bind);
93 }
94
95 static constexpr const char* _type_key = "interpreter.RecClosure";
96 TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object);
97};
98
99class RecClosure : public ObjectRef {
100 public:
101 TVM_DLL RecClosure(InterpreterClosure clos, Var bind);
102 TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj);
103};
104
105struct RefValueObj : Object {
106 mutable ObjectRef value;
107
108 RefValueObj() {}
109
110 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); }
111
112 static constexpr const char* _type_key = "relay.RefValue";
113 TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object);
114};
115
116class RefValue : public ObjectRef {
117 public:
118 TVM_DLL RefValue(ObjectRef val);
119 TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj);
120};
121
122struct ConstructorValueObj : Object {
123 int32_t tag;
124
125 tvm::Array<ObjectRef> fields;
126
127 /*! \brief Optional field tracking ADT constructor. */
128 Constructor constructor;
129
130 void VisitAttrs(tvm::AttrVisitor* v) {
131 v->Visit("tag", &tag);
132 v->Visit("fields", &fields);
133 v->Visit("constructor", &constructor);
134 }
135
136 static constexpr const char* _type_key = "relay.ConstructorValue";
137 TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object);
138};
139
140class ConstructorValue : public ObjectRef {
141 public:
142 TVM_DLL ConstructorValue(int32_t tag, tvm::Array<ObjectRef> fields, Constructor construtor = {});
143
144 TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
145};
146
147/*!
148 * \brief Returns a packed function over Relay expressions which will evaluate \p expr
149 * applied to those arguments, where \p expr is w.r.t. the definitions in \p mod.
150 *
151 * This function is intended to support the Python 'debug' executor.
152 *
153 * The given \p expr should have function type. The given \p mod may be empty or
154 * undefined if \p expr is self-contained. Relay arguments passed to the result
155 * packed function must be constants, references, or constructors/tuples over such.
156 * As much work as possible is done while constructing the result packed function, and
157 * that function may be reasonably efficiently applied multiple times without redoing
158 * unnecessary work.
159 *
160 * Primitives are lowered and compiled to packed functions for execution on \p device
161 * with properties given by \p target. All other Relay constructs are interpreted.
162 *
163 * The interpreter is intended to be a 'reference' implementation of the Relay semantics
164 * for testing and interactive use. It is not intended to be particularly efficient.
165 *
166 * \param mod A module containing definitions which can be referenced from
167 * \p expr. May be empty or undefined.
168 * \param expr An expression of function type to evaluate. May reference definitions from \p mod.
169 * \param device The device on which all primitives will be executed.
170 * \param target The compiler target flag for compiling primitives.
171 * \return A packed function that takes an array of Relay expressions and returns the
172 * result of applying \p expr to those arguments.
173 */
174TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, Device device,
175 Target target);
176
177/*!
178 * \brief Evaluates \p expr and returns its result.
179 *
180 * This function is intended to support TVM constant evaluation.
181 *
182 * \param expr An expression to evaluate.
183 * \param type_definitions Global type definitions which \p expr may references.
184 * \param import_set Already imported external modules.
185 * \param device The device on which all primitives will be executed.
186 * \param target The compiler target flag for compiling primitives.
187 * \param attrs Attributes for the expression to be evaluated with
188 * @return The object representing the result.
189 */
190ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
191 std::unordered_set<String> import_set, Device device, Target target,
192 Map<String, ObjectRef> attrs = {});
193
194} // namespace relay
195} // namespace tvm
196
197#endif // TVM_RELAY_INTERPRETER_H_
198