1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/annotate_target.cc
22 * \brief Wraps an expr with compiler_begin and compiler_end to indicate that
23 * this expr should be handled by the external compiler.
24 */
25
26#include <tvm/relay/attrs/annotation.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/op_attr_types.h>
29#include <tvm/relay/transform.h>
30
31#include "pass_utils.h"
32
33namespace tvm {
34namespace relay {
35namespace annotate_target {
36
37static const PackedFunc* make_begin_op =
38 runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
39static const PackedFunc* make_end_op =
40 runtime::Registry::Get("relay.op.annotation._make.compiler_end");
41static const char default_target[] = "default";
42// A helper class to insert annotation boundaries for all the ops of a program
43// region that will be handled by a specific compiler.
44class AnnotateTargetRewriter : public ExprRewriter {
45 public:
46 explicit AnnotateTargetRewriter(Array<runtime::String> targets) : targets_(std::move(targets)) {}
47
48 protected:
49 /*! \brief The target backends for annotation. */
50 Array<runtime::String> targets_;
51 /*! \brief Maintain the decision of the target for each op expr. */
52 std::unordered_map<Expr, std::string, ObjectPtrHash, ObjectPtrEqual> op_expr_to_target_;
53
54 /*!
55 * \brief This function annotates a compiler end and a compiler begin to all arguments.
56 *
57 * The compiler end is based on the arg target while the compiler begin is based on the given
58 * target. If target is not given and all arguments are going to the same target, then we will
59 * use that target; otherwise we use default for this op. Note that all arg exprs must be
60 * available in op_expr_to_target before calling this function.
61 *
62 * \param args An array of arguments of the given node.
63 * \param target The target of the current node.
64 * \return A pair of target and annotated argument expressions.
65 */
66 std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
67 const std::string& target = "") {
68 std::string ref_target = "";
69 Array<Expr> compiler_begins;
70 Array<Expr> compiler_ends;
71 for (auto arg : args) {
72 std::string arg_target = default_target;
73 const CallNode* call = arg.as<CallNode>();
74
75 if (call && call->op == CompilerBeginOp()) {
76 // Argument is already compiler begin node meaning that this is not the first time
77 // running this pass, so we simply remove it and will add a new one later.
78 ICHECK_EQ(call->args.size(), 1U);
79 // Do not alter existing annotation if not default
80 if (default_target != call->attrs.as<CompilerAttrs>()->compiler) {
81 compiler_begins.push_back(arg);
82 } else {
83 // Remove default
84 compiler_ends.push_back(call->args[0]);
85 }
86 const CallNode* end = call->args[0].as<CallNode>();
87 if (end && end->op == CompilerEndOp()) {
88 arg_target = end->attrs.as<CompilerAttrs>()->compiler;
89 }
90 } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
91 arg_target = op_expr_to_target_[arg];
92 // If an argument is a call node and has no argument, then it should be tensor ops such as
93 // zeros, so we treat it as input vars.
94 if (call && call->args.size() == 0) {
95 compiler_ends.push_back(arg);
96 } else {
97 compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op));
98 }
99 } else {
100 // Input vars.
101 compiler_ends.push_back(arg);
102 }
103
104 // Maintain reference target in case the target of the current node is unassigned.
105 if (ref_target == "") {
106 ref_target = arg_target;
107 } else if (ref_target != arg_target) {
108 ref_target = default_target;
109 }
110 }
111
112 // Determine compiler begin target.
113 std::string op_target = (target == "") ? ref_target : target;
114
115 if (ref_target != "") {
116 for (const auto& end : compiler_ends) {
117 compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op));
118 }
119 } else {
120 return {op_target, args};
121 }
122 return {op_target, compiler_begins};
123 }
124
125 Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc* ann_op) {
126 Expr new_op = (*ann_op)(expr, target);
127 new_op->checked_type_ = expr->checked_type_;
128 return new_op;
129 }
130
131 Expr InsertCompilerEndAndPropogateTarget(const Expr& expr) {
132 /*!
133 * \brief This function inserts compiler end to expr and maps the corresponding target to the
134 * new expression.
135 *
136 * This function checks for expr existence within the map and inserts the annotation.
137 * If the expression has a free variable (e.g: relay.zeros, relay.ones) we do not insert
138 * compiler end, since there are no compiler begins for it.
139 * Further, it propagates the target to the new expression and returns it
140 *
141 * \param expr A relay expression
142 * \return An annotated and target-propagated relay expression.
143 */
144 Expr new_expr = expr;
145 const CallNode* call = expr.as<CallNode>();
146 const TupleNode* tup = expr.as<TupleNode>();
147 if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) {
148 // Check whether expr has args, if not - do not insert compiler_end.
149 if (expr->IsInstance<RefWriteNode>() || expr->IsInstance<RefCreateNode>() ||
150 expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleGetItemNode>() ||
151 (call && !call->args.empty()) || (tup && !tup->fields.empty())) {
152 std::string target = op_expr_to_target_[new_expr];
153 new_expr = InsertAnnotation(new_expr, target, make_end_op);
154 op_expr_to_target_[new_expr] = target;
155 }
156 } else if (call && call->op == CompilerEndOp()) {
157 if (default_target == call->attrs.as<CompilerAttrs>()->compiler) {
158 ICHECK_EQ(call->args.size(), 1U);
159 new_expr = call->args[0];
160 std::string target = op_expr_to_target_[new_expr];
161 new_expr = InsertAnnotation(new_expr, target, make_end_op);
162 op_expr_to_target_[new_expr] = target;
163 }
164 }
165
166 return std::move(new_expr);
167 }
168
169 public:
170 Expr Rewrite_(const CallNode* pre, const Expr& post) override {
171 // Supported targets for this node. The order implies the priority.
172 std::vector<std::string> supported_targets;
173
174 auto op_node = pre->op.as<OpNode>();
175
176 // This graph has annotations, meaning that this is not the first time running this pass.
177 if (op_node && pre->op == CompilerBeginOp()) {
178 // Bypass compiler begin due to lack of target information. It will be processed
179 // when the following op handling arguments.
180 ICHECK_EQ(pre->args.size(), 1U);
181 // Preserve annotations
182 return post;
183 } else if (op_node && pre->op == CompilerEndOp()) {
184 // Override compiler end with the new target.
185 ICHECK_EQ(pre->args.size(), 1U);
186 auto input_expr = post.as<CallNode>()->args[0];
187 // Already annotated. Recover target
188 if (op_expr_to_target_.find(input_expr) == op_expr_to_target_.end()) {
189 op_expr_to_target_[input_expr] = post.as<CallNode>()->attrs.as<CompilerAttrs>()->compiler;
190 }
191 ICHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
192 // Preserve annotated nodes
193 return post;
194 }
195 // Check prior to peeking first argument
196 if (pre->args.size()) {
197 // Peek the first argument. If it is compiler begin then this node had annotated by
198 // another target before, so we also consider that target as a supported target.
199 const CallNode* first_arg_call = pre->args[0].as<CallNode>();
200 if (first_arg_call && first_arg_call->op == CompilerBeginOp()) {
201 std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
202 if (arg_target != default_target) {
203 // annotated already
204 return post;
205 }
206 }
207 }
208
209 // Check which targets this op can be offloaded.
210 if (op_node) {
211 // TVM operators: Check target specific op checking function and add to supported_targets
212 // if it is supported.
213 Op op = Downcast<Op>(pre->op);
214 ICHECK(op.defined());
215 for (const auto& target : this->targets_) {
216 if (!Op::HasAttrMap("target." + std::string(target))) {
217 continue;
218 }
219 auto fannotate = Op::GetAttrMap<FTVMAnnotateTarget>("target." + std::string(target));
220 const Expr& ex = GetRef<Expr>(pre);
221 if (fannotate.count(op) && fannotate[op](ex)) {
222 supported_targets.push_back(target);
223 }
224 }
225 } else if (pre->op->IsInstance<FunctionNode>()) {
226 // Composite function: Add the target of a composite function to supported_targets
227 // if it is in the target list.
228 Function func = Downcast<Function>(pre->op);
229 ICHECK(func.defined());
230 if (auto comp_name = func->GetAttr<String>(attr::kComposite)) {
231 std::string comp_name_str = comp_name.value();
232 size_t i = comp_name_str.find('.');
233 if (i != std::string::npos) {
234 std::string comp_target = comp_name_str.substr(0, i);
235 for (const auto& target : this->targets_) {
236 if (std::string(target) == comp_target) {
237 supported_targets.push_back(comp_target);
238 break;
239 }
240 }
241 }
242 }
243 }
244 supported_targets.push_back(default_target); // Make default as the last option.
245 // Visit and mutate arguments after the target of this op has been determined.
246 Call post_call = Downcast<Call>(post);
247 if (pre->op->IsInstance<VarNode>()) {
248 auto new_call = RewriteVarCall(post_call);
249 if (nullptr != new_call) return GetRef<Expr>(new_call->get());
250 }
251 // TODO(@comaniac, @zhiics): Now we simply assign this node to the target with
252 // the highest priority, but we should preserve all supported targets so that
253 // we can make a better decision.
254 std::string target = supported_targets[0];
255
256 // Add annotations to each arg.
257 auto target_n_args = AnnotateArgs(post_call->args, target);
258 Array<Expr> compiler_begins = std::get<1>(target_n_args);
259 Call new_call = Call(post_call->op, compiler_begins, post_call->attrs);
260 new_call->checked_type_ = pre->checked_type_;
261
262 // Update the target map.
263 op_expr_to_target_[new_call] = target;
264 return std::move(new_call);
265 }
266
267 virtual std::unique_ptr<Call> RewriteVarCall(const Call& post_call) { return nullptr; }
268
269 Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
270 auto tuple = Downcast<Tuple>(post);
271
272 auto target_n_args = AnnotateArgs(tuple->fields);
273 auto new_expr = WithFields(tuple, std::get<1>(target_n_args));
274 op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
275 return std::move(new_expr);
276 }
277
278 Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
279 auto expr = Downcast<TupleGetItem>(post);
280
281 auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple}));
282 auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index);
283 op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
284 return std::move(new_expr);
285 }
286
287 Expr Rewrite_(const FunctionNode* fn, const Expr& post) override {
288 Function func;
289 Expr new_body;
290 // don't step into composite functions
291 if (fn->GetAttr<String>(attr::kComposite).defined()) {
292 func = GetRef<Function>(fn);
293 new_body = func->body;
294 } else {
295 func = Downcast<Function>(post);
296 new_body = InsertCompilerEndAndPropogateTarget(func->body);
297 }
298 return WithFields(func, func->params, new_body);
299 }
300
301 Expr Rewrite_(const LetNode* op, const Expr& post) override {
302 auto let = Downcast<Let>(post);
303
304 Expr new_expr;
305 std::pair<std::string, Array<Expr>> target_n_args;
306 Expr new_body = InsertCompilerEndAndPropogateTarget(let->body);
307 // Do not annotate function literal with let binding.
308 if (let->value->IsInstance<FunctionNode>()) {
309 new_expr = Let(let->var, let->value, new_body);
310 } else {
311 target_n_args = AnnotateArgs({let->value});
312 new_expr = Let(let->var, std::get<1>(target_n_args)[0], new_body);
313 }
314
315 return std::move(new_expr);
316 }
317
318 Expr Rewrite_(const IfNode* op, const Expr& post) override {
319 auto expr = Downcast<If>(post);
320 Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond);
321 Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch);
322 Expr new_false_branch = InsertCompilerEndAndPropogateTarget(expr->false_branch);
323
324 auto new_expr = If(new_cond, new_true_branch, new_false_branch);
325 return std::move(new_expr);
326 }
327
328 Expr Rewrite_(const RefCreateNode* op, const Expr& post) override {
329 auto expr = Downcast<RefCreate>(post);
330
331 auto target_n_args = AnnotateArgs(Array<Expr>({expr->value}));
332 auto new_expr = RefCreate(std::get<1>(target_n_args)[0]);
333 op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
334 return std::move(new_expr);
335 }
336
337 Expr Rewrite_(const RefReadNode* op, const Expr& post) override {
338 auto expr = Downcast<RefRead>(post);
339
340 auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref}));
341 auto new_expr = RefRead(std::get<1>(target_n_args)[0]);
342 op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
343 return std::move(new_expr);
344 }
345
346 Expr Rewrite_(const RefWriteNode* op, const Expr& post) override {
347 auto expr = Downcast<RefWrite>(post);
348
349 auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value}));
350 auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
351 op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
352 return std::move(new_expr);
353 }
354};
355
356// A helper class to insert annotation boundaries for call ops and function nodes
357// in a program region that will be handled by a specific compiler.
358class CallOpsTargetRewriter : public AnnotateTargetRewriter {
359 public:
360 explicit CallOpsTargetRewriter(Array<runtime::String> targets)
361 : AnnotateTargetRewriter(std::move(targets)) {}
362
363 std::unique_ptr<Call> RewriteVarCall(const Call& post_call) override {
364 Array<Expr> ends;
365 for (auto arg : post_call->args) {
366 ends.push_back(InsertCompilerEndAndPropogateTarget(arg));
367 }
368 auto new_call = std::make_unique<Call>(post_call->op, ends, post_call->attrs);
369 (*new_call)->checked_type_ = post_call->checked_type_;
370 return new_call;
371 }
372
373 Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
374 auto tuple = Downcast<Tuple>(post);
375 Array<Expr> new_fields;
376 new_fields.reserve(tuple->fields.size());
377
378 for (auto f : tuple->fields) {
379 new_fields.push_back(InsertCompilerEndAndPropogateTarget(f));
380 }
381 return WithFields(tuple, new_fields);
382 }
383
384 Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
385 auto expr = Downcast<TupleGetItem>(post);
386 return std::move(TupleGetItem(InsertCompilerEndAndPropogateTarget(expr->tuple), expr->index));
387 }
388
389 Expr Rewrite_(const IfNode* op, const Expr& post) override {
390 auto expr = Downcast<If>(post);
391 Expr new_cond = InsertCompilerEndAndPropogateTarget(expr->cond);
392 Expr new_true_branch = InsertCompilerEndAndPropogateTarget(expr->true_branch);
393 Expr new_false_branch = InsertCompilerEndAndPropogateTarget(expr->false_branch);
394
395 auto new_expr = If(new_cond, new_true_branch, new_false_branch);
396 return std::move(new_expr);
397 }
398
399 Expr Rewrite_(const RefCreateNode* op, const Expr& post) override {
400 auto expr = Downcast<RefCreate>(post);
401 auto new_expr = RefCreate(InsertCompilerEndAndPropogateTarget(expr->value));
402 return std::move(new_expr);
403 }
404
405 Expr Rewrite_(const RefReadNode* op, const Expr& post) override {
406 auto expr = Downcast<RefRead>(post);
407 auto new_expr = RefRead(InsertCompilerEndAndPropogateTarget(expr->ref));
408 return std::move(new_expr);
409 }
410
411 Expr Rewrite_(const RefWriteNode* op, const Expr& post) override {
412 auto expr = Downcast<RefWrite>(post);
413 auto new_expr = RefWrite(InsertCompilerEndAndPropogateTarget(expr->ref),
414 InsertCompilerEndAndPropogateTarget(expr->value));
415 return std::move(new_expr);
416 }
417};
418
419Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets,
420 bool include_non_call_ops) {
421 auto r = include_non_call_ops ? std::make_unique<AnnotateTargetRewriter>(targets)
422 : std::make_unique<CallOpsTargetRewriter>(targets);
423 return PostOrderRewrite(expr, r.get());
424}
425
426} // namespace annotate_target
427
428namespace transform {
429
430Pass AnnotateTarget(const Array<runtime::String>& targets, bool include_non_call_ops) {
431 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
432 [=](Function f, IRModule m, PassContext pc) {
433 return Downcast<Function>(
434 relay::annotate_target::AnnotateTarget(f, targets, include_non_call_ops));
435 };
436 auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"});
437 return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
438}
439
440TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget").set_body_typed(AnnotateTarget);
441
442} // namespace transform
443
444} // namespace relay
445} // namespace tvm
446