1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | String GetRuleKindFromTarget(const Target& target) { |
25 | if (target->kind->name == "llvm" ) { |
26 | static const PackedFunc* f_check_vnni = |
27 | runtime::Registry::Get("tvm.target.x86.target_has_vnni" ); |
28 | ICHECK(f_check_vnni != nullptr) << "The `target_has_vnni` func is not in tvm registry." ; |
29 | if (target->GetAttr<String>("mcpu" ) && |
30 | (*f_check_vnni)(target->GetAttr<String>("mcpu" ).value())) { |
31 | return "vnni" ; |
32 | } else { |
33 | static const PackedFunc* f_check_avx512 = |
34 | runtime::Registry::Get("tvm.target.x86.target_has_avx512" ); |
35 | ICHECK(f_check_avx512 != nullptr) << "The `target_has_avx512` func is not in tvm registry." ; |
36 | if (target->GetAttr<String>("mcpu" ) && |
37 | (*f_check_avx512)(target->GetAttr<String>("mcpu" ).value())) { |
38 | return "avx512" ; |
39 | } |
40 | } |
41 | return "llvm" ; |
42 | } |
43 | if (target->kind->name == "hexagon" ) { |
44 | return "hexagon" ; |
45 | } |
46 | if (target->kind->name == "cuda" ) { |
47 | if (Optional<String> opt_sm = target->GetAttr<String>("arch" )) { |
48 | std::string sm = opt_sm.value(); |
49 | if (support::StartsWith(sm, "sm_" )) { |
50 | sm = sm.substr(3); |
51 | try { |
52 | if (std::stoi(sm) >= 75) { |
53 | return "cuda-tensorcore" ; |
54 | } |
55 | } catch (const std::invalid_argument& e) { |
56 | LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm |
57 | << ". Details: " << e.what(); |
58 | } |
59 | } |
60 | } |
61 | return "cuda" ; |
62 | } |
63 | |
64 | if (IsGPUTarget(target->kind->name)) { |
65 | return "cuda" ; |
66 | } |
67 | |
68 | if (target->kind->name == "c" ) { |
69 | return "c" ; |
70 | } |
71 | LOG(FATAL) << "Unsupported target: " << target; |
72 | throw; |
73 | } |
74 | |
75 | void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { |
76 | if (context->target.defined() && // |
77 | !(sch_rules.defined() && // |
78 | postprocs.defined() && // |
79 | mutator_probs.defined())) { |
80 | String kind = GetRuleKindFromTarget(context->target.value()); |
81 | Array<ScheduleRule> default_sch_rules; |
82 | Array<Postproc> default_postprocs; |
83 | Map<Mutator, FloatImm> default_mutator_probs; |
84 | // for target with skylake-avx512 |
85 | if (kind == "llvm" ) { |
86 | default_sch_rules = ScheduleRule::DefaultLLVM(); |
87 | default_postprocs = Postproc::DefaultLLVM(); |
88 | default_mutator_probs = Mutator::DefaultLLVM(); |
89 | } else if (kind == "cuda" ) { |
90 | default_sch_rules = ScheduleRule::DefaultCUDA(); |
91 | default_postprocs = Postproc::DefaultCUDA(); |
92 | default_mutator_probs = Mutator::DefaultCUDA(); |
93 | } else if (kind == "cuda-tensorcore" ) { |
94 | default_sch_rules = ScheduleRule::DefaultCUDATensorCore(); |
95 | default_postprocs = Postproc::DefaultCUDATensorCore(); |
96 | default_mutator_probs = Mutator::DefaultCUDATensorCore(); |
97 | } else if (kind == "hexagon" ) { |
98 | default_sch_rules = ScheduleRule::DefaultHexagon(); |
99 | default_postprocs = Postproc::DefaultHexagon(); |
100 | default_mutator_probs = Mutator::DefaultHexagon(); |
101 | } else if (kind == "vnni" ) { |
102 | default_sch_rules = ScheduleRule::DefaultX86("vnni" ); |
103 | default_postprocs = Postproc::DefaultCPUTensorization(); |
104 | default_mutator_probs = Mutator::DefaultLLVM(); |
105 | } else if (kind == "avx512" ) { |
106 | default_sch_rules = ScheduleRule::DefaultX86("avx512" ); |
107 | default_postprocs = Postproc::DefaultCPUTensorization(); |
108 | default_mutator_probs = Mutator::DefaultLLVM(); |
109 | } else if (kind == "c" ) { |
110 | default_sch_rules = ScheduleRule::DefaultMicro(); |
111 | default_postprocs = Postproc::DefaultMicro(); |
112 | default_mutator_probs = Mutator::DefaultMicro(); |
113 | } else { |
114 | LOG(FATAL) << "Unsupported kind: " << kind; |
115 | throw; |
116 | } |
117 | if (!sch_rules.defined()) { |
118 | sch_rules = default_sch_rules; |
119 | } |
120 | if (!postprocs.defined()) { |
121 | postprocs = default_postprocs; |
122 | } |
123 | if (!mutator_probs.defined()) { |
124 | mutator_probs = default_mutator_probs; |
125 | } |
126 | } |
127 | if (sch_rules.defined()) { |
128 | for (ScheduleRule i : sch_rules.value()) { |
129 | i->InitializeWithTuneContext(context); |
130 | } |
131 | } |
132 | if (postprocs.defined()) { |
133 | for (Postproc i : postprocs.value()) { |
134 | i->InitializeWithTuneContext(context); |
135 | } |
136 | } |
137 | if (mutator_probs.defined()) { |
138 | for (const auto& kv : mutator_probs.value()) { |
139 | Mutator mutator = kv.first; |
140 | mutator->InitializeWithTuneContext(context); |
141 | } |
142 | } |
143 | } |
144 | |
145 | void PySpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { |
146 | ICHECK(f_initialize_with_tune_context != nullptr) |
147 | << "PySpaceGenerator's InitializeWithTuneContext method not implemented!" ; |
148 | f_initialize_with_tune_context(context); |
149 | } |
150 | |
151 | Array<tir::Schedule> PySpaceGeneratorNode::GenerateDesignSpace(const IRModule& mod) { |
152 | ICHECK(f_generate_design_space != nullptr) |
153 | << "PySpaceGenerator's GenerateDesignSpace method not implemented!" ; |
154 | return f_generate_design_space(mod); |
155 | } |
156 | |
157 | SpaceGenerator PySpaceGeneratorNode::Clone() const { |
158 | ICHECK(f_clone != nullptr) << "PySpaceGenerator's Clone method not implemented!" ; |
159 | return f_clone(); |
160 | } |
161 | |
162 | SpaceGenerator SpaceGenerator::PySpaceGenerator( |
163 | Optional<Array<ScheduleRule>> sch_rules, Optional<Array<Postproc>> postprocs, |
164 | Optional<Map<Mutator, FloatImm>> mutator_probs, |
165 | FInitializeWithTuneContext f_initialize_with_tune_context, |
166 | FGenerateDesignSpace f_generate_design_space, FClone f_clone) { |
167 | ObjectPtr<PySpaceGeneratorNode> n = make_object<PySpaceGeneratorNode>(); |
168 | n->sch_rules = sch_rules; |
169 | n->postprocs = postprocs; |
170 | n->mutator_probs = mutator_probs; |
171 | n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); |
172 | n->f_generate_design_space = std::move(f_generate_design_space); |
173 | n->f_clone = std::move(f_clone); |
174 | return SpaceGenerator(n); |
175 | } |
176 | |
177 | TVM_REGISTER_OBJECT_TYPE(SpaceGeneratorNode); |
178 | TVM_REGISTER_NODE_TYPE(PySpaceGeneratorNode); |
179 | |
180 | TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorInitializeWithTuneContext" ) |
181 | .set_body_method<SpaceGenerator>(&SpaceGeneratorNode::InitializeWithTuneContext); |
182 | TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace" ) |
183 | .set_body_method<SpaceGenerator>(&SpaceGeneratorNode::GenerateDesignSpace); |
184 | TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator" ) |
185 | .set_body_typed(SpaceGenerator::PySpaceGenerator); |
186 | TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorClone" ) |
187 | .set_body_method<SpaceGenerator>(&SpaceGeneratorNode::Clone); |
188 | |
189 | } // namespace meta_schedule |
190 | } // namespace tvm |
191 | |