1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/gather_partition_specs.cc |
22 | * \brief Gather the relevant \p PartitionSpecs from the available \p Targets. |
23 | */ |
24 | |
25 | #include "./gather_partition_specs.h" |
26 | |
27 | #include "./utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace relay { |
31 | namespace collage { |
32 | |
33 | namespace { |
34 | |
35 | PartitionRule MakeCombinePartitionRule(PartitionRule sub_rule, Array<CombinerRule> combiner_rules, |
36 | size_t max_depth) { |
37 | if (combiner_rules.empty()) { |
38 | return sub_rule; |
39 | } else { |
40 | return CombinePartitionRule("" , std::move(sub_rule), std::move(combiner_rules), max_depth); |
41 | } |
42 | } |
43 | |
44 | /*! \brief Returns the primitive combiner rules which mimic TVM's \p FuseOps. */ |
45 | Array<CombinerRule> TVMCombinerRules() { |
46 | Array<SimpleCombinerRule> simple_rules; |
47 | // Mimic the FuseOps rules. |
48 | simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kBroadcast)); |
49 | simple_rules.push_back(ByKindSimpleCombinerRule(kBroadcast, kCommReduce)); |
50 | simple_rules.push_back(ByKindSimpleCombinerRule(kInjective, kInjective)); |
51 | |
52 | Array<CombinerRule> combiner_rules; |
53 | // Fire the simple fusion rules |
54 | combiner_rules.push_back(AllSimpleCombinerRule("combiner" , std::move(simple_rules))); |
55 | // Fuse tuple arguments |
56 | combiner_rules.push_back(TupleArgCombinerRule("tuple" )); |
57 | // Fuse tuple projection |
58 | combiner_rules.push_back(TupleProjCombinerRule("proj" )); |
59 | |
60 | return combiner_rules; |
61 | } |
62 | |
63 | size_t GetMaxDepth(std::string key) { |
64 | tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); |
65 | std::string config_key = "relay.collage." + key; |
66 | Optional<Integer> opt_max_depth = ctxt->GetConfig(config_key, Optional<Integer>()); |
67 | ICHECK(opt_max_depth.defined()) << "missing binding for '" << config_key << " in pass context" ; |
68 | ICHECK(opt_max_depth.value()->value > 0) |
69 | << "invalid value for '" << config_key << " in pass context" ; |
70 | return static_cast<size_t>(opt_max_depth.value()->value); |
71 | } |
72 | |
73 | /*! \brief Returns partition rule mimicking TVM FuseOps. */ |
74 | PartitionRule MakeTVMPartitionRule() { |
75 | size_t max_depth = GetMaxDepth("tvm_max_depth" ); |
76 | // Build singleton candidates for all calls to ops <= kOutEWiseFusable. |
77 | OpCallByKindPartitionRule op_call_by_kind("" ); |
78 | // Combine candidates according to the TVM fusion rules. |
79 | PartitionRule combine = |
80 | MakeCombinePartitionRule(std::move(op_call_by_kind), TVMCombinerRules(), max_depth); |
81 | // Discard invalid candidates. |
82 | SubGraphConfig sub_graph_config; |
83 | sub_graph_config.allow_taps = false; |
84 | sub_graph_config.max_depth = max_depth; |
85 | sub_graph_config.max_exits = 1; |
86 | return OnlyValidPartitionRule("" , std::move(combine), sub_graph_config); |
87 | // NOTE: We don't wrap by a "Primitive" since we want to defer making TVM fusion decisions until |
88 | // after running more Relay passes. |
89 | } |
90 | |
91 | /*! |
92 | * \brief Returns the fusion style for default compiler. |
93 | */ |
94 | BYOCStyle DefaultBYOCFusionStyleForCompiler(const String& compiler) { |
95 | if (compiler == "cutlass" || compiler == "cublas" || compiler == "cudnn" ) { |
96 | return kNoFusionBYOCStyle; |
97 | } else if (compiler == "tensorrt" ) { |
98 | return kTVMFusionBYOCStyle; |
99 | } else { |
100 | return kArbitraryFusionBYOCStyle; |
101 | } |
102 | } |
103 | |
104 | /*! |
105 | * \brief Returns the fusion style for given compiler. |
106 | */ |
107 | BYOCStyle BYOCFusionStyleForCompiler(const String& compiler) { |
108 | tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); |
109 | std::string config_key = "relay.collage.byoc_fusion_style" ; |
110 | Optional<Array<String>> byoc_configs = ctxt->GetConfig(config_key, Optional<Array<String>>()); |
111 | BYOCStyle byoc_fusion_style = DefaultBYOCFusionStyleForCompiler(compiler); |
112 | if (!byoc_configs) { |
113 | return byoc_fusion_style; |
114 | } |
115 | for (auto config_ : byoc_configs.value()) { |
116 | std::vector<std::string> byoc_cfg = SplitString(config_, "." ); |
117 | if (byoc_cfg[0] == compiler) { |
118 | if (byoc_cfg[1] == "NoFusion" ) { |
119 | byoc_fusion_style = kNoFusionBYOCStyle; |
120 | } else if (byoc_cfg[1] == "TVMFusion" ) { |
121 | byoc_fusion_style = kTVMFusionBYOCStyle; |
122 | } else if (byoc_cfg[1] == "ArbitraryFusion" ) { |
123 | byoc_fusion_style = kArbitraryFusionBYOCStyle; |
124 | } else { |
125 | ICHECK(false) << "Invalid fusion name for compiler " << byoc_cfg[0] << " in pass context" ; |
126 | } |
127 | break; |
128 | } |
129 | } |
130 | return byoc_fusion_style; |
131 | } |
132 | |
133 | /*! |
134 | * \brief Returns the primitive combiner rules which allow for any touching candidates |
135 | * to be fused provided they don't have kind \p kOpaque. |
136 | */ |
137 | Array<CombinerRule> BYOCCombinerRules(const String& compiler) { |
138 | Array<SimpleCombinerRule> simple_rules; |
139 | Array<CombinerRule> combiner_rules; |
140 | switch (BYOCFusionStyleForCompiler(compiler)) { |
141 | case kNoFusionBYOCStyle: |
142 | break; |
143 | case kTVMFusionBYOCStyle: |
144 | // Conservatively assume the BYOC toolchain follows the same rules as for TVM's FuseOps. |
145 | simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kBroadcast)); |
146 | simple_rules.push_back(ByKindSimpleCombinerRule(kBroadcast, kCommReduce)); |
147 | simple_rules.push_back(ByKindSimpleCombinerRule(kInjective, kInjective)); |
148 | combiner_rules.push_back(AllSimpleCombinerRule("combiner" , std::move(simple_rules))); |
149 | break; |
150 | case kArbitraryFusionBYOCStyle: |
151 | // Just try all combinations up to the max_depth limit. |
152 | simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kOutEWiseFusable)); |
153 | combiner_rules.push_back(AllSimpleCombinerRule("combiner" , std::move(simple_rules))); |
154 | break; |
155 | } |
156 | return combiner_rules; |
157 | } |
158 | |
159 | /*! |
160 | * \brief Returns partition rule mimicking one entry in the patterns list passed to the |
161 | * MergeComposite pass. |
162 | */ |
163 | PartitionRule MakeLabelledDFPatternPartitionRule( |
164 | const std::string& compiler, String rule_name, DFPattern dataflow_pattern, |
165 | TPatternPredicate predicate = DefaultPatternPredicate) { |
166 | DFPatternPartitionRule patterns("" , std::move(dataflow_pattern), std::move(predicate)); |
167 | return CompositePartitionRule(std::move(rule_name), std::move(patterns)); |
168 | } |
169 | |
170 | /*! |
171 | * \brief Returns partition rule mimicking |
172 | * MergeComposite/AnnotateTarget/MergeCompilerRegions/PartitionGraph passes for "compiler" |
173 | * attribute of \p target. |
174 | */ |
175 | PartitionRule MakePatternBYOCPartitionRule(const std::string& compiler, |
176 | Array<PartitionRule> sub_rules) { |
177 | size_t max_depth = GetMaxDepth("byoc_max_depth" ); |
178 | // Union all the individual pattern rules. |
179 | UnionPartitionRule unioned("" , std::move(sub_rules)); |
180 | PartitionRule combine = |
181 | MakeCombinePartitionRule(std::move(unioned), BYOCCombinerRules(compiler), max_depth); |
182 | // Ignore invalid candidates. |
183 | SubGraphConfig sub_graph_config; |
184 | sub_graph_config.allow_taps = false; |
185 | sub_graph_config.max_depth = max_depth; |
186 | sub_graph_config.max_exits = 1; |
187 | OnlyValidPartitionRule valid("" , std::move(combine), sub_graph_config); |
188 | // Wrap the candidates in a "Primitive" function with a "Compiler" attribute. |
189 | return PrimitivePartitionRule("" , std::move(valid)); |
190 | } |
191 | |
192 | TVM_REGISTER_GLOBAL("relay.collage.MakeLabelledDFPatternPartitionRule" ) |
193 | .set_body_typed(MakeLabelledDFPatternPartitionRule); |
194 | |
195 | TVM_REGISTER_GLOBAL("relay.collage.MakeLabelledDFPatternPartitionRuleWithPredicate" ) |
196 | .set_body_typed(MakeLabelledDFPatternPartitionRule); |
197 | |
198 | TVM_REGISTER_GLOBAL("relay.collage.MakePatternBYOCPartitionRule" ) |
199 | .set_body_typed(MakePatternBYOCPartitionRule); |
200 | |
201 | /*! |
202 | * \brief Returns the rule to pick out expression nodes which can be 'left behind' for execution |
203 | * on the host. |
204 | */ |
205 | PartitionRule MakeHostPartitionRule() { return HostPartitionRule("" ); } |
206 | |
207 | } // namespace |
208 | |
209 | Array<PartitionSpec> GatherPartitionSpecs(const CompilationConfig& config) { |
210 | Array<PartitionSpec> result; |
211 | for (const auto& primitive_target : config->primitive_targets) { |
212 | String spec_name = GetSpecName(primitive_target); |
213 | PartitionRule rule; |
214 | if (primitive_target.IsExternalCodegen()) { |
215 | // Transition to the Python side so we can get access to the BYOC pattern registry. |
216 | // That will bounce right back into the above construction helpers. |
217 | static const runtime::PackedFunc* make_byoc_partition_rule = |
218 | runtime::Registry::Get("tvm.relay.collage.make_byoc_partition_rule" ); |
219 | ICHECK(make_byoc_partition_rule); |
220 | rule = (*make_byoc_partition_rule)(spec_name); // spec_name == primitive_target->kind->name |
221 | VLOG(1) << "Target " << primitive_target->ToDebugString() << " is for BYOC spec_name " |
222 | << spec_name << " and has default partition rule:\n" |
223 | << rule->ToString(); |
224 | } else { |
225 | rule = MakeTVMPartitionRule(); |
226 | VLOG(1) << "Target " << primitive_target->ToDebugString() << " is for TVM spec_name " |
227 | << spec_name << " and has default partition rule:\n" |
228 | << rule->ToString(); |
229 | } |
230 | result.push_back(PartitionSpec(spec_name, primitive_target, rule)); |
231 | } |
232 | |
233 | // Add one more spec to cover the host target. |
234 | result.push_back(PartitionSpec(kHostSpecName, config->host_target, MakeHostPartitionRule())); |
235 | |
236 | return result; |
237 | } |
238 | |
239 | } // namespace collage |
240 | } // namespace relay |
241 | } // namespace tvm |
242 | |