1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file lazy_gradient_init.cc
23 *
24 * \brief Lazily instantiate 0-filled or 1-filled tensors.
25 * This pass should be used after reverse-mode ad so that gradient tensors
26 * are not instantiated until after the forward pass.
27 *
28 * This pass delays or removes memory allocation by converting tensors into
29 * GradCell, an algebraic data type defined in gradient.rly.
30 *
31 * This will delay or decrease memory usage. All calls to
32 * ones, ones_like, zeros, zeros_like will call the One or Zero constructor
33 * of GradCell, which will not instantiate in memory until needed. All other cases result
34 * in using the Raw constructor which means the tensor is instantiated in memory.
35 *
36 * It also overloads + and * operation which can increase performance when doing
37 * operations involving tensors with values of only 0 or 1.
38 *
39 * Note: this pass can only be used with functions where the input/output types are
40 * a combination of TupleTypes and TensorTypes
41 *
42 * This pass optimizes 6 ops:
43 * - add
44 * - multiply
45 * - ones
46 * - ones_like
47 * - zeros
48 * - zeros_like
49 *
50 * This pass makes use of three visitor. The most important one visits the entire function,
51 * one is used for wrap inputs and one to unwrap outputs.
52 *
53 * For example:
54 * fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32]
55 *
56 * After this pass
57 * fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]]
58 *
59 * Thus, it is necessary to wrap this outer function so that the input/output types remain the same
60 */
61
62#include <tvm/ir/type_functor.h>
63#include <tvm/node/structural_equal.h>
64#include <tvm/relay/analysis.h>
65#include <tvm/relay/expr_functor.h>
66#include <tvm/relay/feature.h>
67#include <tvm/relay/transform.h>
68
69#include "let_list.h"
70
71namespace tvm {
72namespace relay {
73
74class LazyGradientInitializer : public ExprMutator, public TypeMutator {
75 public:
76 explicit LazyGradientInitializer(IRModule module) : module_(module) {
77 module_->ImportFromStd("gradient.rly");
78 }
79
80 Expr WrapExpr(const Var& var, const Type& type, LetList* ll) {
81 if (type.as<TensorTypeNode>()) {
82 return Call(module_->GetConstructor("GradCell", "Raw"), {var}, Attrs(), {type});
83 } else if (auto* type_anno = type.as<TupleTypeNode>()) {
84 tvm::Array<Expr> fields;
85 for (size_t i = 0; i < type_anno->fields.size(); i++) {
86 const Type& t = type_anno->fields[i];
87 fields.push_back(WrapExpr(ll->Push(TupleGetItem(var, i)), t, ll));
88 }
89 Expr tuple = Tuple(fields);
90 return tuple;
91 }
92
93 return var;
94 }
95
96 Expr UnwrapExpr(const Var& var, const Type& type, LetList* ll) {
97 if (auto* type_call = type.as<TypeCallNode>()) {
98 if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) {
99 return Call(module_->GetGlobalVar("FromGradCell"), {var});
100 }
101 return var;
102 } else if (auto* type_anno = type.as<TupleTypeNode>()) {
103 tvm::Array<Expr> fields;
104 for (size_t i = 0; i < type_anno->fields.size(); i++) {
105 const Type& t = type_anno->fields[i];
106 fields.push_back(UnwrapExpr(ll->Push(TupleGetItem(var, i)), t, ll));
107 }
108 Expr tuple = Tuple(fields);
109 return tuple;
110 }
111
112 return var;
113 }
114
115 // Turn off memo for constant node.
116 Expr VisitExpr(const Expr& e) final {
117 if (e.as<ConstantNode>()) {
118 return ExprFunctor::VisitExpr(e);
119 } else {
120 return ExprMutator::VisitExpr(e);
121 }
122 }
123
124 /*!
125 * \brief apply LazyGradientInit transformation and wrap function
126 * so that function type stays the same
127 *
128 * input/output types should only be a combination of TupleTypes and TensorTypes
129 */
130 Expr Transform(const Expr& e) {
131 auto* f = e.as<FunctionNode>();
132 auto* transformed = this->Mutate(e).as<FunctionNode>();
133
134 ICHECK(f);
135 ICHECK(transformed);
136
137 if (e.same_as(GetRef<Function>(transformed))) {
138 return GetRef<Function>(transformed);
139 }
140
141 auto tensorOutput = LetList::With([&](LetList* ll) {
142 // wrap inputs of Tensor type using InputVisitor class
143 tvm::Array<Expr> args;
144 for (const Var& var : f->params) {
145 args.push_back(WrapExpr(var, var->checked_type(), ll));
146 }
147 Expr transformedExpr = Call(GetRef<Function>(transformed), args);
148 // unwrap outputs of GradCell type into Tensor type using OutputVisitor class
149 return UnwrapExpr(ll->Push(transformedExpr), transformed->ret_type, ll);
150 });
151 return Function(f->params, tensorOutput, f->ret_type, Array<TypeVar>());
152 }
153
154 Expr VisitExpr_(const ConstantNode* op) final {
155 return Call(module_->GetConstructor("GradCell", "Raw"), {GetRef<Constant>(op)}, Attrs(),
156 {op->checked_type()});
157 }
158
159 Expr VisitExpr_(const CallNode* call_node) final {
160 if (auto* op = (call_node->op).as<OpNode>()) {
161 Expr op_expr = GetRef<Op>(op);
162
163 if (op_expr == Op::Get("add")) {
164 return CallGradCellFunction(call_node, module_->GetGlobalVar("AddGradCell"));
165 }
166
167 if (op_expr == Op::Get("multiply")) {
168 return CallGradCellFunction(call_node, module_->GetGlobalVar("MultiplyGradCell"));
169 }
170
171 if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) {
172 // ones and zeros need TensorType input
173 Expr result = CallPrimitiveOp(call_node);
174 Expr func = Function({}, result, {call_node->checked_type()}, Array<TypeVar>());
175 // call appropriate GradCell constructor
176 std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero";
177 return Call(module_->GetConstructor("GradCell", constructor_name), {func}, Attrs(),
178 {call_node->checked_type()});
179 }
180
181 if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) {
182 // ones_like and zeros_like need TensorType input
183 Expr result = CallPrimitiveOp(call_node);
184 // fn() -> T, function returns result of operation
185 Expr func = Function({}, result, {call_node->checked_type()}, Array<TypeVar>());
186 // call appropriate GradCell constructor
187 std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero";
188 return Call(module_->GetConstructor("GradCell", "One"), {func}, Attrs(),
189 {call_node->checked_type()});
190 }
191
192 // handle all other ops
193 Expr result = CallPrimitiveOp(call_node);
194 // wrap result with Raw constructor
195 return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(),
196 {call_node->checked_type()});
197 }
198 // not an op
199 return ExprMutator::VisitExpr_(call_node);
200 }
201
202 Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); }
203
204 Type VisitType_(const TensorTypeNode* op) {
205 GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell");
206 tvm::Array<Type> args;
207 args.push_back(GetRef<TensorType>(op));
208 return TypeCall(gradCell, args);
209 }
210
211 private:
212 // Module
213 IRModule module_;
214
215 /*!
216 * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type
217 */
218 Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) {
219 // can only use overloaded functions if 2 arguments of same type
220 if (call_node->args.size() != 2 ||
221 !tvm::StructuralEqual()(call_node->args[0]->checked_type(),
222 call_node->args[1]->checked_type())) {
223 Expr result = CallPrimitiveOp(call_node);
224 return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(),
225 {call_node->checked_type()});
226 }
227
228 tvm::Array<Expr> args;
229 // create "fallback" function for overloaded function
230 Type paramType = call_node->args[0]->checked_type();
231 tvm::Array<Var> params = {Var("lhs", paramType), Var("rhs", paramType)};
232 // use primitive op in this case
233 Expr callOp = Call(call_node->op, {params[0], params[1]});
234 Expr func = Function(params, callOp, paramType, Array<TypeVar>());
235
236 // pass "fallback" function and tensors as arguments
237 args.push_back(func);
238 for (Expr expr : call_node->args) {
239 args.push_back(VisitExpr(expr));
240 }
241 // return new call to overloaded function
242 return Call(overloaded_op, args, Attrs(), {paramType});
243 }
244
245 /*!
246 * \brief Convert calls to other ops by converting args into TensorType
247 * \return call expr returning result of op
248 */
249 Expr CallPrimitiveOp(const CallNode* call_node) {
250 const auto fromFunc = module_->GetGlobalVar("FromGradCell");
251 tvm::Array<Expr> args;
252 // use FromGradCell to convert args to Tensor
253 for (Expr expr : call_node->args) {
254 args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()}));
255 }
256 // result of operation
257 return Call(call_node->op, args, call_node->attrs);
258 }
259};
260
261Expr LazyGradientInit(const Expr& e, IRModule mod) {
262 CheckFeature(e, mod, FeatureSet::All() - fGraph);
263 auto ret = LazyGradientInitializer(mod).Transform(e);
264 CheckFeature(ret, mod, FeatureSet::All() - fGraph);
265 return ret;
266}
267
268namespace transform {
269Pass LazyGradientInit() {
270 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
271 [=](Function f, IRModule m, PassContext pc) {
272 return Downcast<Function>(LazyGradientInit(f, m));
273 };
274 return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {});
275}
276
277TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit").set_body_typed(LazyGradientInit);
278
279} // namespace transform
280
281} // namespace relay
282} // namespace tvm
283