1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/merge_compiler_regions.cc |
22 | * |
23 | * \brief After operators have been annotated with the targets that support |
24 | * them, this pass creates regions of the operators for each target. It |
25 | * is guaranteed that the regions will have a topological ordering so that |
26 | * no data dependency issues exist. |
27 | * |
28 | * This pass only introduces annotations to indicate the regions. |
29 | * partition_graph must subsequently be called to lift these regions out |
30 | * as external functions. |
31 | */ |
32 | |
33 | #include <tvm/relay/analysis.h> |
34 | #include <tvm/relay/attrs/annotation.h> |
35 | #include <tvm/relay/error.h> |
36 | #include <tvm/relay/expr.h> |
37 | #include <tvm/relay/expr_functor.h> |
38 | #include <tvm/relay/transform.h> |
39 | |
40 | #include <string> |
41 | #include <unordered_map> |
42 | #include <unordered_set> |
43 | #include <vector> |
44 | |
45 | #include "../analysis/annotated_region_set.h" |
46 | #include "pass_utils.h" |
47 | |
48 | namespace tvm { |
49 | namespace relay { |
50 | namespace merge_compiler_region { |
51 | |
52 | class RegionMerger : public MixedModeVisitor { |
53 | public: |
54 | explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {} |
55 | |
56 | void VisitExpr_(const CallNode* call) final { |
57 | if (call->op == CompilerEndOp()) { |
58 | auto region = regions_->GetRegion(GetRef<Call>(call)); |
59 | |
60 | // Skip this region if it has been merged to the other region. |
61 | if (merged_regions_.find(region->GetID()) != merged_regions_.end()) { |
62 | return; |
63 | } |
64 | |
65 | // Check the region target. |
66 | auto compiler_attrs = call->attrs.as<CompilerAttrs>(); |
67 | ICHECK_EQ(region->GetTarget(), compiler_attrs->compiler); |
68 | |
69 | // Visit the unmerged parent regions. |
70 | for (const auto& arg : region->GetInputs()) { |
71 | // Region inputs must be begin annotation, and the region of |
72 | // the begin annotation's argument is the parent region. |
73 | auto begin = Downcast<Call>(arg); |
74 | ICHECK_EQ(begin->op, CompilerBeginOp()); |
75 | auto parent_region = regions_->GetRegion(begin->args[0]); |
76 | |
77 | // Skip this region if it has been merged. |
78 | if (!parent_region.defined()) { |
79 | continue; |
80 | } else if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) { |
81 | VisitExpr(begin->args[0]); |
82 | } |
83 | } |
84 | |
85 | // Collect unmerged parent regions. |
86 | std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> mergeable_regions; |
87 | for (const auto& arg : region->GetInputs()) { |
88 | auto begin = Downcast<Call>(arg); |
89 | ICHECK_EQ(begin->op, CompilerBeginOp()); |
90 | auto parent_region = regions_->GetRegion(begin->args[0]); |
91 | if (parent_region.defined()) { |
92 | mergeable_regions.insert(parent_region); |
93 | } |
94 | } |
95 | |
96 | // Propogate all the parent restrictions to the current region. |
97 | auto& region_restrictions = region_restrictions_[region->GetID()]; |
98 | for (const auto& parent_region : mergeable_regions) { |
99 | auto parent_restrictions = region_restrictions_[parent_region->GetID()]; |
100 | region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end()); |
101 | } |
102 | |
103 | for (const auto& parent_region : mergeable_regions) { |
104 | // Skip the parent region with a different target. |
105 | if (parent_region->GetTarget() != compiler_attrs->compiler) { |
106 | region_restrictions.insert(parent_region->GetID()); |
107 | continue; |
108 | } |
109 | |
110 | // Skip the parent region if it is in the restriction set. |
111 | if (region_restrictions.find(parent_region->GetID()) != region_restrictions.end()) { |
112 | continue; |
113 | } |
114 | |
115 | // Merge the parent region to the current one. |
116 | regions_->MergeRegions(parent_region, region); |
117 | |
118 | // Replace the parent region ID with the current region for all |
119 | // other regions' restriction sets. |
120 | for (const auto& r : regions_) { |
121 | auto& restrictions = region_restrictions_[r->GetID()]; |
122 | if (restrictions.find(parent_region->GetID()) != restrictions.end()) { |
123 | restrictions.erase(parent_region->GetID()); |
124 | restrictions.insert(region->GetID()); |
125 | } |
126 | } |
127 | } |
128 | merged_regions_.insert(region->GetID()); |
129 | } |
130 | } |
131 | |
132 | private: |
133 | AnnotatedRegionSet regions_; |
134 | std::unordered_set<int> merged_regions_; |
135 | std::unordered_map<int, std::unordered_set<int>> region_restrictions_; |
136 | }; |
137 | |
138 | class MergeAnnotations : public ExprRewriter { |
139 | public: |
140 | explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} |
141 | |
142 | Expr Rewrite_(const CallNode* call, const Expr& post) final { |
143 | // Merge annotations which are now internal to a region. |
144 | // This happens if we see a compiler begin next to a |
145 | // compiler end and they're both in the same region. |
146 | if (call->op == CompilerBeginOp() && call->args[0]->IsInstance<CallNode>()) { |
147 | auto arg = Downcast<Call>(call->args[0]); |
148 | if (arg->op == CompilerEndOp()) { |
149 | auto region1 = regions_->GetRegion(GetRef<Call>(call)); |
150 | auto region2 = regions_->GetRegion(arg); |
151 | if (region1 == region2) { |
152 | auto post_arg = post.as<CallNode>()->args[0]; |
153 | return post_arg.as<CallNode>()->args[0]; |
154 | } |
155 | } |
156 | } |
157 | return post; |
158 | } |
159 | |
160 | private: |
161 | AnnotatedRegionSet regions_; |
162 | }; |
163 | |
164 | Expr MergeCompilerRegions(const Expr& expr) { |
165 | // Create regions using the annotations. |
166 | AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, CompilerBeginOp(), CompilerEndOp()); |
167 | |
168 | // Analyze the graph to explore the opportunities of merging regions. |
169 | RegionMerger merger(regions); |
170 | merger.VisitExpr(expr); |
171 | |
172 | // Remove annotations that are not in the region boundaries. |
173 | MergeAnnotations merge_anno(regions); |
174 | return PostOrderRewrite(expr, &merge_anno); |
175 | } |
176 | |
177 | } // namespace merge_compiler_region |
178 | |
179 | namespace transform { |
180 | |
181 | Pass MergeCompilerRegions() { |
182 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func = |
183 | [=](Function f, IRModule m, PassContext pc) { |
184 | return Downcast<Function>(merge_compiler_region::MergeCompilerRegions(f)); |
185 | }; |
186 | auto merged = CreateFunctionPass(part_func, 0, "MergeCompilerRegions" , {}); |
187 | return Sequential({merged, InferType()}); |
188 | } |
189 | |
190 | TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions" ) |
191 | .set_body_typed(transform::MergeCompilerRegions); |
192 | |
193 | } // namespace transform |
194 | |
195 | } // namespace relay |
196 | } // namespace tvm |
197 | |