1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/annotate_target.cc |
22 | * \brief Wraps an expr with compiler_begin and compiler_end to indicate that |
23 | * this expr should be handled by the external compiler. |
24 | */ |
25 | |
26 | #include <tvm/relay/attrs/annotation.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/op_attr_types.h> |
29 | #include <tvm/relay/transform.h> |
30 | |
31 | #include "pass_utils.h" |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | namespace annotate_target { |
36 | |
37 | static const PackedFunc* make_begin_op = |
38 | runtime::Registry::Get("relay.op.annotation._make.compiler_begin" ); |
39 | static const PackedFunc* make_end_op = |
40 | runtime::Registry::Get("relay.op.annotation._make.compiler_end" ); |
41 | static const char default_target[] = "default" ; |
42 | // A helper class to insert annotation boundaries for all the ops of a program |
43 | // region that will be handled by a specific compiler. |
44 | class AnnotateTargetRewriter : public ExprRewriter { |
45 | public: |
46 | explicit AnnotateTargetRewriter(Array<runtime::String> targets) : targets_(std::move(targets)) {} |
47 | |
48 | protected: |
49 | /*! \brief The target backends for annotation. */ |
50 | Array<runtime::String> targets_; |
51 | /*! \brief Maintain the decision of the target for each op expr. */ |
52 | std::unordered_map<Expr, std::string, ObjectPtrHash, ObjectPtrEqual> op_expr_to_target_; |
53 | |
54 | /*! |
55 | * \brief This function annotates a compiler end and a compiler begin to all arguments. |
56 | * |
57 | * The compiler end is based on the arg target while the compiler begin is based on the given |
58 | * target. If target is not given and all arguments are going to the same target, then we will |
59 | * use that target; otherwise we use default for this op. Note that all arg exprs must be |
60 | * available in op_expr_to_target before calling this function. |
61 | * |
62 | * \param args An array of arguments of the given node. |
63 | * \param target The target of the current node. |
64 | * \return A pair of target and annotated argument expressions. |
65 | */ |
66 | std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args, |
67 | const std::string& target = "" ) { |
68 | std::string ref_target = "" ; |
69 | Array<Expr> compiler_begins; |
70 | Array<Expr> compiler_ends; |
71 | for (auto arg : args) { |
72 | std::string arg_target = default_target; |
73 | const CallNode* call = arg.as<CallNode>(); |
74 | |
75 | if (call && call->op == CompilerBeginOp()) { |
76 | // Argument is already compiler begin node meaning that this is not the first time |
77 | // running this pass, so we simply remove it and will add a new one later. |
78 | ICHECK_EQ(call->args.size(), 1U); |
79 | // Do not alter existing annotation if not default |
80 | if (default_target != call->attrs.as<CompilerAttrs>()->compiler) { |
81 | compiler_begins.push_back(arg); |
82 | } else { |
83 | // Remove default |
84 | compiler_ends.push_back(call->args[0]); |
85 | } |
86 | const CallNode* end = call->args[0].as<CallNode>(); |
87 | if (end && end->op == CompilerEndOp()) { |
88 | arg_target = end->attrs.as<CompilerAttrs>()->compiler; |
89 | } |
90 | } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { |
91 | arg_target = op_expr_to_target_[arg]; |
92 | // If an argument is a call node and has no argument, then it should be tensor ops such as |
93 | // zeros, so we treat it as input vars. |
94 | if (call && call->args.size() == 0) { |
95 | compiler_ends.push_back(arg); |
96 | } else { |
97 | compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op)); |
98 | } |
99 | } else { |
100 | // Input vars. |
101 | compiler_ends.push_back(arg); |
102 | } |
103 | |
104 | // Maintain reference target in case the target of the current node is unassigned. |
105 | if (ref_target == "" ) { |
106 | ref_target = arg_target; |
107 | } else if (ref_target != arg_target) { |
108 | ref_target = default_target; |
109 | } |
110 | } |
111 | |
112 | // Determine compiler begin target. |
113 | std::string op_target = (target == "" ) ? ref_target : target; |
114 | |
115 | if (ref_target != "" ) { |
116 | for (const auto& end : compiler_ends) { |
117 | compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op)); |
118 | } |
119 | } else { |
120 | return {op_target, args}; |
121 | } |
122 | return {op_target, compiler_begins}; |
123 | } |
124 | |
125 | Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc* ann_op) { |
126 | Expr new_op = (*ann_op)(expr, target); |
127 | new_op->checked_type_ = expr->checked_type_; |
128 | return new_op; |
129 | } |
130 | |
131 | Expr InsertCompilerEndAndPropogateTarget(const Expr& expr) { |
132 | /*! |
133 | * \brief This function inserts compiler end to expr and maps the corresponding target to the |
134 | * new expression. |
135 | * |
136 | * This function checks for expr existence within the map and inserts the annotation. |
137 | * If the expression has a free variable (e.g: relay.zeros, relay.ones) we do not insert |
138 | * compiler end, since there are no compiler begins for it. |
139 | * Further, it propagates the target to the new expression and returns it |
140 | * |
141 | * \param expr A relay expression |
142 | * \return An annotated and target-propagated relay expression. |
143 | */ |
144 | Expr new_expr = expr; |
145 | const CallNode* call = expr.as<CallNode>(); |
146 | const TupleNode* tup = expr.as<TupleNode>(); |
147 | if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) { |
148 | // Check whether expr has args, if not - do not insert compiler_end. |
149 | if (expr->IsInstance<RefWriteNode>() || expr->IsInstance<RefCreateNode>() || |
150 | expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleGetItemNode>() || |
151 | (call && !call->args.empty()) || (tup && !tup->fields.empty())) { |
152 | std::string target = op_expr_to_target_[new_expr]; |
153 | new_expr = InsertAnnotation(new_expr, target, make_end_op); |
154 | op_expr_to_target_[new_expr] = target; |
155 | } |
156 | } else if (call && call->op == CompilerEndOp()) { |
157 | if (default_target == call->attrs.as<CompilerAttrs>()->compiler) { |
158 | ICHECK_EQ(call->args.size(), 1U); |
159 | new_expr = call->args[0]; |
160 | std::string target = op_expr_to_target_[new_expr]; |
161 | new_expr = InsertAnnotation(new_expr, target, make_end_op); |
162 | op_expr_to_target_[new_expr] = target; |
163 | } |
164 | } |
165 | |
166 | return std::move(new_expr); |
167 | } |
168 | |
169 | public: |
170 | Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
171 | // Supported targets for this node. The order implies the priority. |
172 | std::vector<std::string> supported_targets; |
173 | |
174 | auto op_node = pre->op.as<OpNode>(); |
175 | |
176 | // This graph has annotations, meaning that this is not the first time running this pass. |
177 | if (op_node && pre->op == CompilerBeginOp()) { |
178 | // Bypass compiler begin due to lack of target information. It will be processed |
179 | // when the following op handling arguments. |
180 | ICHECK_EQ(pre->args.size(), 1U); |
181 | // Preserve annotations |
182 | return post; |
183 | } else if (op_node && pre->op == CompilerEndOp()) { |
184 | // Override compiler end with the new target. |
185 | ICHECK_EQ(pre->args.size(), 1U); |
186 | auto input_expr = post.as<CallNode>()->args[0]; |
187 | // Already annotated. Recover target |
188 | if (op_expr_to_target_.find(input_expr) == op_expr_to_target_.end()) { |
189 | op_expr_to_target_[input_expr] = post.as<CallNode>()->attrs.as<CompilerAttrs>()->compiler; |
190 | } |
191 | ICHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end()); |
192 | // Preserve annotated nodes |
193 | return post; |
194 | } |
195 | // Check prior to peeking first argument |
196 | if (pre->args.size()) { |
197 | // Peek the first argument. If it is compiler begin then this node had annotated by |
198 | // another target before, so we also consider that target as a supported target. |
199 | const CallNode* first_arg_call = pre->args[0].as<CallNode>(); |
200 | if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { |
201 | std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler; |
202 | if (arg_target != default_target) { |
203 | // annotated already |
204 | return post; |
205 | } |
206 | } |
207 | } |
208 | |
209 | // Check which targets this op can be offloaded. |
210 | if (op_node) { |
211 | // TVM operators: Check target specific op checking function and add to supported_targets |
212 | // if it is supported. |
213 | Op op = Downcast<Op>(pre->op); |
214 | ICHECK(op.defined()); |
215 | for (const auto& target : this->targets_) { |
216 | if (!Op::HasAttrMap("target." + std::string(target))) { |
217 | continue; |
218 | } |
219 | auto fannotate = Op::GetAttrMap<FTVMAnnotateTarget>("target." + std::string(target)); |
220 | const Expr& ex = GetRef<Expr>(pre); |
221 | if (fannotate.count(op) && fannotate[op](ex)) { |
222 | supported_targets.push_back(target); |
223 | } |
224 | } |
225 | } else if (pre->op->IsInstance<FunctionNode>()) { |
226 | // Composite function: Add the target of a composite function to supported_targets |
227 | // if it is in the target list. |
228 | Function func = Downcast<Function>(pre->op); |
229 | ICHECK(func.defined()); |
230 | if (auto comp_name = func->GetAttr<String>(attr::kComposite)) { |
231 | std::string comp_name_str = comp_name.value(); |
232 | size_t i = comp_name_str.find('.'); |
233 | if (i != std::string::npos) { |
234 | std::string comp_target = comp_name_str.substr(0, i); |
235 | for (const auto& target : this->targets_) { |
236 | if (std::string(target) == comp_target) { |
237 | supported_targets.push_back(comp_target); |
238 | break; |
239 | } |
240 | } |
241 | } |
242 | } |
243 | } |
244 | supported_targets.push_back(default_target); // Make default as the last option. |
245 | // Visit and mutate arguments after the target of this op has been determined. |
246 | Call post_call = Downcast<Call>(post); |
247 | if (pre->op->IsInstance<VarNode>()) { |
248 | auto new_call = RewriteVarCall(post_call); |
249 | if (nullptr != new_call) return GetRef<Expr>(new_call->get()); |
250 | } |
251 | // TODO(@comaniac, @zhiics): Now we simply assign this node to the target with |
252 | // the highest priority, but we should preserve all supported targets so that |
253 | // we can make a better decision. |
254 | std::string target = supported_targets[0]; |
255 | |
256 | // Add annotations to each arg. |
257 | auto target_n_args = AnnotateArgs(post_call->args, target); |
258 | Array<Expr> compiler_begins = std::get<1>(target_n_args); |
259 | Call new_call = Call(post_call->op, compiler_begins, post_call->attrs); |
260 | new_call->checked_type_ = pre->checked_type_; |
261 | |
262 | // Update the target map. |
263 | op_expr_to_target_[new_call] = target; |
264 | return std::move(new_call); |
265 | } |
266 | |
267 | virtual std::unique_ptr<Call> RewriteVarCall(const Call& post_call) { return nullptr; } |
268 | |
269 | Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override { |
270 | auto tuple = Downcast<Tuple>(post); |
271 | |
272 | auto target_n_args = AnnotateArgs(tuple->fields); |
273 | auto new_expr = WithFields(tuple, std::get<1>(target_n_args)); |
274 | op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
275 | return std::move(new_expr); |
276 | } |
277 | |
278 | Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override { |
279 | auto expr = Downcast<TupleGetItem>(post); |
280 | |
281 | auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple})); |
282 | auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index); |
283 | op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
284 | return std::move(new_expr); |
285 | } |
286 | |
287 | Expr Rewrite_(const FunctionNode* fn, const Expr& post) override { |
288 | Function func; |
289 | Expr new_body; |
290 | // don't step into composite functions |
291 | if (fn->GetAttr<String>(attr::kComposite).defined()) { |
292 | func = GetRef<Function>(fn); |
293 | new_body = func->body; |
294 | } else { |
295 | func = Downcast<Function>(post); |
296 | new_body = InsertCompilerEndAndPropogateTarget(func->body); |
297 | } |
298 | return WithFields(func, func->params, new_body); |
299 | } |
300 | |
301 | Expr Rewrite_(const LetNode* op, const Expr& post) override { |
302 | auto let = Downcast<Let>(post); |
303 | |
304 | Expr new_expr; |
305 | std::pair<std::string, Array<Expr>> target_n_args; |
306 | Expr new_body = InsertCompilerEndAndPropogateTarget(let->body); |
307 | // Do not annotate function literal with let binding. |
308 | if (let->value->IsInstance<FunctionNode>()) { |
309 | new_expr = Let(let->var, let->value, new_body); |
310 | } else { |
311 | target_n_args = AnnotateArgs({let->value}); |
312 | new_expr = Let(let->var, std::get<1>(target_n_args)[0], new_body); |
313 | } |
314 | |
315 | return std::move(new_expr); |
316 | } |
317 | |
318 | Expr Rewrite_(const IfNode* op, const Expr& post) override { |
319 | auto expr = Downcast<If>(post); |
320 | Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond); |
321 | Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch); |
322 | Expr new_false_branch = InsertCompilerEndAndPropogateTarget(expr->false_branch); |
323 | |
324 | auto new_expr = If(new_cond, new_true_branch, new_false_branch); |
325 | return std::move(new_expr); |
326 | } |
327 | |
328 | Expr Rewrite_(const RefCreateNode* op, const Expr& post) override { |
329 | auto expr = Downcast<RefCreate>(post); |
330 | |
331 | auto target_n_args = AnnotateArgs(Array<Expr>({expr->value})); |
332 | auto new_expr = RefCreate(std::get<1>(target_n_args)[0]); |
333 | op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
334 | return std::move(new_expr); |
335 | } |
336 | |
337 | Expr Rewrite_(const RefReadNode* op, const Expr& post) override { |
338 | auto expr = Downcast<RefRead>(post); |
339 | |
340 | auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref})); |
341 | auto new_expr = RefRead(std::get<1>(target_n_args)[0]); |
342 | op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
343 | return std::move(new_expr); |
344 | } |
345 | |
346 | Expr Rewrite_(const RefWriteNode* op, const Expr& post) override { |
347 | auto expr = Downcast<RefWrite>(post); |
348 | |
349 | auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value})); |
350 | auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); |
351 | op_expr_to_target_[new_expr] = std::get<0>(target_n_args); |
352 | return std::move(new_expr); |
353 | } |
354 | }; |
355 | |
356 | // A helper class to insert annotation boundaries for call ops and function nodes |
357 | // in a program region that will be handled by a specific compiler. |
358 | class CallOpsTargetRewriter : public AnnotateTargetRewriter { |
359 | public: |
360 | explicit CallOpsTargetRewriter(Array<runtime::String> targets) |
361 | : AnnotateTargetRewriter(std::move(targets)) {} |
362 | |
363 | std::unique_ptr<Call> RewriteVarCall(const Call& post_call) override { |
364 | Array<Expr> ends; |
365 | for (auto arg : post_call->args) { |
366 | ends.push_back(InsertCompilerEndAndPropogateTarget(arg)); |
367 | } |
368 | auto new_call = std::make_unique<Call>(post_call->op, ends, post_call->attrs); |
369 | (*new_call)->checked_type_ = post_call->checked_type_; |
370 | return new_call; |
371 | } |
372 | |
373 | Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override { |
374 | auto tuple = Downcast<Tuple>(post); |
375 | Array<Expr> new_fields; |
376 | new_fields.reserve(tuple->fields.size()); |
377 | |
378 | for (auto f : tuple->fields) { |
379 | new_fields.push_back(InsertCompilerEndAndPropogateTarget(f)); |
380 | } |
381 | return WithFields(tuple, new_fields); |
382 | } |
383 | |
384 | Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override { |
385 | auto expr = Downcast<TupleGetItem>(post); |
386 | return std::move(TupleGetItem(InsertCompilerEndAndPropogateTarget(expr->tuple), expr->index)); |
387 | } |
388 | |
389 | Expr Rewrite_(const IfNode* op, const Expr& post) override { |
390 | auto expr = Downcast<If>(post); |
391 | Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond); |
392 | Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch); |
393 | Expr new_false_branch = InsertCompilerEndAndPropogateTarget(expr->false_branch); |
394 | |
395 | auto new_expr = If(new_cond, new_true_branch, new_false_branch); |
396 | return std::move(new_expr); |
397 | } |
398 | |
399 | Expr Rewrite_(const RefCreateNode* op, const Expr& post) override { |
400 | auto expr = Downcast<RefCreate>(post); |
401 | auto new_expr = RefCreate(InsertCompilerEndAndPropogateTarget(expr->value)); |
402 | return std::move(new_expr); |
403 | } |
404 | |
405 | Expr Rewrite_(const RefReadNode* op, const Expr& post) override { |
406 | auto expr = Downcast<RefRead>(post); |
407 | auto new_expr = RefRead(InsertCompilerEndAndPropogateTarget(expr->ref)); |
408 | return std::move(new_expr); |
409 | } |
410 | |
411 | Expr Rewrite_(const RefWriteNode* op, const Expr& post) override { |
412 | auto expr = Downcast<RefWrite>(post); |
413 | auto new_expr = RefWrite(InsertCompilerEndAndPropogateTarget(expr->ref), |
414 | InsertCompilerEndAndPropogateTarget(expr->value)); |
415 | return std::move(new_expr); |
416 | } |
417 | }; |
418 | |
419 | Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets, |
420 | bool include_non_call_ops) { |
421 | auto r = include_non_call_ops ? std::make_unique<AnnotateTargetRewriter>(targets) |
422 | : std::make_unique<CallOpsTargetRewriter>(targets); |
423 | return PostOrderRewrite(expr, r.get()); |
424 | } |
425 | |
426 | } // namespace annotate_target |
427 | |
428 | namespace transform { |
429 | |
430 | Pass AnnotateTarget(const Array<runtime::String>& targets, bool include_non_call_ops) { |
431 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
432 | [=](Function f, IRModule m, PassContext pc) { |
433 | return Downcast<Function>( |
434 | relay::annotate_target::AnnotateTarget(f, targets, include_non_call_ops)); |
435 | }; |
436 | auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc" , {"InferType" }); |
437 | return transform::Sequential({func_pass, InferType()}, "AnnotateTarget" ); |
438 | } |
439 | |
440 | TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget" ).set_body_typed(AnnotateTarget); |
441 | |
442 | } // namespace transform |
443 | |
444 | } // namespace relay |
445 | } // namespace tvm |
446 | |