1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/interpreter.h |
22 | * \brief An interpreter for Relay. |
23 | * |
24 | * This file implements a simple reference interpreter for Relay programs. |
25 | * Given a Relay module, and a Relay expression it produces a value. |
26 | * |
27 | * The interpreter's values are a naive representation of the values that |
28 | * can be produced by a Relay program and are exposed via TVM's object |
29 | * protocol to Python for introspection and debugging. |
30 | * |
31 | * The interpreter's intent is to serve as a reference semantics for the Relay IR, |
32 | * as well as for debugging and testing. |
33 | */ |
34 | #ifndef TVM_RELAY_INTERPRETER_H_ |
35 | #define TVM_RELAY_INTERPRETER_H_ |
36 | |
37 | #include <tvm/ir/module.h> |
38 | #include <tvm/relay/expr.h> |
39 | #include <tvm/runtime/container/closure.h> |
40 | #include <tvm/runtime/object.h> |
41 | #include <tvm/target/target.h> |
42 | |
43 | #include <unordered_set> |
44 | |
45 | namespace tvm { |
46 | namespace relay { |
47 | |
48 | /*! \brief The container type of Closures used by the interpreter. */ |
49 | class InterpreterClosureObj : public runtime::ClosureObj { |
50 | public: |
51 | /*! \brief The set of free variables in the closure. |
52 | * |
53 | * These are the captured variables which are required for |
54 | * evaluation when we call the closure. |
55 | */ |
56 | tvm::Map<Var, ObjectRef> env; |
57 | /*! \brief The function which implements the closure. |
58 | * |
59 | * \note May reference the variables contained in the env. |
60 | */ |
61 | Function func; |
62 | |
63 | InterpreterClosureObj() {} |
64 | |
65 | void VisitAttrs(tvm::AttrVisitor* v) { |
66 | v->Visit("env" , &env); |
67 | v->Visit("func" , &func); |
68 | } |
69 | |
70 | static constexpr const char* _type_key = "interpreter.Closure" ; |
71 | TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::ClosureObj); |
72 | }; |
73 | |
74 | class InterpreterClosure : public runtime::Closure { |
75 | public: |
76 | TVM_DLL InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func); |
77 | TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::Closure, InterpreterClosureObj); |
78 | }; |
79 | |
80 | /*! \brief The container type of RecClosure. */ |
81 | class RecClosureObj : public Object { |
82 | public: |
83 | /*! \brief The closure. */ |
84 | InterpreterClosure clos; |
85 | /*! \brief variable the closure bind to. */ |
86 | Var bind; |
87 | |
88 | RecClosureObj() {} |
89 | |
90 | void VisitAttrs(tvm::AttrVisitor* v) { |
91 | v->Visit("clos" , &clos); |
92 | v->Visit("bind" , &bind); |
93 | } |
94 | |
95 | static constexpr const char* _type_key = "interpreter.RecClosure" ; |
96 | TVM_DECLARE_FINAL_OBJECT_INFO(RecClosureObj, Object); |
97 | }; |
98 | |
99 | class RecClosure : public ObjectRef { |
100 | public: |
101 | TVM_DLL RecClosure(InterpreterClosure clos, Var bind); |
102 | TVM_DEFINE_OBJECT_REF_METHODS(RecClosure, ObjectRef, RecClosureObj); |
103 | }; |
104 | |
105 | struct RefValueObj : Object { |
106 | mutable ObjectRef value; |
107 | |
108 | RefValueObj() {} |
109 | |
110 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value" , &value); } |
111 | |
112 | static constexpr const char* _type_key = "relay.RefValue" ; |
113 | TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object); |
114 | }; |
115 | |
116 | class RefValue : public ObjectRef { |
117 | public: |
118 | TVM_DLL RefValue(ObjectRef val); |
119 | TVM_DEFINE_OBJECT_REF_METHODS(RefValue, ObjectRef, RefValueObj); |
120 | }; |
121 | |
122 | struct ConstructorValueObj : Object { |
123 | int32_t tag; |
124 | |
125 | tvm::Array<ObjectRef> fields; |
126 | |
127 | /*! \brief Optional field tracking ADT constructor. */ |
128 | Constructor constructor; |
129 | |
130 | void VisitAttrs(tvm::AttrVisitor* v) { |
131 | v->Visit("tag" , &tag); |
132 | v->Visit("fields" , &fields); |
133 | v->Visit("constructor" , &constructor); |
134 | } |
135 | |
136 | static constexpr const char* _type_key = "relay.ConstructorValue" ; |
137 | TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorValueObj, Object); |
138 | }; |
139 | |
140 | class ConstructorValue : public ObjectRef { |
141 | public: |
142 | TVM_DLL ConstructorValue(int32_t tag, tvm::Array<ObjectRef> fields, Constructor construtor = {}); |
143 | |
144 | TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj); |
145 | }; |
146 | |
147 | /*! |
148 | * \brief Returns a packed function over Relay expressions which will evaluate \p expr |
149 | * applied to those arguments, where \p expr is w.r.t. the definitions in \p mod. |
150 | * |
151 | * This function is intended to support the Python 'debug' executor. |
152 | * |
153 | * The given \p expr should have function type. The given \p mod may be empty or |
154 | * undefined if \p expr is self-contained. Relay arguments passed to the result |
155 | * packed function must be constants, references, or constructors/tuples over such. |
156 | * As much work as possible is done while constructing the result packed function, and |
157 | * that function may be reasonably efficiently applied multiple times without redoing |
158 | * unnecessary work. |
159 | * |
160 | * Primitives are lowered and compiled to packed functions for execution on \p device |
161 | * with properties given by \p target. All other Relay constructs are interpreted. |
162 | * |
163 | * The interpreter is intended to be a 'reference' implementation of the Relay semantics |
164 | * for testing and interactive use. It is not intended to be particularly efficient. |
165 | * |
166 | * \param mod A module containing definitions which can be referenced from |
167 | * \p expr. May be empty or undefined. |
168 | * \param expr An expression of function type to evaluate. May reference definitions from \p mod. |
169 | * \param device The device on which all primitives will be executed. |
170 | * \param target The compiler target flag for compiling primitives. |
171 | * \return A packed function that takes an array of Relay expressions and returns the |
172 | * result of applying \p expr to those arguments. |
173 | */ |
174 | TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, Device device, |
175 | Target target); |
176 | |
177 | /*! |
178 | * \brief Evaluates \p expr and returns its result. |
179 | * |
180 | * This function is intended to support TVM constant evaluation. |
181 | * |
182 | * \param expr An expression to evaluate. |
183 | * \param type_definitions Global type definitions which \p expr may references. |
184 | * \param import_set Already imported external modules. |
185 | * \param device The device on which all primitives will be executed. |
186 | * \param target The compiler target flag for compiling primitives. |
187 | * \param attrs Attributes for the expression to be evaluated with |
188 | * @return The object representing the result. |
189 | */ |
190 | ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions, |
191 | std::unordered_set<String> import_set, Device device, Target target, |
192 | Map<String, ObjectRef> attrs = {}); |
193 | |
194 | } // namespace relay |
195 | } // namespace tvm |
196 | |
197 | #endif // TVM_RELAY_INTERPRETER_H_ |
198 | |