1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/interpreter.cc |
22 | * \brief An interpreter for the Relay IR. |
23 | */ |
24 | |
25 | #include <tvm/driver/driver_api.h> |
26 | #include <tvm/relay/analysis.h> |
27 | #include <tvm/relay/attrs/annotation.h> |
28 | #include <tvm/relay/attrs/call.h> |
29 | #include <tvm/relay/attrs/debug.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | #include <tvm/relay/feature.h> |
32 | #include <tvm/relay/interpreter.h> |
33 | #include <tvm/relay/pattern_functor.h> |
34 | #include <tvm/relay/qnn/transform.h> |
35 | #include <tvm/relay/transform.h> |
36 | #include <tvm/runtime/container/map.h> |
37 | #include <tvm/runtime/device_api.h> |
38 | #include <tvm/runtime/object.h> |
39 | #include <tvm/target/compilation_config.h> |
40 | |
41 | #include "../op/annotation/annotation.h" |
42 | #include "../op/call/call.h" |
43 | #include "../op/memory/device_copy.h" |
44 | #include "../transforms/pass_utils.h" |
45 | #include "te_compiler.h" |
46 | |
47 | namespace tvm { |
48 | namespace relay { |
49 | |
50 | using runtime::ADT; |
51 | using runtime::ADTObj; |
52 | using runtime::NDArray; |
53 | using runtime::TVMArgsSetter; |
54 | using runtime::operator<<; |
55 | |
56 | namespace { |
57 | // TODO(mbs): Centralize. |
58 | struct PairHash { |
59 | template <typename T1, typename T2> |
60 | std::size_t operator()(const std::pair<T1, T2>& k) const { |
61 | return dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second)); |
62 | } |
63 | template <typename T2> |
64 | std::size_t operator()(const std::pair<Target, T2>& k) const { |
65 | return dmlc::HashCombine(ObjectHash()(k.first), std::hash<T2>()(k.second)); |
66 | } |
67 | }; |
68 | |
69 | // Analogue of FlattenTupleType for runtime ADT vs NDArray values. |
70 | // TODO(mbs): Hoist somewhere sensible, maybe op/memory.h? |
71 | void FlattenADTAux(const ObjectRef& object_ref, std::vector<NDArray>* out) { |
72 | if (const NDArray::ContainerType* ndarray = object_ref.as<NDArray::ContainerType>()) { |
73 | out->push_back(GetRef<NDArray>(ndarray)); |
74 | } else if (const ADTObj* adt = object_ref.as<ADTObj>()) { |
75 | for (size_t i = 0; i < adt->size; ++i) { |
76 | FlattenADTAux((*adt)[i], out); |
77 | } |
78 | } else { |
79 | LOG(FATAL) << "unsupported " << object_ref; |
80 | } |
81 | } |
82 | |
83 | std::vector<NDArray> FlattenADT(const ObjectRef& object_ref) { |
84 | std::vector<NDArray> out; |
85 | FlattenADTAux(object_ref, &out); |
86 | return out; |
87 | } |
88 | |
89 | std::vector<NDArray> FlattenADTs(const std::vector<ObjectRef>& object_refs) { |
90 | std::vector<NDArray> out; |
91 | for (const auto& object_ref : object_refs) { |
92 | FlattenADTAux(object_ref, &out); |
93 | } |
94 | return out; |
95 | } |
96 | |
97 | // Analogue of ToTupleType for runtime ADT vs NDArray values. |
98 | // TODO(mbs): Hoist somewhere sensible, maybe op/memory.h? |
99 | void ToADTOrNDArrayAux(const Type& type, const std::vector<NDArray>& nd_arrays, int* index, |
100 | std::vector<ObjectRef>* out) { |
101 | if (type.as<TensorTypeNode>()) { |
102 | out->push_back(nd_arrays[*index]); |
103 | *index += 1; |
104 | } else if (const TupleTypeNode* ttn = type.as<TupleTypeNode>()) { |
105 | std::vector<ObjectRef> tuple_out; |
106 | for (size_t i = 0; i < ttn->fields.size(); i++) { |
107 | ToADTOrNDArrayAux(ttn->fields[i], nd_arrays, index, &tuple_out); |
108 | } |
109 | out->push_back(ADT::Tuple(tuple_out)); |
110 | } else { |
111 | LOG(FATAL) << "unsupported " << type; |
112 | } |
113 | } |
114 | |
115 | ObjectRef ToADTOrNDArray(const Type& type, const std::vector<NDArray>& nd_arrays) { |
116 | if (type.as<TensorTypeNode>() && nd_arrays.size() == 1) { |
117 | return nd_arrays[0]; |
118 | } else { |
119 | std::vector<ObjectRef> out; |
120 | int index = 0; |
121 | ToADTOrNDArrayAux(type, nd_arrays, &index, &out); |
122 | return out[0]; |
123 | } |
124 | } |
125 | |
126 | } // namespace |
127 | |
128 | InterpreterClosure::InterpreterClosure(Map<Var, ObjectRef> env, Function func) { |
129 | ObjectPtr<InterpreterClosureObj> n = make_object<InterpreterClosureObj>(); |
130 | n->env = std::move(env); |
131 | n->func = std::move(func); |
132 | data_ = std::move(n); |
133 | } |
134 | |
135 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
136 | .set_dispatch<InterpreterClosureObj>([](const ObjectRef& ref, ReprPrinter* p) { |
137 | auto* node = static_cast<const InterpreterClosureObj*>(ref.get()); |
138 | p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")" ; |
139 | }); |
140 | |
141 | inline const PackedFunc& GetPackedFunc(const std::string& name) { |
142 | const PackedFunc* pf = runtime::Registry::Get(name); |
143 | ICHECK(pf != nullptr) << "Cannot find function " << name << " in registry" ; |
144 | return *pf; |
145 | } |
146 | |
147 | // TODO(@jroesch): this doesn't support mutual letrec |
148 | /* Object Implementation */ |
149 | RecClosure::RecClosure(InterpreterClosure clos, Var bind) { |
150 | ObjectPtr<RecClosureObj> n = make_object<RecClosureObj>(); |
151 | n->clos = std::move(clos); |
152 | n->bind = std::move(bind); |
153 | data_ = std::move(n); |
154 | } |
155 | |
156 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
157 | .set_dispatch<RecClosureObj>([](const ObjectRef& ref, ReprPrinter* p) { |
158 | auto* node = static_cast<const RecClosureObj*>(ref.get()); |
159 | p->stream << "RecClosureObj(" << node->clos << ")" ; |
160 | }); |
161 | |
162 | RefValue::RefValue(ObjectRef value) { |
163 | ObjectPtr<RefValueObj> n = make_object<RefValueObj>(); |
164 | n->value = value; |
165 | data_ = std::move(n); |
166 | } |
167 | |
168 | TVM_REGISTER_GLOBAL("relay._make.RefValue" ).set_body_typed([](ObjectRef value) { |
169 | return RefValue(value); |
170 | }); |
171 | |
172 | TVM_REGISTER_NODE_TYPE(RefValueObj); |
173 | |
174 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
175 | .set_dispatch<RefValueObj>([](const ObjectRef& ref, ReprPrinter* p) { |
176 | auto* node = static_cast<const RefValueObj*>(ref.get()); |
177 | p->stream << "RefValueObj(" << node->value << ")" ; |
178 | }); |
179 | |
180 | ConstructorValue::ConstructorValue(int32_t tag, Array<ObjectRef> fields, Constructor constructor) { |
181 | ObjectPtr<ConstructorValueObj> n = make_object<ConstructorValueObj>(); |
182 | n->tag = tag; |
183 | n->fields = fields; |
184 | n->constructor = constructor; |
185 | data_ = std::move(n); |
186 | } |
187 | |
188 | TVM_REGISTER_GLOBAL("relay._make.ConstructorValue" ) |
189 | .set_body_typed([](int32_t tag, Array<ObjectRef> fields, Constructor constructor) { |
190 | return ConstructorValue(tag, fields, constructor); |
191 | }); |
192 | |
193 | TVM_REGISTER_NODE_TYPE(ConstructorValueObj); |
194 | |
195 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
196 | .set_dispatch<ConstructorValueObj>([](const ObjectRef& ref, ReprPrinter* p) { |
197 | auto* node = static_cast<const ConstructorValueObj*>(ref.get()); |
198 | p->stream << "ConstructorValueObj(" << node->tag << "," << node->fields << ")" ; |
199 | }); |
200 | |
201 | /*! |
202 | * \brief A stack frame in the Relay interpreter. |
203 | * |
204 | * Contains a mapping from relay::Var to relay::ObjectRef. |
205 | */ |
206 | struct Frame { |
207 | /*! \brief The set of local variables and arguments for the frame. */ |
208 | Map<Var, ObjectRef> locals; |
209 | |
210 | explicit Frame(Map<Var, ObjectRef> locals) : locals(locals) {} |
211 | }; |
212 | |
213 | /*! |
214 | * \brief The call stack in the Relay interpreter. |
215 | * |
216 | * Contains a stack of frames; each corresponding to |
217 | * a function call. |
218 | */ |
219 | struct Stack { |
220 | /*! \brief The stack frames. */ |
221 | std::vector<Frame> frames; |
222 | Stack() : frames() { frames.push_back(Frame({})); } |
223 | |
224 | Frame& current_frame() { return frames.back(); } |
225 | |
226 | ObjectRef Lookup(const Var& local) { |
227 | for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) { |
228 | auto elem = frame->locals.find(local); |
229 | if (elem != frame->locals.end()) { |
230 | return (*elem).second; |
231 | } |
232 | } |
233 | |
234 | LOG(FATAL) << "could not find variable binding for " << local |
235 | << "address= " << local.operator->(); |
236 | return ObjectRef(); |
237 | } |
238 | /*! |
239 | * A wrapper around Frame to add RAII semantics to pushing and popping |
240 | * stack frames. |
241 | */ |
242 | struct LocalFrame { |
243 | Stack& st; |
244 | explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { st.frames.push_back(fr); } |
245 | ~LocalFrame() { st.frames.pop_back(); } |
246 | }; |
247 | }; |
248 | |
249 | /*! \brief A representation of the interpreter state which can be passed back to Python. */ |
250 | class InterpreterState; |
251 | |
252 | /*! \brief A container capturing the state of the interpreter. */ |
253 | class InterpreterStateObj : public Object { |
254 | public: |
255 | using Frame = Map<Var, ObjectRef>; |
256 | using Stack = Array<Frame>; |
257 | |
258 | /*! \brief The current expression under evaluation. */ |
259 | Expr current_expr; |
260 | |
261 | /*! \brief The call stack of the interpreter. */ |
262 | Stack stack; |
263 | |
264 | void VisitAttrs(AttrVisitor* v) { |
265 | v->Visit("current_expr" , ¤t_expr); |
266 | v->Visit("stack" , &stack); |
267 | } |
268 | |
269 | static constexpr const char* _type_key = "relay.InterpreterState" ; |
270 | TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateObj, Object); |
271 | }; |
272 | |
273 | class InterpreterState : public ObjectRef { |
274 | public: |
275 | using Frame = Map<Var, ObjectRef>; |
276 | using Stack = Array<Frame>; |
277 | |
278 | InterpreterState(Expr current_expr, Stack stack); |
279 | |
280 | TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateObj); |
281 | }; |
282 | |
283 | InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack stack) { |
284 | ObjectPtr<InterpreterStateObj> n = make_object<InterpreterStateObj>(); |
285 | n->current_expr = std::move(current_expr); |
286 | n->stack = std::move(stack); |
287 | data_ = std::move(n); |
288 | } |
289 | |
290 | // NOTE: the current interpreter assumes A-normal form. |
291 | // which is better for execution. |
292 | // |
293 | // It will run duplicated computations when taking program that |
294 | // contains DAG in dataflow-form. |
295 | // |
296 | // Conversion to ANF is recommended before running the interpretation. |
297 | class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>, |
298 | PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> { |
299 | public: |
300 | Interpreter(IRModule unified_mod, CompilationConfig config, Device device) |
301 | : unified_mod_(unified_mod), |
302 | config_(std::move(config)), |
303 | device_(device), |
304 | debug_op_(Op::Get("debug" )) {} |
305 | |
306 | template <typename T> |
307 | T WithFrame(const Frame& fr, const std::function<T()>& f) { |
308 | Stack::LocalFrame lf(stack_, fr); |
309 | return f(); |
310 | } |
311 | |
312 | void extend(const Var& id, ObjectRef v) { stack_.current_frame().locals.Set(id, v); } |
313 | |
314 | ObjectRef Lookup(const Var& local) { return stack_.Lookup(local); } |
315 | |
316 | ObjectRef Eval(const Expr& expr) { return VisitExpr(expr); } |
317 | |
318 | ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef<Var>(var_node)); } |
319 | |
320 | ObjectRef VisitExpr_(const GlobalVarNode* op) final { |
321 | return Eval(unified_mod_->Lookup(GetRef<GlobalVar>(op))); |
322 | } |
323 | |
324 | ObjectRef VisitExpr_(const OpNode* id) override { |
325 | // TODO(@jroesch): Eta-expand and return in this case. |
326 | LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node " |
327 | << "in this case, eta expand" ; |
328 | return ObjectRef(); |
329 | } |
330 | |
331 | ObjectRef VisitExpr_(const ConstantNode* op) final { return op->data.CopyTo(device_); } |
332 | |
333 | ObjectRef VisitExpr_(const TupleNode* op) final { |
334 | std::vector<ObjectRef> values; |
335 | |
336 | for (const auto& field : op->fields) { |
337 | ObjectRef field_value = Eval(field); |
338 | values.push_back(field_value); |
339 | } |
340 | |
341 | return ADT::Tuple(values); |
342 | } |
343 | |
344 | ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) { |
345 | Map<Var, ObjectRef> captured_mod; |
346 | Array<Var> free_vars = FreeVars(func); |
347 | |
348 | for (const auto& var : free_vars) { |
349 | // Evaluate the free var (which could be a function call) if it hasn't |
350 | // shown up in a letting binding that has invoked the function. |
351 | if (letrec_name.defined() && letrec_name == var) { |
352 | continue; |
353 | } |
354 | |
355 | captured_mod.Set(var, Eval(var)); |
356 | } |
357 | |
358 | // We must use mutation here to build a self referential closure. |
359 | InterpreterClosure closure(captured_mod, func); |
360 | if (letrec_name.defined()) { |
361 | return RecClosure(closure, letrec_name); |
362 | } |
363 | return std::move(closure); |
364 | } |
365 | |
366 | ObjectRef VisitExpr_(const FunctionNode* func_node) final { |
367 | auto func = GetRef<Function>(func_node); |
368 | return MakeClosure(func); |
369 | } |
370 | |
371 | /*! |
372 | * \brief Returns the packed function implementing the TIR function bound to \p tir_fn_var. |
373 | * |
374 | * \param tir_fn_var Global var for the already lowered TIR function. |
375 | * \param all_tir_fn_vars Global vars for all lowered TIR functions the above |
376 | * may reference, plus \p tir_fn_var itself. |
377 | * \param target Target for which the TIR function should be compiled. For primitives this |
378 | * will be the interpreter's target_. However for shape functions this will be the generic |
379 | * 'cpu' target, since shape functions are always executed on the host cpu. |
380 | */ |
381 | PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array<GlobalVar>& all_tir_fn_vars, |
382 | Target target) { |
383 | std::pair<Target, std::string> packed_func_key(target, tir_fn_var->name_hint); |
384 | auto packed_itr = compiled_packed_funcs_.find(packed_func_key); |
385 | if (packed_itr != compiled_packed_funcs_.end()) { |
386 | // Already compiled. |
387 | return packed_itr->second; |
388 | } |
389 | |
390 | // Project out just the function(s) we need. |
391 | IRModule lowered_projected_mod; |
392 | Map<Target, IRModule> per_target_module = tec::GetPerTargetModules(unified_mod_); |
393 | std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual> |
394 | per_target_module_std_map = backend::TargetModuleMapToTargetStrModuleMap(per_target_module); |
395 | auto mod_itr = per_target_module_std_map.find(target); |
396 | ICHECK(mod_itr != per_target_module_std_map.end()) |
397 | << "No target module for target " << target->ToDebugString(); |
398 | const IRModule& target_module = (*mod_itr).second; |
399 | for (const auto& var : all_tir_fn_vars) { |
400 | ICHECK(target_module->ContainGlobalVar(var->name_hint)) |
401 | << "No global var for '" << var->name_hint << "' in module for target " |
402 | << target->ToDebugString(); |
403 | lowered_projected_mod->Add(var, target_module->Lookup(var->name_hint)); |
404 | } |
405 | |
406 | // Compile (aka 'build') the projected module into a runtime module of packed functions. |
407 | runtime::Module runtime_module; |
408 | if (const auto* f = runtime::Registry::Get("relay.backend.build" )) { |
409 | // TODO(mbs): Cleanup hooks. |
410 | runtime_module = (*f)(lowered_projected_mod, target); |
411 | } else { |
412 | runtime_module = build(lowered_projected_mod, target, /*target_host=*/Target(nullptr)); |
413 | } |
414 | |
415 | // Extract all the packed functions. |
416 | for (const auto& var : all_tir_fn_vars) { |
417 | PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); |
418 | ICHECK(packed_func != nullptr) |
419 | << "No packed function for global var '" << var->name_hint |
420 | << "' in compiled module for target " << target->ToDebugString(); |
421 | compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); |
422 | } |
423 | |
424 | // Return just what we need for this call. |
425 | packed_itr = compiled_packed_funcs_.find(packed_func_key); |
426 | ICHECK(packed_itr != compiled_packed_funcs_.end()) << " " << tir_fn_var->name_hint; |
427 | ICHECK_NOTNULL(packed_itr->second); |
428 | return packed_itr->second; |
429 | } |
430 | |
431 | /*! |
432 | * \brief Call the dynamic shape function bound to \p prim_shape_fn_var passing the |
433 | * shapes of args, and return the resulting shapes. |
434 | * |
435 | * \param prim_shape_fn_var Global var bound to lowered shape function. |
436 | * \param all_prim_shape_fn_vars All the global vars needed to build the above, including |
437 | * the shape function itself. |
438 | * \param prim_shape_fn_states For each primitive arg, indicate whether the primitive shape |
439 | * function requires the shape of the argument and/or the actual argument tensor. |
440 | * \param num_shape_inputs The number of inputs, after accounting for both shapes vs data |
441 | * inputs and unfolding of tuple types. |
442 | * \param num_shape_outputs The number of outputs, after accounting for flattening of |
443 | * tuple types. |
444 | * \param args Arguments to the primitive this shape function is for. |
445 | * \return Expected shapes of the underlying primitive's flattened outputs. |
446 | */ |
447 | Array<Shape> ComputeDynamicShape(const GlobalVar& prim_shape_fn_var, |
448 | const Array<GlobalVar>& all_prim_shape_fn_vars, |
449 | const Array<Integer>& prim_shape_fn_states, |
450 | size_t num_shape_inputs, size_t num_shape_outputs, |
451 | Target prim_shape_target, const std::vector<ObjectRef>& args) { |
452 | VLOG_CONTEXT << "ComputeDynamicShape" ; |
453 | ICHECK(prim_shape_fn_var.defined()); |
454 | ICHECK(prim_shape_fn_var->checked_type().defined()); |
455 | VLOG(1) << "prim_shape_fn_var:" << std::endl << PrettyPrint(prim_shape_fn_var); |
456 | ICHECK(prim_shape_fn_states.defined()); |
457 | for (size_t i = 0; i < prim_shape_fn_states.size(); ++i) { |
458 | VLOG(1) << "prim_shape_fn_states[" << i << "]: " << prim_shape_fn_states[i]; |
459 | } |
460 | VLOG(1) << "num_shape_inputs: " << num_shape_inputs; |
461 | VLOG(1) << "num_shape_outputs: " << num_shape_outputs; |
462 | VLOG(1) << "args.size(): " << args.size(); |
463 | VLOG(1) << "prim_shape_target: " << prim_shape_target->ToDebugString(); |
464 | |
465 | // The function type is that of the shape function rather than the original primitive the shape |
466 | // function is for. |
467 | const auto* func_type_node = prim_shape_fn_var->checked_type().as<FuncTypeNode>(); |
468 | ICHECK(func_type_node); |
469 | // The shape function states are w.r.t. the original primitive's arguments in |
470 | // non-flattened form. |
471 | // TODO(mbs): Clean this up so we don't mix flattened vs original conventions. |
472 | ICHECK_EQ(args.size(), prim_shape_fn_states.size()); |
473 | |
474 | // num_shape_inputs will account for which primitive function arguments are dynamic, |
475 | // whether the shape and or data needs to be passed, and flattening of tuples. |
476 | // Similarly, num_shape_outputs will account for flattening of tuples. |
477 | |
478 | // TODO(mbs): Take this from the host_virtual_device. |
479 | Device shape_device; |
480 | shape_device.device_type = static_cast<DLDeviceType>(prim_shape_target->GetTargetDeviceType()); |
481 | shape_device.device_id = 0; |
482 | |
483 | // 'Compile' the TIR shape function to appropriate callable form. |
484 | PackedFunc packed_shape_func = |
485 | TIRToPackedFunc(prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_target); |
486 | |
487 | size_t arity = num_shape_inputs + num_shape_outputs; |
488 | std::vector<TVMValue> values(arity); |
489 | std::vector<int> codes(arity); |
490 | TVMArgsSetter setter(values.data(), codes.data()); |
491 | std::vector<NDArray> inputs(num_shape_inputs); |
492 | std::vector<NDArray> outputs(num_shape_outputs); |
493 | |
494 | // Collect the shapes and/or data needed by the shape function from |
495 | // the primitive's arguments. |
496 | size_t arg_counter = 0; |
497 | for (size_t i = 0; i < args.size(); ++i) { |
498 | // TODO(mbs): The same need data/need shape arg state applies to everything in the |
499 | // flattened form of this arg. Does that match what lowering actually does? |
500 | int64_t state = prim_shape_fn_states[i]->value; |
501 | for (const auto& nd_array : FlattenADT(args[i])) { |
502 | if (state & tec::kNeedInputData) { |
503 | auto arr = nd_array.CopyTo(shape_device); |
504 | inputs[arg_counter] = arr; |
505 | setter(arg_counter, arr); |
506 | ++arg_counter; |
507 | } |
508 | if (state & tec::kNeedInputShape) { |
509 | int64_t ndim = nd_array.Shape().size(); |
510 | NDArray shape_arr; |
511 | if (ndim == 0) { |
512 | shape_arr = NDArray::Empty({}, DataType::Int(64), shape_device); |
513 | } else { |
514 | shape_arr = NDArray::Empty({ndim}, DataType::Int(64), shape_device); |
515 | int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data); |
516 | for (auto j = 0; j < ndim; ++j) { |
517 | data[j] = nd_array.Shape()[j]; |
518 | } |
519 | } |
520 | inputs[arg_counter] = shape_arr; |
521 | setter(arg_counter, shape_arr); |
522 | ++arg_counter; |
523 | } |
524 | } |
525 | } |
526 | ICHECK_EQ(arg_counter, num_shape_inputs) << "Shape function input sizes mismatch" ; |
527 | |
528 | // Prepare NDArrays to hold the output shapes. |
529 | size_t out_cnt = 0; |
530 | for (const auto& ttype : FlattenTupleType(func_type_node->ret_type)) { |
531 | ICHECK(out_cnt < num_shape_outputs); |
532 | std::vector<int64_t> concrete_shape; |
533 | for (const auto& dim : ttype->shape) { |
534 | const auto* ivalue = tir::as_const_int(dim); |
535 | ICHECK(ivalue) << "expected concrete dimensions" ; |
536 | concrete_shape.push_back(ivalue[0]); |
537 | } |
538 | auto arr = NDArray::Empty(concrete_shape, ttype->dtype, shape_device); |
539 | outputs[out_cnt] = arr; |
540 | setter(arg_counter + out_cnt, arr); |
541 | ++out_cnt; |
542 | } |
543 | ICHECK_EQ(out_cnt, num_shape_outputs) << "Shape function output sizes mismatch" ; |
544 | |
545 | // Call the dynamic shape function. |
546 | TVMRetValue rv; // ignored |
547 | packed_shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); |
548 | |
549 | // Convert result tensors back to shapes. |
550 | Array<Shape> out_shapes; |
551 | for (auto out_tensor : outputs) { |
552 | int64_t* shape_data = reinterpret_cast<int64_t*>(out_tensor->data); |
553 | Shape out_shape; |
554 | for (int i = 0; i < out_tensor->shape[0]; ++i) { |
555 | out_shape.push_back(Integer(shape_data[i])); |
556 | } |
557 | out_shapes.push_back(out_shape); |
558 | } |
559 | return out_shapes; |
560 | } |
561 | |
562 | /*! |
563 | * \brief Call primitive op bound to \p prim_fn_var with \p args. If necessary, evaluate dynamic |
564 | * shape function bound to \p prim_shape_fn_var to calculate shapes of result tensors. |
565 | * |
566 | * @param prim_fn_var Global bound to lowered primitive. |
567 | * @param all_prim_fn_vars All globals references by lowered primitive, plus prim_fn_var itself. |
568 | * @param prim_shape_fn_var Global bound to lowered shape function for primitive, if needed. |
569 | * @param all_prim_shape_fn_vars All globals references by lowered shape function, plus |
570 | * prim_shape_fn_var itself. |
571 | * @param prim_shape_fn_states Records whether shape and/or data is needed by the dynamic |
572 | * shape function (if any) for each (flattened) argument. |
573 | * @param num_shape_inputs Number of arguments to the dynamic shape function (if any). |
574 | * @param num_shape_outputs Number of outputs from the dynamic shape function (if any). |
575 | * @param args Already evaluated arguments to primitive. |
576 | * @return Result of primitive. |
577 | */ |
578 | ObjectRef InvokePrimitiveOp(const GlobalVar& prim_fn_var, const Array<GlobalVar> all_prim_fn_vars, |
579 | Target prim_target, const GlobalVar& prim_shape_fn_var, |
580 | const Array<GlobalVar>& all_prim_shape_fn_vars, |
581 | const Array<Integer>& prim_shape_fn_states, size_t num_shape_inputs, |
582 | size_t num_shape_outputs, Target prim_shape_target, |
583 | const std::vector<ObjectRef>& args) { |
584 | ICHECK(prim_fn_var->checked_type().defined()); |
585 | const FuncTypeNode* ftn = prim_fn_var->checked_type().as<FuncTypeNode>(); |
586 | ICHECK(ftn); |
587 | |
588 | // 'Compile' the TIR primitive to appropriate callable form (on the desired target). |
589 | PackedFunc packed_func = TIRToPackedFunc(prim_fn_var, all_prim_fn_vars, prim_target); |
590 | |
591 | // Argument tuples are flattened. |
592 | std::vector<NDArray> arg_nd_arrays = FlattenADTs(args); |
593 | const size_t num_inputs = arg_nd_arrays.size(); |
594 | // num_inputs should equal size(concat(map(FlattenTupleType, function arg types))) |
595 | |
596 | // TVM's primitive calling convention is for the final arguments to be for output |
597 | // buffers. We must allocate space for those buffers based on the return type. |
598 | std::vector<TensorType> result_tensor_types = FlattenTupleType(ftn->ret_type); |
599 | const size_t arg_len = num_inputs + result_tensor_types.size(); |
600 | |
601 | std::vector<TVMValue> values(arg_len); |
602 | std::vector<int> codes(arg_len); |
603 | TVMArgsSetter setter(values.data(), codes.data()); |
604 | |
605 | // Marshall the call's arguments in flattened form. |
606 | int arg_counter = 0; |
607 | for (const auto& nd_array : arg_nd_arrays) { |
608 | setter(arg_counter++, nd_array); |
609 | Device arg_dev = nd_array->device; |
610 | ICHECK(arg_dev.device_type == device_.device_type && arg_dev.device_id == device_.device_id) |
611 | << "Interpreter expect device to be " << device_ << ", but got " << arg_dev; |
612 | } |
613 | |
614 | // If necessary, retrieve concrete shapes for outputs from shape function rather |
615 | // than relying on TensorType shapes. |
616 | Array<Shape> runtime_shapes; |
617 | bool is_dyn = IsDynamic(ftn->ret_type); |
618 | if (is_dyn) { |
619 | ICHECK(prim_shape_fn_var.defined()); |
620 | ICHECK(prim_shape_fn_states.defined()); |
621 | runtime_shapes = |
622 | ComputeDynamicShape(prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, |
623 | num_shape_inputs, num_shape_outputs, prim_shape_target, args); |
624 | ICHECK_EQ(runtime_shapes.size(), result_tensor_types.size()); |
625 | } |
626 | |
627 | // Prepare the result tensors for the call. |
628 | TVMRetValue rv; // ignored |
629 | std::vector<NDArray> result_nd_arrays; |
630 | for (size_t i = 0; i < result_tensor_types.size(); ++i) { |
631 | const auto& ttype = result_tensor_types[i]; |
632 | const Shape& shape = is_dyn ? runtime_shapes[i] : ttype->shape; |
633 | // Allocate output tensor of appropriate shape. |
634 | std::vector<int64_t> concrete_shape; |
635 | for (const auto& dim : shape) { |
636 | const auto* ivalue = tir::as_const_int(dim); |
637 | ICHECK(ivalue) << "expected concrete dimensions" ; |
638 | concrete_shape.push_back(ivalue[0]); |
639 | } |
640 | NDArray nd_array = NDArray::Empty(concrete_shape, ttype->dtype, device_); |
641 | setter(num_inputs + i, nd_array); |
642 | result_nd_arrays.emplace_back(nd_array); |
643 | } |
644 | |
645 | // Call the primitive. |
646 | packed_func.CallPacked(TVMArgs(values.data(), codes.data(), static_cast<int>(arg_len)), &rv); |
647 | |
648 | // Unflatten the results. |
649 | return ToADTOrNDArray(ftn->ret_type, result_nd_arrays); |
650 | } |
651 | |
652 | /*! |
653 | * \brief Invoke \p closure with \p args. If \p bind is defined then this is a recursive |
654 | * closure and \p bind should refer to itself. |
655 | */ |
656 | ObjectRef Invoke(const InterpreterClosure& closure, const Array<ObjectRef>& args, |
657 | const Var& bind = Var()) { |
658 | // Get a reference to the function inside the closure. |
659 | Function func = closure->func; |
660 | ICHECK_EQ(func->params.size(), args.size()); |
661 | |
662 | if (func->HasNonzeroAttr(attr::kPrimitive)) { |
663 | if (const CallNode* call_node = closure->func->body.as<CallNode>()) { |
664 | if (call_node->op == debug_op_) { |
665 | // Special case: Calling the debug tracing function. |
666 | auto dattrs = call_node->attrs.as<DebugAttrs>(); |
667 | auto interp_state = get_state(call_node->args[0]); |
668 | |
669 | if (dattrs->debug_func.defined()) { |
670 | dattrs->debug_func(interp_state); |
671 | } else { |
672 | RELAY_DEBUG_INTERP(interp_state); |
673 | } |
674 | |
675 | return args[0]; |
676 | } |
677 | } |
678 | } |
679 | |
680 | ICHECK(!func->HasNonzeroAttr(attr::kPrimitive)) |
681 | << "Calls to primitive functions should have been removed by lowering" ; |
682 | |
683 | // Allocate a frame with the parameters and free variables. |
684 | Map<Var, ObjectRef> locals; |
685 | for (size_t i = 0; i < func->params.size(); i++) { |
686 | ICHECK_EQ(locals.count(func->params[i]), 0); |
687 | locals.Set(func->params[i], args[i]); |
688 | } |
689 | |
690 | // Add the var to value mappings from the Closure's environment. |
691 | for (auto it = closure->env.begin(); it != closure->env.end(); ++it) { |
692 | ICHECK_EQ(locals.count((*it).first), 0); |
693 | locals.Set((*it).first, (*it).second); |
694 | } |
695 | |
696 | if (bind.defined()) { |
697 | locals.Set(bind, RecClosure(closure, bind)); |
698 | } |
699 | |
700 | return WithFrame<ObjectRef>(Frame(locals), [&]() { return Eval(func->body); }); |
701 | } |
702 | |
703 | ObjectRef VisitExpr_(const CallNode* call_node) final { |
704 | DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); |
705 | CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); |
706 | |
707 | if (device_copy_props.body.defined()) { |
708 | // TODO(mbs): device_copy cleanup |
709 | LOG(FATAL) << "The interpreter does not support device_copy" ; |
710 | } else if (call_lowered_props.lowered_func.defined()) { |
711 | // Special case: Call a lowered TIR function. |
712 | |
713 | // Evaluate only function args |
714 | std::vector<ObjectRef> args; |
715 | for (auto arg : call_lowered_props.arguments) { |
716 | args.push_back(Eval(arg)); |
717 | } |
718 | |
719 | // TODO(mbs): Make calling convention first-class in Relay. |
720 | Array<GlobalVar> all_prim_fn_vars; |
721 | if (call_lowered_props.attrs.metadata.count("all_prim_fn_vars" )) { |
722 | all_prim_fn_vars = |
723 | Downcast<Array<GlobalVar>>(call_lowered_props.attrs.metadata.at("all_prim_fn_vars" )); |
724 | } |
725 | GlobalVar prim_shape_fn_var; |
726 | if (call_lowered_props.attrs.metadata.count("prim_shape_fn_var" )) { |
727 | prim_shape_fn_var = |
728 | Downcast<GlobalVar>(call_lowered_props.attrs.metadata.at("prim_shape_fn_var" )); |
729 | } |
730 | Array<GlobalVar> all_prim_shape_fn_vars; |
731 | if (call_lowered_props.attrs.metadata.count("all_prim_shape_fn_vars" )) { |
732 | all_prim_shape_fn_vars = Downcast<Array<GlobalVar>>( |
733 | call_lowered_props.attrs.metadata.at("all_prim_shape_fn_vars" )); |
734 | } |
735 | Array<Integer> prim_shape_fn_states; |
736 | if (call_lowered_props.attrs.metadata.count("prim_shape_fn_states" )) { |
737 | prim_shape_fn_states = |
738 | Downcast<Array<Integer>>(call_lowered_props.attrs.metadata.at("prim_shape_fn_states" )); |
739 | } |
740 | |
741 | size_t num_shape_inputs = 0; |
742 | if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_inputs" )) { |
743 | num_shape_inputs = static_cast<size_t>( |
744 | Downcast<Integer>(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_inputs" )) |
745 | ->value); |
746 | } |
747 | size_t num_shape_outputs = 0; |
748 | if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_outputs" )) { |
749 | num_shape_outputs = static_cast<size_t>( |
750 | Downcast<Integer>(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_outputs" )) |
751 | ->value); |
752 | } |
753 | ICHECK(config_->optional_homogeneous_target.defined()); |
754 | return InvokePrimitiveOp(call_lowered_props.lowered_func, all_prim_fn_vars, |
755 | config_->optional_homogeneous_target, prim_shape_fn_var, |
756 | all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, |
757 | num_shape_outputs, config_->host_virtual_device->target, args); |
758 | } else { // All other calls |
759 | // Evaluate all arguments |
760 | std::vector<ObjectRef> args; |
761 | for (auto arg : call_node->args) { |
762 | args.push_back(Eval(arg)); |
763 | } |
764 | |
765 | if (call_node->op == OnDeviceOp()) { |
766 | // Special case: The call 'on_device(expr)' denotes that expr should be executed on |
767 | // a particular device. We can ignore this during interpretation. |
768 | ICHECK_EQ(call_node->args.size(), 1UL); |
769 | return args[0]; |
770 | } |
771 | if (const ConstructorNode* con = call_node->op.as<ConstructorNode>()) { |
772 | // Special case: ADT constructor |
773 | |
774 | return ConstructorValue(con->tag, args, GetRef<Constructor>(con)); |
775 | } |
776 | |
777 | if (const OpNode* op_node = call_node->op.as<OpNode>()) { |
778 | // Except for call_lowered and on_device, we should not find calls to operators after |
779 | // running fusion and lowering. |
780 | LOG(FATAL) << "found " << op_node->name |
781 | << "; operators should have been removed by previous passes; try " |
782 | "fusing and lowering" ; |
783 | } |
784 | |
785 | // Now we just evaluate and expect to find a closure. |
786 | // TODO(@electriclilies): How should call_lowered behave with closures? |
787 | ObjectRef fn_val = Eval(call_node->op); |
788 | if (const InterpreterClosureObj* closure_node = fn_val.as<InterpreterClosureObj>()) { |
789 | auto closure = GetRef<InterpreterClosure>(closure_node); |
790 | return Invoke(closure, args); |
791 | } else if (const RecClosureObj* closure_node = fn_val.as<RecClosureObj>()) { |
792 | return Invoke(closure_node->clos, args, closure_node->bind); |
793 | } else { |
794 | LOG(FATAL) << "internal error: type error, expected function value in the call " |
795 | << "position" ; |
796 | return ObjectRef(); |
797 | } |
798 | } |
799 | } |
800 | |
801 | ObjectRef VisitExpr_(const LetNode* let) final { |
802 | if (auto func = let->value.as<FunctionNode>()) { |
803 | auto clo = MakeClosure(GetRef<Function>(func), let->var); |
804 | this->extend(let->var, clo); |
805 | } else { |
806 | auto value = Eval(let->value); |
807 | this->extend(let->var, value); |
808 | } |
809 | |
810 | return Eval(let->body); |
811 | } |
812 | |
813 | ObjectRef VisitExpr_(const TupleGetItemNode* op) final { |
814 | ObjectRef val = Eval(op->tuple); |
815 | const auto* adt_obj = val.as<ADTObj>(); |
816 | ICHECK(adt_obj) << "internal error: when evaluating TupleGetItem expected an ADT value" ; |
817 | auto adt = GetRef<ADT>(adt_obj); |
818 | ICHECK_LT(static_cast<size_t>(op->index), adt.size()) << "internal error: index out of bounds" ; |
819 | return adt[op->index]; |
820 | } |
821 | |
822 | ObjectRef VisitExpr_(const IfNode* op) final { |
823 | ObjectRef v = Eval(op->cond); |
824 | if (v->IsInstance<NDArray::ContainerType>()) { |
825 | auto nd_array = Downcast<NDArray>(v); |
826 | Device cpu_dev; |
827 | cpu_dev.device_type = kDLCPU; |
828 | cpu_dev.device_id = 0; |
829 | NDArray cpu_array = nd_array.CopyTo(cpu_dev); |
830 | ICHECK_EQ(DataType(cpu_array->dtype), DataType::Bool()); |
831 | // TODO(@jroesch, @MK): Refactor code into helper from DCE. |
832 | if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) { |
833 | return Eval(op->true_branch); |
834 | } else { |
835 | return Eval(op->false_branch); |
836 | } |
837 | } else { |
838 | LOG(FATAL) << "type error, type system should have caught this" ; |
839 | } |
840 | } |
841 | |
842 | ObjectRef VisitExpr_(const RefWriteNode* op) final { |
843 | ObjectRef r = Eval(op->ref); |
844 | if (const RefValueObj* rv = r.as<RefValueObj>()) { |
845 | rv->value = Eval(op->value); |
846 | return ADT::Tuple(std::vector<ObjectRef>()); |
847 | } else { |
848 | LOG(FATAL) << "type error, type system should have caught this" ; |
849 | } |
850 | } |
851 | |
852 | ObjectRef VisitExpr_(const RefCreateNode* op) final { return RefValue(Eval(op->value)); } |
853 | |
854 | ObjectRef VisitExpr_(const RefReadNode* op) final { |
855 | ObjectRef r = Eval(op->ref); |
856 | if (const RefValueObj* rv = r.as<RefValueObj>()) { |
857 | return rv->value; |
858 | } else { |
859 | LOG(FATAL) << "type error, type system should have caught this" ; |
860 | } |
861 | } |
862 | |
863 | ObjectRef VisitExpr_(const MatchNode* op) final { |
864 | ObjectRef v = Eval(op->data); |
865 | for (const Clause& c : op->clauses) { |
866 | if (VisitPattern(c->lhs, v)) { |
867 | return VisitExpr(c->rhs); |
868 | } |
869 | } |
870 | LOG(FATAL) << "did not find any match" ; |
871 | } |
872 | |
873 | bool VisitPattern_(const PatternConstructorNode* op, const ObjectRef& v) final { |
874 | const ConstructorValueObj* cvn = v.as<ConstructorValueObj>(); |
875 | ICHECK(cvn) << "need to be a constructor for match" ; |
876 | ICHECK_NE(op->constructor->tag, -1); |
877 | ICHECK_NE(cvn->tag, -1); |
878 | if (op->constructor->tag == cvn->tag) { |
879 | ICHECK_EQ(op->patterns.size(), cvn->fields.size()); |
880 | for (size_t i = 0; i < op->patterns.size(); ++i) { |
881 | if (!VisitPattern(op->patterns[i], cvn->fields[i])) { |
882 | return false; |
883 | } |
884 | } |
885 | return true; |
886 | } |
887 | return false; |
888 | } |
889 | |
890 | bool VisitPattern_(const PatternTupleNode* op, const ObjectRef& v) final { |
891 | auto adt = Downcast<ADT>(v); |
892 | ICHECK_EQ(op->patterns.size(), adt.size()); |
893 | for (size_t i = 0; i < op->patterns.size(); ++i) { |
894 | if (!VisitPattern(op->patterns[i], adt[i])) { |
895 | return false; |
896 | } |
897 | } |
898 | return true; |
899 | } |
900 | |
901 | bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { return true; } |
902 | |
903 | bool VisitPattern_(const PatternVarNode* op, const ObjectRef& v) final { |
904 | extend(op->var, v); |
905 | return true; |
906 | } |
907 | |
908 | InterpreterState get_state(Expr e = Expr()) const { |
909 | InterpreterStateObj::Stack stack; |
910 | for (auto fr : this->stack_.frames) { |
911 | InterpreterStateObj::Frame frame = fr.locals; |
912 | stack.push_back(frame); |
913 | } |
914 | auto state = InterpreterState(e, stack); |
915 | return state; |
916 | } |
917 | |
918 | private: |
919 | // Unified module. Functions are annotated with their target. |
920 | // All expressions are eval'ed w.r.t. the definitions in this module. |
921 | // This module contains functions that used to be in main_module and the per_target_module (TIR |
922 | // functions) in one module. |
923 | IRModule unified_mod_; |
924 | // Cached packed functions for the primitives and shape functions, keyed by target and |
925 | // global var name. |
926 | std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash> compiled_packed_funcs_; |
927 | /*! \brief Compilation config describing the available targets. */ |
928 | CompilationConfig config_; |
929 | // Unique device on which primitives (but not shape functions) will be executed. |
930 | // (For simplicity we only run the interpreter on a single device.) |
931 | Device device_; |
932 | // Call stack. |
933 | Stack stack_; |
934 | // The distinguished 'debug' operator, which is handled specially. |
935 | const Op& debug_op_; |
936 | }; |
937 | |
938 | /*! |
939 | * Lowers all calls to primitives in \p mod appropriate for \p config. Returns the |
940 | * rewritten \p mod and target-specific modules containing bindings for all TIR primitive |
941 | * functions needed by the rewritten module. |
942 | */ |
943 | IRModule Prepare(IRModule mod, const CompilationConfig& config) { |
944 | // Run minimal transforms on module to establish invariants needed by interpreter. |
945 | transform::Sequential seq( |
946 | {transform::SimplifyInference(), qnn::transform::Legalize(), |
947 | // Figure out which devices should be used to execute. |
948 | // TODO(mbs): Should ignore all existing annotations when constant folding |
949 | transform::PlanDevices(config), |
950 | // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' |
951 | // attribute. |
952 | transform::FuseOps(/*fuse_opt_level=*/0), |
953 | // Use ANF to reduce number of cases to handle. |
954 | transform::ToANormalForm(), |
955 | // eta expand to support constructors in argument position. |
956 | transform::EtaExpand( |
957 | /*expand_constructor=*/true, /*expand_global_var=*/false), |
958 | transform::InferType(), tec::LowerTE(/*module_name=*/"intrp" , config)}); |
959 | |
960 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
961 | With<transform::PassContext> ctx(pass_ctx); |
962 | mod = seq(mod); |
963 | |
964 | return mod; |
965 | } |
966 | |
967 | /*! \brief Check if an expression could be changed by \p Prepare. |
968 | * |
969 | * If not we can evaluate it directly and don't need to bind it into a fresh module. |
970 | */ |
971 | class NeedsPreparationVisitor : public ExprVisitor { |
972 | public: |
973 | bool needs_preparation = false; |
974 | |
975 | private: |
976 | void VisitExpr_(const VarNode* vn) override { |
977 | // Could be prim. |
978 | needs_preparation = true; |
979 | } |
980 | // ConstantNode ok |
981 | // GlobalVarNode ok |
982 | void VisitExpr_(const OpNode* op) override { |
983 | // Could be prim. |
984 | needs_preparation = true; |
985 | } |
986 | // TupleNode recurse |
987 | void VisitExpr_(const FunctionNode* op) override { |
988 | // Could be prim. |
989 | needs_preparation = true; |
990 | } |
991 | // CallNode recurse |
992 | void VisitExpr_(const LetNode* ln) override { |
993 | // May bind prim. |
994 | needs_preparation = true; |
995 | } |
996 | // IfNode recurse |
997 | // TupleGetItemNode recurse |
998 | // RefCreateNode recurse |
999 | // RefReadNode recurse |
1000 | // RefWriteNode recurse |
1001 | // ConstructorNode ok |
1002 | void VisitExpr_(const MatchNode* op) override { |
1003 | // Needs eta-expansion. |
1004 | needs_preparation = true; |
1005 | } |
1006 | }; |
1007 | |
1008 | TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, Device device, |
1009 | Target target) { |
1010 | VLOG_CONTEXT << "EvalFunction" ; |
1011 | VLOG(1) << "evaling module:" << std::endl |
1012 | << PrettyPrint(mod) << "and expression:" << std::endl |
1013 | << PrettyPrint(expr); |
1014 | |
1015 | ICHECK_EQ(device.device_type, target->GetTargetDeviceType()); |
1016 | Array<Target> raw_targets = {target}; |
1017 | CompilationConfig config(transform::PassContext::Current(), raw_targets); |
1018 | |
1019 | // |
1020 | // Step 1: Prepare mod. |
1021 | // |
1022 | |
1023 | // If expr is simple enough we can avoid binding it into the module and |
1024 | // just eval it directly. |
1025 | NeedsPreparationVisitor visitor; |
1026 | visitor.VisitExpr(expr); |
1027 | |
1028 | Expr expr_to_eval; |
1029 | IRModule mod_with_expr; // default empty |
1030 | if (visitor.needs_preparation) { |
1031 | GlobalVar main; |
1032 | // Bind expr to a new zero-argument function so it can be prepared along with the module |
1033 | // (if any). |
1034 | std::pair<IRModule, GlobalVar> mod_and_global; |
1035 | if (mod.defined()) { |
1036 | // TODO(mbs): Type inference currently assumes all global functions in modules have |
1037 | // known result types, and so each global function has it's body types inferred independently |
1038 | // and in arbitrary order. However, the interpreter may be called with an expression relative |
1039 | // to a 'main' which has no result type annotation, and that expressions will be bound into a |
1040 | // fresh global below. Type inference then fails since 'main' has unknown type. We should |
1041 | // allow inference on mutually recursive global functions. To workaround, infer the type |
1042 | // of mod now. Obviously that won't work if 'main' itself calls other global functions of |
1043 | // partial type, but it at least maintains legacy behavior. |
1044 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
1045 | With<transform::PassContext> ctx(pass_ctx); |
1046 | mod = transform::InferType()(mod); |
1047 | mod_and_global = |
1048 | IRModule::FromExprInContext(expr, mod->functions, mod->type_definitions, mod->Imports()); |
1049 | } else { |
1050 | mod_and_global = IRModule::FromExprInContext(expr); |
1051 | } |
1052 | mod_with_expr = mod_and_global.first; |
1053 | expr_to_eval = mod_and_global.second; |
1054 | } else { |
1055 | if (mod.defined()) { |
1056 | mod_with_expr = mod; |
1057 | } |
1058 | // Prepare won't change expr, so we don't need to worry about binding it into a module |
1059 | // and can just eval it directly. |
1060 | expr_to_eval = expr; |
1061 | } |
1062 | IRModule lowered_mod = Prepare(mod_with_expr, config); |
1063 | |
1064 | std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(lowered_mod, config, device); |
1065 | |
1066 | // |
1067 | // Step 2: Evaluate target function to a closure. |
1068 | // |
1069 | ObjectRef object_ref = intrp->Eval(expr_to_eval); |
1070 | if (const InterpreterClosureObj* closure_obj = object_ref.as<InterpreterClosureObj>()) { |
1071 | InterpreterClosure closure = GetRef<InterpreterClosure>(closure_obj); |
1072 | ICHECK(closure.defined()); |
1073 | ICHECK(closure->func.defined()); |
1074 | |
1075 | return TypedPackedFunc<ObjectRef(Array<Expr>)>([intrp, closure](Array<Expr> args) { |
1076 | VLOG_CONTEXT << "EvalFunction::Apply" ; |
1077 | VLOG(1) << "evaling closure with " << args.size() << " arguments" ; |
1078 | // |
1079 | // Step 3: Apply closure to arguments. |
1080 | // |
1081 | ICHECK_NOTNULL(intrp); |
1082 | ICHECK(closure.defined()); |
1083 | ICHECK(closure->func.defined()); |
1084 | Array<ObjectRef> evaled_args; |
1085 | for (auto arg : args) { |
1086 | NeedsPreparationVisitor visitor; |
1087 | visitor.VisitExpr(arg); |
1088 | ICHECK(!visitor.needs_preparation) |
1089 | << "attempting to apply closure to expression which needs preparation: " |
1090 | << PrettyPrint(arg); |
1091 | evaled_args.push_back(intrp->Eval(arg)); |
1092 | } |
1093 | return intrp->Invoke(closure, evaled_args); |
1094 | }); |
1095 | } else { |
1096 | LOG(FATAL) << "expecting expression to have function type and evaluate to a closure" ; |
1097 | } |
1098 | } |
1099 | |
1100 | ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions, |
1101 | std::unordered_set<String> import_set, Device device, Target target, |
1102 | Map<String, ObjectRef> attrs) { |
1103 | ICHECK_EQ(device.device_type, target->GetTargetDeviceType()); |
1104 | Array<Target> raw_targets = {target}; |
1105 | CompilationConfig config(transform::PassContext::Current(), raw_targets); |
1106 | |
1107 | std::pair<IRModule, GlobalVar> mod_and_global = |
1108 | IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); |
1109 | |
1110 | IRModule mod = Prepare(WithAttrs(mod_and_global.first, {attrs}), config); |
1111 | |
1112 | Interpreter intrp(mod, config, device); |
1113 | Expr expr_to_eval = mod->GetGlobalVar(mod_and_global.second->name_hint); |
1114 | if (expr.as<BaseFuncNode>() == nullptr) { |
1115 | // TODO(mbs): IRModule::FromExpr will implicitly close over the free vars of expr |
1116 | // unless it is a function, so we must reverse that in the expression to eval. |
1117 | // This should done more systematically. |
1118 | expr_to_eval = Call(expr_to_eval, {}); |
1119 | } |
1120 | return intrp.Eval(expr_to_eval); |
1121 | } |
1122 | |
1123 | TVM_REGISTER_GLOBAL("relay.backend.EvalFunction" ).set_body_typed(EvalFunction); |
1124 | |
1125 | } // namespace relay |
1126 | } // namespace tvm |
1127 | |