1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file higher_order_gradient.cc |
22 | * \brief Higher-order Automatic Differentiation in Relay IR, for non-graph programs. |
23 | */ |
24 | #include <tvm/ir/type_functor.h> |
25 | #include <tvm/relay/analysis.h> |
26 | #include <tvm/relay/expr_functor.h> |
27 | #include <tvm/relay/feature.h> |
28 | #include <tvm/relay/transform.h> |
29 | #include <tvm/te/operation.h> |
30 | |
31 | #include "gradient.h" |
32 | #include "let_list.h" |
33 | #include "pass_utils.h" |
34 | #include "pattern_utils.h" |
35 | |
36 | namespace tvm { |
37 | namespace relay { |
38 | |
39 | using namespace tvm::runtime; |
40 | |
41 | /*! What is automatic differentiation(AD) and why is it important? |
42 | * By AD, we roughly mean, given a term which denotes some mathematical function, |
43 | * derive a term which denotes the derivative of that mathematical function. |
44 | * Such a method can be compile-time, which is a macro on completely known function. |
45 | * Formally speaking, such requirement mean that the input function is a closed expression - |
46 | * that is, it only refer to local variable that is it's parameter, or defined inside it. |
47 | * Every top level definition satisfy this criteria. |
48 | * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> |
49 | * (Float[] -> Float[]). In relay we currently only support compile-time AD, but it should be enough |
50 | * for a lot of use case. |
51 | * |
52 | * In deep learning, the most common way to train a deep neural network is by gradient descent or |
53 | * some of it's variant. Such optimization method require us to input the gradient of neural |
54 | * network, which can be obtained easily using AD. In fact, back propagation is essentially |
55 | * reverse-mode automatic differentiation, a kind of AD! |
56 | */ |
57 | |
58 | /*! In relay, automatic differentiation(AD) is a macro, |
59 | * that transform closed expr(expr without free variable/free type variable) of type |
60 | * (x0, x1, x2, ...) -> Float[] to |
61 | * (x0, x1, x2, ...) -> (Float[], (x0, x1, x2, ...)), |
62 | * When x0, x1, x2... are Float of different shape. |
63 | * the return value is a pair, with left hand side as the original value, and right hand side as |
64 | * gradient of the input. WithGradientType will take the type of input, and produce the type of |
65 | * output. There are multiple implementation of AD in relay, with different characteristic. However, |
66 | * they all transform the input expr according to WithGradientType. |
67 | */ |
68 | Type WithGradientType(const Type& t) { |
69 | // TODO(@M.K.): stricter checking |
70 | auto ty = t.as<FuncTypeNode>(); |
71 | ICHECK(ty) << "input should be a function" ; |
72 | return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {}); |
73 | } |
74 | |
75 | //! \brief if the expression is a GlobalVar, transform to it's expression. |
76 | Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) { |
77 | const auto* x = e.as<GlobalVarNode>(); |
78 | |
79 | if (mod.defined() && x) { |
80 | BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x)); |
81 | if (auto* n = base_func.as<FunctionNode>()) { |
82 | return GetRef<Function>(n); |
83 | } else { |
84 | return e; |
85 | } |
86 | } else { |
87 | return e; |
88 | } |
89 | } |
90 | |
91 | static Type bpt = RelayRefType(FuncType({}, TupleType(Array<Type>()), {}, {})); |
92 | |
93 | struct ReverseADType : TypeMutator { |
94 | Type VisitType_(const TensorTypeNode* ttn) final { |
95 | Type t = GetRef<Type>(ttn); |
96 | return TupleType({t, RelayRefType(t)}); |
97 | } |
98 | |
99 | Type VisitType_(const FuncTypeNode* ftn) final { |
100 | std::vector<Type> arg_types; |
101 | for (const auto& t : ftn->arg_types) { |
102 | arg_types.push_back(VisitType(t)); |
103 | } |
104 | arg_types.push_back(bpt); |
105 | return FuncType(arg_types, ftn->ret_type, ftn->type_params, ftn->type_constraints); |
106 | } |
107 | }; |
108 | |
109 | Type ReverseType(const Type& t) { return ReverseADType()(t); } |
110 | |
111 | /*! \brief Lift a function that transform Tensor to a function that also transform more type |
112 | * by doing a structure preserving map. |
113 | */ |
114 | Expr LiftTensor(const std::function<Expr(const Expr& t)>& f, |
115 | const std::function<Type(const Type&)>& tf, const Type& forward_type, const Expr& e, |
116 | LetList* ll) { |
117 | ICHECK(IsAtomic(e)) << e; |
118 | if (forward_type.as<TensorTypeNode>()) { |
119 | auto ret = ll->Push(f(e)); |
120 | ret->checked_type_ = tf(forward_type); |
121 | return std::move(ret); |
122 | } else if (auto* tt = forward_type.as<TupleTypeNode>()) { |
123 | tvm::Array<Expr> fields; |
124 | tvm::Array<Type> types; |
125 | for (size_t i = 0; i < tt->fields.size(); ++i) { |
126 | auto field = LiftTensor(f, tf, tt->fields[i], ll->Push(GetField(e, i)), ll); |
127 | fields.push_back(field); |
128 | types.push_back(field->checked_type_); |
129 | } |
130 | auto ret = ll->Push(Tuple(fields)); |
131 | ret->checked_type_ = TupleType(types); |
132 | return std::move(ret); |
133 | } else { |
134 | LOG(FATAL) << "unsupported input/output type: " << tt; |
135 | throw; |
136 | } |
137 | } |
138 | |
139 | /*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr, |
140 | * by stitching the references in the AD values. |
141 | */ |
142 | void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, LetList* ll) { |
143 | ICHECK(IsAtomic(from)) << from; |
144 | ICHECK(IsAtomic(to)) << to; |
145 | if (forward_type.as<TensorTypeNode>()) { |
146 | auto from_ref = TupleGetItem(from, 1); |
147 | auto to_ref = TupleGetItem(to, 1); |
148 | ll->Push(RefWrite(to_ref, RefRead(from_ref))); |
149 | } else if (auto* tt = forward_type.as<TupleTypeNode>()) { |
150 | for (size_t i = 0; i < tt->fields.size(); ++i) { |
151 | TransferGrads(tt->fields[i], ll->Push(TupleGetItem(from, i)), ll->Push(TupleGetItem(to, i)), |
152 | ll); |
153 | } |
154 | } else { |
155 | LOG(FATAL) << "Unsupported input/output type: " << forward_type; |
156 | throw; |
157 | } |
158 | } |
159 | |
160 | // TODO(@M.K.): why take Expr? |
161 | /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */ |
162 | Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) { |
163 | auto rev = [&](const Expr& e) { return Pair(e, RefCreate(ZerosLike(e))); }; |
164 | auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); }; |
165 | return LiftTensor(rev, rev_type, forward_type, e, ll); |
166 | } |
167 | |
168 | /*! \brief ReverseType(t) -> t. Get the original value. */ |
169 | Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) { |
170 | auto val = [&](const Expr& e) { return GetField(e, 0); }; |
171 | auto val_type = [&](const Type& forward_type) { return forward_type; }; |
172 | return LiftTensor(val, val_type, forward_type, e, ll); |
173 | } |
174 | |
175 | /*! \brief ReverseType(t) -> t. Get the gradient. */ |
176 | Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) { |
177 | auto grad = [&](const Expr& e) { return RefRead(GetField(e, 1)); }; |
178 | auto grad_type = [&](const Type& forward_type) { return forward_type; }; |
179 | return LiftTensor(grad, grad_type, forward_type, e, ll); |
180 | } |
181 | |
182 | void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { |
183 | if (t.as<TensorTypeNode>()) { |
184 | ll->Push(RefWrite(GetField(arg, 1), Add(RefRead(GetField(arg, 1)), grad))); |
185 | } else if (auto* tt = t.as<TupleTypeNode>()) { |
186 | for (size_t i = 0; i < tt->fields.size(); ++i) { |
187 | UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll); |
188 | } |
189 | } else { |
190 | LOG(FATAL) << "unsupported arg type of operator: " << t; |
191 | throw; |
192 | } |
193 | } |
194 | |
195 | Expr BPEmpty() { |
196 | Expr unitF = Function({}, Tuple(tvm::Array<Expr>({})), TupleType::Empty(), {}); |
197 | return RefCreate(unitF); |
198 | } |
199 | |
200 | struct ReverseAD : ExprMutator { |
201 | using ADVarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>; |
202 | using ADGlobalVarMap = std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>; |
203 | Optional<IRModule> mod; |
204 | // TODO(@M.K.) refactor AD to always use mod. |
205 | Var bp; |
206 | std::shared_ptr<ADVarMap> ad_vars; |
207 | std::shared_ptr<ADGlobalVarMap> ad_gvars; |
208 | const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient" ); |
209 | |
210 | explicit ReverseAD(const Optional<IRModule>& mod, const Var& bp, |
211 | const std::shared_ptr<ADVarMap>& ad_vars, |
212 | const std::shared_ptr<ADGlobalVarMap>& ad_gvars) |
213 | : mod(mod), bp(bp), ad_vars(ad_vars), ad_gvars(ad_gvars) {} |
214 | |
215 | Expr VisitExpr_(const OpNode* op) final { |
216 | LOG(FATAL) << "op should only be inside call" ; |
217 | throw; |
218 | } |
219 | |
220 | Expr Remap(const Expr& e) { |
221 | struct Remapper : ExprMutator { |
222 | std::shared_ptr<ADVarMap> ad_vars; |
223 | LetList* ll; |
224 | Remapper(const std::shared_ptr<ADVarMap>& ad_vars, LetList* ll) : ad_vars(ad_vars), ll(ll) {} |
225 | Expr VisitExpr_(const VarNode* var) final { |
226 | // memoize Var -> ADVar so we don't end up with free Vars when checkpointing |
227 | auto var_ref = GetRef<Var>(var); |
228 | if (ad_vars->count(var_ref) == 0) { |
229 | return std::move(var_ref); |
230 | } else { |
231 | return GetValue(var_ref->checked_type(), ad_vars->at(var_ref), ll); |
232 | } |
233 | } |
234 | }; |
235 | return LetList::With([&](LetList* ll) { return Remapper(ad_vars, ll)(e); }); |
236 | } |
237 | |
238 | Expr VisitCheckpoint(const CallNode* call) { |
239 | const OpNode* op_node = call->op.as<OpNode>(); |
240 | ICHECK(op_node) << "expected op in call" ; |
241 | Op op_ref = GetRef<Op>(op_node); |
242 | ICHECK(op_ref->name == "annotation.checkpoint" ) << "expected checkpoint annotation" ; |
243 | auto x = call->args[0]; |
244 | return LetList::With([&](LetList* ll) { |
245 | auto x_var = ll->Push(Remap(x)); |
246 | auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); |
247 | auto bpv = ll->Push(RefRead(bp)); |
248 | Expr nbp = Function({}, LetList::With([&](LetList* ll) { |
249 | // we need a new ReverseAD visitor to avoid clobbering the bp local var |
250 | auto dup_bp = ll->Push(BPEmpty()); |
251 | auto dup_ad = |
252 | ll->Push(ReverseAD(mod, dup_bp, ad_vars, ad_gvars)(DeDup(x))); |
253 | TransferGrads(call->checked_type(), ret, dup_ad, ll); |
254 | ll->Push(Call(RefRead(dup_bp), {})); |
255 | return Call(bpv, {}); |
256 | }), |
257 | TupleType::Empty(), {}); |
258 | ll->Push(RefWrite(bp, nbp)); |
259 | return ret; |
260 | }); |
261 | } |
262 | |
263 | Expr VisitExpr_(const CallNode* call) final { |
264 | if (const OpNode* op_node = call->op.as<OpNode>()) { |
265 | Op op_ref = GetRef<Op>(op_node); |
266 | |
267 | if (op_ref->name == "annotation.checkpoint" ) { |
268 | return VisitCheckpoint(call); |
269 | } |
270 | |
271 | ICHECK(rev_map.count(op_ref)) << op_node->name << " does not have reverse mode defined" ; |
272 | return LetList::With([&](LetList* ll) { |
273 | std::vector<Var> args; |
274 | for (const auto& arg : call->args) { |
275 | args.push_back(ll->Push(VisitExpr(arg))); |
276 | } |
277 | std::vector<Expr> orig_args; |
278 | for (size_t i = 0; i < args.size(); i++) { |
279 | orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll)); |
280 | } |
281 | Expr orig = Call(call->op, orig_args, call->attrs, call->type_args); |
282 | orig->checked_type_ = call->checked_type(); |
283 | Var orig_var = ll->Push(orig); |
284 | orig_var->checked_type_ = call->checked_type(); |
285 | auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); |
286 | auto bpv = ll->Push(RefRead(bp)); |
287 | Expr nbp_body = LetList::With([&](LetList* ll) { |
288 | tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); |
289 | ICHECK(args.size() == rev.size()); |
290 | for (size_t i = 0; i < args.size(); ++i) { |
291 | UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); |
292 | } |
293 | return Call(bpv, {}); |
294 | }); |
295 | Expr nbp = Function({}, nbp_body, TupleType::Empty(), {}); |
296 | ll->Push(RefWrite(bp, transform::ToANormalForm(nbp))); |
297 | // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that. |
298 | return ret; |
299 | }); |
300 | } else if (call->op.as<ConstructorNode>()) { |
301 | return ExprMutator::VisitExpr_(call); |
302 | } else { |
303 | std::vector<Expr> args; |
304 | for (const auto& arg : call->args) { |
305 | args.push_back(VisitExpr(arg)); |
306 | } |
307 | args.push_back(bp); |
308 | return Call(VisitExpr(call->op), args); |
309 | } |
310 | } |
311 | |
312 | Expr VisitExpr_(const ConstantNode* op) final { |
313 | return LetList::With([&](LetList* ll) { |
314 | Expr e = ll->Push(GetRef<Expr>(op)); |
315 | return Pair(e, RefCreate(ZerosLike(e))); |
316 | }); |
317 | } |
318 | |
319 | Expr VisitExpr_(const IfNode* op) final { |
320 | return If(TupleGetItem(VisitExpr(op->cond), 0), VisitExpr(op->true_branch), |
321 | VisitExpr(op->false_branch)); |
322 | } |
323 | |
324 | Expr VisitExpr_(const VarNode* var) final { |
325 | // memoize Var -> ADVar so we don't end up with free Vars when checkpointing |
326 | auto var_ref = GetRef<Var>(var); |
327 | if (ad_vars->count(var_ref) == 0) { |
328 | auto res = Downcast<Var>(ExprMutator::VisitExpr_(var)); |
329 | (*ad_vars)[var_ref] = res; |
330 | } |
331 | |
332 | return ad_vars->at(var_ref); |
333 | } |
334 | |
335 | Expr VisitExpr_(const GlobalVarNode* op) final { |
336 | // todo: concatenating string to add attribute seems like a brittle hack. |
337 | // maybe get module indexed by a rose tree of string? |
338 | ICHECK(mod.defined()); |
339 | auto orig_gv = GetRef<GlobalVar>(op); |
340 | if (ad_gvars->count(orig_gv) == 0) { |
341 | GlobalVar gv(op->name_hint + "_grad" ); |
342 | (*ad_gvars)[orig_gv] = gv; |
343 | Function orig_f = Downcast<Function>(DeDup(mod.value()->Lookup(orig_gv))); |
344 | Array<Var> params; |
345 | for (const auto& p : orig_f->params) { |
346 | params.push_back(Downcast<Var>(VisitExpr(p))); |
347 | } |
348 | params.push_back(bp); |
349 | Function f = WithFields(orig_f, params, VisitExpr(orig_f->body), VisitType(orig_f->ret_type)); |
350 | std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl; |
351 | mod.value()->Add(gv, f); |
352 | } |
353 | return ad_gvars->at(orig_gv); |
354 | } |
355 | |
356 | Expr VisitExpr_(const FunctionNode* func_node) final { |
357 | Array<Var> params; |
358 | for (const auto& var : func_node->params) { |
359 | params.push_back(Downcast<Var>(VisitExpr(var))); |
360 | } |
361 | auto new_bp = Var("bp" , bpt); |
362 | params.push_back(new_bp); |
363 | return WithFields(GetRef<Function>(func_node), params, |
364 | ReverseAD(mod, new_bp, ad_vars, ad_gvars)(func_node->body), |
365 | VisitType(func_node->ret_type)); |
366 | } |
367 | |
368 | Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } |
369 | }; |
370 | |
371 | bool MissingGrad(const Expr& e) { |
372 | struct MGVisitor : ExprVisitor { |
373 | const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient" ); |
374 | std::unordered_set<std::string> op_names; |
375 | |
376 | void VisitExpr_(const OpNode* op) final { |
377 | Op op_ref = GetRef<Op>(op); |
378 | if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) { |
379 | op_names.insert(op_ref->name); |
380 | } |
381 | ExprVisitor::VisitExpr_(op); |
382 | } |
383 | }; |
384 | |
385 | MGVisitor mg; |
386 | mg.VisitExpr(e); |
387 | |
388 | if (mg.op_names.size() > 0) { |
389 | LOG(WARNING) << "found operators with missing gradients:" ; |
390 | for (const auto& op : mg.op_names) { |
391 | LOG(WARNING) << " " << op; |
392 | } |
393 | return true; |
394 | } |
395 | |
396 | return false; |
397 | } |
398 | |
399 | Expr Gradient(const Expr& re, const Optional<IRModule>& mod) { |
400 | CheckFeature(re, FeatureSet::All() - fGraph); |
401 | if (mod.defined()) { |
402 | CheckFeature(mod.value(), FeatureSet::All() - fGraph); |
403 | } |
404 | auto e = DeGlobal(mod, re); |
405 | auto f = e.as<FunctionNode>(); |
406 | ICHECK(f) << "input need to be a function" ; |
407 | ICHECK(f->type_params.size() == 0) << "no polymorphism supported for now" ; |
408 | for (const auto& p : f->params) { |
409 | ICHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor" ; |
410 | } |
411 | ICHECK(!MissingGrad(e)) << "input has operators with missing gradients" ; |
412 | Expr body = LetList::With([&](LetList* ll) { |
413 | Var bp = ll->Push(BPEmpty(), bpt); |
414 | Expr rev = ReverseAD(mod, bp, std::make_shared<ReverseAD::ADVarMap>(), |
415 | std::make_shared<ReverseAD::ADGlobalVarMap>())(e); |
416 | std::vector<Expr> normal_args, args; |
417 | for (const auto& p : f->params) { |
418 | auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p)))); |
419 | normal_args.push_back(x); |
420 | args.push_back(x); |
421 | } |
422 | args.push_back(bp); |
423 | auto c = ll->Push(Call(rev, args)); |
424 | std::function<void(const Expr&, const Type&)> init_grad; |
425 | init_grad = [&](const Expr& e, const Type& t) { |
426 | if (t.as<TensorTypeNode>()) { |
427 | ll->Push(RefWrite(GetField(e, 1), OnesLike(GetField(e, 0)))); |
428 | } else if (auto tt = t.as<TupleTypeNode>()) { |
429 | ICHECK_GT(tt->fields.size(), 0); |
430 | init_grad(ll->Push(GetField(e, 0)), tt->fields[0]); |
431 | } else { |
432 | LOG(FATAL) << "unhandled type " << t; |
433 | throw; |
434 | } |
435 | }; |
436 | init_grad(c, f->body->checked_type()); |
437 | ll->Push(Call(RefRead(bp), {})); |
438 | std::vector<Expr> ret; |
439 | for (const auto& a : normal_args) { |
440 | ret.push_back(RefRead(GetField(a, 1))); |
441 | } |
442 | std::function<Expr(const Expr&, const Type&)> get_final_result; |
443 | get_final_result = [&](const Expr& e, const Type& t) -> Expr { |
444 | if (t.as<TensorTypeNode>()) { |
445 | return GetField(e, 0); |
446 | } else if (auto tt = t.as<TupleTypeNode>()) { |
447 | tvm::Array<Expr> fields; |
448 | for (size_t i = 0; i < tt->fields.size(); ++i) { |
449 | fields.push_back(get_final_result(ll->Push(GetField(e, i)), tt->fields[i])); |
450 | } |
451 | return Tuple(fields); |
452 | } else { |
453 | LOG(FATAL) << "unhandled type " << t; |
454 | throw; |
455 | } |
456 | }; |
457 | return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret)); |
458 | }); |
459 | Function ret = WithFields(GetRef<Function>(f), f->params, body, GradRetType(GetRef<Function>(f)), |
460 | /* erase type params */ Array<TypeVar>()); |
461 | CheckFeature(ret, FeatureSet::All() - fGraph); |
462 | return std::move(ret); |
463 | } |
464 | |
465 | TVM_REGISTER_GLOBAL("relay._transform.gradient" ).set_body_typed(Gradient); |
466 | |
467 | } // namespace relay |
468 | } // namespace tvm |
469 | |