1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file lazy_gradient_init.cc |
23 | * |
24 | * \brief Lazily instantiate 0-filled or 1-filled tensors. |
25 | * This pass should be used after reverse-mode ad so that gradient tensors |
26 | * are not instantiated until after the forward pass. |
27 | * |
28 | * This pass delays or removes memory allocation by converting tensors into |
29 | * GradCell, an algebraic data type defined in gradient.rly. |
30 | * |
31 | * This will delay or decrease memory usage. All calls to |
32 | * ones, ones_like, zeros, zeros_like will call the One or Zero constructor |
33 | * of GradCell, which will not instantiate in memory until needed. All other cases result |
34 | * in using the Raw constructor which means the tensor is instantiated in memory. |
35 | * |
36 | * It also overloads + and * operation which can increase performance when doing |
37 | * operations involving tensors with values of only 0 or 1. |
38 | * |
39 | * Note: this pass can only be used with functions where the input/output types are |
40 | * a combination of TupleTypes and TensorTypes |
41 | * |
42 | * This pass optimizes 6 ops: |
43 | * - add |
44 | * - multiply |
45 | * - ones |
46 | * - ones_like |
47 | * - zeros |
48 | * - zeros_like |
49 | * |
50 | * This pass makes use of three visitor. The most important one visits the entire function, |
51 | * one is used for wrap inputs and one to unwrap outputs. |
52 | * |
53 | * For example: |
54 | * fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32] |
55 | * |
56 | * After this pass |
57 | * fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]] |
58 | * |
59 | * Thus, it is necessary to wrap this outer function so that the input/output types remain the same |
60 | */ |
61 | |
62 | #include <tvm/ir/type_functor.h> |
63 | #include <tvm/node/structural_equal.h> |
64 | #include <tvm/relay/analysis.h> |
65 | #include <tvm/relay/expr_functor.h> |
66 | #include <tvm/relay/feature.h> |
67 | #include <tvm/relay/transform.h> |
68 | |
69 | #include "let_list.h" |
70 | |
71 | namespace tvm { |
72 | namespace relay { |
73 | |
74 | class LazyGradientInitializer : public ExprMutator, public TypeMutator { |
75 | public: |
76 | explicit LazyGradientInitializer(IRModule module) : module_(module) { |
77 | module_->ImportFromStd("gradient.rly" ); |
78 | } |
79 | |
80 | Expr WrapExpr(const Var& var, const Type& type, LetList* ll) { |
81 | if (type.as<TensorTypeNode>()) { |
82 | return Call(module_->GetConstructor("GradCell" , "Raw" ), {var}, Attrs(), {type}); |
83 | } else if (auto* type_anno = type.as<TupleTypeNode>()) { |
84 | tvm::Array<Expr> fields; |
85 | for (size_t i = 0; i < type_anno->fields.size(); i++) { |
86 | const Type& t = type_anno->fields[i]; |
87 | fields.push_back(WrapExpr(ll->Push(TupleGetItem(var, i)), t, ll)); |
88 | } |
89 | Expr tuple = Tuple(fields); |
90 | return tuple; |
91 | } |
92 | |
93 | return var; |
94 | } |
95 | |
96 | Expr UnwrapExpr(const Var& var, const Type& type, LetList* ll) { |
97 | if (auto* type_call = type.as<TypeCallNode>()) { |
98 | if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell" ))) { |
99 | return Call(module_->GetGlobalVar("FromGradCell" ), {var}); |
100 | } |
101 | return var; |
102 | } else if (auto* type_anno = type.as<TupleTypeNode>()) { |
103 | tvm::Array<Expr> fields; |
104 | for (size_t i = 0; i < type_anno->fields.size(); i++) { |
105 | const Type& t = type_anno->fields[i]; |
106 | fields.push_back(UnwrapExpr(ll->Push(TupleGetItem(var, i)), t, ll)); |
107 | } |
108 | Expr tuple = Tuple(fields); |
109 | return tuple; |
110 | } |
111 | |
112 | return var; |
113 | } |
114 | |
115 | // Turn off memo for constant node. |
116 | Expr VisitExpr(const Expr& e) final { |
117 | if (e.as<ConstantNode>()) { |
118 | return ExprFunctor::VisitExpr(e); |
119 | } else { |
120 | return ExprMutator::VisitExpr(e); |
121 | } |
122 | } |
123 | |
124 | /*! |
125 | * \brief apply LazyGradientInit transformation and wrap function |
126 | * so that function type stays the same |
127 | * |
128 | * input/output types should only be a combination of TupleTypes and TensorTypes |
129 | */ |
130 | Expr Transform(const Expr& e) { |
131 | auto* f = e.as<FunctionNode>(); |
132 | auto* transformed = this->Mutate(e).as<FunctionNode>(); |
133 | |
134 | ICHECK(f); |
135 | ICHECK(transformed); |
136 | |
137 | if (e.same_as(GetRef<Function>(transformed))) { |
138 | return GetRef<Function>(transformed); |
139 | } |
140 | |
141 | auto tensorOutput = LetList::With([&](LetList* ll) { |
142 | // wrap inputs of Tensor type using InputVisitor class |
143 | tvm::Array<Expr> args; |
144 | for (const Var& var : f->params) { |
145 | args.push_back(WrapExpr(var, var->checked_type(), ll)); |
146 | } |
147 | Expr transformedExpr = Call(GetRef<Function>(transformed), args); |
148 | // unwrap outputs of GradCell type into Tensor type using OutputVisitor class |
149 | return UnwrapExpr(ll->Push(transformedExpr), transformed->ret_type, ll); |
150 | }); |
151 | return Function(f->params, tensorOutput, f->ret_type, Array<TypeVar>()); |
152 | } |
153 | |
154 | Expr VisitExpr_(const ConstantNode* op) final { |
155 | return Call(module_->GetConstructor("GradCell" , "Raw" ), {GetRef<Constant>(op)}, Attrs(), |
156 | {op->checked_type()}); |
157 | } |
158 | |
159 | Expr VisitExpr_(const CallNode* call_node) final { |
160 | if (auto* op = (call_node->op).as<OpNode>()) { |
161 | Expr op_expr = GetRef<Op>(op); |
162 | |
163 | if (op_expr == Op::Get("add" )) { |
164 | return CallGradCellFunction(call_node, module_->GetGlobalVar("AddGradCell" )); |
165 | } |
166 | |
167 | if (op_expr == Op::Get("multiply" )) { |
168 | return CallGradCellFunction(call_node, module_->GetGlobalVar("MultiplyGradCell" )); |
169 | } |
170 | |
171 | if (op_expr == Op::Get("ones" ) || op_expr == Op::Get("zeros" )) { |
172 | // ones and zeros need TensorType input |
173 | Expr result = CallPrimitiveOp(call_node); |
174 | Expr func = Function({}, result, {call_node->checked_type()}, Array<TypeVar>()); |
175 | // call appropriate GradCell constructor |
176 | std::string constructor_name = op_expr == Op::Get("ones" ) ? "One" : "Zero" ; |
177 | return Call(module_->GetConstructor("GradCell" , constructor_name), {func}, Attrs(), |
178 | {call_node->checked_type()}); |
179 | } |
180 | |
181 | if (op_expr == Op::Get("ones_like" ) || op_expr == Op::Get("zeros_like" )) { |
182 | // ones_like and zeros_like need TensorType input |
183 | Expr result = CallPrimitiveOp(call_node); |
184 | // fn() -> T, function returns result of operation |
185 | Expr func = Function({}, result, {call_node->checked_type()}, Array<TypeVar>()); |
186 | // call appropriate GradCell constructor |
187 | std::string constructor_name = op_expr == Op::Get("ones_like" ) ? "One" : "Zero" ; |
188 | return Call(module_->GetConstructor("GradCell" , "One" ), {func}, Attrs(), |
189 | {call_node->checked_type()}); |
190 | } |
191 | |
192 | // handle all other ops |
193 | Expr result = CallPrimitiveOp(call_node); |
194 | // wrap result with Raw constructor |
195 | return Call(module_->GetConstructor("GradCell" , "Raw" ), {result}, Attrs(), |
196 | {call_node->checked_type()}); |
197 | } |
198 | // not an op |
199 | return ExprMutator::VisitExpr_(call_node); |
200 | } |
201 | |
202 | Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } |
203 | |
204 | Type VisitType_(const TensorTypeNode* op) { |
205 | GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell" ); |
206 | tvm::Array<Type> args; |
207 | args.push_back(GetRef<TensorType>(op)); |
208 | return TypeCall(gradCell, args); |
209 | } |
210 | |
211 | private: |
212 | // Module |
213 | IRModule module_; |
214 | |
215 | /*! |
216 | * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type |
217 | */ |
218 | Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) { |
219 | // can only use overloaded functions if 2 arguments of same type |
220 | if (call_node->args.size() != 2 || |
221 | !tvm::StructuralEqual()(call_node->args[0]->checked_type(), |
222 | call_node->args[1]->checked_type())) { |
223 | Expr result = CallPrimitiveOp(call_node); |
224 | return Call(module_->GetConstructor("GradCell" , "Raw" ), {result}, Attrs(), |
225 | {call_node->checked_type()}); |
226 | } |
227 | |
228 | tvm::Array<Expr> args; |
229 | // create "fallback" function for overloaded function |
230 | Type paramType = call_node->args[0]->checked_type(); |
231 | tvm::Array<Var> params = {Var("lhs" , paramType), Var("rhs" , paramType)}; |
232 | // use primitive op in this case |
233 | Expr callOp = Call(call_node->op, {params[0], params[1]}); |
234 | Expr func = Function(params, callOp, paramType, Array<TypeVar>()); |
235 | |
236 | // pass "fallback" function and tensors as arguments |
237 | args.push_back(func); |
238 | for (Expr expr : call_node->args) { |
239 | args.push_back(VisitExpr(expr)); |
240 | } |
241 | // return new call to overloaded function |
242 | return Call(overloaded_op, args, Attrs(), {paramType}); |
243 | } |
244 | |
245 | /*! |
246 | * \brief Convert calls to other ops by converting args into TensorType |
247 | * \return call expr returning result of op |
248 | */ |
249 | Expr CallPrimitiveOp(const CallNode* call_node) { |
250 | const auto fromFunc = module_->GetGlobalVar("FromGradCell" ); |
251 | tvm::Array<Expr> args; |
252 | // use FromGradCell to convert args to Tensor |
253 | for (Expr expr : call_node->args) { |
254 | args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); |
255 | } |
256 | // result of operation |
257 | return Call(call_node->op, args, call_node->attrs); |
258 | } |
259 | }; |
260 | |
261 | Expr LazyGradientInit(const Expr& e, IRModule mod) { |
262 | CheckFeature(e, mod, FeatureSet::All() - fGraph); |
263 | auto ret = LazyGradientInitializer(mod).Transform(e); |
264 | CheckFeature(ret, mod, FeatureSet::All() - fGraph); |
265 | return ret; |
266 | } |
267 | |
268 | namespace transform { |
269 | Pass LazyGradientInit() { |
270 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
271 | [=](Function f, IRModule m, PassContext pc) { |
272 | return Downcast<Function>(LazyGradientInit(f, m)); |
273 | }; |
274 | return CreateFunctionPass(pass_func, 2, "LazyGradientInit" , {}); |
275 | } |
276 | |
277 | TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit" ).set_body_typed(LazyGradientInit); |
278 | |
279 | } // namespace transform |
280 | |
281 | } // namespace relay |
282 | } // namespace tvm |
283 | |