1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file let_list.h
22 * \brief LetList record let binding and insert let expression implicitly.
23 * using it, one can treat AST as value instead of expression,
24 * and pass them around freely without fear of AST explosion (or effect duplication).
25 * for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'.
26 * if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);',
27 * the AST will contain 2 'a', as b and c are now variables.
28 */
29#ifndef TVM_RELAY_TRANSFORMS_LET_LIST_H_
30#define TVM_RELAY_TRANSFORMS_LET_LIST_H_
31
32#include <tvm/relay/analysis.h>
33#include <tvm/relay/expr.h>
34
35#include <string>
36#include <tuple>
37#include <utility>
38#include <vector>
39
40#include "tvm/relay/type.h"
41
42namespace tvm {
43namespace relay {
44
45/*!
46 * \brief LetList allow you to transform expression into variables, so you can copy them around.
47 * one can insert into the LetList by calling Push, and wrap an expression with bindings with Get.
48 * additionally, there is the 'With' function, which automatically call Get.
49 */
50class LetList {
51 public:
52 ~LetList() {
53 if (lets_.size() > 0 && !used_) {
54 LOG(WARNING) << "letlist not used";
55 }
56 }
57 /*!
58 * \brief insert a binding.
59 *
60 * \param pv the var of the binding.
61 *
62 * \param expr the value of the binding.
63 *
64 * \return a Var that hold the inserted expr.
65 */
66 Var Push(Var pv, Expr expr) {
67 ICHECK(!used_);
68 ICHECK(WellFormed(expr)) << "expression:" << std::endl << PrettyPrint(expr);
69 lets_.emplace_back(std::make_pair(pv, expr));
70 return pv;
71 }
72
73 /*!
74 * \brief insert a binding.
75 *
76 * \param expr the value of the binding.
77 *
78 * \param ty the type of the binding.
79 *
80 * \return a Var that hold the inserted expr.
81 */
82 Var Push(Expr expr, Type ty) { return Push(Var::GenSym(ty), expr); }
83
84 /*!
85 * \brief insert a binding.
86 *
87 * \param expr the value of the binding.
88 *
89 * \return a Var that hold the inserted expr.
90 */
91 Var Push(Expr expr) { return Push(expr, Type()); }
92
93 /*!
94 * \brief wrap an expr around the LetList.
95 *
96 * \param body the Expression to be wrapped around.
97 *
98 * \return the wrapped expr.
99 */
100 Expr Get(const Expr& body) {
101 ICHECK(!used_);
102 Expr ret = body;
103 for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
104 ret = Let(std::get<0>(*rit), std::get<1>(*rit), ret);
105 }
106 used_ = true;
107 return ret;
108 }
109
110 /*! \brief get the number of let bindings in the let list.
111 *
112 * \return the let list size.
113 */
114 size_t size() const { return lets_.size(); }
115
116 /*! \brief generate an LetList and wrap the result automatically.
117 *
118 * \param f a function that generate the unwrapped Expr.
119 *
120 * \code
121 * // Example code that generate `16 * a` using 4 plus instead of 15 plus.
122 * Expr mult_sixteen(const Var& a) {
123 * Op plus = Op::Get("plus");
124 * // Automatically call Get with LetList::With
125 * return LetList::With([&](LetList* ll) {
126 * // Turn a call to plus into a variable to avoid duplication of code
127 * Var b = ll->Push(Call(plus, {a, a}));
128 * Var c = ll->Push(Call(plus, {b, b}));
129 * Var d = ll->Push(Callplus, {c, c}));
130 * return Call(plus, {d, d});
131 * });
132 * }
133 * \endcode
134 *
135 * \return the wrapped Expr.
136 */
137 template <typename F>
138 static Expr With(F&& f) {
139 LetList ll;
140 return ll.Get(f(&ll));
141 }
142
143 static Expr LetBind(const Expr& e, const std::function<Expr(const Var&)>& f) {
144 return With([&](LetList* ll) { return f(ll->Push(e)); });
145 }
146
147 private:
148 std::vector<std::pair<Var, Expr>> lets_;
149 bool used_ = false;
150};
151
152} // namespace relay
153} // namespace tvm
154#endif // TVM_RELAY_TRANSFORMS_LET_LIST_H_
155