1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file first_order_gradient.cc |
22 | * \brief First-order Automatic Differentiation in Relay for pure dataflow graphs. |
23 | */ |
24 | #include <tvm/ir/type_functor.h> |
25 | #include <tvm/relay/analysis.h> |
26 | #include <tvm/relay/dataflow_matcher.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/feature.h> |
29 | #include <tvm/relay/transform.h> |
30 | #include <tvm/te/operation.h> |
31 | |
32 | #include "gradient.h" |
33 | #include "let_list.h" |
34 | #include "pass_utils.h" |
35 | #include "pattern_utils.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | template <typename F> |
41 | Expr MultiFactory(const Type& t, F factory, DiagnosticContext diag_ctx) { |
42 | if (auto* tt = t.as<TensorTypeNode>()) { |
43 | return factory(tt->shape, tt->dtype); |
44 | } else if (auto* tt = t.as<TupleTypeNode>()) { |
45 | std::vector<Expr> res; |
46 | for (size_t i = 0; i < tt->fields.size(); i++) { |
47 | res.push_back(MultiFactory(tt->fields[i], factory, diag_ctx)); |
48 | } |
49 | return Tuple(res); |
50 | } else { |
51 | diag_ctx.EmitFatal(Diagnostic::Error(t->span) |
52 | << "could not build tensors using factory for type " << PrettyPrint(t)); |
53 | throw; |
54 | } |
55 | } |
56 | |
57 | template <typename F, typename F2> |
58 | Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like, |
59 | DiagnosticContext diag_ctx) { |
60 | if (t.as<TensorTypeNode>()) { |
61 | return factory_like(e); |
62 | } else if (auto* tt = t.as<TupleTypeNode>()) { |
63 | return MultiFactory(t, factory, diag_ctx); |
64 | } else { |
65 | diag_ctx.EmitFatal(Diagnostic::Error(t->span) |
66 | << "could not build tensors using factory for type " << PrettyPrint(t)); |
67 | throw; |
68 | } |
69 | } |
70 | |
71 | /*! \brief A fragment of the program being built by the automatic differentation |
72 | * pass. |
73 | */ |
74 | struct ADValueNode { |
75 | virtual ~ADValueNode() {} |
76 | template <typename T> |
77 | T& get() { |
78 | auto ret = dynamic_cast<T*>(this); |
79 | ICHECK(ret) << "cannot downcast" ; |
80 | return *ret; |
81 | } |
82 | }; |
83 | |
84 | using ADValue = std::shared_ptr<ADValueNode>; |
85 | |
86 | /*! \brief AD over a program which generates a tensor output. */ |
87 | struct ADTensor : ADValueNode { |
88 | Expr forward; |
89 | mutable Expr reverse; // must be a variable to avoid duplication |
90 | ADTensor(LetList* ll, const Expr& forward, DiagnosticContext diag_ctx) |
91 | : forward(ll->Push(forward)), |
92 | reverse(ll->Push( |
93 | MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike, diag_ctx))) { |
94 | this->forward->checked_type_ = forward->checked_type(); |
95 | } |
96 | }; |
97 | |
98 | /*! \brief A staged representation of the program, we reflect |
99 | * Relay functions into a function over fragments of AD. We |
100 | * can compute away this function to obtain a reverse mode program. |
101 | */ |
102 | struct ADFunction : ADValueNode { |
103 | // (ad_args, orig) -> ad_ret |
104 | using ADFunctionType = ADValue(const std::vector<ADValue>&, const Call&); |
105 | std::function<ADFunctionType> func; |
106 | explicit ADFunction(const std::function<ADFunctionType>& func) : func(func) {} |
107 | }; |
108 | |
109 | struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> { |
110 | const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient" ); |
111 | std::vector<std::function<void(LetList* ll)>> backprop_actions; |
112 | // we assume no closure so no need for lexical scoping |
113 | std::unordered_map<Expr, ADValue, ObjectPtrHash, ObjectPtrEqual> env; |
114 | LetList* ll; |
115 | DiagnosticContext diag_ctx; |
116 | |
117 | FirstOrderReverseAD(LetList* ll, DiagnosticContext diag_ctx) : ll(ll), diag_ctx(diag_ctx) {} |
118 | |
119 | ADValue VisitExpr(const Expr& n) final { |
120 | if (env.count(n)) { |
121 | return env.at(n); |
122 | } |
123 | auto ret = ExprFunctor::VisitExpr(n); |
124 | env[n] = ret; |
125 | return ret; |
126 | } |
127 | |
128 | static Expr LiftedAdd(const Type& t, const Expr& x, const Expr& y, LetList* ll) { |
129 | if (t.as<TensorTypeNode>()) { |
130 | return ll->Push(Add(x, y)); |
131 | } else if (auto* tt = t.as<TupleTypeNode>()) { |
132 | Array<Expr> fields; |
133 | for (size_t i = 0; i < tt->fields.size(); ++i) { |
134 | fields.push_back( |
135 | LiftedAdd(tt->fields[i], ll->Push(GetField(x, i)), ll->Push(GetField(y, i)), ll)); |
136 | } |
137 | return ll->Push(Tuple(fields)); |
138 | } else { |
139 | LOG(FATAL) << "cannot lift addition for type " << PrettyPrint(t); |
140 | throw; |
141 | } |
142 | } |
143 | |
144 | ADValue VisitExpr_(const OpNode* op) final { |
145 | Op op_ref = GetRef<Op>(op); |
146 | if (!rev_map.count(op_ref)) { |
147 | diag_ctx.EmitFatal(Diagnostic::Error(op->span) |
148 | << "the operator " << op->name << " does not have a registered gradient." ); |
149 | } |
150 | return std::make_shared<ADFunction>([this, op_ref](const std::vector<ADValue>& ad_args, |
151 | const Call& orig) { |
152 | std::vector<Expr> orig_args; |
153 | for (const ADValue& adval : ad_args) { |
154 | orig_args.push_back(adval->get<ADTensor>().forward); |
155 | } |
156 | auto orig_new = Call(op_ref, orig_args, orig->attrs, orig->type_args); |
157 | orig_new->checked_type_ = orig->checked_type(); |
158 | auto ret = std::make_shared<ADTensor>(ll, orig_new, diag_ctx); |
159 | backprop_actions.push_back([this, ad_args, orig_new, ret, op_ref](LetList* ll) { |
160 | tvm::Array<Expr> rev = rev_map[op_ref](orig_new, ret->reverse); |
161 | if (ad_args.size() != rev.size()) { |
162 | diag_ctx.EmitFatal(Diagnostic::Error(op_ref->span) |
163 | << "arity mismatch for operator " << op_ref->name |
164 | << " and its registered gradient: expected " << ad_args.size() |
165 | << " but got " << rev.size() << " gradients." ); |
166 | } |
167 | for (size_t i = 0; i < ad_args.size(); ++i) { |
168 | auto& ad_arg = ad_args[i]->get<ADTensor>(); |
169 | ad_arg.reverse = LiftedAdd(ad_arg.forward->checked_type(), ad_arg.reverse, rev[i], ll); |
170 | } |
171 | }); |
172 | return ret; |
173 | }); |
174 | } |
175 | |
176 | ADValue VisitExpr_(const TupleGetItemNode* op) final { |
177 | ADValue tup = VisitExpr(op->tuple); |
178 | TupleType tt = Downcast<TupleType>(op->tuple->checked_type()); |
179 | size_t idx = op->index; |
180 | // reconstruct projection using let-bound variable to avoid duplicating input tuple |
181 | TupleGetItem orig = TupleGetItem(tup->get<ADTensor>().forward, idx); |
182 | orig->checked_type_ = op->checked_type(); |
183 | auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx); |
184 | // for orig = pi(tup, i), pi_grad(tup, i, g) = G where pi(G, i) = g and pi(G, j) = 0 for j != i |
185 | backprop_actions.push_back([tup, tt, idx, ret](LetList* ll) { |
186 | auto& ad_tup = tup->get<ADTensor>(); |
187 | std::vector<Expr> updated_grads; |
188 | for (size_t i = 0; i < tt->fields.size(); ++i) { |
189 | Expr grad_pre = GetField(ad_tup.reverse, i); |
190 | updated_grads.push_back(i != idx ? grad_pre |
191 | : LiftedAdd(tt->fields[i], grad_pre, ret->reverse, ll)); |
192 | } |
193 | ad_tup.reverse = ll->Push(Tuple(updated_grads)); |
194 | }); |
195 | return ret; |
196 | } |
197 | |
198 | ADValue VisitExpr_(const TupleNode* tuple_node) final { |
199 | auto tt = Downcast<TupleType>(tuple_node->checked_type()); |
200 | std::vector<ADValue> ad_fields; |
201 | Array<Expr> field_bindings; |
202 | field_bindings.reserve(tuple_node->fields.size()); |
203 | |
204 | for (const auto& f : tuple_node->fields) { |
205 | ADValue f_ad = VisitExpr(f); |
206 | if (!dynamic_cast<ADTensor*>(f_ad.get())) { |
207 | diag_ctx.EmitFatal(Diagnostic::Error(f->span) |
208 | << "first-order AD only supports (nested) tuples of tensors" ); |
209 | } |
210 | ad_fields.push_back(f_ad); |
211 | field_bindings.push_back(f_ad->get<ADTensor>().forward); |
212 | } |
213 | // reconstruct tuple using let-bound variables to avoid duplication |
214 | auto orig = WithFields(GetRef<Tuple>(tuple_node), field_bindings); |
215 | orig->checked_type_ = tt; |
216 | auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx); |
217 | // for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)] |
218 | backprop_actions.push_back([ad_fields, tt, ret](LetList* ll) { |
219 | for (size_t i = 0; i < ad_fields.size(); ++i) { |
220 | auto& ad_field = ad_fields[i]->get<ADTensor>(); |
221 | ad_field.reverse = |
222 | LiftedAdd(tt->fields[i], ad_field.reverse, GetField(ret->reverse, i), ll); |
223 | } |
224 | }); |
225 | return ret; |
226 | } |
227 | |
228 | ADValue VisitExpr_(const ConstantNode* op) final { |
229 | Expr e = GetRef<Expr>(op); |
230 | return std::make_shared<ADTensor>(ll, e, diag_ctx); |
231 | } |
232 | |
233 | ADValue VisitExpr_(const CallNode* op) final { |
234 | ADValue f = VisitExpr(op->op); |
235 | std::vector<ADValue> args; |
236 | for (const auto& arg : op->args) { |
237 | args.push_back(VisitExpr(arg)); |
238 | } |
239 | return f->get<ADFunction>().func(args, GetRef<Call>(op)); |
240 | } |
241 | |
242 | ADValue VisitExpr_(const FunctionNode* op) final { |
243 | Function f = GetRef<Function>(op); |
244 | // todo: assert no closure |
245 | return std::make_shared<ADFunction>( |
246 | [this, f](const std::vector<ADValue>& ad_args, const Call& orig) { |
247 | ICHECK_EQ(f->params.size(), ad_args.size()); |
248 | for (size_t i = 0; i < f->params.size(); ++i) { |
249 | env[f->params[i]] = ad_args[i]; |
250 | } |
251 | return VisitExpr(f->body); |
252 | }); |
253 | } |
254 | |
255 | // Var will always be in env, handled in VisitExpr (without _), so we don't need |
256 | // to implement its VisitExpr_. |
257 | }; |
258 | |
259 | namespace transform { |
260 | |
261 | Pass FirstOrderGradient() { |
262 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> f = [](IRModule mod, PassContext ctx) { |
263 | CheckFeature( |
264 | mod, FeatureSet({fVar, fConstant, fTuple, fTupleGetItem, fFunction, fOp, fCall, fGraph})); |
265 | IRModule ad_mod = GetRef<IRModule>(mod.CopyOnWrite()); |
266 | DiagnosticContext diag_ctx = DiagnosticContext::Default(ad_mod); |
267 | |
268 | if (mod->functions.size() > 1) { |
269 | LOG(WARNING) << "IRModule contains multiple global functions: first-order AD will transform " |
270 | "them indepedently!" ; |
271 | } |
272 | |
273 | for (const auto& pr : mod->functions) { |
274 | const FunctionNode* func = pr.second.as<FunctionNode>(); |
275 | if (!func) { |
276 | diag_ctx.Emit(Diagnostic::Warning(pr.second->span) |
277 | << "AD can only be performed on Relay functions, skipping " |
278 | << PrettyPrint(pr.first)); |
279 | } |
280 | if (func->type_params.size() > 0) { |
281 | diag_ctx.EmitFatal(Diagnostic::Error(pr.second->span) |
282 | << "first-order AD does not support polymorphism yet." ); |
283 | } |
284 | Expr body = LetList::With([&](LetList* ll) { |
285 | FirstOrderReverseAD reverse_ad(ll, diag_ctx); |
286 | ADValue rev = reverse_ad(pr.second); |
287 | std::vector<ADValue> args; |
288 | for (const auto& p : func->params) { |
289 | args.push_back(std::make_shared<ADTensor>(ll, p, diag_ctx)); |
290 | } |
291 | Call placeholder = Call(GetRef<Function>(func), {}); |
292 | placeholder->checked_type_ = func->checked_type().as<FuncTypeNode>()->ret_type; |
293 | auto grad_call = rev->get<ADFunction>().func(args, placeholder); |
294 | auto& res = grad_call->get<ADTensor>(); |
295 | Expr grad_tuple = LetList::With([&](LetList* ll) { |
296 | res.reverse = |
297 | MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike, diag_ctx); |
298 | for (auto it = reverse_ad.backprop_actions.rbegin(); |
299 | it != reverse_ad.backprop_actions.rend(); ++it) { |
300 | (*it)(ll); |
301 | } |
302 | std::vector<Expr> grads; |
303 | for (const auto& a : args) { |
304 | grads.push_back(a->get<ADTensor>().reverse); |
305 | } |
306 | return Tuple(grads); |
307 | }); |
308 | return Pair(res.forward, grad_tuple); |
309 | }); |
310 | ad_mod->Update(pr.first, WithFields(GetRef<Function>(func), func->params, body, |
311 | GradRetType(GetRef<Function>(func)), |
312 | /* erase type params */ Array<TypeVar>())); |
313 | } |
314 | |
315 | return ad_mod; |
316 | }; |
317 | return CreateModulePass(f, 0, "FirstOrderGradient" , {}); |
318 | } |
319 | |
320 | TVM_REGISTER_GLOBAL("relay._transform.FirstOrderGradient" ).set_body_typed(FirstOrderGradient); |
321 | |
322 | } // namespace transform |
323 | |
324 | } // namespace relay |
325 | } // namespace tvm |
326 | |