1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24String GetRuleKindFromTarget(const Target& target) {
25 if (target->kind->name == "llvm") {
26 static const PackedFunc* f_check_vnni =
27 runtime::Registry::Get("tvm.target.x86.target_has_vnni");
28 ICHECK(f_check_vnni != nullptr) << "The `target_has_vnni` func is not in tvm registry.";
29 if (target->GetAttr<String>("mcpu") &&
30 (*f_check_vnni)(target->GetAttr<String>("mcpu").value())) {
31 return "vnni";
32 } else {
33 static const PackedFunc* f_check_avx512 =
34 runtime::Registry::Get("tvm.target.x86.target_has_avx512");
35 ICHECK(f_check_avx512 != nullptr) << "The `target_has_avx512` func is not in tvm registry.";
36 if (target->GetAttr<String>("mcpu") &&
37 (*f_check_avx512)(target->GetAttr<String>("mcpu").value())) {
38 return "avx512";
39 }
40 }
41 return "llvm";
42 }
43 if (target->kind->name == "hexagon") {
44 return "hexagon";
45 }
46 if (target->kind->name == "cuda") {
47 if (Optional<String> opt_sm = target->GetAttr<String>("arch")) {
48 std::string sm = opt_sm.value();
49 if (support::StartsWith(sm, "sm_")) {
50 sm = sm.substr(3);
51 try {
52 if (std::stoi(sm) >= 75) {
53 return "cuda-tensorcore";
54 }
55 } catch (const std::invalid_argument& e) {
56 LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm
57 << ". Details: " << e.what();
58 }
59 }
60 }
61 return "cuda";
62 }
63
64 if (IsGPUTarget(target->kind->name)) {
65 return "cuda";
66 }
67
68 if (target->kind->name == "c") {
69 return "c";
70 }
71 LOG(FATAL) << "Unsupported target: " << target;
72 throw;
73}
74
75void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) {
76 if (context->target.defined() && //
77 !(sch_rules.defined() && //
78 postprocs.defined() && //
79 mutator_probs.defined())) {
80 String kind = GetRuleKindFromTarget(context->target.value());
81 Array<ScheduleRule> default_sch_rules;
82 Array<Postproc> default_postprocs;
83 Map<Mutator, FloatImm> default_mutator_probs;
84 // for target with skylake-avx512
85 if (kind == "llvm") {
86 default_sch_rules = ScheduleRule::DefaultLLVM();
87 default_postprocs = Postproc::DefaultLLVM();
88 default_mutator_probs = Mutator::DefaultLLVM();
89 } else if (kind == "cuda") {
90 default_sch_rules = ScheduleRule::DefaultCUDA();
91 default_postprocs = Postproc::DefaultCUDA();
92 default_mutator_probs = Mutator::DefaultCUDA();
93 } else if (kind == "cuda-tensorcore") {
94 default_sch_rules = ScheduleRule::DefaultCUDATensorCore();
95 default_postprocs = Postproc::DefaultCUDATensorCore();
96 default_mutator_probs = Mutator::DefaultCUDATensorCore();
97 } else if (kind == "hexagon") {
98 default_sch_rules = ScheduleRule::DefaultHexagon();
99 default_postprocs = Postproc::DefaultHexagon();
100 default_mutator_probs = Mutator::DefaultHexagon();
101 } else if (kind == "vnni") {
102 default_sch_rules = ScheduleRule::DefaultX86("vnni");
103 default_postprocs = Postproc::DefaultCPUTensorization();
104 default_mutator_probs = Mutator::DefaultLLVM();
105 } else if (kind == "avx512") {
106 default_sch_rules = ScheduleRule::DefaultX86("avx512");
107 default_postprocs = Postproc::DefaultCPUTensorization();
108 default_mutator_probs = Mutator::DefaultLLVM();
109 } else if (kind == "c") {
110 default_sch_rules = ScheduleRule::DefaultMicro();
111 default_postprocs = Postproc::DefaultMicro();
112 default_mutator_probs = Mutator::DefaultMicro();
113 } else {
114 LOG(FATAL) << "Unsupported kind: " << kind;
115 throw;
116 }
117 if (!sch_rules.defined()) {
118 sch_rules = default_sch_rules;
119 }
120 if (!postprocs.defined()) {
121 postprocs = default_postprocs;
122 }
123 if (!mutator_probs.defined()) {
124 mutator_probs = default_mutator_probs;
125 }
126 }
127 if (sch_rules.defined()) {
128 for (ScheduleRule i : sch_rules.value()) {
129 i->InitializeWithTuneContext(context);
130 }
131 }
132 if (postprocs.defined()) {
133 for (Postproc i : postprocs.value()) {
134 i->InitializeWithTuneContext(context);
135 }
136 }
137 if (mutator_probs.defined()) {
138 for (const auto& kv : mutator_probs.value()) {
139 Mutator mutator = kv.first;
140 mutator->InitializeWithTuneContext(context);
141 }
142 }
143}
144
145void PySpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) {
146 ICHECK(f_initialize_with_tune_context != nullptr)
147 << "PySpaceGenerator's InitializeWithTuneContext method not implemented!";
148 f_initialize_with_tune_context(context);
149}
150
151Array<tir::Schedule> PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& mod) {
152 ICHECK(f_generate_design_space != nullptr)
153 << "PySpaceGenerator's GenerateDesignSpace method not implemented!";
154 return f_generate_design_space(mod);
155}
156
157SpaceGenerator PySpaceGeneratorNode::Clone() const {
158 ICHECK(f_clone != nullptr) << "PySpaceGenerator's Clone method not implemented!";
159 return f_clone();
160}
161
162SpaceGenerator SpaceGenerator::PySpaceGenerator(
163 Optional<Array<ScheduleRule>> sch_rules, Optional<Array<Postproc>> postprocs,
164 Optional<Map<Mutator, FloatImm>> mutator_probs,
165 FInitializeWithTuneContext f_initialize_with_tune_context,
166 FGenerateDesignSpace f_generate_design_space, FClone f_clone) {
167 ObjectPtr<PySpaceGeneratorNode> n = make_object<PySpaceGeneratorNode>();
168 n->sch_rules = sch_rules;
169 n->postprocs = postprocs;
170 n->mutator_probs = mutator_probs;
171 n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
172 n->f_generate_design_space = std::move(f_generate_design_space);
173 n->f_clone = std::move(f_clone);
174 return SpaceGenerator(n);
175}
176
177TVM_REGISTER_OBJECT_TYPE(SpaceGeneratorNode);
178TVM_REGISTER_NODE_TYPE(PySpaceGeneratorNode);
179
180TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorInitializeWithTuneContext")
181 .set_body_method<SpaceGenerator>(&SpaceGeneratorNode::InitializeWithTuneContext);
182TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace")
183 .set_body_method<SpaceGenerator>(&SpaceGeneratorNode::GenerateDesignSpace);
184TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator")
185 .set_body_typed(SpaceGenerator::PySpaceGenerator);
186TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone")
187 .set_body_method<SpaceGenerator>(&SpaceGeneratorNode::Clone);
188
189} // namespace meta_schedule
190} // namespace tvm
191