1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/interpreter.cc
22 * \brief An interpreter for the Relay IR.
23 */
24
25#include <tvm/driver/driver_api.h>
26#include <tvm/relay/analysis.h>
27#include <tvm/relay/attrs/annotation.h>
28#include <tvm/relay/attrs/call.h>
29#include <tvm/relay/attrs/debug.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/relay/feature.h>
32#include <tvm/relay/interpreter.h>
33#include <tvm/relay/pattern_functor.h>
34#include <tvm/relay/qnn/transform.h>
35#include <tvm/relay/transform.h>
36#include <tvm/runtime/container/map.h>
37#include <tvm/runtime/device_api.h>
38#include <tvm/runtime/object.h>
39#include <tvm/target/compilation_config.h>
40
41#include "../op/annotation/annotation.h"
42#include "../op/call/call.h"
43#include "../op/memory/device_copy.h"
44#include "../transforms/pass_utils.h"
45#include "te_compiler.h"
46
47namespace tvm {
48namespace relay {
49
50using runtime::ADT;
51using runtime::ADTObj;
52using runtime::NDArray;
53using runtime::TVMArgsSetter;
54using runtime::operator<<;
55
56namespace {
57// TODO(mbs): Centralize.
58struct PairHash {
59 template <typename T1, typename T2>
60 std::size_t operator()(const std::pair<T1, T2>& k) const {
61 return dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second));
62 }
63 template <typename T2>
64 std::size_t operator()(const std::pair<Target, T2>& k) const {
65 return dmlc::HashCombine(ObjectHash()(k.first), std::hash<T2>()(k.second));
66 }
67};
68
69// Analogue of FlattenTupleType for runtime ADT vs NDArray values.
70// TODO(mbs): Hoist somewhere sensible, maybe op/memory.h?
71void FlattenADTAux(const ObjectRef& object_ref, std::vector<NDArray>* out) {
72 if (const NDArray::ContainerType* ndarray = object_ref.as<NDArray::ContainerType>()) {
73 out->push_back(GetRef<NDArray>(ndarray));
74 } else if (const ADTObj* adt = object_ref.as<ADTObj>()) {
75 for (size_t i = 0; i < adt->size; ++i) {
76 FlattenADTAux((*adt)[i], out);
77 }
78 } else {
79 LOG(FATAL) << "unsupported " << object_ref;
80 }
81}
82
83std::vector<NDArray> FlattenADT(const ObjectRef& object_ref) {
84 std::vector<NDArray> out;
85 FlattenADTAux(object_ref, &out);
86 return out;
87}
88
89std::vector<NDArray> FlattenADTs(const std::vector<ObjectRef>& object_refs) {
90 std::vector<NDArray> out;
91 for (const auto& object_ref : object_refs) {
92 FlattenADTAux(object_ref, &out);
93 }
94 return out;
95}
96
97// Analogue of ToTupleType for runtime ADT vs NDArray values.
98// TODO(mbs): Hoist somewhere sensible, maybe op/memory.h?
99void ToADTOrNDArrayAux(const Type& type, const std::vector<NDArray>& nd_arrays, int* index,
100 std::vector<ObjectRef>* out) {
101 if (type.as<TensorTypeNode>()) {
102 out->push_back(nd_arrays[*index]);
103 *index += 1;
104 } else if (const TupleTypeNode* ttn = type.as<TupleTypeNode>()) {
105 std::vector<ObjectRef> tuple_out;
106 for (size_t i = 0; i < ttn->fields.size(); i++) {
107 ToADTOrNDArrayAux(ttn->fields[i], nd_arrays, index, &tuple_out);
108 }
109 out->push_back(ADT::Tuple(tuple_out));
110 } else {
111 LOG(FATAL) << "unsupported " << type;
112 }
113}
114
115ObjectRef ToADTOrNDArray(const Type& type, const std::vector<NDArray>& nd_arrays) {
116 if (type.as<TensorTypeNode>() && nd_arrays.size() == 1) {
117 return nd_arrays[0];
118 } else {
119 std::vector<ObjectRef> out;
120 int index = 0;
121 ToADTOrNDArrayAux(type, nd_arrays, &index, &out);
122 return out[0];
123 }
124}
125
126} // namespace
127
128InterpreterClosure::InterpreterClosure(Map<Var, ObjectRef> env, Function func) {
129 ObjectPtr<InterpreterClosureObj> n = make_object<InterpreterClosureObj>();
130 n->env = std::move(env);
131 n->func = std::move(func);
132 data_ = std::move(n);
133}
134
135TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
136 .set_dispatch<InterpreterClosureObj>([](const ObjectRef& ref, ReprPrinter* p) {
137 auto* node = static_cast<const InterpreterClosureObj*>(ref.get());
138 p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")";
139 });
140
141inline const PackedFunc& GetPackedFunc(const std::string& name) {
142 const PackedFunc* pf = runtime::Registry::Get(name);
143 ICHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
144 return *pf;
145}
146
147// TODO(@jroesch): this doesn't support mutual letrec
148/* Object Implementation */
149RecClosure::RecClosure(InterpreterClosure clos, Var bind) {
150 ObjectPtr<RecClosureObj> n = make_object<RecClosureObj>();
151 n->clos = std::move(clos);
152 n->bind = std::move(bind);
153 data_ = std::move(n);
154}
155
156TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
157 .set_dispatch<RecClosureObj>([](const ObjectRef& ref, ReprPrinter* p) {
158 auto* node = static_cast<const RecClosureObj*>(ref.get());
159 p->stream << "RecClosureObj(" << node->clos << ")";
160 });
161
162RefValue::RefValue(ObjectRef value) {
163 ObjectPtr<RefValueObj> n = make_object<RefValueObj>();
164 n->value = value;
165 data_ = std::move(n);
166}
167
168TVM_REGISTER_GLOBAL("relay._make.RefValue").set_body_typed([](ObjectRef value) {
169 return RefValue(value);
170});
171
172TVM_REGISTER_NODE_TYPE(RefValueObj);
173
174TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
175 .set_dispatch<RefValueObj>([](const ObjectRef& ref, ReprPrinter* p) {
176 auto* node = static_cast<const RefValueObj*>(ref.get());
177 p->stream << "RefValueObj(" << node->value << ")";
178 });
179
180ConstructorValue::ConstructorValue(int32_t tag, Array<ObjectRef> fields, Constructor constructor) {
181 ObjectPtr<ConstructorValueObj> n = make_object<ConstructorValueObj>();
182 n->tag = tag;
183 n->fields = fields;
184 n->constructor = constructor;
185 data_ = std::move(n);
186}
187
188TVM_REGISTER_GLOBAL("relay._make.ConstructorValue")
189 .set_body_typed([](int32_t tag, Array<ObjectRef> fields, Constructor constructor) {
190 return ConstructorValue(tag, fields, constructor);
191 });
192
193TVM_REGISTER_NODE_TYPE(ConstructorValueObj);
194
195TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
196 .set_dispatch<ConstructorValueObj>([](const ObjectRef& ref, ReprPrinter* p) {
197 auto* node = static_cast<const ConstructorValueObj*>(ref.get());
198 p->stream << "ConstructorValueObj(" << node->tag << "," << node->fields << ")";
199 });
200
201/*!
202 * \brief A stack frame in the Relay interpreter.
203 *
204 * Contains a mapping from relay::Var to relay::ObjectRef.
205 */
206struct Frame {
207 /*! \brief The set of local variables and arguments for the frame. */
208 Map<Var, ObjectRef> locals;
209
210 explicit Frame(Map<Var, ObjectRef> locals) : locals(locals) {}
211};
212
213/*!
214 * \brief The call stack in the Relay interpreter.
215 *
216 * Contains a stack of frames; each corresponding to
217 * a function call.
218 */
219struct Stack {
220 /*! \brief The stack frames. */
221 std::vector<Frame> frames;
222 Stack() : frames() { frames.push_back(Frame({})); }
223
224 Frame& current_frame() { return frames.back(); }
225
226 ObjectRef Lookup(const Var& local) {
227 for (auto frame = frames.rbegin(); frame != frames.rend(); frame++) {
228 auto elem = frame->locals.find(local);
229 if (elem != frame->locals.end()) {
230 return (*elem).second;
231 }
232 }
233
234 LOG(FATAL) << "could not find variable binding for " << local
235 << "address= " << local.operator->();
236 return ObjectRef();
237 }
238 /*!
239 * A wrapper around Frame to add RAII semantics to pushing and popping
240 * stack frames.
241 */
242 struct LocalFrame {
243 Stack& st;
244 explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { st.frames.push_back(fr); }
245 ~LocalFrame() { st.frames.pop_back(); }
246 };
247};
248
249/*! \brief A representation of the interpreter state which can be passed back to Python. */
250class InterpreterState;
251
252/*! \brief A container capturing the state of the interpreter. */
253class InterpreterStateObj : public Object {
254 public:
255 using Frame = Map<Var, ObjectRef>;
256 using Stack = Array<Frame>;
257
258 /*! \brief The current expression under evaluation. */
259 Expr current_expr;
260
261 /*! \brief The call stack of the interpreter. */
262 Stack stack;
263
264 void VisitAttrs(AttrVisitor* v) {
265 v->Visit("current_expr", &current_expr);
266 v->Visit("stack", &stack);
267 }
268
269 static constexpr const char* _type_key = "relay.InterpreterState";
270 TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateObj, Object);
271};
272
273class InterpreterState : public ObjectRef {
274 public:
275 using Frame = Map<Var, ObjectRef>;
276 using Stack = Array<Frame>;
277
278 InterpreterState(Expr current_expr, Stack stack);
279
280 TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateObj);
281};
282
283InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack stack) {
284 ObjectPtr<InterpreterStateObj> n = make_object<InterpreterStateObj>();
285 n->current_expr = std::move(current_expr);
286 n->stack = std::move(stack);
287 data_ = std::move(n);
288}
289
290// NOTE: the current interpreter assumes A-normal form.
291// which is better for execution.
292//
293// It will run duplicated computations when taking program that
294// contains DAG in dataflow-form.
295//
296// Conversion to ANF is recommended before running the interpretation.
297class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
298 PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
299 public:
300 Interpreter(IRModule unified_mod, CompilationConfig config, Device device)
301 : unified_mod_(unified_mod),
302 config_(std::move(config)),
303 device_(device),
304 debug_op_(Op::Get("debug")) {}
305
306 template <typename T>
307 T WithFrame(const Frame& fr, const std::function<T()>& f) {
308 Stack::LocalFrame lf(stack_, fr);
309 return f();
310 }
311
312 void extend(const Var& id, ObjectRef v) { stack_.current_frame().locals.Set(id, v); }
313
314 ObjectRef Lookup(const Var& local) { return stack_.Lookup(local); }
315
316 ObjectRef Eval(const Expr& expr) { return VisitExpr(expr); }
317
318 ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef<Var>(var_node)); }
319
320 ObjectRef VisitExpr_(const GlobalVarNode* op) final {
321 return Eval(unified_mod_->Lookup(GetRef<GlobalVar>(op)));
322 }
323
324 ObjectRef VisitExpr_(const OpNode* id) override {
325 // TODO(@jroesch): Eta-expand and return in this case.
326 LOG(FATAL) << "internal error, need to wrap intrinsic into call synthetic call node "
327 << "in this case, eta expand";
328 return ObjectRef();
329 }
330
331 ObjectRef VisitExpr_(const ConstantNode* op) final { return op->data.CopyTo(device_); }
332
333 ObjectRef VisitExpr_(const TupleNode* op) final {
334 std::vector<ObjectRef> values;
335
336 for (const auto& field : op->fields) {
337 ObjectRef field_value = Eval(field);
338 values.push_back(field_value);
339 }
340
341 return ADT::Tuple(values);
342 }
343
344 ObjectRef MakeClosure(const Function& func, Var letrec_name = Var()) {
345 Map<Var, ObjectRef> captured_mod;
346 Array<Var> free_vars = FreeVars(func);
347
348 for (const auto& var : free_vars) {
349 // Evaluate the free var (which could be a function call) if it hasn't
350 // shown up in a letting binding that has invoked the function.
351 if (letrec_name.defined() && letrec_name == var) {
352 continue;
353 }
354
355 captured_mod.Set(var, Eval(var));
356 }
357
358 // We must use mutation here to build a self referential closure.
359 InterpreterClosure closure(captured_mod, func);
360 if (letrec_name.defined()) {
361 return RecClosure(closure, letrec_name);
362 }
363 return std::move(closure);
364 }
365
366 ObjectRef VisitExpr_(const FunctionNode* func_node) final {
367 auto func = GetRef<Function>(func_node);
368 return MakeClosure(func);
369 }
370
371 /*!
372 * \brief Returns the packed function implementing the TIR function bound to \p tir_fn_var.
373 *
374 * \param tir_fn_var Global var for the already lowered TIR function.
375 * \param all_tir_fn_vars Global vars for all lowered TIR functions the above
376 * may reference, plus \p tir_fn_var itself.
377 * \param target Target for which the TIR function should be compiled. For primitives this
378 * will be the interpreter's target_. However for shape functions this will be the generic
379 * 'cpu' target, since shape functions are always executed on the host cpu.
380 */
381 PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array<GlobalVar>& all_tir_fn_vars,
382 Target target) {
383 std::pair<Target, std::string> packed_func_key(target, tir_fn_var->name_hint);
384 auto packed_itr = compiled_packed_funcs_.find(packed_func_key);
385 if (packed_itr != compiled_packed_funcs_.end()) {
386 // Already compiled.
387 return packed_itr->second;
388 }
389
390 // Project out just the function(s) we need.
391 IRModule lowered_projected_mod;
392 Map<Target, IRModule> per_target_module = tec::GetPerTargetModules(unified_mod_);
393 std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
394 per_target_module_std_map = backend::TargetModuleMapToTargetStrModuleMap(per_target_module);
395 auto mod_itr = per_target_module_std_map.find(target);
396 ICHECK(mod_itr != per_target_module_std_map.end())
397 << "No target module for target " << target->ToDebugString();
398 const IRModule& target_module = (*mod_itr).second;
399 for (const auto& var : all_tir_fn_vars) {
400 ICHECK(target_module->ContainGlobalVar(var->name_hint))
401 << "No global var for '" << var->name_hint << "' in module for target "
402 << target->ToDebugString();
403 lowered_projected_mod->Add(var, target_module->Lookup(var->name_hint));
404 }
405
406 // Compile (aka 'build') the projected module into a runtime module of packed functions.
407 runtime::Module runtime_module;
408 if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
409 // TODO(mbs): Cleanup hooks.
410 runtime_module = (*f)(lowered_projected_mod, target);
411 } else {
412 runtime_module = build(lowered_projected_mod, target, /*target_host=*/Target(nullptr));
413 }
414
415 // Extract all the packed functions.
416 for (const auto& var : all_tir_fn_vars) {
417 PackedFunc packed_func = runtime_module.GetFunction(var->name_hint);
418 ICHECK(packed_func != nullptr)
419 << "No packed function for global var '" << var->name_hint
420 << "' in compiled module for target " << target->ToDebugString();
421 compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func);
422 }
423
424 // Return just what we need for this call.
425 packed_itr = compiled_packed_funcs_.find(packed_func_key);
426 ICHECK(packed_itr != compiled_packed_funcs_.end()) << " " << tir_fn_var->name_hint;
427 ICHECK_NOTNULL(packed_itr->second);
428 return packed_itr->second;
429 }
430
431 /*!
432 * \brief Call the dynamic shape function bound to \p prim_shape_fn_var passing the
433 * shapes of args, and return the resulting shapes.
434 *
435 * \param prim_shape_fn_var Global var bound to lowered shape function.
436 * \param all_prim_shape_fn_vars All the global vars needed to build the above, including
437 * the shape function itself.
438 * \param prim_shape_fn_states For each primitive arg, indicate whether the primitive shape
439 * function requires the shape of the argument and/or the actual argument tensor.
440 * \param num_shape_inputs The number of inputs, after accounting for both shapes vs data
441 * inputs and unfolding of tuple types.
442 * \param num_shape_outputs The number of outputs, after accounting for flattening of
443 * tuple types.
444 * \param args Arguments to the primitive this shape function is for.
445 * \return Expected shapes of the underlying primitive's flattened outputs.
446 */
447 Array<Shape> ComputeDynamicShape(const GlobalVar& prim_shape_fn_var,
448 const Array<GlobalVar>& all_prim_shape_fn_vars,
449 const Array<Integer>& prim_shape_fn_states,
450 size_t num_shape_inputs, size_t num_shape_outputs,
451 Target prim_shape_target, const std::vector<ObjectRef>& args) {
452 VLOG_CONTEXT << "ComputeDynamicShape";
453 ICHECK(prim_shape_fn_var.defined());
454 ICHECK(prim_shape_fn_var->checked_type().defined());
455 VLOG(1) << "prim_shape_fn_var:" << std::endl << PrettyPrint(prim_shape_fn_var);
456 ICHECK(prim_shape_fn_states.defined());
457 for (size_t i = 0; i < prim_shape_fn_states.size(); ++i) {
458 VLOG(1) << "prim_shape_fn_states[" << i << "]: " << prim_shape_fn_states[i];
459 }
460 VLOG(1) << "num_shape_inputs: " << num_shape_inputs;
461 VLOG(1) << "num_shape_outputs: " << num_shape_outputs;
462 VLOG(1) << "args.size(): " << args.size();
463 VLOG(1) << "prim_shape_target: " << prim_shape_target->ToDebugString();
464
465 // The function type is that of the shape function rather than the original primitive the shape
466 // function is for.
467 const auto* func_type_node = prim_shape_fn_var->checked_type().as<FuncTypeNode>();
468 ICHECK(func_type_node);
469 // The shape function states are w.r.t. the original primitive's arguments in
470 // non-flattened form.
471 // TODO(mbs): Clean this up so we don't mix flattened vs original conventions.
472 ICHECK_EQ(args.size(), prim_shape_fn_states.size());
473
474 // num_shape_inputs will account for which primitive function arguments are dynamic,
475 // whether the shape and or data needs to be passed, and flattening of tuples.
476 // Similarly, num_shape_outputs will account for flattening of tuples.
477
478 // TODO(mbs): Take this from the host_virtual_device.
479 Device shape_device;
480 shape_device.device_type = static_cast<DLDeviceType>(prim_shape_target->GetTargetDeviceType());
481 shape_device.device_id = 0;
482
483 // 'Compile' the TIR shape function to appropriate callable form.
484 PackedFunc packed_shape_func =
485 TIRToPackedFunc(prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_target);
486
487 size_t arity = num_shape_inputs + num_shape_outputs;
488 std::vector<TVMValue> values(arity);
489 std::vector<int> codes(arity);
490 TVMArgsSetter setter(values.data(), codes.data());
491 std::vector<NDArray> inputs(num_shape_inputs);
492 std::vector<NDArray> outputs(num_shape_outputs);
493
494 // Collect the shapes and/or data needed by the shape function from
495 // the primitive's arguments.
496 size_t arg_counter = 0;
497 for (size_t i = 0; i < args.size(); ++i) {
498 // TODO(mbs): The same need data/need shape arg state applies to everything in the
499 // flattened form of this arg. Does that match what lowering actually does?
500 int64_t state = prim_shape_fn_states[i]->value;
501 for (const auto& nd_array : FlattenADT(args[i])) {
502 if (state & tec::kNeedInputData) {
503 auto arr = nd_array.CopyTo(shape_device);
504 inputs[arg_counter] = arr;
505 setter(arg_counter, arr);
506 ++arg_counter;
507 }
508 if (state & tec::kNeedInputShape) {
509 int64_t ndim = nd_array.Shape().size();
510 NDArray shape_arr;
511 if (ndim == 0) {
512 shape_arr = NDArray::Empty({}, DataType::Int(64), shape_device);
513 } else {
514 shape_arr = NDArray::Empty({ndim}, DataType::Int(64), shape_device);
515 int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
516 for (auto j = 0; j < ndim; ++j) {
517 data[j] = nd_array.Shape()[j];
518 }
519 }
520 inputs[arg_counter] = shape_arr;
521 setter(arg_counter, shape_arr);
522 ++arg_counter;
523 }
524 }
525 }
526 ICHECK_EQ(arg_counter, num_shape_inputs) << "Shape function input sizes mismatch";
527
528 // Prepare NDArrays to hold the output shapes.
529 size_t out_cnt = 0;
530 for (const auto& ttype : FlattenTupleType(func_type_node->ret_type)) {
531 ICHECK(out_cnt < num_shape_outputs);
532 std::vector<int64_t> concrete_shape;
533 for (const auto& dim : ttype->shape) {
534 const auto* ivalue = tir::as_const_int(dim);
535 ICHECK(ivalue) << "expected concrete dimensions";
536 concrete_shape.push_back(ivalue[0]);
537 }
538 auto arr = NDArray::Empty(concrete_shape, ttype->dtype, shape_device);
539 outputs[out_cnt] = arr;
540 setter(arg_counter + out_cnt, arr);
541 ++out_cnt;
542 }
543 ICHECK_EQ(out_cnt, num_shape_outputs) << "Shape function output sizes mismatch";
544
545 // Call the dynamic shape function.
546 TVMRetValue rv; // ignored
547 packed_shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
548
549 // Convert result tensors back to shapes.
550 Array<Shape> out_shapes;
551 for (auto out_tensor : outputs) {
552 int64_t* shape_data = reinterpret_cast<int64_t*>(out_tensor->data);
553 Shape out_shape;
554 for (int i = 0; i < out_tensor->shape[0]; ++i) {
555 out_shape.push_back(Integer(shape_data[i]));
556 }
557 out_shapes.push_back(out_shape);
558 }
559 return out_shapes;
560 }
561
562 /*!
563 * \brief Call primitive op bound to \p prim_fn_var with \p args. If necessary, evaluate dynamic
564 * shape function bound to \p prim_shape_fn_var to calculate shapes of result tensors.
565 *
566 * @param prim_fn_var Global bound to lowered primitive.
567 * @param all_prim_fn_vars All globals references by lowered primitive, plus prim_fn_var itself.
568 * @param prim_shape_fn_var Global bound to lowered shape function for primitive, if needed.
569 * @param all_prim_shape_fn_vars All globals references by lowered shape function, plus
570 * prim_shape_fn_var itself.
571 * @param prim_shape_fn_states Records whether shape and/or data is needed by the dynamic
572 * shape function (if any) for each (flattened) argument.
573 * @param num_shape_inputs Number of arguments to the dynamic shape function (if any).
574 * @param num_shape_outputs Number of outputs from the dynamic shape function (if any).
575 * @param args Already evaluated arguments to primitive.
576 * @return Result of primitive.
577 */
578 ObjectRef InvokePrimitiveOp(const GlobalVar& prim_fn_var, const Array<GlobalVar> all_prim_fn_vars,
579 Target prim_target, const GlobalVar& prim_shape_fn_var,
580 const Array<GlobalVar>& all_prim_shape_fn_vars,
581 const Array<Integer>& prim_shape_fn_states, size_t num_shape_inputs,
582 size_t num_shape_outputs, Target prim_shape_target,
583 const std::vector<ObjectRef>& args) {
584 ICHECK(prim_fn_var->checked_type().defined());
585 const FuncTypeNode* ftn = prim_fn_var->checked_type().as<FuncTypeNode>();
586 ICHECK(ftn);
587
588 // 'Compile' the TIR primitive to appropriate callable form (on the desired target).
589 PackedFunc packed_func = TIRToPackedFunc(prim_fn_var, all_prim_fn_vars, prim_target);
590
591 // Argument tuples are flattened.
592 std::vector<NDArray> arg_nd_arrays = FlattenADTs(args);
593 const size_t num_inputs = arg_nd_arrays.size();
594 // num_inputs should equal size(concat(map(FlattenTupleType, function arg types)))
595
596 // TVM's primitive calling convention is for the final arguments to be for output
597 // buffers. We must allocate space for those buffers based on the return type.
598 std::vector<TensorType> result_tensor_types = FlattenTupleType(ftn->ret_type);
599 const size_t arg_len = num_inputs + result_tensor_types.size();
600
601 std::vector<TVMValue> values(arg_len);
602 std::vector<int> codes(arg_len);
603 TVMArgsSetter setter(values.data(), codes.data());
604
605 // Marshall the call's arguments in flattened form.
606 int arg_counter = 0;
607 for (const auto& nd_array : arg_nd_arrays) {
608 setter(arg_counter++, nd_array);
609 Device arg_dev = nd_array->device;
610 ICHECK(arg_dev.device_type == device_.device_type && arg_dev.device_id == device_.device_id)
611 << "Interpreter expect device to be " << device_ << ", but got " << arg_dev;
612 }
613
614 // If necessary, retrieve concrete shapes for outputs from shape function rather
615 // than relying on TensorType shapes.
616 Array<Shape> runtime_shapes;
617 bool is_dyn = IsDynamic(ftn->ret_type);
618 if (is_dyn) {
619 ICHECK(prim_shape_fn_var.defined());
620 ICHECK(prim_shape_fn_states.defined());
621 runtime_shapes =
622 ComputeDynamicShape(prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states,
623 num_shape_inputs, num_shape_outputs, prim_shape_target, args);
624 ICHECK_EQ(runtime_shapes.size(), result_tensor_types.size());
625 }
626
627 // Prepare the result tensors for the call.
628 TVMRetValue rv; // ignored
629 std::vector<NDArray> result_nd_arrays;
630 for (size_t i = 0; i < result_tensor_types.size(); ++i) {
631 const auto& ttype = result_tensor_types[i];
632 const Shape& shape = is_dyn ? runtime_shapes[i] : ttype->shape;
633 // Allocate output tensor of appropriate shape.
634 std::vector<int64_t> concrete_shape;
635 for (const auto& dim : shape) {
636 const auto* ivalue = tir::as_const_int(dim);
637 ICHECK(ivalue) << "expected concrete dimensions";
638 concrete_shape.push_back(ivalue[0]);
639 }
640 NDArray nd_array = NDArray::Empty(concrete_shape, ttype->dtype, device_);
641 setter(num_inputs + i, nd_array);
642 result_nd_arrays.emplace_back(nd_array);
643 }
644
645 // Call the primitive.
646 packed_func.CallPacked(TVMArgs(values.data(), codes.data(), static_cast<int>(arg_len)), &rv);
647
648 // Unflatten the results.
649 return ToADTOrNDArray(ftn->ret_type, result_nd_arrays);
650 }
651
652 /*!
653 * \brief Invoke \p closure with \p args. If \p bind is defined then this is a recursive
654 * closure and \p bind should refer to itself.
655 */
656 ObjectRef Invoke(const InterpreterClosure& closure, const Array<ObjectRef>& args,
657 const Var& bind = Var()) {
658 // Get a reference to the function inside the closure.
659 Function func = closure->func;
660 ICHECK_EQ(func->params.size(), args.size());
661
662 if (func->HasNonzeroAttr(attr::kPrimitive)) {
663 if (const CallNode* call_node = closure->func->body.as<CallNode>()) {
664 if (call_node->op == debug_op_) {
665 // Special case: Calling the debug tracing function.
666 auto dattrs = call_node->attrs.as<DebugAttrs>();
667 auto interp_state = get_state(call_node->args[0]);
668
669 if (dattrs->debug_func.defined()) {
670 dattrs->debug_func(interp_state);
671 } else {
672 RELAY_DEBUG_INTERP(interp_state);
673 }
674
675 return args[0];
676 }
677 }
678 }
679
680 ICHECK(!func->HasNonzeroAttr(attr::kPrimitive))
681 << "Calls to primitive functions should have been removed by lowering";
682
683 // Allocate a frame with the parameters and free variables.
684 Map<Var, ObjectRef> locals;
685 for (size_t i = 0; i < func->params.size(); i++) {
686 ICHECK_EQ(locals.count(func->params[i]), 0);
687 locals.Set(func->params[i], args[i]);
688 }
689
690 // Add the var to value mappings from the Closure's environment.
691 for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
692 ICHECK_EQ(locals.count((*it).first), 0);
693 locals.Set((*it).first, (*it).second);
694 }
695
696 if (bind.defined()) {
697 locals.Set(bind, RecClosure(closure, bind));
698 }
699
700 return WithFrame<ObjectRef>(Frame(locals), [&]() { return Eval(func->body); });
701 }
702
703 ObjectRef VisitExpr_(const CallNode* call_node) final {
704 DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
705 CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
706
707 if (device_copy_props.body.defined()) {
708 // TODO(mbs): device_copy cleanup
709 LOG(FATAL) << "The interpreter does not support device_copy";
710 } else if (call_lowered_props.lowered_func.defined()) {
711 // Special case: Call a lowered TIR function.
712
713 // Evaluate only function args
714 std::vector<ObjectRef> args;
715 for (auto arg : call_lowered_props.arguments) {
716 args.push_back(Eval(arg));
717 }
718
719 // TODO(mbs): Make calling convention first-class in Relay.
720 Array<GlobalVar> all_prim_fn_vars;
721 if (call_lowered_props.attrs.metadata.count("all_prim_fn_vars")) {
722 all_prim_fn_vars =
723 Downcast<Array<GlobalVar>>(call_lowered_props.attrs.metadata.at("all_prim_fn_vars"));
724 }
725 GlobalVar prim_shape_fn_var;
726 if (call_lowered_props.attrs.metadata.count("prim_shape_fn_var")) {
727 prim_shape_fn_var =
728 Downcast<GlobalVar>(call_lowered_props.attrs.metadata.at("prim_shape_fn_var"));
729 }
730 Array<GlobalVar> all_prim_shape_fn_vars;
731 if (call_lowered_props.attrs.metadata.count("all_prim_shape_fn_vars")) {
732 all_prim_shape_fn_vars = Downcast<Array<GlobalVar>>(
733 call_lowered_props.attrs.metadata.at("all_prim_shape_fn_vars"));
734 }
735 Array<Integer> prim_shape_fn_states;
736 if (call_lowered_props.attrs.metadata.count("prim_shape_fn_states")) {
737 prim_shape_fn_states =
738 Downcast<Array<Integer>>(call_lowered_props.attrs.metadata.at("prim_shape_fn_states"));
739 }
740
741 size_t num_shape_inputs = 0;
742 if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_inputs")) {
743 num_shape_inputs = static_cast<size_t>(
744 Downcast<Integer>(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_inputs"))
745 ->value);
746 }
747 size_t num_shape_outputs = 0;
748 if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_outputs")) {
749 num_shape_outputs = static_cast<size_t>(
750 Downcast<Integer>(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_outputs"))
751 ->value);
752 }
753 ICHECK(config_->optional_homogeneous_target.defined());
754 return InvokePrimitiveOp(call_lowered_props.lowered_func, all_prim_fn_vars,
755 config_->optional_homogeneous_target, prim_shape_fn_var,
756 all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs,
757 num_shape_outputs, config_->host_virtual_device->target, args);
758 } else { // All other calls
759 // Evaluate all arguments
760 std::vector<ObjectRef> args;
761 for (auto arg : call_node->args) {
762 args.push_back(Eval(arg));
763 }
764
765 if (call_node->op == OnDeviceOp()) {
766 // Special case: The call 'on_device(expr)' denotes that expr should be executed on
767 // a particular device. We can ignore this during interpretation.
768 ICHECK_EQ(call_node->args.size(), 1UL);
769 return args[0];
770 }
771 if (const ConstructorNode* con = call_node->op.as<ConstructorNode>()) {
772 // Special case: ADT constructor
773
774 return ConstructorValue(con->tag, args, GetRef<Constructor>(con));
775 }
776
777 if (const OpNode* op_node = call_node->op.as<OpNode>()) {
778 // Except for call_lowered and on_device, we should not find calls to operators after
779 // running fusion and lowering.
780 LOG(FATAL) << "found " << op_node->name
781 << "; operators should have been removed by previous passes; try "
782 "fusing and lowering";
783 }
784
785 // Now we just evaluate and expect to find a closure.
786 // TODO(@electriclilies): How should call_lowered behave with closures?
787 ObjectRef fn_val = Eval(call_node->op);
788 if (const InterpreterClosureObj* closure_node = fn_val.as<InterpreterClosureObj>()) {
789 auto closure = GetRef<InterpreterClosure>(closure_node);
790 return Invoke(closure, args);
791 } else if (const RecClosureObj* closure_node = fn_val.as<RecClosureObj>()) {
792 return Invoke(closure_node->clos, args, closure_node->bind);
793 } else {
794 LOG(FATAL) << "internal error: type error, expected function value in the call "
795 << "position";
796 return ObjectRef();
797 }
798 }
799 }
800
801 ObjectRef VisitExpr_(const LetNode* let) final {
802 if (auto func = let->value.as<FunctionNode>()) {
803 auto clo = MakeClosure(GetRef<Function>(func), let->var);
804 this->extend(let->var, clo);
805 } else {
806 auto value = Eval(let->value);
807 this->extend(let->var, value);
808 }
809
810 return Eval(let->body);
811 }
812
813 ObjectRef VisitExpr_(const TupleGetItemNode* op) final {
814 ObjectRef val = Eval(op->tuple);
815 const auto* adt_obj = val.as<ADTObj>();
816 ICHECK(adt_obj) << "internal error: when evaluating TupleGetItem expected an ADT value";
817 auto adt = GetRef<ADT>(adt_obj);
818 ICHECK_LT(static_cast<size_t>(op->index), adt.size()) << "internal error: index out of bounds";
819 return adt[op->index];
820 }
821
822 ObjectRef VisitExpr_(const IfNode* op) final {
823 ObjectRef v = Eval(op->cond);
824 if (v->IsInstance<NDArray::ContainerType>()) {
825 auto nd_array = Downcast<NDArray>(v);
826 Device cpu_dev;
827 cpu_dev.device_type = kDLCPU;
828 cpu_dev.device_id = 0;
829 NDArray cpu_array = nd_array.CopyTo(cpu_dev);
830 ICHECK_EQ(DataType(cpu_array->dtype), DataType::Bool());
831 // TODO(@jroesch, @MK): Refactor code into helper from DCE.
832 if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) {
833 return Eval(op->true_branch);
834 } else {
835 return Eval(op->false_branch);
836 }
837 } else {
838 LOG(FATAL) << "type error, type system should have caught this";
839 }
840 }
841
842 ObjectRef VisitExpr_(const RefWriteNode* op) final {
843 ObjectRef r = Eval(op->ref);
844 if (const RefValueObj* rv = r.as<RefValueObj>()) {
845 rv->value = Eval(op->value);
846 return ADT::Tuple(std::vector<ObjectRef>());
847 } else {
848 LOG(FATAL) << "type error, type system should have caught this";
849 }
850 }
851
852 ObjectRef VisitExpr_(const RefCreateNode* op) final { return RefValue(Eval(op->value)); }
853
854 ObjectRef VisitExpr_(const RefReadNode* op) final {
855 ObjectRef r = Eval(op->ref);
856 if (const RefValueObj* rv = r.as<RefValueObj>()) {
857 return rv->value;
858 } else {
859 LOG(FATAL) << "type error, type system should have caught this";
860 }
861 }
862
863 ObjectRef VisitExpr_(const MatchNode* op) final {
864 ObjectRef v = Eval(op->data);
865 for (const Clause& c : op->clauses) {
866 if (VisitPattern(c->lhs, v)) {
867 return VisitExpr(c->rhs);
868 }
869 }
870 LOG(FATAL) << "did not find any match";
871 }
872
873 bool VisitPattern_(const PatternConstructorNode* op, const ObjectRef& v) final {
874 const ConstructorValueObj* cvn = v.as<ConstructorValueObj>();
875 ICHECK(cvn) << "need to be a constructor for match";
876 ICHECK_NE(op->constructor->tag, -1);
877 ICHECK_NE(cvn->tag, -1);
878 if (op->constructor->tag == cvn->tag) {
879 ICHECK_EQ(op->patterns.size(), cvn->fields.size());
880 for (size_t i = 0; i < op->patterns.size(); ++i) {
881 if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
882 return false;
883 }
884 }
885 return true;
886 }
887 return false;
888 }
889
890 bool VisitPattern_(const PatternTupleNode* op, const ObjectRef& v) final {
891 auto adt = Downcast<ADT>(v);
892 ICHECK_EQ(op->patterns.size(), adt.size());
893 for (size_t i = 0; i < op->patterns.size(); ++i) {
894 if (!VisitPattern(op->patterns[i], adt[i])) {
895 return false;
896 }
897 }
898 return true;
899 }
900
901 bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { return true; }
902
903 bool VisitPattern_(const PatternVarNode* op, const ObjectRef& v) final {
904 extend(op->var, v);
905 return true;
906 }
907
908 InterpreterState get_state(Expr e = Expr()) const {
909 InterpreterStateObj::Stack stack;
910 for (auto fr : this->stack_.frames) {
911 InterpreterStateObj::Frame frame = fr.locals;
912 stack.push_back(frame);
913 }
914 auto state = InterpreterState(e, stack);
915 return state;
916 }
917
918 private:
919 // Unified module. Functions are annotated with their target.
920 // All expressions are eval'ed w.r.t. the definitions in this module.
921 // This module contains functions that used to be in main_module and the per_target_module (TIR
922 // functions) in one module.
923 IRModule unified_mod_;
924 // Cached packed functions for the primitives and shape functions, keyed by target and
925 // global var name.
926 std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash> compiled_packed_funcs_;
927 /*! \brief Compilation config describing the available targets. */
928 CompilationConfig config_;
929 // Unique device on which primitives (but not shape functions) will be executed.
930 // (For simplicity we only run the interpreter on a single device.)
931 Device device_;
932 // Call stack.
933 Stack stack_;
934 // The distinguished 'debug' operator, which is handled specially.
935 const Op& debug_op_;
936};
937
938/*!
939 * Lowers all calls to primitives in \p mod appropriate for \p config. Returns the
940 * rewritten \p mod and target-specific modules containing bindings for all TIR primitive
941 * functions needed by the rewritten module.
942 */
943IRModule Prepare(IRModule mod, const CompilationConfig& config) {
944 // Run minimal transforms on module to establish invariants needed by interpreter.
945 transform::Sequential seq(
946 {transform::SimplifyInference(), qnn::transform::Legalize(),
947 // Figure out which devices should be used to execute.
948 // TODO(mbs): Should ignore all existing annotations when constant folding
949 transform::PlanDevices(config),
950 // FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
951 // attribute.
952 transform::FuseOps(/*fuse_opt_level=*/0),
953 // Use ANF to reduce number of cases to handle.
954 transform::ToANormalForm(),
955 // eta expand to support constructors in argument position.
956 transform::EtaExpand(
957 /*expand_constructor=*/true, /*expand_global_var=*/false),
958 transform::InferType(), tec::LowerTE(/*module_name=*/"intrp", config)});
959
960 transform::PassContext pass_ctx = transform::PassContext::Current();
961 With<transform::PassContext> ctx(pass_ctx);
962 mod = seq(mod);
963
964 return mod;
965}
966
967/*! \brief Check if an expression could be changed by \p Prepare.
968 *
969 * If not we can evaluate it directly and don't need to bind it into a fresh module.
970 */
971class NeedsPreparationVisitor : public ExprVisitor {
972 public:
973 bool needs_preparation = false;
974
975 private:
976 void VisitExpr_(const VarNode* vn) override {
977 // Could be prim.
978 needs_preparation = true;
979 }
980 // ConstantNode ok
981 // GlobalVarNode ok
982 void VisitExpr_(const OpNode* op) override {
983 // Could be prim.
984 needs_preparation = true;
985 }
986 // TupleNode recurse
987 void VisitExpr_(const FunctionNode* op) override {
988 // Could be prim.
989 needs_preparation = true;
990 }
991 // CallNode recurse
992 void VisitExpr_(const LetNode* ln) override {
993 // May bind prim.
994 needs_preparation = true;
995 }
996 // IfNode recurse
997 // TupleGetItemNode recurse
998 // RefCreateNode recurse
999 // RefReadNode recurse
1000 // RefWriteNode recurse
1001 // ConstructorNode ok
1002 void VisitExpr_(const MatchNode* op) override {
1003 // Needs eta-expansion.
1004 needs_preparation = true;
1005 }
1006};
1007
1008TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, Device device,
1009 Target target) {
1010 VLOG_CONTEXT << "EvalFunction";
1011 VLOG(1) << "evaling module:" << std::endl
1012 << PrettyPrint(mod) << "and expression:" << std::endl
1013 << PrettyPrint(expr);
1014
1015 ICHECK_EQ(device.device_type, target->GetTargetDeviceType());
1016 Array<Target> raw_targets = {target};
1017 CompilationConfig config(transform::PassContext::Current(), raw_targets);
1018
1019 //
1020 // Step 1: Prepare mod.
1021 //
1022
1023 // If expr is simple enough we can avoid binding it into the module and
1024 // just eval it directly.
1025 NeedsPreparationVisitor visitor;
1026 visitor.VisitExpr(expr);
1027
1028 Expr expr_to_eval;
1029 IRModule mod_with_expr; // default empty
1030 if (visitor.needs_preparation) {
1031 GlobalVar main;
1032 // Bind expr to a new zero-argument function so it can be prepared along with the module
1033 // (if any).
1034 std::pair<IRModule, GlobalVar> mod_and_global;
1035 if (mod.defined()) {
1036 // TODO(mbs): Type inference currently assumes all global functions in modules have
1037 // known result types, and so each global function has it's body types inferred independently
1038 // and in arbitrary order. However, the interpreter may be called with an expression relative
1039 // to a 'main' which has no result type annotation, and that expressions will be bound into a
1040 // fresh global below. Type inference then fails since 'main' has unknown type. We should
1041 // allow inference on mutually recursive global functions. To workaround, infer the type
1042 // of mod now. Obviously that won't work if 'main' itself calls other global functions of
1043 // partial type, but it at least maintains legacy behavior.
1044 transform::PassContext pass_ctx = transform::PassContext::Current();
1045 With<transform::PassContext> ctx(pass_ctx);
1046 mod = transform::InferType()(mod);
1047 mod_and_global =
1048 IRModule::FromExprInContext(expr, mod->functions, mod->type_definitions, mod->Imports());
1049 } else {
1050 mod_and_global = IRModule::FromExprInContext(expr);
1051 }
1052 mod_with_expr = mod_and_global.first;
1053 expr_to_eval = mod_and_global.second;
1054 } else {
1055 if (mod.defined()) {
1056 mod_with_expr = mod;
1057 }
1058 // Prepare won't change expr, so we don't need to worry about binding it into a module
1059 // and can just eval it directly.
1060 expr_to_eval = expr;
1061 }
1062 IRModule lowered_mod = Prepare(mod_with_expr, config);
1063
1064 std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(lowered_mod, config, device);
1065
1066 //
1067 // Step 2: Evaluate target function to a closure.
1068 //
1069 ObjectRef object_ref = intrp->Eval(expr_to_eval);
1070 if (const InterpreterClosureObj* closure_obj = object_ref.as<InterpreterClosureObj>()) {
1071 InterpreterClosure closure = GetRef<InterpreterClosure>(closure_obj);
1072 ICHECK(closure.defined());
1073 ICHECK(closure->func.defined());
1074
1075 return TypedPackedFunc<ObjectRef(Array<Expr>)>([intrp, closure](Array<Expr> args) {
1076 VLOG_CONTEXT << "EvalFunction::Apply";
1077 VLOG(1) << "evaling closure with " << args.size() << " arguments";
1078 //
1079 // Step 3: Apply closure to arguments.
1080 //
1081 ICHECK_NOTNULL(intrp);
1082 ICHECK(closure.defined());
1083 ICHECK(closure->func.defined());
1084 Array<ObjectRef> evaled_args;
1085 for (auto arg : args) {
1086 NeedsPreparationVisitor visitor;
1087 visitor.VisitExpr(arg);
1088 ICHECK(!visitor.needs_preparation)
1089 << "attempting to apply closure to expression which needs preparation: "
1090 << PrettyPrint(arg);
1091 evaled_args.push_back(intrp->Eval(arg));
1092 }
1093 return intrp->Invoke(closure, evaled_args);
1094 });
1095 } else {
1096 LOG(FATAL) << "expecting expression to have function type and evaluate to a closure";
1097 }
1098}
1099
1100ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
1101 std::unordered_set<String> import_set, Device device, Target target,
1102 Map<String, ObjectRef> attrs) {
1103 ICHECK_EQ(device.device_type, target->GetTargetDeviceType());
1104 Array<Target> raw_targets = {target};
1105 CompilationConfig config(transform::PassContext::Current(), raw_targets);
1106
1107 std::pair<IRModule, GlobalVar> mod_and_global =
1108 IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set);
1109
1110 IRModule mod = Prepare(WithAttrs(mod_and_global.first, {attrs}), config);
1111
1112 Interpreter intrp(mod, config, device);
1113 Expr expr_to_eval = mod->GetGlobalVar(mod_and_global.second->name_hint);
1114 if (expr.as<BaseFuncNode>() == nullptr) {
1115 // TODO(mbs): IRModule::FromExpr will implicitly close over the free vars of expr
1116 // unless it is a function, so we must reverse that in the expression to eval.
1117 // This should done more systematically.
1118 expr_to_eval = Call(expr_to_eval, {});
1119 }
1120 return intrp.Eval(expr_to_eval);
1121}
1122
1123TVM_REGISTER_GLOBAL("relay.backend.EvalFunction").set_body_typed(EvalFunction);
1124
1125} // namespace relay
1126} // namespace tvm
1127