1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file let_list.h |
22 | * \brief LetList record let binding and insert let expression implicitly. |
23 | * using it, one can treat AST as value instead of expression, |
24 | * and pass them around freely without fear of AST explosion (or effect duplication). |
25 | * for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'. |
26 | * if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);', |
27 | * the AST will contain 2 'a', as b and c are now variables. |
28 | */ |
29 | #ifndef TVM_RELAY_TRANSFORMS_LET_LIST_H_ |
30 | #define TVM_RELAY_TRANSFORMS_LET_LIST_H_ |
31 | |
32 | #include <tvm/relay/analysis.h> |
33 | #include <tvm/relay/expr.h> |
34 | |
35 | #include <string> |
36 | #include <tuple> |
37 | #include <utility> |
38 | #include <vector> |
39 | |
40 | #include "tvm/relay/type.h" |
41 | |
42 | namespace tvm { |
43 | namespace relay { |
44 | |
45 | /*! |
46 | * \brief LetList allow you to transform expression into variables, so you can copy them around. |
47 | * one can insert into the LetList by calling Push, and wrap an expression with bindings with Get. |
48 | * additionally, there is the 'With' function, which automatically call Get. |
49 | */ |
50 | class LetList { |
51 | public: |
52 | ~LetList() { |
53 | if (lets_.size() > 0 && !used_) { |
54 | LOG(WARNING) << "letlist not used" ; |
55 | } |
56 | } |
57 | /*! |
58 | * \brief insert a binding. |
59 | * |
60 | * \param pv the var of the binding. |
61 | * |
62 | * \param expr the value of the binding. |
63 | * |
64 | * \return a Var that hold the inserted expr. |
65 | */ |
66 | Var Push(Var pv, Expr expr) { |
67 | ICHECK(!used_); |
68 | ICHECK(WellFormed(expr)) << "expression:" << std::endl << PrettyPrint(expr); |
69 | lets_.emplace_back(std::make_pair(pv, expr)); |
70 | return pv; |
71 | } |
72 | |
73 | /*! |
74 | * \brief insert a binding. |
75 | * |
76 | * \param expr the value of the binding. |
77 | * |
78 | * \param ty the type of the binding. |
79 | * |
80 | * \return a Var that hold the inserted expr. |
81 | */ |
82 | Var Push(Expr expr, Type ty) { return Push(Var::GenSym(ty), expr); } |
83 | |
84 | /*! |
85 | * \brief insert a binding. |
86 | * |
87 | * \param expr the value of the binding. |
88 | * |
89 | * \return a Var that hold the inserted expr. |
90 | */ |
91 | Var Push(Expr expr) { return Push(expr, Type()); } |
92 | |
93 | /*! |
94 | * \brief wrap an expr around the LetList. |
95 | * |
96 | * \param body the Expression to be wrapped around. |
97 | * |
98 | * \return the wrapped expr. |
99 | */ |
100 | Expr Get(const Expr& body) { |
101 | ICHECK(!used_); |
102 | Expr ret = body; |
103 | for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) { |
104 | ret = Let(std::get<0>(*rit), std::get<1>(*rit), ret); |
105 | } |
106 | used_ = true; |
107 | return ret; |
108 | } |
109 | |
110 | /*! \brief get the number of let bindings in the let list. |
111 | * |
112 | * \return the let list size. |
113 | */ |
114 | size_t size() const { return lets_.size(); } |
115 | |
116 | /*! \brief generate an LetList and wrap the result automatically. |
117 | * |
118 | * \param f a function that generate the unwrapped Expr. |
119 | * |
120 | * \code |
121 | * // Example code that generate `16 * a` using 4 plus instead of 15 plus. |
122 | * Expr mult_sixteen(const Var& a) { |
123 | * Op plus = Op::Get("plus"); |
124 | * // Automatically call Get with LetList::With |
125 | * return LetList::With([&](LetList* ll) { |
126 | * // Turn a call to plus into a variable to avoid duplication of code |
127 | * Var b = ll->Push(Call(plus, {a, a})); |
128 | * Var c = ll->Push(Call(plus, {b, b})); |
129 | * Var d = ll->Push(Callplus, {c, c})); |
130 | * return Call(plus, {d, d}); |
131 | * }); |
132 | * } |
133 | * \endcode |
134 | * |
135 | * \return the wrapped Expr. |
136 | */ |
137 | template <typename F> |
138 | static Expr With(F&& f) { |
139 | LetList ll; |
140 | return ll.Get(f(&ll)); |
141 | } |
142 | |
143 | static Expr LetBind(const Expr& e, const std::function<Expr(const Var&)>& f) { |
144 | return With([&](LetList* ll) { return f(ll->Push(e)); }); |
145 | } |
146 | |
147 | private: |
148 | std::vector<std::pair<Var, Expr>> lets_; |
149 | bool used_ = false; |
150 | }; |
151 | |
152 | } // namespace relay |
153 | } // namespace tvm |
154 | #endif // TVM_RELAY_TRANSFORMS_LET_LIST_H_ |
155 | |