1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/merge_compiler_regions.cc
22 *
23 * \brief After operators have been annotated with the targets that support
24 * them, this pass creates regions of the operators for each target. It
25 * is guaranteed that the regions will have a topological ordering so that
26 * no data dependency issues exist.
27 *
28 * This pass only introduces annotations to indicate the regions.
29 * partition_graph must subsequently be called to lift these regions out
30 * as external functions.
31 */
32
33#include <tvm/relay/analysis.h>
34#include <tvm/relay/attrs/annotation.h>
35#include <tvm/relay/error.h>
36#include <tvm/relay/expr.h>
37#include <tvm/relay/expr_functor.h>
38#include <tvm/relay/transform.h>
39
40#include <string>
41#include <unordered_map>
42#include <unordered_set>
43#include <vector>
44
45#include "../analysis/annotated_region_set.h"
46#include "pass_utils.h"
47
48namespace tvm {
49namespace relay {
50namespace merge_compiler_region {
51
52class RegionMerger : public MixedModeVisitor {
53 public:
54 explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}
55
56 void VisitExpr_(const CallNode* call) final {
57 if (call->op == CompilerEndOp()) {
58 auto region = regions_->GetRegion(GetRef<Call>(call));
59
60 // Skip this region if it has been merged to the other region.
61 if (merged_regions_.find(region->GetID()) != merged_regions_.end()) {
62 return;
63 }
64
65 // Check the region target.
66 auto compiler_attrs = call->attrs.as<CompilerAttrs>();
67 ICHECK_EQ(region->GetTarget(), compiler_attrs->compiler);
68
69 // Visit the unmerged parent regions.
70 for (const auto& arg : region->GetInputs()) {
71 // Region inputs must be begin annotation, and the region of
72 // the begin annotation's argument is the parent region.
73 auto begin = Downcast<Call>(arg);
74 ICHECK_EQ(begin->op, CompilerBeginOp());
75 auto parent_region = regions_->GetRegion(begin->args[0]);
76
77 // Skip this region if it has been merged.
78 if (!parent_region.defined()) {
79 continue;
80 } else if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) {
81 VisitExpr(begin->args[0]);
82 }
83 }
84
85 // Collect unmerged parent regions.
86 std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> mergeable_regions;
87 for (const auto& arg : region->GetInputs()) {
88 auto begin = Downcast<Call>(arg);
89 ICHECK_EQ(begin->op, CompilerBeginOp());
90 auto parent_region = regions_->GetRegion(begin->args[0]);
91 if (parent_region.defined()) {
92 mergeable_regions.insert(parent_region);
93 }
94 }
95
96 // Propogate all the parent restrictions to the current region.
97 auto& region_restrictions = region_restrictions_[region->GetID()];
98 for (const auto& parent_region : mergeable_regions) {
99 auto parent_restrictions = region_restrictions_[parent_region->GetID()];
100 region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end());
101 }
102
103 for (const auto& parent_region : mergeable_regions) {
104 // Skip the parent region with a different target.
105 if (parent_region->GetTarget() != compiler_attrs->compiler) {
106 region_restrictions.insert(parent_region->GetID());
107 continue;
108 }
109
110 // Skip the parent region if it is in the restriction set.
111 if (region_restrictions.find(parent_region->GetID()) != region_restrictions.end()) {
112 continue;
113 }
114
115 // Merge the parent region to the current one.
116 regions_->MergeRegions(parent_region, region);
117
118 // Replace the parent region ID with the current region for all
119 // other regions' restriction sets.
120 for (const auto& r : regions_) {
121 auto& restrictions = region_restrictions_[r->GetID()];
122 if (restrictions.find(parent_region->GetID()) != restrictions.end()) {
123 restrictions.erase(parent_region->GetID());
124 restrictions.insert(region->GetID());
125 }
126 }
127 }
128 merged_regions_.insert(region->GetID());
129 }
130 }
131
132 private:
133 AnnotatedRegionSet regions_;
134 std::unordered_set<int> merged_regions_;
135 std::unordered_map<int, std::unordered_set<int>> region_restrictions_;
136};
137
138class MergeAnnotations : public ExprRewriter {
139 public:
140 explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}
141
142 Expr Rewrite_(const CallNode* call, const Expr& post) final {
143 // Merge annotations which are now internal to a region.
144 // This happens if we see a compiler begin next to a
145 // compiler end and they're both in the same region.
146 if (call->op == CompilerBeginOp() && call->args[0]->IsInstance<CallNode>()) {
147 auto arg = Downcast<Call>(call->args[0]);
148 if (arg->op == CompilerEndOp()) {
149 auto region1 = regions_->GetRegion(GetRef<Call>(call));
150 auto region2 = regions_->GetRegion(arg);
151 if (region1 == region2) {
152 auto post_arg = post.as<CallNode>()->args[0];
153 return post_arg.as<CallNode>()->args[0];
154 }
155 }
156 }
157 return post;
158 }
159
160 private:
161 AnnotatedRegionSet regions_;
162};
163
164Expr MergeCompilerRegions(const Expr& expr) {
165 // Create regions using the annotations.
166 AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, CompilerBeginOp(), CompilerEndOp());
167
168 // Analyze the graph to explore the opportunities of merging regions.
169 RegionMerger merger(regions);
170 merger.VisitExpr(expr);
171
172 // Remove annotations that are not in the region boundaries.
173 MergeAnnotations merge_anno(regions);
174 return PostOrderRewrite(expr, &merge_anno);
175}
176
177} // namespace merge_compiler_region
178
179namespace transform {
180
181Pass MergeCompilerRegions() {
182 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
183 [=](Function f, IRModule m, PassContext pc) {
184 return Downcast<Function>(merge_compiler_region::MergeCompilerRegions(f));
185 };
186 auto merged = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {});
187 return Sequential({merged, InferType()});
188}
189
190TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions")
191 .set_body_typed(transform::MergeCompilerRegions);
192
193} // namespace transform
194
195} // namespace relay
196} // namespace tvm
197