1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file partial_eval.cc
22 *
23 * \brief Perform known computation in compile time.
24 *
25 * The partial evaluator try to do computation at compile time,
26 * so it can generate code that do less work.
27 * Additionally, it might open more chance for further optimization,
28 * since the high level, structural part of the code (closure, reference, control flow)
29 * might get partially evaluated away, and the subsequent optimization (for example, kernel fusion)
30 * can reason across those structural code as it got removed.
31 * In the extreme case, partial evaluation can even turn the whole program
32 * into pure first order computation with no control flow.
33 * In such a case, we can compile the whole computation onto SIMD Instruction/GPU/FPGA,
34 * and get huge speedup.
35 *
36 * It works by making the following modifications to the standard relay interpreter:
37 *
38 * 0: The values become partially static value.
39 * Since we cannot know the value of every term at compile time,
40 * Term might get partially evaluated to 'Unknown Value'.
41 * Every partially static value is, hence,
42 * a static fragment that might not be there (partially static),
43 * and a dynamic fragment that is semantically equivalent to the original term,
44 * so the unknown part will be computed at runtime, using the dynamic fragment.
45 *
46 * 1: The interpreter holds a LetList, which preserves A Normal Form for the generated code.
47 * More specifically, we require that all dynamic is an atom.
48 * This avoids code duplication (which is both inefficient and incorrect), as atom has constant size
49 * and allow us to not handle capture-avoidance substitution (as atom has no binder).
50 *
51 * 2: The map of References to partially static values is reified, as described below.
52 * Instead of Reference having mutable field, Reference only has an unique identifier.
53 * There will be a mutable mapping of id to partially static value, called the store.
54 * This allow us to rollback the store:
55 * when a path may or may not be executed (as in a conditional), we copy the store,
56 * recurse with the copy, and reinstate the original when the call returns
57 * so that the effects of the computation are not preserved.
58 * We do this in if else, pattern matching, and in function,
59 * as, when we see a function, we partially evaluate it with all the argument as dynamic,
60 * to generate efficient dynamic for that function.
61 *
62 * 3: The generated code reuses bindings (although they are not shadowed),
63 * so we have to deduplicate them.
64 *
65 * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id.
66 * While it is permitted, most pass use ObjectPtrHash for Var,
67 * and having multiple VarNode for same Id break them.
68 * Thus we remap them to a single Id for now.
69 *
70 * Also, It will also generate lots of dead code,
71 * so it is a good idea to feed it through the dead code eliminator after partial evaluation.
72 *
73 * The partial evaluator makes several assumptions, so there is room for improvement:
74 *
75 * 0: Every time an unknown effect happened, we clear the whole store.
76 * It is too conservative: if a local reference is created (and do not get passed outside),
77 * An unknown global function call/global reference write can not modify it.
78 * We can pair PE with escape analysis/alias analysis.
79 *
80 * 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise.
81 *
82 * 2: When doing pattern matching, we can simplify the match even for dynamic case.
83 * Right now it is all or nothing: either a complete match, or the original dynamic code.
84 * Instead, we can get a match tree, pair it with the data and evaluate it to a normal form.
85 * We then can reify the result.
86 *
87 * 3: Every time a function is called, its code will get expanded and partially evaluated.
88 * We can do a binding time analysis to cache the result and avoid re-partial evaluation.
89 *
90 * These assumptions do not affect the correctness of the algorithm, however.
91 */
92#include <tvm/ir/type_functor.h>
93#include <tvm/relay/analysis.h>
94#include <tvm/relay/expr_functor.h>
95#include <tvm/relay/feature.h>
96#include <tvm/relay/interpreter.h>
97#include <tvm/relay/pattern_functor.h>
98#include <tvm/relay/transform.h>
99
100#include "let_list.h"
101#include "pass_utils.h"
102
103namespace tvm {
104namespace relay {
105namespace partial_eval {
106
107using namespace runtime;
108
109/*! \brief Hash Var by it's id.
110 * Different VarNode might has same vid, and they are considered to be the same var in such case.
111 * Use VarHash to hash Var by id.
112 */
113struct VarHash {
114 size_t operator()(const Var& v) const { return ObjectPtrHash()(v->vid); }
115};
116
117/*! \brief Compare Var by it's id.
118 * Different VarNode might has same vid, and they are considered to be the same var in such case.
119 * Use VarEqual to compare Var by id.
120 */
121struct VarEqual {
122 bool operator()(const Var& l, const Var& r) const { return l->vid.get() == r->vid.get(); }
123};
124
125Expr PostProcess(const Expr&);
126
127/*! \brief A StaticNode contains some static data that the Partial Evaluator can use. */
128class StaticNode : public RelayNode {
129 public:
130 static constexpr const char* _type_key = "relay.Static";
131 TVM_DECLARE_BASE_OBJECT_INFO(StaticNode, RelayNode);
132};
133
134class Static : public ObjectRef {
135 public:
136 Static() {}
137 explicit Static(ObjectPtr<Object> n) : ObjectRef(n) {}
138 const StaticNode* operator->() const { return static_cast<const StaticNode*>(get()); }
139
140 using ContainerType = StaticNode;
141};
142
143using Time = size_t;
144
145struct PStaticNode : Object {
146 static Time time() {
147 static Time time_ = 0;
148 Time ret = time_;
149 time_++;
150 return ret;
151 }
152 Static pstatic; // may be null
153 Expr dynamic;
154 Time created_time;
155 PStaticNode(const Static& pstatic, const Expr& dynamic)
156 : pstatic(pstatic), dynamic(dynamic), created_time(time()) {}
157 explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) {}
158 static constexpr const char* _type_key = "relay.PStatic";
159 TVM_DECLARE_FINAL_OBJECT_INFO(PStaticNode, Object);
160};
161
162class PStatic : public ObjectRef {
163 public:
164 TVM_DEFINE_OBJECT_REF_METHODS(PStatic, ObjectRef, PStaticNode);
165};
166
167struct STupleNode : StaticNode {
168 std::vector<PStatic> fields;
169 explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) {}
170 static constexpr const char* _type_key = "relay.STuple";
171 TVM_DECLARE_FINAL_OBJECT_INFO(STupleNode, StaticNode);
172};
173
174class STuple : public Static {
175 public:
176 TVM_DEFINE_OBJECT_REF_METHODS(STuple, Static, STupleNode);
177};
178
179Static MkSTuple(const std::vector<PStatic>& fields) {
180 return Static(make_object<STupleNode>(fields));
181}
182
183struct STensorNode : StaticNode {
184 runtime::NDArray data;
185 explicit STensorNode(const NDArray& data) : data(data) {}
186 static constexpr const char* _type_key = "relay.STensor";
187 TVM_DECLARE_FINAL_OBJECT_INFO(STensorNode, StaticNode);
188};
189
190class STensor : public Static {
191 public:
192 TVM_DEFINE_OBJECT_REF_METHODS(STensor, Static, STensorNode);
193};
194
195Static MkSTensor(const NDArray& data) { return Static(make_object<STensorNode>(data)); }
196
197struct SConstructorNode : StaticNode {
198 Constructor constructor;
199 std::vector<PStatic> fields;
200 SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields)
201 : constructor(constructor), fields(fields) {}
202 static constexpr const char* _type_key = "relay.SConstructor";
203 TVM_DECLARE_FINAL_OBJECT_INFO(SConstructorNode, StaticNode);
204};
205
206class SConstructor : public Static {
207 public:
208 TVM_DEFINE_OBJECT_REF_METHODS(SConstructor, Static, SConstructorNode);
209};
210
211Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic>& fields) {
212 return Static(make_object<SConstructorNode>(constructor, fields));
213}
214
215struct SRefNode : StaticNode {
216 static constexpr const char* _type_key = "relay.SRef";
217 // we will use the address as the guid for hashing
218 TVM_DECLARE_FINAL_OBJECT_INFO(SRefNode, StaticNode);
219};
220
221class SRef : public Static {
222 public:
223 TVM_DEFINE_OBJECT_REF_METHODS(SRef, Static, SRefNode);
224};
225
226Static MkSRef() { return Static(make_object<SRefNode>()); }
227
228using Func = std::function<PStatic(const PStatic&, const std::vector<PStatic>&, const Attrs&,
229 const Array<Type>&, LetList*)>;
230
231struct SFuncNode : StaticNode {
232 Func func;
233 explicit SFuncNode(const Func& func) : func(func) {}
234 static constexpr const char* _type_key = "relay.SFunc";
235 TVM_DECLARE_FINAL_OBJECT_INFO(SFuncNode, StaticNode);
236};
237
238class SFunc : public Static {
239 public:
240 TVM_DEFINE_OBJECT_REF_METHODS(SFunc, Static, SFuncNode);
241};
242
243Static MkSFunc(const Func& func) { return Static(make_object<SFuncNode>(func)); }
244
245class FuelNode;
246/*! \brief A meet-semilattice with finite descending chain.
247 * It means that we can meet two element to get an element,
248 * and for every element, there is only a finite amount of meet before getting back the same
249 * element.
250 *
251 * Every time we recurse, we do a meet and require that progress must be made.
252 * This ensures we do not recurse infinitely in the Partial Evaluator.
253 */
254class Fuel : public ObjectRef {
255 public:
256 Fuel() {}
257 explicit Fuel(ObjectPtr<Object> n) : ObjectRef(n) {}
258 const FuelNode* operator->() const;
259
260 using ContainerType = FuelNode;
261};
262
263class FuelNode : public RelayNode {
264 public:
265 virtual ~FuelNode() {}
266 // Please implement one of the following function or there will be infinite loop.
267 /*! \brief return the new Fuel, and whether progress is made.
268 *
269 * Note that progress is not symmetric - it only measure progress for (*this).
270 *
271 * Thus, if the generated is smaller then the argument of Meet,
272 * and the generated is not smaller then (*this),
273 * progress should be false.
274 */
275 virtual std::tuple<Fuel, bool> Meet(const Fuel& f) const {
276 bool progress = false;
277 auto ret = Meet(f, &progress);
278 return std::make_tuple(ret, progress);
279 }
280 /*! \brief return the new Fuel, and write (*progress | is progress made) to *progress. */
281 virtual Fuel Meet(const Fuel& f, bool* progress) const {
282 ICHECK(progress);
283 auto ret = Meet(f);
284 *progress |= std::get<1>(ret);
285 return std::get<0>(ret);
286 }
287 static constexpr const char* _type_key = "relay.Fuel";
288 TVM_DECLARE_BASE_OBJECT_INFO(FuelNode, RelayNode);
289};
290
291const FuelNode* Fuel::operator->() const { return static_cast<const FuelNode*>(get()); }
292
293Fuel MkFSeq(const std::vector<Fuel>& fuels);
294struct FSeqNode : FuelNode {
295 std::vector<Fuel> fuels;
296 Fuel Meet(const Fuel& f, bool* progress) const final {
297 auto x = f.as<FSeqNode>();
298 ICHECK(x);
299 ICHECK_EQ(fuels.size(), x->fuels.size());
300 std::vector<Fuel> new_fuels;
301 for (size_t i = 0; i < fuels.size(); ++i) {
302 new_fuels.push_back(fuels[i]->Meet(x->fuels[i], progress));
303 }
304 return MkFSeq(new_fuels);
305 }
306 explicit FSeqNode(const std::vector<Fuel>& fuels) : fuels(fuels) {}
307 static constexpr const char* _type_key = "relay.FSeq";
308 TVM_DECLARE_FINAL_OBJECT_INFO(FSeqNode, FuelNode);
309};
310
311class FSeq : public Fuel {
312 public:
313 TVM_DEFINE_OBJECT_REF_METHODS(FSeq, Fuel, FSeqNode);
314};
315
316Fuel MkFSeq(const std::vector<Fuel>& fuels) { return Fuel(make_object<FSeqNode>(fuels)); }
317
318Fuel MkFTime(Time time);
319struct FTimeNode : FuelNode {
320 Time time;
321 std::tuple<Fuel, bool> Meet(const Fuel& f) const final {
322 auto x = f.as<FTimeNode>();
323 ICHECK(x);
324 Time new_time = std::min(time, x->time);
325 return std::make_tuple(MkFTime(new_time), new_time < time);
326 }
327 explicit FTimeNode(Time time) : time(time) {}
328 static constexpr const char* _type_key = "relay.FTime";
329 TVM_DECLARE_FINAL_OBJECT_INFO(FTimeNode, FuelNode);
330};
331
332class FTime : public Fuel {
333 public:
334 TVM_DEFINE_OBJECT_REF_METHODS(FTime, Fuel, FTimeNode);
335};
336
337Fuel MkFTime(Time time) { return Fuel(make_object<FTimeNode>(time)); }
338
339Fuel MkFTValue(size_t tvalue);
340/*! \brief If the pstatic is hold a positive integer scalar, that number, else 0. */
341struct FTValueNode : FuelNode {
342 size_t tvalue;
343 std::tuple<Fuel, bool> Meet(const Fuel& f) const final {
344 auto x = f.as<FTValueNode>();
345 ICHECK(x);
346 size_t new_tvalue = std::min(tvalue, x->tvalue);
347 return std::make_tuple(MkFTValue(new_tvalue), new_tvalue < tvalue);
348 }
349 explicit FTValueNode(size_t tvalue) : tvalue(tvalue) {}
350 static constexpr const char* _type_key = "relay.FTValue";
351 TVM_DECLARE_FINAL_OBJECT_INFO(FTValueNode, FuelNode);
352};
353
354class FTValue : public Fuel {
355 public:
356 TVM_DEFINE_OBJECT_REF_METHODS(FTValue, Fuel, FTValueNode);
357};
358
359Fuel MkFTValue(size_t tvalue) { return Fuel(make_object<FTValueNode>(tvalue)); }
360
361/*! \brief Initially every element has Fuel of FTop. It is the largest element.
362 *
363 * Note that it is illegal to has FTop inside some other Fuel -
364 * doing so break the finite descending chain property.
365 */
366struct FTopNode : FuelNode {
367 std::tuple<Fuel, bool> Meet(const Fuel& f) const final {
368 return std::make_tuple(f, !f.as<FTopNode>());
369 }
370 static constexpr const char* _type_key = "relay.FTop";
371 TVM_DECLARE_FINAL_OBJECT_INFO(FTopNode, FuelNode);
372};
373
374class FTop : public Fuel {
375 public:
376 TVM_DEFINE_OBJECT_REF_METHODS(FTop, Fuel, FTopNode);
377};
378
379Fuel MkFTop() { return Fuel(make_object<FTopNode>()); }
380
381/*!
382 * \brief A stack frame in the Relay interpreter.
383 *
384 * Contains a mapping from relay::Var to relay::Object.
385 */
386struct Frame {
387 /*! \brief The set of local variables and arguments for the frame. */
388 std::unordered_map<Var, PStatic, VarHash, VarEqual> locals;
389 Frame() = default;
390};
391
392class Environment {
393 public:
394 Environment() : env_({Frame()}) {}
395 Environment(const Environment&) = delete;
396
397 template <typename T>
398 T Extend(const std::function<T()>& body) {
399 FrameContext fc(this);
400 return body();
401 }
402
403 void Insert(const Var& v, const PStatic& ps) {
404 ICHECK(ps.defined());
405 ICHECK_GT(env_.size(), 0);
406 ICHECK_EQ(env_.back().locals.count(v), 0);
407 env_.back().locals[v] = ps;
408 }
409
410 PStatic Lookup(const Var& v) {
411 auto rit = env_.rbegin();
412 while (rit != env_.rend()) {
413 if (rit->locals.find(v) != rit->locals.end()) {
414 return rit->locals.find(v)->second;
415 }
416 ++rit;
417 }
418 LOG(FATAL) << "Unknown Variable: " << v;
419 throw;
420 }
421
422 private:
423 std::list<Frame> env_;
424
425 struct FrameContext {
426 Environment* env_;
427 explicit FrameContext(Environment* env) : env_(env) { env_->env_.push_back(Frame()); }
428 ~FrameContext() { env_->env_.pop_back(); }
429 };
430};
431
432/*!
433 * \brief As our store require rollback, we implement it as a frame.
434 *
435 * Every time we need to copy the store, a new frame is insert.
436 * Every time we roll back, a frame is popped.
437 */
438struct StoreFrame {
439 std::unordered_map<const SRefNode*, PStatic> store;
440 /*!
441 * \brief On unknown effect, history_valid is set to true to signal above frame is outdated.
442 *
443 * It only outdate the frame above it, but not the current frame.
444 */
445 bool history_valid = true;
446 explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) {}
447 StoreFrame() = default;
448};
449
450class Store {
451 public:
452 Store() : store_({StoreFrame()}) {}
453 Store(const Store&) = delete;
454
455 template <typename T>
456 T Extend(const std::function<T()>& body) {
457 StoreFrameContext sfc(this);
458 return body();
459 }
460
461 void Insert(const SRefNode* r, const PStatic& ps) {
462 ICHECK(r);
463 store_.back().store[r] = ps;
464 }
465
466 // return null if not found
467 PStatic Lookup(const SRefNode* r) {
468 auto rit = store_.rbegin();
469 while (rit != store_.rend()) {
470 if (rit->store.find(r) != rit->store.end()) {
471 return rit->store.find(r)->second;
472 }
473 if (!rit->history_valid) {
474 return PStatic();
475 }
476 ++rit;
477 }
478 return PStatic();
479 }
480
481 void Invalidate() {
482 StoreFrame sf;
483 sf.history_valid = false;
484 store_.push_back(sf);
485 }
486
487 private:
488 std::list<StoreFrame> store_;
489
490 struct StoreFrameContext {
491 Store* store_;
492 explicit StoreFrameContext(Store* store) : store_(store) {
493 store_->store_.push_back(StoreFrame());
494 }
495 ~StoreFrameContext() {
496 // push one history valid frame off.
497 while (!store_->store_.back().history_valid) {
498 store_->store_.pop_back();
499 }
500 store_->store_.pop_back();
501 }
502 };
503};
504
505PStatic HasStatic(const Static& stat, const Expr& dynamic) {
506 ICHECK(stat.defined());
507 return PStatic(make_object<PStaticNode>(stat, dynamic));
508}
509
510PStatic NoStatic(const Expr& dynamic) { return PStatic(make_object<PStaticNode>(dynamic)); }
511
512enum struct MatchStatus { Match, NoMatch, Unknown };
513
514bool StatefulOp(const Expr& e) {
515 static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
516 struct StatefulOpVisitor : ExprVisitor {
517 bool stateful = false;
518 void VisitExpr_(const OpNode* op) {
519 stateful = stateful || op_stateful.get(GetRef<Op>(op), false);
520 }
521 };
522 StatefulOpVisitor sov;
523 sov(e);
524 return sov.stateful;
525}
526
527using FInterpreter = runtime::TypedPackedFunc<ObjectRef(Expr)>;
528
529Target CPUTarget() { return Target("llvm"); }
530
531Device CPUDevice() {
532 Device dev;
533 dev.device_type = kDLCPU;
534 dev.device_id = 0;
535 return dev;
536}
537
538using FuncId = int;
539
540/*!
541 * \brief Annotate a function with a FuncId.
542 */
543struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
544 FuncId fid;
545
546 TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") {
547 TVM_ATTR_FIELD(fid).describe("The FuncId that an function is annotated with.").set_default(-1);
548 }
549};
550
551TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
552
553RELAY_REGISTER_OP("annotation.with_funcid")
554 .describe(R"code(Annotate a function with a funcid.)code" TVM_ADD_FILELINE)
555 .set_num_inputs(1)
556 .add_argument("func", "Function", "The input data.");
557
558// Cache with_funcid op to reduce lookup overhead during traversal.
559static const Op& with_funcid_op = Op::Get("annotation.with_funcid");
560
561Expr MkWithFuncId(const Expr& expr, FuncId fid) {
562 auto attrs = make_object<WithFuncIdAttrs>();
563 attrs->fid = fid;
564 return Call(with_funcid_op, {expr}, Attrs(attrs), {});
565}
566
567Expr StripWithFuncId(const Expr& e);
568
569Function AsFunc(const Expr& e) {
570 if (e.as<FunctionNode>()) {
571 return Downcast<Function>(e);
572 } else if (const CallNode* c = e.as<CallNode>()) {
573 ICHECK(c->op == with_funcid_op);
574 ICHECK_EQ(c->args.size(), 1);
575 return AsFunc(c->args[0]);
576 } else {
577 LOG(FATAL) << "Unknown case";
578 throw;
579 }
580}
581
582class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>,
583 public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> {
584 public:
585 PartialEvaluator(const IRModule& mod) : mod_(mod) {}
586
587 PStatic VisitExpr(const Expr& e, LetList* ll) final {
588 PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll);
589 ICHECK(IsAtomic(ret->dynamic)) << ret->dynamic;
590 return ret;
591 }
592
593 PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) {
594 if (const CallNode* c = e.as<CallNode>()) {
595 if (c->op == with_funcid_op) {
596 ICHECK_EQ(c->args.size(), 1);
597 return VisitExpr(c->args[0], ll, name);
598 }
599 }
600 PStatic ret =
601 e.as<FunctionNode>() ? VisitFunc(Downcast<Function>(e), ll, name) : VisitExpr(e, ll);
602 ICHECK(IsAtomic(ret->dynamic)) << ret->dynamic;
603 return ret;
604 }
605
606 PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final {
607 return HasStatic(MkSTensor(op->data.CopyTo(device_)), ll->Push(GetRef<Expr>(op)));
608 }
609
610 PStatic VisitExpr_(const TupleNode* op, LetList* ll) final {
611 std::vector<PStatic> value;
612 tvm::Array<Expr> expr;
613 for (const Expr& e : op->fields) {
614 PStatic ps = VisitExpr(e, ll);
615 value.push_back(ps);
616 expr.push_back(ps->dynamic);
617 }
618 // Note: The partial evaluator seems to do some weird stuff with sharing. Changing Tuple(expr)
619 // to WithFields(op, expr) causes failures in the partial evaluator tests.
620 return HasStatic(MkSTuple(value), ll->Push(Tuple(expr)));
621 }
622
623 PStatic VisitExpr_(const TupleGetItemNode* op, LetList* ll) final {
624 PStatic ps = VisitExpr(op->tuple, ll);
625 if (ps->pstatic.defined()) {
626 return Downcast<STuple>(ps->pstatic)->fields[op->index];
627 } else {
628 return NoStatic(ll->Push(TupleGetItem(ps->dynamic, op->index)));
629 }
630 }
631
632 PStatic VisitExpr_(const VarNode* op, LetList* ll) final { return env_.Lookup(GetRef<Var>(op)); }
633
634 PStatic VisitGlobalVar(const GlobalVar& gv) {
635 ICHECK(mod_.defined());
636 if (gv_map_.count(gv) == 0) {
637 BaseFunc base_func = mod_->Lookup(gv);
638 if (auto* n = base_func.as<FunctionNode>()) {
639 Function func = GetRef<Function>(n);
640 InitializeFuncId(func);
641 Func f = VisitFuncStatic(func, gv);
642 gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)});
643 func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv)));
644 mod_->Update(gv, func);
645 return gv_map_.at(gv);
646 } else {
647 return NoStatic(gv);
648 }
649 }
650 return gv_map_.at(gv);
651 }
652
653 PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final {
654 return VisitGlobalVar(GetRef<GlobalVar>(op));
655 }
656
657 PStatic VisitExpr_(const LetNode* op, LetList* ll) final {
658 env_.Insert(op->var, VisitExpr(op->value, ll, op->var));
659 return VisitExpr(op->body, ll);
660 }
661
662 PStatic VisitExpr_(const IfNode* op, LetList* ll) final {
663 PStatic c = VisitExpr(op->cond, ll);
664 if (c->pstatic.defined()) {
665 NDArray cpu_array = Downcast<STensor>(c->pstatic)->data.CopyTo(CPUDevice());
666 ICHECK_EQ(DataType(cpu_array->dtype), DataType::Bool());
667 if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) {
668 return VisitExpr(op->true_branch, ll);
669 } else {
670 return VisitExpr(op->false_branch, ll);
671 }
672 } else {
673 Expr t = store_.Extend<Expr>([&]() {
674 return LetList::With([&](LetList* ll) { return VisitExpr(op->true_branch, ll)->dynamic; });
675 });
676 Expr f = store_.Extend<Expr>([&]() {
677 return LetList::With([&](LetList* ll) { return VisitExpr(op->false_branch, ll)->dynamic; });
678 });
679 store_.Invalidate();
680 return NoStatic(ll->Push(If(c->dynamic, t, f)));
681 }
682 }
683
684 PStatic VisitExpr_(const RefCreateNode* op, LetList* ll) final {
685 PStatic ps = VisitExpr(op->value, ll);
686 Static r = MkSRef();
687 store_.Insert(r.as<SRefNode>(), ps);
688 return HasStatic(r, ll->Push(RefCreate(ps->dynamic)));
689 }
690
691 PStatic VisitExpr_(const RefWriteNode* op, LetList* ll) final {
692 PStatic r = VisitExpr(op->ref, ll);
693 PStatic v = VisitExpr(op->value, ll);
694 if (r->pstatic.defined()) {
695 store_.Insert(r->pstatic.as<SRefNode>(), v);
696 } else {
697 store_.Invalidate();
698 }
699 return HasStatic(MkSTuple({}), ll->Push(RefWrite(r->dynamic, v->dynamic)));
700 }
701
702 PStatic VisitExpr_(const RefReadNode* op, LetList* ll) final {
703 PStatic r = VisitExpr(op->ref, ll);
704 if (r->pstatic.defined()) {
705 PStatic ret = store_.Lookup(r->pstatic.as<SRefNode>());
706 if (ret.defined()) {
707 return ret;
708 }
709 }
710 return NoStatic(ll->Push(RefRead(r->dynamic)));
711 }
712
713 PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
714 if (op->op == with_funcid_op) {
715 ICHECK_EQ(op->args.size(), 1);
716 return VisitExpr(op->args[0], ll);
717 }
718 PStatic f = VisitExpr(op->op, ll);
719 std::vector<PStatic> x;
720 tvm::Array<Expr> x_dyn;
721 for (const Expr& e : op->args) {
722 PStatic ps = VisitExpr(e, ll);
723 x.push_back(ps);
724 x_dyn.push_back(ps->dynamic);
725 }
726 if (f->pstatic.defined()) {
727 return Downcast<SFunc>(f->pstatic)->func(f, x, op->attrs, op->type_args, ll);
728 } else {
729 store_.Invalidate();
730 return NoStatic(ll->Push(Call(f->dynamic, x_dyn, op->attrs, op->type_args)));
731 }
732 }
733
734 struct FuelFrame {
735 PartialEvaluator* pe_;
736 FuncId fid_;
737 Fuel old_fuel;
738 FuelFrame(PartialEvaluator* pe, FuncId fid, const Fuel& new_fuel) : pe_(pe), fid_(fid) {
739 ICHECK_GT(pe_->fuel_map_.count(fid_), 0);
740 old_fuel = pe_->fuel_map_[fid_];
741 pe_->fuel_map_[fid_] = new_fuel;
742 }
743 ~FuelFrame() { pe_->fuel_map_[fid_] = old_fuel; }
744 };
745
746 size_t GetFTValue(const PStatic& ps) {
747 if (ps->pstatic.defined()) {
748 if (auto* st = ps->pstatic.as<STensorNode>()) {
749 if (st->data.Shape().empty()) {
750 NDArray cpu_array = st->data.CopyTo(CPUDevice());
751 DataType dtype = DataType(cpu_array->dtype);
752 if (dtype == DataType::Int(32)) {
753 return std::max<int32_t>(0, *static_cast<const int32_t*>(cpu_array->data));
754 } else if (dtype == DataType::Int(64)) {
755 return std::max<int64_t>(0, *static_cast<const int64_t*>(cpu_array->data));
756 }
757 }
758 }
759 }
760 return 0;
761 }
762
763 Fuel GetFuel(const PStatic& ps) {
764 std::vector<Fuel> fuels;
765 fuels.push_back(MkFTime(ps->created_time));
766 fuels.push_back(MkFTValue(GetFTValue(ps)));
767 return MkFSeq(fuels);
768 }
769
770 Func VisitFuncStatic(const Function& func, const Expr& var) {
771 ICHECK(IsAtomic(var));
772 if (func->HasNonzeroAttr(attr::kPrimitive)) {
773 return ConstEvaluateFunc(func);
774 }
775 std::vector<std::pair<Var, PStatic>> free_vars;
776 for (const auto& v : FreeVars(func)) {
777 if (v != var) {
778 free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v)));
779 }
780 }
781 return [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs,
782 const tvm::Array<Type>& type_args, LetList* ll) {
783 return env_.Extend<PStatic>([&]() {
784 ICHECK_EQ(pv.size(), func->params.size());
785 ICHECK_GT(func_map_.count(func), 0);
786 FuncId fid = func_map_.at(func);
787 if (fuel_map_.count(fid) == 0) {
788 fuel_map_.insert({fid, MkFTop()});
789 }
790 std::vector<Fuel> args_fuel;
791 for (const auto& v : pv) {
792 args_fuel.push_back(GetFuel(v));
793 }
794 auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel));
795 if (std::get<1>(meet_res)) {
796 FuelFrame tf(this, fid, std::get<0>(meet_res));
797 Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func)));
798 Function func = AsFunc(dedup_func);
799 if (var.as<VarNode>()) {
800 env_.Insert(Downcast<Var>(var), self);
801 }
802 for (size_t i = 0; i < pv.size(); ++i) {
803 env_.Insert(func->params[i], pv[i]);
804 }
805 for (const auto& p : free_vars) {
806 env_.Insert(p.first, p.second);
807 }
808 tvm::Map<TypeVar, Type> subst;
809 for (size_t i = 0; i < type_args.size(); ++i) {
810 subst.Set(func->type_params[i], type_args[i]);
811 }
812 for (size_t i = type_args.size(); i < func->type_params.size(); ++i) {
813 subst.Set(func->type_params[i], IncompleteType(kType));
814 }
815 return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll);
816 } else {
817 std::vector<Expr> dyn;
818 for (const auto& v : pv) {
819 dyn.push_back(v->dynamic);
820 }
821 return NoStatic(ll->Push(Call(var, dyn, attrs, type_args)));
822 }
823 });
824 };
825 }
826
827 Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) {
828 return store_.Extend<Expr>([&]() {
829 store_.Invalidate();
830 return WithFields(
831 func, func->params, LetList::With([&](LetList* ll) {
832 std::vector<PStatic> pv;
833 for (const auto& v : func->params) {
834 pv.push_back(NoStatic(v));
835 }
836 tvm::Array<Type> type_args;
837 for (const auto& tp : func->type_params) {
838 type_args.push_back(tp);
839 }
840 return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic;
841 }));
842 });
843 }
844
845 PStatic VisitFunc(const Function& func, LetList* ll, const Var& name) {
846 Func f = VisitFuncStatic(func, name);
847 Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func))));
848 // TODO(@M.K.): we seems to reduce landin knot into letrec.
849 // restore letrec support across whole relay.
850 return HasStatic(MkSFunc(f), ll->Push(name, VisitFuncDynamic(u_func, f, name)));
851 }
852
853 PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final {
854 return VisitFunc(GetRef<Function>(op), ll, Var::GenSym());
855 }
856
857 struct ReflectError : Error {
858 ReflectError() : Error("static value not found") {}
859 };
860
861 Expr Reflect(const PStatic& st) {
862 if (!st->pstatic.defined()) {
863 throw ReflectError();
864 } else if (const STensorNode* op = st->pstatic.as<STensorNode>()) {
865 return Constant(op->data);
866 } else if (const STupleNode* op = st->pstatic.as<STupleNode>()) {
867 tvm::Array<Expr> fields;
868 for (const PStatic& field : op->fields) {
869 fields.push_back(Reflect(field));
870 }
871 return Tuple(fields);
872 } else {
873 LOG(FATAL) << "Unknown case: " << st->dynamic;
874 throw;
875 }
876 }
877
878 PStatic Reify(const ObjectRef& v, LetList* ll) const {
879 if (v->IsInstance<runtime::NDArray::ContainerType>()) {
880 auto nd_array = Downcast<runtime::NDArray>(v);
881 return HasStatic(MkSTensor(nd_array), ll->Push(Constant(nd_array)));
882 } else if (const runtime::ADTObj* op = v.as<runtime::ADTObj>()) {
883 std::vector<PStatic> fields;
884 tvm::Array<Expr> fields_dyn;
885 auto adt = GetRef<runtime::ADT>(op);
886 for (size_t i = 0; i < adt.size(); ++i) {
887 PStatic ps = Reify(adt[i], ll);
888 fields.push_back(ps);
889 fields_dyn.push_back(ps->dynamic);
890 }
891 return HasStatic(MkSTuple(fields), ll->Push(Tuple(fields_dyn)));
892 } else {
893 LOG(FATAL) << "Unknown case";
894 throw;
895 }
896 }
897
898 // Constant evaluate an expression.
899 PStatic ConstEvaluate(const Expr& expr, LetList* ll) {
900 // use a fresh build context in case we are already in a build context.
901 With<transform::PassContext> fresh_build_ctx(transform::PassContext::Create());
902 return Reify(Eval(expr, mod_->type_definitions, mod_->Imports(), CPUDevice(), CPUTarget()), ll);
903 }
904
905 Func ConstEvaluateFunc(const Expr& expr) {
906 ICHECK_EQ(FreeVars(expr).size(), 0);
907 return [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs,
908 const tvm::Array<Type>& type_args, LetList* ll) {
909 tvm::Array<Expr> ns_args;
910 for (const PStatic& ps : pv) {
911 ns_args.push_back(ps->dynamic);
912 }
913 auto ns = [&]() { return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args))); };
914 if (StatefulOp(expr)) {
915 return ns();
916 }
917 try {
918 tvm::Array<Expr> args;
919 for (const PStatic& ps : pv) {
920 args.push_back(Reflect(ps));
921 }
922 return ConstEvaluate(Call(expr, args, attrs, type_args), ll);
923 } catch (const ReflectError&) {
924 return ns();
925 }
926 };
927 }
928
929 PStatic VisitExpr_(const OpNode* op, LetList* ll) final {
930 return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op))), GetRef<Expr>(op));
931 }
932
933 PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final {
934 Constructor c = GetRef<Constructor>(op);
935 Func f = [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs,
936 const tvm::Array<Type>& type_args, LetList* ll) {
937 tvm::Array<Expr> dyn;
938 for (const PStatic& ps : pv) {
939 dyn.push_back(ps->dynamic);
940 }
941 return HasStatic(MkSConstructor(c, pv), ll->Push(Call(c, dyn)));
942 };
943 return HasStatic(MkSFunc(f), GetRef<Expr>(op));
944 }
945
946 PStatic VisitExpr_(const MatchNode* op, LetList* ll) final {
947 PStatic ps = VisitExpr(op->data, ll);
948 return env_.Extend<PStatic>([&]() {
949 for (const Clause& c : op->clauses) {
950 switch (VisitPattern(c->lhs, ps)) {
951 case MatchStatus::Match:
952 return VisitExpr(c->rhs, ll);
953 case MatchStatus::NoMatch:
954 continue;
955 case MatchStatus::Unknown:
956 return [&]() {
957 tvm::Array<Clause> clauses;
958 for (const Clause& c : op->clauses) {
959 Expr expr = store_.Extend<Expr>([&]() {
960 return LetList::With([&](LetList* ll) {
961 for (const Var& v : BoundVars(c->lhs)) {
962 env_.Insert(v, NoStatic(v));
963 }
964 return VisitExpr(c->rhs, ll)->dynamic;
965 });
966 });
967 clauses.push_back(Clause(c->lhs, expr));
968 }
969 store_.Invalidate();
970 return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete)));
971 }();
972 default:
973 LOG(FATAL) << "Unknown MatchStatus";
974 throw;
975 }
976 }
977 LOG(FATAL) << "No case Match";
978 throw;
979 });
980 }
981
982 MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final {
983 return MatchStatus::Match;
984 }
985
986 MatchStatus VisitPattern_(const PatternVarNode* op, const PStatic& ps) final {
987 env_.Insert(op->var, ps);
988 return MatchStatus::Match;
989 }
990
991 MatchStatus VisitPattern_(const PatternConstructorNode* op, const PStatic& ps) final {
992 if (ps->pstatic.defined()) {
993 SConstructor scn = Downcast<SConstructor>(ps->pstatic);
994 ICHECK_NE(op->constructor->tag, -1);
995 ICHECK_NE(scn->constructor->tag, -1);
996 if (op->constructor->tag == scn->constructor->tag) {
997 ICHECK_EQ(op->patterns.size(), scn->fields.size());
998 MatchStatus current_match_status = MatchStatus::Match;
999 for (size_t i = 0; i < op->patterns.size(); ++i) {
1000 MatchStatus ms = VisitPattern(op->patterns[i], scn->fields[i]);
1001 switch (ms) {
1002 case MatchStatus::Match:
1003 continue;
1004 case MatchStatus::NoMatch:
1005 return MatchStatus::NoMatch;
1006 case MatchStatus::Unknown:
1007 current_match_status = MatchStatus::Unknown;
1008 }
1009 }
1010 return current_match_status;
1011 }
1012 return MatchStatus::NoMatch;
1013 } else {
1014 return MatchStatus::Unknown;
1015 }
1016 }
1017
1018 MatchStatus VisitPattern_(const PatternTupleNode* op, const PStatic& ps) final {
1019 if (ps->pstatic.defined()) {
1020 STuple stn = Downcast<STuple>(ps->pstatic);
1021 ICHECK_EQ(op->patterns.size(), stn->fields.size());
1022 MatchStatus current_match_status = MatchStatus::Match;
1023 for (size_t i = 0; i < op->patterns.size(); ++i) {
1024 MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]);
1025 switch (ms) {
1026 case MatchStatus::Match:
1027 continue;
1028 case MatchStatus::NoMatch:
1029 return MatchStatus::NoMatch;
1030 case MatchStatus::Unknown:
1031 current_match_status = MatchStatus::Unknown;
1032 }
1033 }
1034 return current_match_status;
1035 } else {
1036 return MatchStatus::Unknown;
1037 }
1038 }
1039
1040 void InitializeFuncId(const Expr& e) {
1041 struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor {
1042 PartialEvaluator* pe;
1043 explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {}
1044
1045 void VisitExpr_(const FunctionNode* op) final {
1046 Function f = GetRef<Function>(op);
1047 ICHECK_EQ(pe->func_map_.count(f), 0);
1048 pe->func_map_.insert({f, pe->func_map_.size()});
1049 VisitExpr(f->body);
1050 }
1051
1052 void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
1053 };
1054 InitializeFuncIdVisitor(this).VisitExpr(e);
1055 }
1056
1057 Expr RegisterFuncId(const Expr& e) {
1058 struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor {
1059 PartialEvaluator* pe;
1060 explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {}
1061
1062 void VisitExpr_(const CallNode* op) final {
1063 if (op->op == with_funcid_op) {
1064 ICHECK_EQ(op->args.size(), 1);
1065 ICHECK(op->attrs.defined());
1066 ICHECK(op->attrs.as<WithFuncIdAttrs>());
1067 Function f = AsFunc(op->args[0]);
1068 FuncId fid = op->attrs.as<WithFuncIdAttrs>()->fid;
1069 if (pe->func_map_.count(f) != 0) {
1070 ICHECK_EQ(pe->func_map_.at(f), fid);
1071 }
1072 pe->func_map_.insert({f, fid});
1073 }
1074 ExprVisitor::VisitExpr_(op);
1075 }
1076
1077 void VisitExpr_(const FunctionNode* op) final {
1078 Function f = GetRef<Function>(op);
1079 ICHECK_GT(pe->func_map_.count(f), 0);
1080 ExprVisitor::VisitExpr_(op);
1081 }
1082
1083 void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
1084 };
1085 RegisterFuncIdVisitor(this).VisitExpr(e);
1086 return e;
1087 }
1088
1089 Expr AnnotateFuncId(const Expr& e) {
1090 struct AnnotateFuncIdMutator : ExprMutator, PatternMutator {
1091 PartialEvaluator* pe;
1092 explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) {}
1093
1094 Expr VisitExpr_(const FunctionNode* op) final {
1095 Function f = GetRef<Function>(op);
1096 ICHECK_GT(pe->func_map_.count(f), 0);
1097 return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f));
1098 }
1099
1100 Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
1101
1102 Var VisitVar(const Var& v) final { return v; }
1103 };
1104 return AnnotateFuncIdMutator(this).VisitExpr(e);
1105 }
1106
1107 private:
1108 Environment env_;
1109 IRModule mod_;
1110 std::unordered_map<GlobalVar, PStatic, ObjectPtrHash, ObjectPtrEqual> gv_map_;
1111 /*! Termination checking is done as follows:
1112 * We have finitely many FunctionIds.
1113 * Each FunctionId maps to a class of semantically equivalent function (ignoring type),
1114 * as both TypeSubst and DeDup create semantically equivalent function.
1115 * We partially map each FunctionId to a Fuel.
1116 * Every time we try to inline a Function,
1117 * we make sure it either does not have a Fuel,
1118 * or we meet the existing fuel with the fuel calculated from the argument.
1119 * If no progress is made, we do not inline.
1120 * In both case, we remap the mapping to the new Fuel
1121 * when we PE inside the Function body.
1122 * Termination is guaranteed because Fuel is finitely descending - there can only be so many
1123 * meet.
1124 */
1125 std::unordered_map<Function, FuncId, ObjectPtrHash, ObjectPtrEqual> func_map_;
1126 std::unordered_map<FuncId, Fuel> fuel_map_;
1127 Store store_;
1128 Device device_ = CPUDevice();
1129};
1130
1131/*! \brief Remap multiple Var sharing the same Id into the same Var. */
1132Expr Remap(const Expr& e) {
1133 class RemapMutator : public ExprMutator, public PatternMutator {
1134 Expr VisitExpr_(const VarNode* op) final {
1135 Var v = GetRef<Var>(op);
1136 if (remap_.count(v) == 0) {
1137 remap_.insert({v, v});
1138 }
1139 return remap_.at(v);
1140 }
1141
1142 Var VisitVar(const Var& v) final { return Downcast<Var>(VisitExpr(v)); }
1143
1144 private:
1145 std::unordered_map<Var, Var, VarHash, VarEqual> remap_;
1146 };
1147 return RemapMutator().VisitExpr(e);
1148}
1149
1150Expr StripWithFuncId(const Expr& e) {
1151 struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
1152 Expr VisitExpr_(const CallNode* op) final {
1153 if (op->op == with_funcid_op) {
1154 ICHECK_EQ(op->args.size(), 1);
1155 return VisitExpr(op->args[0]);
1156 } else {
1157 return ExprMutator::VisitExpr_(op);
1158 }
1159 }
1160
1161 Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
1162
1163 Var VisitVar(const Var& v) final { return v; }
1164 };
1165 return StripWithFuncIdMutator().VisitExpr(e);
1166}
1167
1168Expr PostProcess(const Expr& e) { return StripWithFuncId(DeDup(Remap(e))); }
1169
1170} // namespace partial_eval
1171
1172IRModule PartialEval(const IRModule& m) {
1173 CheckFeature(m, FeatureSet::All() - fGraph);
1174 relay::partial_eval::PartialEvaluator pe(m);
1175 std::vector<GlobalVar> gvs;
1176 for (const auto& p : m->functions) {
1177 gvs.push_back(p.first);
1178 }
1179 for (const auto& gv : gvs) {
1180 pe.VisitGlobalVar(gv);
1181 }
1182 CheckFeature(m, FeatureSet::All() - fGraph);
1183 return m;
1184}
1185
1186namespace transform {
1187
1188Pass PartialEval() {
1189 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
1190 [=](IRModule m, PassContext pc) { return relay::PartialEval(m); };
1191 return CreateModulePass(pass_func, 1, "PartialEval", {});
1192}
1193
1194TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval);
1195
1196} // namespace transform
1197
1198} // namespace relay
1199} // namespace tvm
1200