1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file first_order_gradient.cc
22 * \brief First-order Automatic Differentiation in Relay for pure dataflow graphs.
23 */
24#include <tvm/ir/type_functor.h>
25#include <tvm/relay/analysis.h>
26#include <tvm/relay/dataflow_matcher.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/feature.h>
29#include <tvm/relay/transform.h>
30#include <tvm/te/operation.h>
31
32#include "gradient.h"
33#include "let_list.h"
34#include "pass_utils.h"
35#include "pattern_utils.h"
36
37namespace tvm {
38namespace relay {
39
40template <typename F>
41Expr MultiFactory(const Type& t, F factory, DiagnosticContext diag_ctx) {
42 if (auto* tt = t.as<TensorTypeNode>()) {
43 return factory(tt->shape, tt->dtype);
44 } else if (auto* tt = t.as<TupleTypeNode>()) {
45 std::vector<Expr> res;
46 for (size_t i = 0; i < tt->fields.size(); i++) {
47 res.push_back(MultiFactory(tt->fields[i], factory, diag_ctx));
48 }
49 return Tuple(res);
50 } else {
51 diag_ctx.EmitFatal(Diagnostic::Error(t->span)
52 << "could not build tensors using factory for type " << PrettyPrint(t));
53 throw;
54 }
55}
56
57template <typename F, typename F2>
58Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like,
59 DiagnosticContext diag_ctx) {
60 if (t.as<TensorTypeNode>()) {
61 return factory_like(e);
62 } else if (auto* tt = t.as<TupleTypeNode>()) {
63 return MultiFactory(t, factory, diag_ctx);
64 } else {
65 diag_ctx.EmitFatal(Diagnostic::Error(t->span)
66 << "could not build tensors using factory for type " << PrettyPrint(t));
67 throw;
68 }
69}
70
71/*! \brief A fragment of the program being built by the automatic differentation
72 * pass.
73 */
74struct ADValueNode {
75 virtual ~ADValueNode() {}
76 template <typename T>
77 T& get() {
78 auto ret = dynamic_cast<T*>(this);
79 ICHECK(ret) << "cannot downcast";
80 return *ret;
81 }
82};
83
84using ADValue = std::shared_ptr<ADValueNode>;
85
86/*! \brief AD over a program which generates a tensor output. */
87struct ADTensor : ADValueNode {
88 Expr forward;
89 mutable Expr reverse; // must be a variable to avoid duplication
90 ADTensor(LetList* ll, const Expr& forward, DiagnosticContext diag_ctx)
91 : forward(ll->Push(forward)),
92 reverse(ll->Push(
93 MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike, diag_ctx))) {
94 this->forward->checked_type_ = forward->checked_type();
95 }
96};
97
98/*! \brief A staged representation of the program, we reflect
99 * Relay functions into a function over fragments of AD. We
100 * can compute away this function to obtain a reverse mode program.
101 */
102struct ADFunction : ADValueNode {
103 // (ad_args, orig) -> ad_ret
104 using ADFunctionType = ADValue(const std::vector<ADValue>&, const Call&);
105 std::function<ADFunctionType> func;
106 explicit ADFunction(const std::function<ADFunctionType>& func) : func(func) {}
107};
108
109struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
110 const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
111 std::vector<std::function<void(LetList* ll)>> backprop_actions;
112 // we assume no closure so no need for lexical scoping
113 std::unordered_map<Expr, ADValue, ObjectPtrHash, ObjectPtrEqual> env;
114 LetList* ll;
115 DiagnosticContext diag_ctx;
116
117 FirstOrderReverseAD(LetList* ll, DiagnosticContext diag_ctx) : ll(ll), diag_ctx(diag_ctx) {}
118
119 ADValue VisitExpr(const Expr& n) final {
120 if (env.count(n)) {
121 return env.at(n);
122 }
123 auto ret = ExprFunctor::VisitExpr(n);
124 env[n] = ret;
125 return ret;
126 }
127
128 static Expr LiftedAdd(const Type& t, const Expr& x, const Expr& y, LetList* ll) {
129 if (t.as<TensorTypeNode>()) {
130 return ll->Push(Add(x, y));
131 } else if (auto* tt = t.as<TupleTypeNode>()) {
132 Array<Expr> fields;
133 for (size_t i = 0; i < tt->fields.size(); ++i) {
134 fields.push_back(
135 LiftedAdd(tt->fields[i], ll->Push(GetField(x, i)), ll->Push(GetField(y, i)), ll));
136 }
137 return ll->Push(Tuple(fields));
138 } else {
139 LOG(FATAL) << "cannot lift addition for type " << PrettyPrint(t);
140 throw;
141 }
142 }
143
144 ADValue VisitExpr_(const OpNode* op) final {
145 Op op_ref = GetRef<Op>(op);
146 if (!rev_map.count(op_ref)) {
147 diag_ctx.EmitFatal(Diagnostic::Error(op->span)
148 << "the operator " << op->name << " does not have a registered gradient.");
149 }
150 return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& ad_args,
151 const Call& orig) {
152 std::vector<Expr> orig_args;
153 for (const ADValue& adval : ad_args) {
154 orig_args.push_back(adval->get<ADTensor>().forward);
155 }
156 auto orig_new = Call(op_ref, orig_args, orig->attrs, orig->type_args);
157 orig_new->checked_type_ = orig->checked_type();
158 auto ret = std::make_shared<ADTensor>(ll, orig_new, diag_ctx);
159 backprop_actions.push_back([this, ad_args, orig_new, ret, op_ref](LetList* ll) {
160 tvm::Array<Expr> rev = rev_map[op_ref](orig_new, ret->reverse);
161 if (ad_args.size() != rev.size()) {
162 diag_ctx.EmitFatal(Diagnostic::Error(op_ref->span)
163 << "arity mismatch for operator " << op_ref->name
164 << " and its registered gradient: expected " << ad_args.size()
165 << " but got " << rev.size() << " gradients.");
166 }
167 for (size_t i = 0; i < ad_args.size(); ++i) {
168 auto& ad_arg = ad_args[i]->get<ADTensor>();
169 ad_arg.reverse = LiftedAdd(ad_arg.forward->checked_type(), ad_arg.reverse, rev[i], ll);
170 }
171 });
172 return ret;
173 });
174 }
175
176 ADValue VisitExpr_(const TupleGetItemNode* op) final {
177 ADValue tup = VisitExpr(op->tuple);
178 TupleType tt = Downcast<TupleType>(op->tuple->checked_type());
179 size_t idx = op->index;
180 // reconstruct projection using let-bound variable to avoid duplicating input tuple
181 TupleGetItem orig = TupleGetItem(tup->get<ADTensor>().forward, idx);
182 orig->checked_type_ = op->checked_type();
183 auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx);
184 // for orig = pi(tup, i), pi_grad(tup, i, g) = G where pi(G, i) = g and pi(G, j) = 0 for j != i
185 backprop_actions.push_back([tup, tt, idx, ret](LetList* ll) {
186 auto& ad_tup = tup->get<ADTensor>();
187 std::vector<Expr> updated_grads;
188 for (size_t i = 0; i < tt->fields.size(); ++i) {
189 Expr grad_pre = GetField(ad_tup.reverse, i);
190 updated_grads.push_back(i != idx ? grad_pre
191 : LiftedAdd(tt->fields[i], grad_pre, ret->reverse, ll));
192 }
193 ad_tup.reverse = ll->Push(Tuple(updated_grads));
194 });
195 return ret;
196 }
197
198 ADValue VisitExpr_(const TupleNode* tuple_node) final {
199 auto tt = Downcast<TupleType>(tuple_node->checked_type());
200 std::vector<ADValue> ad_fields;
201 Array<Expr> field_bindings;
202 field_bindings.reserve(tuple_node->fields.size());
203
204 for (const auto& f : tuple_node->fields) {
205 ADValue f_ad = VisitExpr(f);
206 if (!dynamic_cast<ADTensor*>(f_ad.get())) {
207 diag_ctx.EmitFatal(Diagnostic::Error(f->span)
208 << "first-order AD only supports (nested) tuples of tensors");
209 }
210 ad_fields.push_back(f_ad);
211 field_bindings.push_back(f_ad->get<ADTensor>().forward);
212 }
213 // reconstruct tuple using let-bound variables to avoid duplication
214 auto orig = WithFields(GetRef<Tuple>(tuple_node), field_bindings);
215 orig->checked_type_ = tt;
216 auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx);
217 // for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)]
218 backprop_actions.push_back([ad_fields, tt, ret](LetList* ll) {
219 for (size_t i = 0; i < ad_fields.size(); ++i) {
220 auto& ad_field = ad_fields[i]->get<ADTensor>();
221 ad_field.reverse =
222 LiftedAdd(tt->fields[i], ad_field.reverse, GetField(ret->reverse, i), ll);
223 }
224 });
225 return ret;
226 }
227
228 ADValue VisitExpr_(const ConstantNode* op) final {
229 Expr e = GetRef<Expr>(op);
230 return std::make_shared<ADTensor>(ll, e, diag_ctx);
231 }
232
233 ADValue VisitExpr_(const CallNode* op) final {
234 ADValue f = VisitExpr(op->op);
235 std::vector<ADValue> args;
236 for (const auto& arg : op->args) {
237 args.push_back(VisitExpr(arg));
238 }
239 return f->get<ADFunction>().func(args, GetRef<Call>(op));
240 }
241
242 ADValue VisitExpr_(const FunctionNode* op) final {
243 Function f = GetRef<Function>(op);
244 // todo: assert no closure
245 return std::make_shared<ADFunction>(
246 [this, f](const std::vector<ADValue>& ad_args, const Call& orig) {
247 ICHECK_EQ(f->params.size(), ad_args.size());
248 for (size_t i = 0; i < f->params.size(); ++i) {
249 env[f->params[i]] = ad_args[i];
250 }
251 return VisitExpr(f->body);
252 });
253 }
254
255 // Var will always be in env, handled in VisitExpr (without _), so we don't need
256 // to implement its VisitExpr_.
257};
258
259namespace transform {
260
261Pass FirstOrderGradient() {
262 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> f = [](IRModule mod, PassContext ctx) {
263 CheckFeature(
264 mod, FeatureSet({fVar, fConstant, fTuple, fTupleGetItem, fFunction, fOp, fCall, fGraph}));
265 IRModule ad_mod = GetRef<IRModule>(mod.CopyOnWrite());
266 DiagnosticContext diag_ctx = DiagnosticContext::Default(ad_mod);
267
268 if (mod->functions.size() > 1) {
269 LOG(WARNING) << "IRModule contains multiple global functions: first-order AD will transform "
270 "them indepedently!";
271 }
272
273 for (const auto& pr : mod->functions) {
274 const FunctionNode* func = pr.second.as<FunctionNode>();
275 if (!func) {
276 diag_ctx.Emit(Diagnostic::Warning(pr.second->span)
277 << "AD can only be performed on Relay functions, skipping "
278 << PrettyPrint(pr.first));
279 }
280 if (func->type_params.size() > 0) {
281 diag_ctx.EmitFatal(Diagnostic::Error(pr.second->span)
282 << "first-order AD does not support polymorphism yet.");
283 }
284 Expr body = LetList::With([&](LetList* ll) {
285 FirstOrderReverseAD reverse_ad(ll, diag_ctx);
286 ADValue rev = reverse_ad(pr.second);
287 std::vector<ADValue> args;
288 for (const auto& p : func->params) {
289 args.push_back(std::make_shared<ADTensor>(ll, p, diag_ctx));
290 }
291 Call placeholder = Call(GetRef<Function>(func), {});
292 placeholder->checked_type_ = func->checked_type().as<FuncTypeNode>()->ret_type;
293 auto grad_call = rev->get<ADFunction>().func(args, placeholder);
294 auto& res = grad_call->get<ADTensor>();
295 Expr grad_tuple = LetList::With([&](LetList* ll) {
296 res.reverse =
297 MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike, diag_ctx);
298 for (auto it = reverse_ad.backprop_actions.rbegin();
299 it != reverse_ad.backprop_actions.rend(); ++it) {
300 (*it)(ll);
301 }
302 std::vector<Expr> grads;
303 for (const auto& a : args) {
304 grads.push_back(a->get<ADTensor>().reverse);
305 }
306 return Tuple(grads);
307 });
308 return Pair(res.forward, grad_tuple);
309 });
310 ad_mod->Update(pr.first, WithFields(GetRef<Function>(func), func->params, body,
311 GradRetType(GetRef<Function>(func)),
312 /* erase type params */ Array<TypeVar>()));
313 }
314
315 return ad_mod;
316 };
317 return CreateModulePass(f, 0, "FirstOrderGradient", {});
318}
319
320TVM_REGISTER_GLOBAL("relay._transform.FirstOrderGradient").set_body_typed(FirstOrderGradient);
321
322} // namespace transform
323
324} // namespace relay
325} // namespace tvm
326