1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/gather_partition_specs.cc
22 * \brief Gather the relevant \p PartitionSpecs from the available \p Targets.
23 */
24
25#include "./gather_partition_specs.h"
26
27#include "./utils.h"
28
29namespace tvm {
30namespace relay {
31namespace collage {
32
33namespace {
34
35PartitionRule MakeCombinePartitionRule(PartitionRule sub_rule, Array<CombinerRule> combiner_rules,
36 size_t max_depth) {
37 if (combiner_rules.empty()) {
38 return sub_rule;
39 } else {
40 return CombinePartitionRule("", std::move(sub_rule), std::move(combiner_rules), max_depth);
41 }
42}
43
44/*! \brief Returns the primitive combiner rules which mimic TVM's \p FuseOps. */
45Array<CombinerRule> TVMCombinerRules() {
46 Array<SimpleCombinerRule> simple_rules;
47 // Mimic the FuseOps rules.
48 simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kBroadcast));
49 simple_rules.push_back(ByKindSimpleCombinerRule(kBroadcast, kCommReduce));
50 simple_rules.push_back(ByKindSimpleCombinerRule(kInjective, kInjective));
51
52 Array<CombinerRule> combiner_rules;
53 // Fire the simple fusion rules
54 combiner_rules.push_back(AllSimpleCombinerRule("combiner", std::move(simple_rules)));
55 // Fuse tuple arguments
56 combiner_rules.push_back(TupleArgCombinerRule("tuple"));
57 // Fuse tuple projection
58 combiner_rules.push_back(TupleProjCombinerRule("proj"));
59
60 return combiner_rules;
61}
62
63size_t GetMaxDepth(std::string key) {
64 tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
65 std::string config_key = "relay.collage." + key;
66 Optional<Integer> opt_max_depth = ctxt->GetConfig(config_key, Optional<Integer>());
67 ICHECK(opt_max_depth.defined()) << "missing binding for '" << config_key << " in pass context";
68 ICHECK(opt_max_depth.value()->value > 0)
69 << "invalid value for '" << config_key << " in pass context";
70 return static_cast<size_t>(opt_max_depth.value()->value);
71}
72
73/*! \brief Returns partition rule mimicking TVM FuseOps. */
74PartitionRule MakeTVMPartitionRule() {
75 size_t max_depth = GetMaxDepth("tvm_max_depth");
76 // Build singleton candidates for all calls to ops <= kOutEWiseFusable.
77 OpCallByKindPartitionRule op_call_by_kind("");
78 // Combine candidates according to the TVM fusion rules.
79 PartitionRule combine =
80 MakeCombinePartitionRule(std::move(op_call_by_kind), TVMCombinerRules(), max_depth);
81 // Discard invalid candidates.
82 SubGraphConfig sub_graph_config;
83 sub_graph_config.allow_taps = false;
84 sub_graph_config.max_depth = max_depth;
85 sub_graph_config.max_exits = 1;
86 return OnlyValidPartitionRule("", std::move(combine), sub_graph_config);
87 // NOTE: We don't wrap by a "Primitive" since we want to defer making TVM fusion decisions until
88 // after running more Relay passes.
89}
90
91/*!
92 * \brief Returns the fusion style for default compiler.
93 */
94BYOCStyle DefaultBYOCFusionStyleForCompiler(const String& compiler) {
95 if (compiler == "cutlass" || compiler == "cublas" || compiler == "cudnn") {
96 return kNoFusionBYOCStyle;
97 } else if (compiler == "tensorrt") {
98 return kTVMFusionBYOCStyle;
99 } else {
100 return kArbitraryFusionBYOCStyle;
101 }
102}
103
104/*!
105 * \brief Returns the fusion style for given compiler.
106 */
107BYOCStyle BYOCFusionStyleForCompiler(const String& compiler) {
108 tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
109 std::string config_key = "relay.collage.byoc_fusion_style";
110 Optional<Array<String>> byoc_configs = ctxt->GetConfig(config_key, Optional<Array<String>>());
111 BYOCStyle byoc_fusion_style = DefaultBYOCFusionStyleForCompiler(compiler);
112 if (!byoc_configs) {
113 return byoc_fusion_style;
114 }
115 for (auto config_ : byoc_configs.value()) {
116 std::vector<std::string> byoc_cfg = SplitString(config_, ".");
117 if (byoc_cfg[0] == compiler) {
118 if (byoc_cfg[1] == "NoFusion") {
119 byoc_fusion_style = kNoFusionBYOCStyle;
120 } else if (byoc_cfg[1] == "TVMFusion") {
121 byoc_fusion_style = kTVMFusionBYOCStyle;
122 } else if (byoc_cfg[1] == "ArbitraryFusion") {
123 byoc_fusion_style = kArbitraryFusionBYOCStyle;
124 } else {
125 ICHECK(false) << "Invalid fusion name for compiler " << byoc_cfg[0] << " in pass context";
126 }
127 break;
128 }
129 }
130 return byoc_fusion_style;
131}
132
133/*!
134 * \brief Returns the primitive combiner rules which allow for any touching candidates
135 * to be fused provided they don't have kind \p kOpaque.
136 */
137Array<CombinerRule> BYOCCombinerRules(const String& compiler) {
138 Array<SimpleCombinerRule> simple_rules;
139 Array<CombinerRule> combiner_rules;
140 switch (BYOCFusionStyleForCompiler(compiler)) {
141 case kNoFusionBYOCStyle:
142 break;
143 case kTVMFusionBYOCStyle:
144 // Conservatively assume the BYOC toolchain follows the same rules as for TVM's FuseOps.
145 simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kBroadcast));
146 simple_rules.push_back(ByKindSimpleCombinerRule(kBroadcast, kCommReduce));
147 simple_rules.push_back(ByKindSimpleCombinerRule(kInjective, kInjective));
148 combiner_rules.push_back(AllSimpleCombinerRule("combiner", std::move(simple_rules)));
149 break;
150 case kArbitraryFusionBYOCStyle:
151 // Just try all combinations up to the max_depth limit.
152 simple_rules.push_back(ByKindSimpleCombinerRule(kOutEWiseFusable, kOutEWiseFusable));
153 combiner_rules.push_back(AllSimpleCombinerRule("combiner", std::move(simple_rules)));
154 break;
155 }
156 return combiner_rules;
157}
158
159/*!
160 * \brief Returns partition rule mimicking one entry in the patterns list passed to the
161 * MergeComposite pass.
162 */
163PartitionRule MakeLabelledDFPatternPartitionRule(
164 const std::string& compiler, String rule_name, DFPattern dataflow_pattern,
165 TPatternPredicate predicate = DefaultPatternPredicate) {
166 DFPatternPartitionRule patterns("", std::move(dataflow_pattern), std::move(predicate));
167 return CompositePartitionRule(std::move(rule_name), std::move(patterns));
168}
169
170/*!
171 * \brief Returns partition rule mimicking
172 * MergeComposite/AnnotateTarget/MergeCompilerRegions/PartitionGraph passes for "compiler"
173 * attribute of \p target.
174 */
175PartitionRule MakePatternBYOCPartitionRule(const std::string& compiler,
176 Array<PartitionRule> sub_rules) {
177 size_t max_depth = GetMaxDepth("byoc_max_depth");
178 // Union all the individual pattern rules.
179 UnionPartitionRule unioned("", std::move(sub_rules));
180 PartitionRule combine =
181 MakeCombinePartitionRule(std::move(unioned), BYOCCombinerRules(compiler), max_depth);
182 // Ignore invalid candidates.
183 SubGraphConfig sub_graph_config;
184 sub_graph_config.allow_taps = false;
185 sub_graph_config.max_depth = max_depth;
186 sub_graph_config.max_exits = 1;
187 OnlyValidPartitionRule valid("", std::move(combine), sub_graph_config);
188 // Wrap the candidates in a "Primitive" function with a "Compiler" attribute.
189 return PrimitivePartitionRule("", std::move(valid));
190}
191
192TVM_REGISTER_GLOBAL("relay.collage.MakeLabelledDFPatternPartitionRule")
193 .set_body_typed(MakeLabelledDFPatternPartitionRule);
194
195TVM_REGISTER_GLOBAL("relay.collage.MakeLabelledDFPatternPartitionRuleWithPredicate")
196 .set_body_typed(MakeLabelledDFPatternPartitionRule);
197
198TVM_REGISTER_GLOBAL("relay.collage.MakePatternBYOCPartitionRule")
199 .set_body_typed(MakePatternBYOCPartitionRule);
200
201/*!
202 * \brief Returns the rule to pick out expression nodes which can be 'left behind' for execution
203 * on the host.
204 */
205PartitionRule MakeHostPartitionRule() { return HostPartitionRule(""); }
206
207} // namespace
208
209Array<PartitionSpec> GatherPartitionSpecs(const CompilationConfig& config) {
210 Array<PartitionSpec> result;
211 for (const auto& primitive_target : config->primitive_targets) {
212 String spec_name = GetSpecName(primitive_target);
213 PartitionRule rule;
214 if (primitive_target.IsExternalCodegen()) {
215 // Transition to the Python side so we can get access to the BYOC pattern registry.
216 // That will bounce right back into the above construction helpers.
217 static const runtime::PackedFunc* make_byoc_partition_rule =
218 runtime::Registry::Get("tvm.relay.collage.make_byoc_partition_rule");
219 ICHECK(make_byoc_partition_rule);
220 rule = (*make_byoc_partition_rule)(spec_name); // spec_name == primitive_target->kind->name
221 VLOG(1) << "Target " << primitive_target->ToDebugString() << " is for BYOC spec_name "
222 << spec_name << " and has default partition rule:\n"
223 << rule->ToString();
224 } else {
225 rule = MakeTVMPartitionRule();
226 VLOG(1) << "Target " << primitive_target->ToDebugString() << " is for TVM spec_name "
227 << spec_name << " and has default partition rule:\n"
228 << rule->ToString();
229 }
230 result.push_back(PartitionSpec(spec_name, primitive_target, rule));
231 }
232
233 // Add one more spec to cover the host target.
234 result.push_back(PartitionSpec(kHostSpecName, config->host_target, MakeHostPartitionRule()));
235
236 return result;
237}
238
239} // namespace collage
240} // namespace relay
241} // namespace tvm
242