1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file partial_eval.cc |
22 | * |
23 | * \brief Perform known computation in compile time. |
24 | * |
25 | * The partial evaluator try to do computation at compile time, |
26 | * so it can generate code that do less work. |
27 | * Additionally, it might open more chance for further optimization, |
28 | * since the high level, structural part of the code (closure, reference, control flow) |
29 | * might get partially evaluated away, and the subsequent optimization (for example, kernel fusion) |
30 | * can reason across those structural code as it got removed. |
31 | * In the extreme case, partial evaluation can even turn the whole program |
32 | * into pure first order computation with no control flow. |
33 | * In such a case, we can compile the whole computation onto SIMD Instruction/GPU/FPGA, |
34 | * and get huge speedup. |
35 | * |
36 | * It works by making the following modifications to the standard relay interpreter: |
37 | * |
38 | * 0: The values become partially static value. |
39 | * Since we cannot know the value of every term at compile time, |
40 | * Term might get partially evaluated to 'Unknown Value'. |
41 | * Every partially static value is, hence, |
42 | * a static fragment that might not be there (partially static), |
43 | * and a dynamic fragment that is semantically equivalent to the original term, |
44 | * so the unknown part will be computed at runtime, using the dynamic fragment. |
45 | * |
46 | * 1: The interpreter holds a LetList, which preserves A Normal Form for the generated code. |
47 | * More specifically, we require that all dynamic is an atom. |
48 | * This avoids code duplication (which is both inefficient and incorrect), as atom has constant size |
49 | * and allow us to not handle capture-avoidance substitution (as atom has no binder). |
50 | * |
51 | * 2: The map of References to partially static values is reified, as described below. |
52 | * Instead of Reference having mutable field, Reference only has an unique identifier. |
53 | * There will be a mutable mapping of id to partially static value, called the store. |
54 | * This allow us to rollback the store: |
55 | * when a path may or may not be executed (as in a conditional), we copy the store, |
56 | * recurse with the copy, and reinstate the original when the call returns |
57 | * so that the effects of the computation are not preserved. |
58 | * We do this in if else, pattern matching, and in function, |
59 | * as, when we see a function, we partially evaluate it with all the argument as dynamic, |
60 | * to generate efficient dynamic for that function. |
61 | * |
62 | * 3: The generated code reuses bindings (although they are not shadowed), |
63 | * so we have to deduplicate them. |
64 | * |
65 | * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id. |
66 | * While it is permitted, most pass use ObjectPtrHash for Var, |
67 | * and having multiple VarNode for same Id break them. |
68 | * Thus we remap them to a single Id for now. |
69 | * |
70 | * Also, It will also generate lots of dead code, |
71 | * so it is a good idea to feed it through the dead code eliminator after partial evaluation. |
72 | * |
73 | * The partial evaluator makes several assumptions, so there is room for improvement: |
74 | * |
75 | * 0: Every time an unknown effect happened, we clear the whole store. |
76 | * It is too conservative: if a local reference is created (and do not get passed outside), |
77 | * An unknown global function call/global reference write can not modify it. |
78 | * We can pair PE with escape analysis/alias analysis. |
79 | * |
80 | * 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise. |
81 | * |
82 | * 2: When doing pattern matching, we can simplify the match even for dynamic case. |
83 | * Right now it is all or nothing: either a complete match, or the original dynamic code. |
84 | * Instead, we can get a match tree, pair it with the data and evaluate it to a normal form. |
85 | * We then can reify the result. |
86 | * |
87 | * 3: Every time a function is called, its code will get expanded and partially evaluated. |
88 | * We can do a binding time analysis to cache the result and avoid re-partial evaluation. |
89 | * |
90 | * These assumptions do not affect the correctness of the algorithm, however. |
91 | */ |
92 | #include <tvm/ir/type_functor.h> |
93 | #include <tvm/relay/analysis.h> |
94 | #include <tvm/relay/expr_functor.h> |
95 | #include <tvm/relay/feature.h> |
96 | #include <tvm/relay/interpreter.h> |
97 | #include <tvm/relay/pattern_functor.h> |
98 | #include <tvm/relay/transform.h> |
99 | |
100 | #include "let_list.h" |
101 | #include "pass_utils.h" |
102 | |
103 | namespace tvm { |
104 | namespace relay { |
105 | namespace partial_eval { |
106 | |
107 | using namespace runtime; |
108 | |
109 | /*! \brief Hash Var by it's id. |
110 | * Different VarNode might has same vid, and they are considered to be the same var in such case. |
111 | * Use VarHash to hash Var by id. |
112 | */ |
113 | struct VarHash { |
114 | size_t operator()(const Var& v) const { return ObjectPtrHash()(v->vid); } |
115 | }; |
116 | |
117 | /*! \brief Compare Var by it's id. |
118 | * Different VarNode might has same vid, and they are considered to be the same var in such case. |
119 | * Use VarEqual to compare Var by id. |
120 | */ |
121 | struct VarEqual { |
122 | bool operator()(const Var& l, const Var& r) const { return l->vid.get() == r->vid.get(); } |
123 | }; |
124 | |
125 | Expr PostProcess(const Expr&); |
126 | |
127 | /*! \brief A StaticNode contains some static data that the Partial Evaluator can use. */ |
128 | class StaticNode : public RelayNode { |
129 | public: |
130 | static constexpr const char* _type_key = "relay.Static" ; |
131 | TVM_DECLARE_BASE_OBJECT_INFO(StaticNode, RelayNode); |
132 | }; |
133 | |
134 | class Static : public ObjectRef { |
135 | public: |
136 | Static() {} |
137 | explicit Static(ObjectPtr<Object> n) : ObjectRef(n) {} |
138 | const StaticNode* operator->() const { return static_cast<const StaticNode*>(get()); } |
139 | |
140 | using ContainerType = StaticNode; |
141 | }; |
142 | |
143 | using Time = size_t; |
144 | |
145 | struct PStaticNode : Object { |
146 | static Time time() { |
147 | static Time time_ = 0; |
148 | Time ret = time_; |
149 | time_++; |
150 | return ret; |
151 | } |
152 | Static pstatic; // may be null |
153 | Expr dynamic; |
154 | Time created_time; |
155 | PStaticNode(const Static& pstatic, const Expr& dynamic) |
156 | : pstatic(pstatic), dynamic(dynamic), created_time(time()) {} |
157 | explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) {} |
158 | static constexpr const char* _type_key = "relay.PStatic" ; |
159 | TVM_DECLARE_FINAL_OBJECT_INFO(PStaticNode, Object); |
160 | }; |
161 | |
162 | class PStatic : public ObjectRef { |
163 | public: |
164 | TVM_DEFINE_OBJECT_REF_METHODS(PStatic, ObjectRef, PStaticNode); |
165 | }; |
166 | |
167 | struct STupleNode : StaticNode { |
168 | std::vector<PStatic> fields; |
169 | explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) {} |
170 | static constexpr const char* _type_key = "relay.STuple" ; |
171 | TVM_DECLARE_FINAL_OBJECT_INFO(STupleNode, StaticNode); |
172 | }; |
173 | |
174 | class STuple : public Static { |
175 | public: |
176 | TVM_DEFINE_OBJECT_REF_METHODS(STuple, Static, STupleNode); |
177 | }; |
178 | |
179 | Static MkSTuple(const std::vector<PStatic>& fields) { |
180 | return Static(make_object<STupleNode>(fields)); |
181 | } |
182 | |
183 | struct STensorNode : StaticNode { |
184 | runtime::NDArray data; |
185 | explicit STensorNode(const NDArray& data) : data(data) {} |
186 | static constexpr const char* _type_key = "relay.STensor" ; |
187 | TVM_DECLARE_FINAL_OBJECT_INFO(STensorNode, StaticNode); |
188 | }; |
189 | |
190 | class STensor : public Static { |
191 | public: |
192 | TVM_DEFINE_OBJECT_REF_METHODS(STensor, Static, STensorNode); |
193 | }; |
194 | |
195 | Static MkSTensor(const NDArray& data) { return Static(make_object<STensorNode>(data)); } |
196 | |
197 | struct SConstructorNode : StaticNode { |
198 | Constructor constructor; |
199 | std::vector<PStatic> fields; |
200 | SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields) |
201 | : constructor(constructor), fields(fields) {} |
202 | static constexpr const char* _type_key = "relay.SConstructor" ; |
203 | TVM_DECLARE_FINAL_OBJECT_INFO(SConstructorNode, StaticNode); |
204 | }; |
205 | |
206 | class SConstructor : public Static { |
207 | public: |
208 | TVM_DEFINE_OBJECT_REF_METHODS(SConstructor, Static, SConstructorNode); |
209 | }; |
210 | |
211 | Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic>& fields) { |
212 | return Static(make_object<SConstructorNode>(constructor, fields)); |
213 | } |
214 | |
215 | struct SRefNode : StaticNode { |
216 | static constexpr const char* _type_key = "relay.SRef" ; |
217 | // we will use the address as the guid for hashing |
218 | TVM_DECLARE_FINAL_OBJECT_INFO(SRefNode, StaticNode); |
219 | }; |
220 | |
221 | class SRef : public Static { |
222 | public: |
223 | TVM_DEFINE_OBJECT_REF_METHODS(SRef, Static, SRefNode); |
224 | }; |
225 | |
226 | Static MkSRef() { return Static(make_object<SRefNode>()); } |
227 | |
228 | using Func = std::function<PStatic(const PStatic&, const std::vector<PStatic>&, const Attrs&, |
229 | const Array<Type>&, LetList*)>; |
230 | |
231 | struct SFuncNode : StaticNode { |
232 | Func func; |
233 | explicit SFuncNode(const Func& func) : func(func) {} |
234 | static constexpr const char* _type_key = "relay.SFunc" ; |
235 | TVM_DECLARE_FINAL_OBJECT_INFO(SFuncNode, StaticNode); |
236 | }; |
237 | |
238 | class SFunc : public Static { |
239 | public: |
240 | TVM_DEFINE_OBJECT_REF_METHODS(SFunc, Static, SFuncNode); |
241 | }; |
242 | |
243 | Static MkSFunc(const Func& func) { return Static(make_object<SFuncNode>(func)); } |
244 | |
245 | class FuelNode; |
246 | /*! \brief A meet-semilattice with finite descending chain. |
247 | * It means that we can meet two element to get an element, |
248 | * and for every element, there is only a finite amount of meet before getting back the same |
249 | * element. |
250 | * |
251 | * Every time we recurse, we do a meet and require that progress must be made. |
252 | * This ensures we do not recurse infinitely in the Partial Evaluator. |
253 | */ |
254 | class Fuel : public ObjectRef { |
255 | public: |
256 | Fuel() {} |
257 | explicit Fuel(ObjectPtr<Object> n) : ObjectRef(n) {} |
258 | const FuelNode* operator->() const; |
259 | |
260 | using ContainerType = FuelNode; |
261 | }; |
262 | |
263 | class FuelNode : public RelayNode { |
264 | public: |
265 | virtual ~FuelNode() {} |
266 | // Please implement one of the following function or there will be infinite loop. |
267 | /*! \brief return the new Fuel, and whether progress is made. |
268 | * |
269 | * Note that progress is not symmetric - it only measure progress for (*this). |
270 | * |
271 | * Thus, if the generated is smaller then the argument of Meet, |
272 | * and the generated is not smaller then (*this), |
273 | * progress should be false. |
274 | */ |
275 | virtual std::tuple<Fuel, bool> Meet(const Fuel& f) const { |
276 | bool progress = false; |
277 | auto ret = Meet(f, &progress); |
278 | return std::make_tuple(ret, progress); |
279 | } |
280 | /*! \brief return the new Fuel, and write (*progress | is progress made) to *progress. */ |
281 | virtual Fuel Meet(const Fuel& f, bool* progress) const { |
282 | ICHECK(progress); |
283 | auto ret = Meet(f); |
284 | *progress |= std::get<1>(ret); |
285 | return std::get<0>(ret); |
286 | } |
287 | static constexpr const char* _type_key = "relay.Fuel" ; |
288 | TVM_DECLARE_BASE_OBJECT_INFO(FuelNode, RelayNode); |
289 | }; |
290 | |
291 | const FuelNode* Fuel::operator->() const { return static_cast<const FuelNode*>(get()); } |
292 | |
293 | Fuel MkFSeq(const std::vector<Fuel>& fuels); |
294 | struct FSeqNode : FuelNode { |
295 | std::vector<Fuel> fuels; |
296 | Fuel Meet(const Fuel& f, bool* progress) const final { |
297 | auto x = f.as<FSeqNode>(); |
298 | ICHECK(x); |
299 | ICHECK_EQ(fuels.size(), x->fuels.size()); |
300 | std::vector<Fuel> new_fuels; |
301 | for (size_t i = 0; i < fuels.size(); ++i) { |
302 | new_fuels.push_back(fuels[i]->Meet(x->fuels[i], progress)); |
303 | } |
304 | return MkFSeq(new_fuels); |
305 | } |
306 | explicit FSeqNode(const std::vector<Fuel>& fuels) : fuels(fuels) {} |
307 | static constexpr const char* _type_key = "relay.FSeq" ; |
308 | TVM_DECLARE_FINAL_OBJECT_INFO(FSeqNode, FuelNode); |
309 | }; |
310 | |
311 | class FSeq : public Fuel { |
312 | public: |
313 | TVM_DEFINE_OBJECT_REF_METHODS(FSeq, Fuel, FSeqNode); |
314 | }; |
315 | |
316 | Fuel MkFSeq(const std::vector<Fuel>& fuels) { return Fuel(make_object<FSeqNode>(fuels)); } |
317 | |
318 | Fuel MkFTime(Time time); |
319 | struct FTimeNode : FuelNode { |
320 | Time time; |
321 | std::tuple<Fuel, bool> Meet(const Fuel& f) const final { |
322 | auto x = f.as<FTimeNode>(); |
323 | ICHECK(x); |
324 | Time new_time = std::min(time, x->time); |
325 | return std::make_tuple(MkFTime(new_time), new_time < time); |
326 | } |
327 | explicit FTimeNode(Time time) : time(time) {} |
328 | static constexpr const char* _type_key = "relay.FTime" ; |
329 | TVM_DECLARE_FINAL_OBJECT_INFO(FTimeNode, FuelNode); |
330 | }; |
331 | |
332 | class FTime : public Fuel { |
333 | public: |
334 | TVM_DEFINE_OBJECT_REF_METHODS(FTime, Fuel, FTimeNode); |
335 | }; |
336 | |
337 | Fuel MkFTime(Time time) { return Fuel(make_object<FTimeNode>(time)); } |
338 | |
339 | Fuel MkFTValue(size_t tvalue); |
340 | /*! \brief If the pstatic is hold a positive integer scalar, that number, else 0. */ |
341 | struct FTValueNode : FuelNode { |
342 | size_t tvalue; |
343 | std::tuple<Fuel, bool> Meet(const Fuel& f) const final { |
344 | auto x = f.as<FTValueNode>(); |
345 | ICHECK(x); |
346 | size_t new_tvalue = std::min(tvalue, x->tvalue); |
347 | return std::make_tuple(MkFTValue(new_tvalue), new_tvalue < tvalue); |
348 | } |
349 | explicit FTValueNode(size_t tvalue) : tvalue(tvalue) {} |
350 | static constexpr const char* _type_key = "relay.FTValue" ; |
351 | TVM_DECLARE_FINAL_OBJECT_INFO(FTValueNode, FuelNode); |
352 | }; |
353 | |
354 | class FTValue : public Fuel { |
355 | public: |
356 | TVM_DEFINE_OBJECT_REF_METHODS(FTValue, Fuel, FTValueNode); |
357 | }; |
358 | |
359 | Fuel MkFTValue(size_t tvalue) { return Fuel(make_object<FTValueNode>(tvalue)); } |
360 | |
361 | /*! \brief Initially every element has Fuel of FTop. It is the largest element. |
362 | * |
363 | * Note that it is illegal to has FTop inside some other Fuel - |
364 | * doing so break the finite descending chain property. |
365 | */ |
366 | struct FTopNode : FuelNode { |
367 | std::tuple<Fuel, bool> Meet(const Fuel& f) const final { |
368 | return std::make_tuple(f, !f.as<FTopNode>()); |
369 | } |
370 | static constexpr const char* _type_key = "relay.FTop" ; |
371 | TVM_DECLARE_FINAL_OBJECT_INFO(FTopNode, FuelNode); |
372 | }; |
373 | |
374 | class FTop : public Fuel { |
375 | public: |
376 | TVM_DEFINE_OBJECT_REF_METHODS(FTop, Fuel, FTopNode); |
377 | }; |
378 | |
379 | Fuel MkFTop() { return Fuel(make_object<FTopNode>()); } |
380 | |
381 | /*! |
382 | * \brief A stack frame in the Relay interpreter. |
383 | * |
384 | * Contains a mapping from relay::Var to relay::Object. |
385 | */ |
386 | struct Frame { |
387 | /*! \brief The set of local variables and arguments for the frame. */ |
388 | std::unordered_map<Var, PStatic, VarHash, VarEqual> locals; |
389 | Frame() = default; |
390 | }; |
391 | |
392 | class Environment { |
393 | public: |
394 | Environment() : env_({Frame()}) {} |
395 | Environment(const Environment&) = delete; |
396 | |
397 | template <typename T> |
398 | T Extend(const std::function<T()>& body) { |
399 | FrameContext fc(this); |
400 | return body(); |
401 | } |
402 | |
403 | void Insert(const Var& v, const PStatic& ps) { |
404 | ICHECK(ps.defined()); |
405 | ICHECK_GT(env_.size(), 0); |
406 | ICHECK_EQ(env_.back().locals.count(v), 0); |
407 | env_.back().locals[v] = ps; |
408 | } |
409 | |
410 | PStatic Lookup(const Var& v) { |
411 | auto rit = env_.rbegin(); |
412 | while (rit != env_.rend()) { |
413 | if (rit->locals.find(v) != rit->locals.end()) { |
414 | return rit->locals.find(v)->second; |
415 | } |
416 | ++rit; |
417 | } |
418 | LOG(FATAL) << "Unknown Variable: " << v; |
419 | throw; |
420 | } |
421 | |
422 | private: |
423 | std::list<Frame> env_; |
424 | |
425 | struct FrameContext { |
426 | Environment* env_; |
427 | explicit FrameContext(Environment* env) : env_(env) { env_->env_.push_back(Frame()); } |
428 | ~FrameContext() { env_->env_.pop_back(); } |
429 | }; |
430 | }; |
431 | |
432 | /*! |
433 | * \brief As our store require rollback, we implement it as a frame. |
434 | * |
435 | * Every time we need to copy the store, a new frame is insert. |
436 | * Every time we roll back, a frame is popped. |
437 | */ |
438 | struct StoreFrame { |
439 | std::unordered_map<const SRefNode*, PStatic> store; |
440 | /*! |
441 | * \brief On unknown effect, history_valid is set to true to signal above frame is outdated. |
442 | * |
443 | * It only outdate the frame above it, but not the current frame. |
444 | */ |
445 | bool history_valid = true; |
446 | explicit StoreFrame(const std::unordered_map<const SRefNode*, PStatic>& store) : store(store) {} |
447 | StoreFrame() = default; |
448 | }; |
449 | |
450 | class Store { |
451 | public: |
452 | Store() : store_({StoreFrame()}) {} |
453 | Store(const Store&) = delete; |
454 | |
455 | template <typename T> |
456 | T Extend(const std::function<T()>& body) { |
457 | StoreFrameContext sfc(this); |
458 | return body(); |
459 | } |
460 | |
461 | void Insert(const SRefNode* r, const PStatic& ps) { |
462 | ICHECK(r); |
463 | store_.back().store[r] = ps; |
464 | } |
465 | |
466 | // return null if not found |
467 | PStatic Lookup(const SRefNode* r) { |
468 | auto rit = store_.rbegin(); |
469 | while (rit != store_.rend()) { |
470 | if (rit->store.find(r) != rit->store.end()) { |
471 | return rit->store.find(r)->second; |
472 | } |
473 | if (!rit->history_valid) { |
474 | return PStatic(); |
475 | } |
476 | ++rit; |
477 | } |
478 | return PStatic(); |
479 | } |
480 | |
481 | void Invalidate() { |
482 | StoreFrame sf; |
483 | sf.history_valid = false; |
484 | store_.push_back(sf); |
485 | } |
486 | |
487 | private: |
488 | std::list<StoreFrame> store_; |
489 | |
490 | struct StoreFrameContext { |
491 | Store* store_; |
492 | explicit StoreFrameContext(Store* store) : store_(store) { |
493 | store_->store_.push_back(StoreFrame()); |
494 | } |
495 | ~StoreFrameContext() { |
496 | // push one history valid frame off. |
497 | while (!store_->store_.back().history_valid) { |
498 | store_->store_.pop_back(); |
499 | } |
500 | store_->store_.pop_back(); |
501 | } |
502 | }; |
503 | }; |
504 | |
505 | PStatic HasStatic(const Static& stat, const Expr& dynamic) { |
506 | ICHECK(stat.defined()); |
507 | return PStatic(make_object<PStaticNode>(stat, dynamic)); |
508 | } |
509 | |
510 | PStatic NoStatic(const Expr& dynamic) { return PStatic(make_object<PStaticNode>(dynamic)); } |
511 | |
512 | enum struct MatchStatus { Match, NoMatch, Unknown }; |
513 | |
514 | bool StatefulOp(const Expr& e) { |
515 | static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful" ); |
516 | struct StatefulOpVisitor : ExprVisitor { |
517 | bool stateful = false; |
518 | void VisitExpr_(const OpNode* op) { |
519 | stateful = stateful || op_stateful.get(GetRef<Op>(op), false); |
520 | } |
521 | }; |
522 | StatefulOpVisitor sov; |
523 | sov(e); |
524 | return sov.stateful; |
525 | } |
526 | |
527 | using FInterpreter = runtime::TypedPackedFunc<ObjectRef(Expr)>; |
528 | |
529 | Target CPUTarget() { return Target("llvm" ); } |
530 | |
531 | Device CPUDevice() { |
532 | Device dev; |
533 | dev.device_type = kDLCPU; |
534 | dev.device_id = 0; |
535 | return dev; |
536 | } |
537 | |
538 | using FuncId = int; |
539 | |
540 | /*! |
541 | * \brief Annotate a function with a FuncId. |
542 | */ |
543 | struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> { |
544 | FuncId fid; |
545 | |
546 | TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs" ) { |
547 | TVM_ATTR_FIELD(fid).describe("The FuncId that an function is annotated with." ).set_default(-1); |
548 | } |
549 | }; |
550 | |
551 | TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); |
552 | |
553 | RELAY_REGISTER_OP("annotation.with_funcid" ) |
554 | .describe(R"code(Annotate a function with a funcid.)code" TVM_ADD_FILELINE) |
555 | .set_num_inputs(1) |
556 | .add_argument("func" , "Function" , "The input data." ); |
557 | |
558 | // Cache with_funcid op to reduce lookup overhead during traversal. |
559 | static const Op& with_funcid_op = Op::Get("annotation.with_funcid" ); |
560 | |
561 | Expr MkWithFuncId(const Expr& expr, FuncId fid) { |
562 | auto attrs = make_object<WithFuncIdAttrs>(); |
563 | attrs->fid = fid; |
564 | return Call(with_funcid_op, {expr}, Attrs(attrs), {}); |
565 | } |
566 | |
567 | Expr StripWithFuncId(const Expr& e); |
568 | |
569 | Function AsFunc(const Expr& e) { |
570 | if (e.as<FunctionNode>()) { |
571 | return Downcast<Function>(e); |
572 | } else if (const CallNode* c = e.as<CallNode>()) { |
573 | ICHECK(c->op == with_funcid_op); |
574 | ICHECK_EQ(c->args.size(), 1); |
575 | return AsFunc(c->args[0]); |
576 | } else { |
577 | LOG(FATAL) << "Unknown case" ; |
578 | throw; |
579 | } |
580 | } |
581 | |
582 | class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>, |
583 | public PatternFunctor<MatchStatus(const Pattern&, const PStatic&)> { |
584 | public: |
585 | PartialEvaluator(const IRModule& mod) : mod_(mod) {} |
586 | |
587 | PStatic VisitExpr(const Expr& e, LetList* ll) final { |
588 | PStatic ret = ExprFunctor<PStatic(const Expr&, LetList*)>::VisitExpr(e, ll); |
589 | ICHECK(IsAtomic(ret->dynamic)) << ret->dynamic; |
590 | return ret; |
591 | } |
592 | |
593 | PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) { |
594 | if (const CallNode* c = e.as<CallNode>()) { |
595 | if (c->op == with_funcid_op) { |
596 | ICHECK_EQ(c->args.size(), 1); |
597 | return VisitExpr(c->args[0], ll, name); |
598 | } |
599 | } |
600 | PStatic ret = |
601 | e.as<FunctionNode>() ? VisitFunc(Downcast<Function>(e), ll, name) : VisitExpr(e, ll); |
602 | ICHECK(IsAtomic(ret->dynamic)) << ret->dynamic; |
603 | return ret; |
604 | } |
605 | |
606 | PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { |
607 | return HasStatic(MkSTensor(op->data.CopyTo(device_)), ll->Push(GetRef<Expr>(op))); |
608 | } |
609 | |
610 | PStatic VisitExpr_(const TupleNode* op, LetList* ll) final { |
611 | std::vector<PStatic> value; |
612 | tvm::Array<Expr> expr; |
613 | for (const Expr& e : op->fields) { |
614 | PStatic ps = VisitExpr(e, ll); |
615 | value.push_back(ps); |
616 | expr.push_back(ps->dynamic); |
617 | } |
618 | // Note: The partial evaluator seems to do some weird stuff with sharing. Changing Tuple(expr) |
619 | // to WithFields(op, expr) causes failures in the partial evaluator tests. |
620 | return HasStatic(MkSTuple(value), ll->Push(Tuple(expr))); |
621 | } |
622 | |
623 | PStatic VisitExpr_(const TupleGetItemNode* op, LetList* ll) final { |
624 | PStatic ps = VisitExpr(op->tuple, ll); |
625 | if (ps->pstatic.defined()) { |
626 | return Downcast<STuple>(ps->pstatic)->fields[op->index]; |
627 | } else { |
628 | return NoStatic(ll->Push(TupleGetItem(ps->dynamic, op->index))); |
629 | } |
630 | } |
631 | |
632 | PStatic VisitExpr_(const VarNode* op, LetList* ll) final { return env_.Lookup(GetRef<Var>(op)); } |
633 | |
634 | PStatic VisitGlobalVar(const GlobalVar& gv) { |
635 | ICHECK(mod_.defined()); |
636 | if (gv_map_.count(gv) == 0) { |
637 | BaseFunc base_func = mod_->Lookup(gv); |
638 | if (auto* n = base_func.as<FunctionNode>()) { |
639 | Function func = GetRef<Function>(n); |
640 | InitializeFuncId(func); |
641 | Func f = VisitFuncStatic(func, gv); |
642 | gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); |
643 | func = AsFunc(PostProcess(VisitFuncDynamic(func, f, gv))); |
644 | mod_->Update(gv, func); |
645 | return gv_map_.at(gv); |
646 | } else { |
647 | return NoStatic(gv); |
648 | } |
649 | } |
650 | return gv_map_.at(gv); |
651 | } |
652 | |
653 | PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { |
654 | return VisitGlobalVar(GetRef<GlobalVar>(op)); |
655 | } |
656 | |
657 | PStatic VisitExpr_(const LetNode* op, LetList* ll) final { |
658 | env_.Insert(op->var, VisitExpr(op->value, ll, op->var)); |
659 | return VisitExpr(op->body, ll); |
660 | } |
661 | |
662 | PStatic VisitExpr_(const IfNode* op, LetList* ll) final { |
663 | PStatic c = VisitExpr(op->cond, ll); |
664 | if (c->pstatic.defined()) { |
665 | NDArray cpu_array = Downcast<STensor>(c->pstatic)->data.CopyTo(CPUDevice()); |
666 | ICHECK_EQ(DataType(cpu_array->dtype), DataType::Bool()); |
667 | if (reinterpret_cast<uint8_t*>(cpu_array->data)[0]) { |
668 | return VisitExpr(op->true_branch, ll); |
669 | } else { |
670 | return VisitExpr(op->false_branch, ll); |
671 | } |
672 | } else { |
673 | Expr t = store_.Extend<Expr>([&]() { |
674 | return LetList::With([&](LetList* ll) { return VisitExpr(op->true_branch, ll)->dynamic; }); |
675 | }); |
676 | Expr f = store_.Extend<Expr>([&]() { |
677 | return LetList::With([&](LetList* ll) { return VisitExpr(op->false_branch, ll)->dynamic; }); |
678 | }); |
679 | store_.Invalidate(); |
680 | return NoStatic(ll->Push(If(c->dynamic, t, f))); |
681 | } |
682 | } |
683 | |
684 | PStatic VisitExpr_(const RefCreateNode* op, LetList* ll) final { |
685 | PStatic ps = VisitExpr(op->value, ll); |
686 | Static r = MkSRef(); |
687 | store_.Insert(r.as<SRefNode>(), ps); |
688 | return HasStatic(r, ll->Push(RefCreate(ps->dynamic))); |
689 | } |
690 | |
691 | PStatic VisitExpr_(const RefWriteNode* op, LetList* ll) final { |
692 | PStatic r = VisitExpr(op->ref, ll); |
693 | PStatic v = VisitExpr(op->value, ll); |
694 | if (r->pstatic.defined()) { |
695 | store_.Insert(r->pstatic.as<SRefNode>(), v); |
696 | } else { |
697 | store_.Invalidate(); |
698 | } |
699 | return HasStatic(MkSTuple({}), ll->Push(RefWrite(r->dynamic, v->dynamic))); |
700 | } |
701 | |
702 | PStatic VisitExpr_(const RefReadNode* op, LetList* ll) final { |
703 | PStatic r = VisitExpr(op->ref, ll); |
704 | if (r->pstatic.defined()) { |
705 | PStatic ret = store_.Lookup(r->pstatic.as<SRefNode>()); |
706 | if (ret.defined()) { |
707 | return ret; |
708 | } |
709 | } |
710 | return NoStatic(ll->Push(RefRead(r->dynamic))); |
711 | } |
712 | |
713 | PStatic VisitExpr_(const CallNode* op, LetList* ll) final { |
714 | if (op->op == with_funcid_op) { |
715 | ICHECK_EQ(op->args.size(), 1); |
716 | return VisitExpr(op->args[0], ll); |
717 | } |
718 | PStatic f = VisitExpr(op->op, ll); |
719 | std::vector<PStatic> x; |
720 | tvm::Array<Expr> x_dyn; |
721 | for (const Expr& e : op->args) { |
722 | PStatic ps = VisitExpr(e, ll); |
723 | x.push_back(ps); |
724 | x_dyn.push_back(ps->dynamic); |
725 | } |
726 | if (f->pstatic.defined()) { |
727 | return Downcast<SFunc>(f->pstatic)->func(f, x, op->attrs, op->type_args, ll); |
728 | } else { |
729 | store_.Invalidate(); |
730 | return NoStatic(ll->Push(Call(f->dynamic, x_dyn, op->attrs, op->type_args))); |
731 | } |
732 | } |
733 | |
734 | struct FuelFrame { |
735 | PartialEvaluator* pe_; |
736 | FuncId fid_; |
737 | Fuel old_fuel; |
738 | FuelFrame(PartialEvaluator* pe, FuncId fid, const Fuel& new_fuel) : pe_(pe), fid_(fid) { |
739 | ICHECK_GT(pe_->fuel_map_.count(fid_), 0); |
740 | old_fuel = pe_->fuel_map_[fid_]; |
741 | pe_->fuel_map_[fid_] = new_fuel; |
742 | } |
743 | ~FuelFrame() { pe_->fuel_map_[fid_] = old_fuel; } |
744 | }; |
745 | |
746 | size_t GetFTValue(const PStatic& ps) { |
747 | if (ps->pstatic.defined()) { |
748 | if (auto* st = ps->pstatic.as<STensorNode>()) { |
749 | if (st->data.Shape().empty()) { |
750 | NDArray cpu_array = st->data.CopyTo(CPUDevice()); |
751 | DataType dtype = DataType(cpu_array->dtype); |
752 | if (dtype == DataType::Int(32)) { |
753 | return std::max<int32_t>(0, *static_cast<const int32_t*>(cpu_array->data)); |
754 | } else if (dtype == DataType::Int(64)) { |
755 | return std::max<int64_t>(0, *static_cast<const int64_t*>(cpu_array->data)); |
756 | } |
757 | } |
758 | } |
759 | } |
760 | return 0; |
761 | } |
762 | |
763 | Fuel GetFuel(const PStatic& ps) { |
764 | std::vector<Fuel> fuels; |
765 | fuels.push_back(MkFTime(ps->created_time)); |
766 | fuels.push_back(MkFTValue(GetFTValue(ps))); |
767 | return MkFSeq(fuels); |
768 | } |
769 | |
770 | Func VisitFuncStatic(const Function& func, const Expr& var) { |
771 | ICHECK(IsAtomic(var)); |
772 | if (func->HasNonzeroAttr(attr::kPrimitive)) { |
773 | return ConstEvaluateFunc(func); |
774 | } |
775 | std::vector<std::pair<Var, PStatic>> free_vars; |
776 | for (const auto& v : FreeVars(func)) { |
777 | if (v != var) { |
778 | free_vars.push_back(std::pair<Var, PStatic>(v, env_.Lookup(v))); |
779 | } |
780 | } |
781 | return [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs, |
782 | const tvm::Array<Type>& type_args, LetList* ll) { |
783 | return env_.Extend<PStatic>([&]() { |
784 | ICHECK_EQ(pv.size(), func->params.size()); |
785 | ICHECK_GT(func_map_.count(func), 0); |
786 | FuncId fid = func_map_.at(func); |
787 | if (fuel_map_.count(fid) == 0) { |
788 | fuel_map_.insert({fid, MkFTop()}); |
789 | } |
790 | std::vector<Fuel> args_fuel; |
791 | for (const auto& v : pv) { |
792 | args_fuel.push_back(GetFuel(v)); |
793 | } |
794 | auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel)); |
795 | if (std::get<1>(meet_res)) { |
796 | FuelFrame tf(this, fid, std::get<0>(meet_res)); |
797 | Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func))); |
798 | Function func = AsFunc(dedup_func); |
799 | if (var.as<VarNode>()) { |
800 | env_.Insert(Downcast<Var>(var), self); |
801 | } |
802 | for (size_t i = 0; i < pv.size(); ++i) { |
803 | env_.Insert(func->params[i], pv[i]); |
804 | } |
805 | for (const auto& p : free_vars) { |
806 | env_.Insert(p.first, p.second); |
807 | } |
808 | tvm::Map<TypeVar, Type> subst; |
809 | for (size_t i = 0; i < type_args.size(); ++i) { |
810 | subst.Set(func->type_params[i], type_args[i]); |
811 | } |
812 | for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { |
813 | subst.Set(func->type_params[i], IncompleteType(kType)); |
814 | } |
815 | return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); |
816 | } else { |
817 | std::vector<Expr> dyn; |
818 | for (const auto& v : pv) { |
819 | dyn.push_back(v->dynamic); |
820 | } |
821 | return NoStatic(ll->Push(Call(var, dyn, attrs, type_args))); |
822 | } |
823 | }); |
824 | }; |
825 | } |
826 | |
827 | Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { |
828 | return store_.Extend<Expr>([&]() { |
829 | store_.Invalidate(); |
830 | return WithFields( |
831 | func, func->params, LetList::With([&](LetList* ll) { |
832 | std::vector<PStatic> pv; |
833 | for (const auto& v : func->params) { |
834 | pv.push_back(NoStatic(v)); |
835 | } |
836 | tvm::Array<Type> type_args; |
837 | for (const auto& tp : func->type_params) { |
838 | type_args.push_back(tp); |
839 | } |
840 | return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; |
841 | })); |
842 | }); |
843 | } |
844 | |
845 | PStatic VisitFunc(const Function& func, LetList* ll, const Var& name) { |
846 | Func f = VisitFuncStatic(func, name); |
847 | Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func)))); |
848 | // TODO(@M.K.): we seems to reduce landin knot into letrec. |
849 | // restore letrec support across whole relay. |
850 | return HasStatic(MkSFunc(f), ll->Push(name, VisitFuncDynamic(u_func, f, name))); |
851 | } |
852 | |
853 | PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { |
854 | return VisitFunc(GetRef<Function>(op), ll, Var::GenSym()); |
855 | } |
856 | |
857 | struct ReflectError : Error { |
858 | ReflectError() : Error("static value not found" ) {} |
859 | }; |
860 | |
861 | Expr Reflect(const PStatic& st) { |
862 | if (!st->pstatic.defined()) { |
863 | throw ReflectError(); |
864 | } else if (const STensorNode* op = st->pstatic.as<STensorNode>()) { |
865 | return Constant(op->data); |
866 | } else if (const STupleNode* op = st->pstatic.as<STupleNode>()) { |
867 | tvm::Array<Expr> fields; |
868 | for (const PStatic& field : op->fields) { |
869 | fields.push_back(Reflect(field)); |
870 | } |
871 | return Tuple(fields); |
872 | } else { |
873 | LOG(FATAL) << "Unknown case: " << st->dynamic; |
874 | throw; |
875 | } |
876 | } |
877 | |
878 | PStatic Reify(const ObjectRef& v, LetList* ll) const { |
879 | if (v->IsInstance<runtime::NDArray::ContainerType>()) { |
880 | auto nd_array = Downcast<runtime::NDArray>(v); |
881 | return HasStatic(MkSTensor(nd_array), ll->Push(Constant(nd_array))); |
882 | } else if (const runtime::ADTObj* op = v.as<runtime::ADTObj>()) { |
883 | std::vector<PStatic> fields; |
884 | tvm::Array<Expr> fields_dyn; |
885 | auto adt = GetRef<runtime::ADT>(op); |
886 | for (size_t i = 0; i < adt.size(); ++i) { |
887 | PStatic ps = Reify(adt[i], ll); |
888 | fields.push_back(ps); |
889 | fields_dyn.push_back(ps->dynamic); |
890 | } |
891 | return HasStatic(MkSTuple(fields), ll->Push(Tuple(fields_dyn))); |
892 | } else { |
893 | LOG(FATAL) << "Unknown case" ; |
894 | throw; |
895 | } |
896 | } |
897 | |
898 | // Constant evaluate an expression. |
899 | PStatic ConstEvaluate(const Expr& expr, LetList* ll) { |
900 | // use a fresh build context in case we are already in a build context. |
901 | With<transform::PassContext> fresh_build_ctx(transform::PassContext::Create()); |
902 | return Reify(Eval(expr, mod_->type_definitions, mod_->Imports(), CPUDevice(), CPUTarget()), ll); |
903 | } |
904 | |
905 | Func ConstEvaluateFunc(const Expr& expr) { |
906 | ICHECK_EQ(FreeVars(expr).size(), 0); |
907 | return [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs, |
908 | const tvm::Array<Type>& type_args, LetList* ll) { |
909 | tvm::Array<Expr> ns_args; |
910 | for (const PStatic& ps : pv) { |
911 | ns_args.push_back(ps->dynamic); |
912 | } |
913 | auto ns = [&]() { return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args))); }; |
914 | if (StatefulOp(expr)) { |
915 | return ns(); |
916 | } |
917 | try { |
918 | tvm::Array<Expr> args; |
919 | for (const PStatic& ps : pv) { |
920 | args.push_back(Reflect(ps)); |
921 | } |
922 | return ConstEvaluate(Call(expr, args, attrs, type_args), ll); |
923 | } catch (const ReflectError&) { |
924 | return ns(); |
925 | } |
926 | }; |
927 | } |
928 | |
929 | PStatic VisitExpr_(const OpNode* op, LetList* ll) final { |
930 | return HasStatic(MkSFunc(ConstEvaluateFunc(GetRef<Expr>(op))), GetRef<Expr>(op)); |
931 | } |
932 | |
933 | PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final { |
934 | Constructor c = GetRef<Constructor>(op); |
935 | Func f = [=](const PStatic& self, const std::vector<PStatic>& pv, const Attrs& attrs, |
936 | const tvm::Array<Type>& type_args, LetList* ll) { |
937 | tvm::Array<Expr> dyn; |
938 | for (const PStatic& ps : pv) { |
939 | dyn.push_back(ps->dynamic); |
940 | } |
941 | return HasStatic(MkSConstructor(c, pv), ll->Push(Call(c, dyn))); |
942 | }; |
943 | return HasStatic(MkSFunc(f), GetRef<Expr>(op)); |
944 | } |
945 | |
946 | PStatic VisitExpr_(const MatchNode* op, LetList* ll) final { |
947 | PStatic ps = VisitExpr(op->data, ll); |
948 | return env_.Extend<PStatic>([&]() { |
949 | for (const Clause& c : op->clauses) { |
950 | switch (VisitPattern(c->lhs, ps)) { |
951 | case MatchStatus::Match: |
952 | return VisitExpr(c->rhs, ll); |
953 | case MatchStatus::NoMatch: |
954 | continue; |
955 | case MatchStatus::Unknown: |
956 | return [&]() { |
957 | tvm::Array<Clause> clauses; |
958 | for (const Clause& c : op->clauses) { |
959 | Expr expr = store_.Extend<Expr>([&]() { |
960 | return LetList::With([&](LetList* ll) { |
961 | for (const Var& v : BoundVars(c->lhs)) { |
962 | env_.Insert(v, NoStatic(v)); |
963 | } |
964 | return VisitExpr(c->rhs, ll)->dynamic; |
965 | }); |
966 | }); |
967 | clauses.push_back(Clause(c->lhs, expr)); |
968 | } |
969 | store_.Invalidate(); |
970 | return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete))); |
971 | }(); |
972 | default: |
973 | LOG(FATAL) << "Unknown MatchStatus" ; |
974 | throw; |
975 | } |
976 | } |
977 | LOG(FATAL) << "No case Match" ; |
978 | throw; |
979 | }); |
980 | } |
981 | |
982 | MatchStatus VisitPattern_(const PatternWildcardNode* op, const PStatic& ps) final { |
983 | return MatchStatus::Match; |
984 | } |
985 | |
986 | MatchStatus VisitPattern_(const PatternVarNode* op, const PStatic& ps) final { |
987 | env_.Insert(op->var, ps); |
988 | return MatchStatus::Match; |
989 | } |
990 | |
991 | MatchStatus VisitPattern_(const PatternConstructorNode* op, const PStatic& ps) final { |
992 | if (ps->pstatic.defined()) { |
993 | SConstructor scn = Downcast<SConstructor>(ps->pstatic); |
994 | ICHECK_NE(op->constructor->tag, -1); |
995 | ICHECK_NE(scn->constructor->tag, -1); |
996 | if (op->constructor->tag == scn->constructor->tag) { |
997 | ICHECK_EQ(op->patterns.size(), scn->fields.size()); |
998 | MatchStatus current_match_status = MatchStatus::Match; |
999 | for (size_t i = 0; i < op->patterns.size(); ++i) { |
1000 | MatchStatus ms = VisitPattern(op->patterns[i], scn->fields[i]); |
1001 | switch (ms) { |
1002 | case MatchStatus::Match: |
1003 | continue; |
1004 | case MatchStatus::NoMatch: |
1005 | return MatchStatus::NoMatch; |
1006 | case MatchStatus::Unknown: |
1007 | current_match_status = MatchStatus::Unknown; |
1008 | } |
1009 | } |
1010 | return current_match_status; |
1011 | } |
1012 | return MatchStatus::NoMatch; |
1013 | } else { |
1014 | return MatchStatus::Unknown; |
1015 | } |
1016 | } |
1017 | |
1018 | MatchStatus VisitPattern_(const PatternTupleNode* op, const PStatic& ps) final { |
1019 | if (ps->pstatic.defined()) { |
1020 | STuple stn = Downcast<STuple>(ps->pstatic); |
1021 | ICHECK_EQ(op->patterns.size(), stn->fields.size()); |
1022 | MatchStatus current_match_status = MatchStatus::Match; |
1023 | for (size_t i = 0; i < op->patterns.size(); ++i) { |
1024 | MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]); |
1025 | switch (ms) { |
1026 | case MatchStatus::Match: |
1027 | continue; |
1028 | case MatchStatus::NoMatch: |
1029 | return MatchStatus::NoMatch; |
1030 | case MatchStatus::Unknown: |
1031 | current_match_status = MatchStatus::Unknown; |
1032 | } |
1033 | } |
1034 | return current_match_status; |
1035 | } else { |
1036 | return MatchStatus::Unknown; |
1037 | } |
1038 | } |
1039 | |
1040 | void InitializeFuncId(const Expr& e) { |
1041 | struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor { |
1042 | PartialEvaluator* pe; |
1043 | explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {} |
1044 | |
1045 | void VisitExpr_(const FunctionNode* op) final { |
1046 | Function f = GetRef<Function>(op); |
1047 | ICHECK_EQ(pe->func_map_.count(f), 0); |
1048 | pe->func_map_.insert({f, pe->func_map_.size()}); |
1049 | VisitExpr(f->body); |
1050 | } |
1051 | |
1052 | void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } |
1053 | }; |
1054 | InitializeFuncIdVisitor(this).VisitExpr(e); |
1055 | } |
1056 | |
1057 | Expr RegisterFuncId(const Expr& e) { |
1058 | struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor { |
1059 | PartialEvaluator* pe; |
1060 | explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {} |
1061 | |
1062 | void VisitExpr_(const CallNode* op) final { |
1063 | if (op->op == with_funcid_op) { |
1064 | ICHECK_EQ(op->args.size(), 1); |
1065 | ICHECK(op->attrs.defined()); |
1066 | ICHECK(op->attrs.as<WithFuncIdAttrs>()); |
1067 | Function f = AsFunc(op->args[0]); |
1068 | FuncId fid = op->attrs.as<WithFuncIdAttrs>()->fid; |
1069 | if (pe->func_map_.count(f) != 0) { |
1070 | ICHECK_EQ(pe->func_map_.at(f), fid); |
1071 | } |
1072 | pe->func_map_.insert({f, fid}); |
1073 | } |
1074 | ExprVisitor::VisitExpr_(op); |
1075 | } |
1076 | |
1077 | void VisitExpr_(const FunctionNode* op) final { |
1078 | Function f = GetRef<Function>(op); |
1079 | ICHECK_GT(pe->func_map_.count(f), 0); |
1080 | ExprVisitor::VisitExpr_(op); |
1081 | } |
1082 | |
1083 | void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } |
1084 | }; |
1085 | RegisterFuncIdVisitor(this).VisitExpr(e); |
1086 | return e; |
1087 | } |
1088 | |
1089 | Expr AnnotateFuncId(const Expr& e) { |
1090 | struct AnnotateFuncIdMutator : ExprMutator, PatternMutator { |
1091 | PartialEvaluator* pe; |
1092 | explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) {} |
1093 | |
1094 | Expr VisitExpr_(const FunctionNode* op) final { |
1095 | Function f = GetRef<Function>(op); |
1096 | ICHECK_GT(pe->func_map_.count(f), 0); |
1097 | return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f)); |
1098 | } |
1099 | |
1100 | Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } |
1101 | |
1102 | Var VisitVar(const Var& v) final { return v; } |
1103 | }; |
1104 | return AnnotateFuncIdMutator(this).VisitExpr(e); |
1105 | } |
1106 | |
1107 | private: |
1108 | Environment env_; |
1109 | IRModule mod_; |
1110 | std::unordered_map<GlobalVar, PStatic, ObjectPtrHash, ObjectPtrEqual> gv_map_; |
1111 | /*! Termination checking is done as follows: |
1112 | * We have finitely many FunctionIds. |
1113 | * Each FunctionId maps to a class of semantically equivalent function (ignoring type), |
1114 | * as both TypeSubst and DeDup create semantically equivalent function. |
1115 | * We partially map each FunctionId to a Fuel. |
1116 | * Every time we try to inline a Function, |
1117 | * we make sure it either does not have a Fuel, |
1118 | * or we meet the existing fuel with the fuel calculated from the argument. |
1119 | * If no progress is made, we do not inline. |
1120 | * In both case, we remap the mapping to the new Fuel |
1121 | * when we PE inside the Function body. |
1122 | * Termination is guaranteed because Fuel is finitely descending - there can only be so many |
1123 | * meet. |
1124 | */ |
1125 | std::unordered_map<Function, FuncId, ObjectPtrHash, ObjectPtrEqual> func_map_; |
1126 | std::unordered_map<FuncId, Fuel> fuel_map_; |
1127 | Store store_; |
1128 | Device device_ = CPUDevice(); |
1129 | }; |
1130 | |
1131 | /*! \brief Remap multiple Var sharing the same Id into the same Var. */ |
1132 | Expr Remap(const Expr& e) { |
1133 | class RemapMutator : public ExprMutator, public PatternMutator { |
1134 | Expr VisitExpr_(const VarNode* op) final { |
1135 | Var v = GetRef<Var>(op); |
1136 | if (remap_.count(v) == 0) { |
1137 | remap_.insert({v, v}); |
1138 | } |
1139 | return remap_.at(v); |
1140 | } |
1141 | |
1142 | Var VisitVar(const Var& v) final { return Downcast<Var>(VisitExpr(v)); } |
1143 | |
1144 | private: |
1145 | std::unordered_map<Var, Var, VarHash, VarEqual> remap_; |
1146 | }; |
1147 | return RemapMutator().VisitExpr(e); |
1148 | } |
1149 | |
1150 | Expr StripWithFuncId(const Expr& e) { |
1151 | struct StripWithFuncIdMutator : ExprMutator, PatternMutator { |
1152 | Expr VisitExpr_(const CallNode* op) final { |
1153 | if (op->op == with_funcid_op) { |
1154 | ICHECK_EQ(op->args.size(), 1); |
1155 | return VisitExpr(op->args[0]); |
1156 | } else { |
1157 | return ExprMutator::VisitExpr_(op); |
1158 | } |
1159 | } |
1160 | |
1161 | Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } |
1162 | |
1163 | Var VisitVar(const Var& v) final { return v; } |
1164 | }; |
1165 | return StripWithFuncIdMutator().VisitExpr(e); |
1166 | } |
1167 | |
1168 | Expr PostProcess(const Expr& e) { return StripWithFuncId(DeDup(Remap(e))); } |
1169 | |
1170 | } // namespace partial_eval |
1171 | |
1172 | IRModule PartialEval(const IRModule& m) { |
1173 | CheckFeature(m, FeatureSet::All() - fGraph); |
1174 | relay::partial_eval::PartialEvaluator pe(m); |
1175 | std::vector<GlobalVar> gvs; |
1176 | for (const auto& p : m->functions) { |
1177 | gvs.push_back(p.first); |
1178 | } |
1179 | for (const auto& gv : gvs) { |
1180 | pe.VisitGlobalVar(gv); |
1181 | } |
1182 | CheckFeature(m, FeatureSet::All() - fGraph); |
1183 | return m; |
1184 | } |
1185 | |
1186 | namespace transform { |
1187 | |
1188 | Pass PartialEval() { |
1189 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = |
1190 | [=](IRModule m, PassContext pc) { return relay::PartialEval(m); }; |
1191 | return CreateModulePass(pass_func, 1, "PartialEval" , {}); |
1192 | } |
1193 | |
1194 | TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate" ).set_body_typed(PartialEval); |
1195 | |
1196 | } // namespace transform |
1197 | |
1198 | } // namespace relay |
1199 | } // namespace tvm |
1200 | |