1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file higher_order_gradient.cc
22 * \brief Higher-order Automatic Differentiation in Relay IR, for non-graph programs.
23 */
24#include <tvm/ir/type_functor.h>
25#include <tvm/relay/analysis.h>
26#include <tvm/relay/expr_functor.h>
27#include <tvm/relay/feature.h>
28#include <tvm/relay/transform.h>
29#include <tvm/te/operation.h>
30
31#include "gradient.h"
32#include "let_list.h"
33#include "pass_utils.h"
34#include "pattern_utils.h"
35
36namespace tvm {
37namespace relay {
38
39using namespace tvm::runtime;
40
41/*! What is automatic differentiation(AD) and why is it important?
42 * By AD, we roughly mean, given a term which denotes some mathematical function,
43 * derive a term which denotes the derivative of that mathematical function.
44 * Such a method can be compile-time, which is a macro on completely known function.
45 * Formally speaking, such requirement mean that the input function is a closed expression -
46 * that is, it only refer to local variable that is it's parameter, or defined inside it.
47 * Every top level definition satisfy this criteria.
48 * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) ->
49 * (Float[] -> Float[]). In relay we currently only support compile-time AD, but it should be enough
50 * for a lot of use case.
51 *
52 * In deep learning, the most common way to train a deep neural network is by gradient descent or
53 * some of it's variant. Such optimization method require us to input the gradient of neural
54 * network, which can be obtained easily using AD. In fact, back propagation is essentially
55 * reverse-mode automatic differentiation, a kind of AD!
56 */
57
58/*! In relay, automatic differentiation(AD) is a macro,
59 * that transform closed expr(expr without free variable/free type variable) of type
60 * (x0, x1, x2, ...) -> Float[] to
61 * (x0, x1, x2, ...) -> (Float[], (x0, x1, x2, ...)),
62 * When x0, x1, x2... are Float of different shape.
63 * the return value is a pair, with left hand side as the original value, and right hand side as
64 * gradient of the input. WithGradientType will take the type of input, and produce the type of
65 * output. There are multiple implementation of AD in relay, with different characteristic. However,
66 * they all transform the input expr according to WithGradientType.
67 */
68Type WithGradientType(const Type& t) {
69 // TODO(@M.K.): stricter checking
70 auto ty = t.as<FuncTypeNode>();
71 ICHECK(ty) << "input should be a function";
72 return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {});
73}
74
75//! \brief if the expression is a GlobalVar, transform to it's expression.
76Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
77 const auto* x = e.as<GlobalVarNode>();
78
79 if (mod.defined() && x) {
80 BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
81 if (auto* n = base_func.as<FunctionNode>()) {
82 return GetRef<Function>(n);
83 } else {
84 return e;
85 }
86 } else {
87 return e;
88 }
89}
90
91static Type bpt = RelayRefType(FuncType({}, TupleType(Array<Type>()), {}, {}));
92
93struct ReverseADType : TypeMutator {
94 Type VisitType_(const TensorTypeNode* ttn) final {
95 Type t = GetRef<Type>(ttn);
96 return TupleType({t, RelayRefType(t)});
97 }
98
99 Type VisitType_(const FuncTypeNode* ftn) final {
100 std::vector<Type> arg_types;
101 for (const auto& t : ftn->arg_types) {
102 arg_types.push_back(VisitType(t));
103 }
104 arg_types.push_back(bpt);
105 return FuncType(arg_types, ftn->ret_type, ftn->type_params, ftn->type_constraints);
106 }
107};
108
109Type ReverseType(const Type& t) { return ReverseADType()(t); }
110
111/*! \brief Lift a function that transform Tensor to a function that also transform more type
112 * by doing a structure preserving map.
113 */
114Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
115 const std::function<Type(const Type&)>& tf, const Type& forward_type, const Expr& e,
116 LetList* ll) {
117 ICHECK(IsAtomic(e)) << e;
118 if (forward_type.as<TensorTypeNode>()) {
119 auto ret = ll->Push(f(e));
120 ret->checked_type_ = tf(forward_type);
121 return std::move(ret);
122 } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
123 tvm::Array<Expr> fields;
124 tvm::Array<Type> types;
125 for (size_t i = 0; i < tt->fields.size(); ++i) {
126 auto field = LiftTensor(f, tf, tt->fields[i], ll->Push(GetField(e, i)), ll);
127 fields.push_back(field);
128 types.push_back(field->checked_type_);
129 }
130 auto ret = ll->Push(Tuple(fields));
131 ret->checked_type_ = TupleType(types);
132 return std::move(ret);
133 } else {
134 LOG(FATAL) << "unsupported input/output type: " << tt;
135 throw;
136 }
137}
138
139/*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
140 * by stitching the references in the AD values.
141 */
142void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, LetList* ll) {
143 ICHECK(IsAtomic(from)) << from;
144 ICHECK(IsAtomic(to)) << to;
145 if (forward_type.as<TensorTypeNode>()) {
146 auto from_ref = TupleGetItem(from, 1);
147 auto to_ref = TupleGetItem(to, 1);
148 ll->Push(RefWrite(to_ref, RefRead(from_ref)));
149 } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
150 for (size_t i = 0; i < tt->fields.size(); ++i) {
151 TransferGrads(tt->fields[i], ll->Push(TupleGetItem(from, i)), ll->Push(TupleGetItem(to, i)),
152 ll);
153 }
154 } else {
155 LOG(FATAL) << "Unsupported input/output type: " << forward_type;
156 throw;
157 }
158}
159
160// TODO(@M.K.): why take Expr?
161/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
162Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
163 auto rev = [&](const Expr& e) { return Pair(e, RefCreate(ZerosLike(e))); };
164 auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); };
165 return LiftTensor(rev, rev_type, forward_type, e, ll);
166}
167
168/*! \brief ReverseType(t) -> t. Get the original value. */
169Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
170 auto val = [&](const Expr& e) { return GetField(e, 0); };
171 auto val_type = [&](const Type& forward_type) { return forward_type; };
172 return LiftTensor(val, val_type, forward_type, e, ll);
173}
174
175/*! \brief ReverseType(t) -> t. Get the gradient. */
176Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
177 auto grad = [&](const Expr& e) { return RefRead(GetField(e, 1)); };
178 auto grad_type = [&](const Type& forward_type) { return forward_type; };
179 return LiftTensor(grad, grad_type, forward_type, e, ll);
180}
181
182void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
183 if (t.as<TensorTypeNode>()) {
184 ll->Push(RefWrite(GetField(arg, 1), Add(RefRead(GetField(arg, 1)), grad)));
185 } else if (auto* tt = t.as<TupleTypeNode>()) {
186 for (size_t i = 0; i < tt->fields.size(); ++i) {
187 UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll);
188 }
189 } else {
190 LOG(FATAL) << "unsupported arg type of operator: " << t;
191 throw;
192 }
193}
194
195Expr BPEmpty() {
196 Expr unitF = Function({}, Tuple(tvm::Array<Expr>({})), TupleType::Empty(), {});
197 return RefCreate(unitF);
198}
199
200struct ReverseAD : ExprMutator {
201 using ADVarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>;
202 using ADGlobalVarMap = std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>;
203 Optional<IRModule> mod;
204 // TODO(@M.K.) refactor AD to always use mod.
205 Var bp;
206 std::shared_ptr<ADVarMap> ad_vars;
207 std::shared_ptr<ADGlobalVarMap> ad_gvars;
208 const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
209
210 explicit ReverseAD(const Optional<IRModule>& mod, const Var& bp,
211 const std::shared_ptr<ADVarMap>& ad_vars,
212 const std::shared_ptr<ADGlobalVarMap>& ad_gvars)
213 : mod(mod), bp(bp), ad_vars(ad_vars), ad_gvars(ad_gvars) {}
214
215 Expr VisitExpr_(const OpNode* op) final {
216 LOG(FATAL) << "op should only be inside call";
217 throw;
218 }
219
220 Expr Remap(const Expr& e) {
221 struct Remapper : ExprMutator {
222 std::shared_ptr<ADVarMap> ad_vars;
223 LetList* ll;
224 Remapper(const std::shared_ptr<ADVarMap>& ad_vars, LetList* ll) : ad_vars(ad_vars), ll(ll) {}
225 Expr VisitExpr_(const VarNode* var) final {
226 // memoize Var -> ADVar so we don't end up with free Vars when checkpointing
227 auto var_ref = GetRef<Var>(var);
228 if (ad_vars->count(var_ref) == 0) {
229 return std::move(var_ref);
230 } else {
231 return GetValue(var_ref->checked_type(), ad_vars->at(var_ref), ll);
232 }
233 }
234 };
235 return LetList::With([&](LetList* ll) { return Remapper(ad_vars, ll)(e); });
236 }
237
238 Expr VisitCheckpoint(const CallNode* call) {
239 const OpNode* op_node = call->op.as<OpNode>();
240 ICHECK(op_node) << "expected op in call";
241 Op op_ref = GetRef<Op>(op_node);
242 ICHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
243 auto x = call->args[0];
244 return LetList::With([&](LetList* ll) {
245 auto x_var = ll->Push(Remap(x));
246 auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
247 auto bpv = ll->Push(RefRead(bp));
248 Expr nbp = Function({}, LetList::With([&](LetList* ll) {
249 // we need a new ReverseAD visitor to avoid clobbering the bp local var
250 auto dup_bp = ll->Push(BPEmpty());
251 auto dup_ad =
252 ll->Push(ReverseAD(mod, dup_bp, ad_vars, ad_gvars)(DeDup(x)));
253 TransferGrads(call->checked_type(), ret, dup_ad, ll);
254 ll->Push(Call(RefRead(dup_bp), {}));
255 return Call(bpv, {});
256 }),
257 TupleType::Empty(), {});
258 ll->Push(RefWrite(bp, nbp));
259 return ret;
260 });
261 }
262
263 Expr VisitExpr_(const CallNode* call) final {
264 if (const OpNode* op_node = call->op.as<OpNode>()) {
265 Op op_ref = GetRef<Op>(op_node);
266
267 if (op_ref->name == "annotation.checkpoint") {
268 return VisitCheckpoint(call);
269 }
270
271 ICHECK(rev_map.count(op_ref)) << op_node->name << " does not have reverse mode defined";
272 return LetList::With([&](LetList* ll) {
273 std::vector<Var> args;
274 for (const auto& arg : call->args) {
275 args.push_back(ll->Push(VisitExpr(arg)));
276 }
277 std::vector<Expr> orig_args;
278 for (size_t i = 0; i < args.size(); i++) {
279 orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll));
280 }
281 Expr orig = Call(call->op, orig_args, call->attrs, call->type_args);
282 orig->checked_type_ = call->checked_type();
283 Var orig_var = ll->Push(orig);
284 orig_var->checked_type_ = call->checked_type();
285 auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
286 auto bpv = ll->Push(RefRead(bp));
287 Expr nbp_body = LetList::With([&](LetList* ll) {
288 tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
289 ICHECK(args.size() == rev.size());
290 for (size_t i = 0; i < args.size(); ++i) {
291 UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
292 }
293 return Call(bpv, {});
294 });
295 Expr nbp = Function({}, nbp_body, TupleType::Empty(), {});
296 ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
297 // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
298 return ret;
299 });
300 } else if (call->op.as<ConstructorNode>()) {
301 return ExprMutator::VisitExpr_(call);
302 } else {
303 std::vector<Expr> args;
304 for (const auto& arg : call->args) {
305 args.push_back(VisitExpr(arg));
306 }
307 args.push_back(bp);
308 return Call(VisitExpr(call->op), args);
309 }
310 }
311
312 Expr VisitExpr_(const ConstantNode* op) final {
313 return LetList::With([&](LetList* ll) {
314 Expr e = ll->Push(GetRef<Expr>(op));
315 return Pair(e, RefCreate(ZerosLike(e)));
316 });
317 }
318
319 Expr VisitExpr_(const IfNode* op) final {
320 return If(TupleGetItem(VisitExpr(op->cond), 0), VisitExpr(op->true_branch),
321 VisitExpr(op->false_branch));
322 }
323
324 Expr VisitExpr_(const VarNode* var) final {
325 // memoize Var -> ADVar so we don't end up with free Vars when checkpointing
326 auto var_ref = GetRef<Var>(var);
327 if (ad_vars->count(var_ref) == 0) {
328 auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
329 (*ad_vars)[var_ref] = res;
330 }
331
332 return ad_vars->at(var_ref);
333 }
334
335 Expr VisitExpr_(const GlobalVarNode* op) final {
336 // todo: concatenating string to add attribute seems like a brittle hack.
337 // maybe get module indexed by a rose tree of string?
338 ICHECK(mod.defined());
339 auto orig_gv = GetRef<GlobalVar>(op);
340 if (ad_gvars->count(orig_gv) == 0) {
341 GlobalVar gv(op->name_hint + "_grad");
342 (*ad_gvars)[orig_gv] = gv;
343 Function orig_f = Downcast<Function>(DeDup(mod.value()->Lookup(orig_gv)));
344 Array<Var> params;
345 for (const auto& p : orig_f->params) {
346 params.push_back(Downcast<Var>(VisitExpr(p)));
347 }
348 params.push_back(bp);
349 Function f = WithFields(orig_f, params, VisitExpr(orig_f->body), VisitType(orig_f->ret_type));
350 std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl;
351 mod.value()->Add(gv, f);
352 }
353 return ad_gvars->at(orig_gv);
354 }
355
356 Expr VisitExpr_(const FunctionNode* func_node) final {
357 Array<Var> params;
358 for (const auto& var : func_node->params) {
359 params.push_back(Downcast<Var>(VisitExpr(var)));
360 }
361 auto new_bp = Var("bp", bpt);
362 params.push_back(new_bp);
363 return WithFields(GetRef<Function>(func_node), params,
364 ReverseAD(mod, new_bp, ad_vars, ad_gvars)(func_node->body),
365 VisitType(func_node->ret_type));
366 }
367
368 Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; }
369};
370
371bool MissingGrad(const Expr& e) {
372 struct MGVisitor : ExprVisitor {
373 const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
374 std::unordered_set<std::string> op_names;
375
376 void VisitExpr_(const OpNode* op) final {
377 Op op_ref = GetRef<Op>(op);
378 if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) {
379 op_names.insert(op_ref->name);
380 }
381 ExprVisitor::VisitExpr_(op);
382 }
383 };
384
385 MGVisitor mg;
386 mg.VisitExpr(e);
387
388 if (mg.op_names.size() > 0) {
389 LOG(WARNING) << "found operators with missing gradients:";
390 for (const auto& op : mg.op_names) {
391 LOG(WARNING) << " " << op;
392 }
393 return true;
394 }
395
396 return false;
397}
398
399Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
400 CheckFeature(re, FeatureSet::All() - fGraph);
401 if (mod.defined()) {
402 CheckFeature(mod.value(), FeatureSet::All() - fGraph);
403 }
404 auto e = DeGlobal(mod, re);
405 auto f = e.as<FunctionNode>();
406 ICHECK(f) << "input need to be a function";
407 ICHECK(f->type_params.size() == 0) << "no polymorphism supported for now";
408 for (const auto& p : f->params) {
409 ICHECK(p->checked_type().as<TensorTypeNode>()) << "input parameters need to be tensor";
410 }
411 ICHECK(!MissingGrad(e)) << "input has operators with missing gradients";
412 Expr body = LetList::With([&](LetList* ll) {
413 Var bp = ll->Push(BPEmpty(), bpt);
414 Expr rev = ReverseAD(mod, bp, std::make_shared<ReverseAD::ADVarMap>(),
415 std::make_shared<ReverseAD::ADGlobalVarMap>())(e);
416 std::vector<Expr> normal_args, args;
417 for (const auto& p : f->params) {
418 auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p))));
419 normal_args.push_back(x);
420 args.push_back(x);
421 }
422 args.push_back(bp);
423 auto c = ll->Push(Call(rev, args));
424 std::function<void(const Expr&, const Type&)> init_grad;
425 init_grad = [&](const Expr& e, const Type& t) {
426 if (t.as<TensorTypeNode>()) {
427 ll->Push(RefWrite(GetField(e, 1), OnesLike(GetField(e, 0))));
428 } else if (auto tt = t.as<TupleTypeNode>()) {
429 ICHECK_GT(tt->fields.size(), 0);
430 init_grad(ll->Push(GetField(e, 0)), tt->fields[0]);
431 } else {
432 LOG(FATAL) << "unhandled type " << t;
433 throw;
434 }
435 };
436 init_grad(c, f->body->checked_type());
437 ll->Push(Call(RefRead(bp), {}));
438 std::vector<Expr> ret;
439 for (const auto& a : normal_args) {
440 ret.push_back(RefRead(GetField(a, 1)));
441 }
442 std::function<Expr(const Expr&, const Type&)> get_final_result;
443 get_final_result = [&](const Expr& e, const Type& t) -> Expr {
444 if (t.as<TensorTypeNode>()) {
445 return GetField(e, 0);
446 } else if (auto tt = t.as<TupleTypeNode>()) {
447 tvm::Array<Expr> fields;
448 for (size_t i = 0; i < tt->fields.size(); ++i) {
449 fields.push_back(get_final_result(ll->Push(GetField(e, i)), tt->fields[i]));
450 }
451 return Tuple(fields);
452 } else {
453 LOG(FATAL) << "unhandled type " << t;
454 throw;
455 }
456 };
457 return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret));
458 });
459 Function ret = WithFields(GetRef<Function>(f), f->params, body, GradRetType(GetRef<Function>(f)),
460 /* erase type params */ Array<TypeVar>());
461 CheckFeature(ret, FeatureSet::All() - fGraph);
462 return std::move(ret);
463}
464
465TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient);
466
467} // namespace relay
468} // namespace tvm
469