1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /**************** Constructors ****************/ |
25 | |
26 | Trace::Trace() { data_ = make_object<TraceNode>(); } |
27 | |
28 | Trace::Trace(Array<Instruction> insts, Map<Instruction, ObjectRef> decisions) { |
29 | ObjectPtr<TraceNode> n = make_object<TraceNode>(); |
30 | n->insts = std::move(insts); |
31 | n->decisions = std::move(decisions); |
32 | data_ = std::move(n); |
33 | } |
34 | |
35 | /**************** Utilities ****************/ |
36 | |
37 | int GetNumValidInstructions(const Array<Instruction>& insts, bool remove_postproc) { |
38 | if (!remove_postproc) { |
39 | return insts.size(); |
40 | } |
41 | int n_insts = 0; |
42 | for (const Instruction& inst : insts) { |
43 | if (!inst->kind->IsPostproc()) { |
44 | ++n_insts; |
45 | } else { |
46 | break; |
47 | } |
48 | } |
49 | return n_insts; |
50 | } |
51 | |
52 | /**************** TranslateInputRVs ****************/ |
53 | |
54 | Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs, |
55 | const std::unordered_map<const Object*, const Object*>& rv_map) { |
56 | Array<ObjectRef> result; |
57 | result.reserve(inputs.size()); |
58 | auto f_subst_with_rv_map = [&rv_map](const Var& var) -> Optional<PrimExpr> { |
59 | auto it = rv_map.find(var.get()); |
60 | if (it == rv_map.end()) { |
61 | return NullOpt; |
62 | } |
63 | const Object* dst = it->second; |
64 | ICHECK(dst->IsInstance<VarNode>()) |
65 | << "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey(); |
66 | return GetRef<Var>(static_cast<const VarNode*>(dst)); |
67 | }; |
68 | |
69 | for (const ObjectRef& input : inputs) { |
70 | if (!input.defined() || // constant: nullptr |
71 | input->IsInstance<StringObj>() || // constant: string |
72 | input->IsInstance<IntImmNode>() || // constant: integer |
73 | input->IsInstance<FloatImmNode>()) { // constant: float |
74 | result.push_back(input); |
75 | } else if (input->IsInstance<BlockRVNode>() || // RV: block |
76 | input->IsInstance<LoopRVNode>() || // RV: loop |
77 | input->IsInstance<VarNode>()) { // RV: var |
78 | auto it = rv_map.find(input.get()); |
79 | ICHECK(it != rv_map.end()) << "IndexError: Random variable doesn't exist: " << input; |
80 | result.push_back(GetRef<ObjectRef>(it->second)); |
81 | } else if (const auto* expr = input.as<PrimExprNode>()) { // RV: Expr |
82 | result.push_back(Substitute(GetRef<PrimExpr>(expr), f_subst_with_rv_map)); |
83 | } else if (const auto* index_map = input.as<IndexMapNode>()) { |
84 | result.push_back(Substitute(GetRef<IndexMap>(index_map), f_subst_with_rv_map)); |
85 | } else if (input->IsInstance<ArrayNode>()) { |
86 | // Recursively convert elements of the array into a new list of ObjectRefs. |
87 | result.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), rv_map)); |
88 | } else { |
89 | ICHECK(false) << "TypeError: Cannot recognize the type of an input random variable: " |
90 | << input->GetTypeKey(); |
91 | throw; |
92 | } |
93 | } |
94 | return result; |
95 | } |
96 | |
97 | Array<ObjectRef> TranslateInputRVs( |
98 | const Array<ObjectRef>& inputs, |
99 | const std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual>& rv_names) { |
100 | Array<ObjectRef> results; |
101 | results.reserve(inputs.size()); |
102 | for (const ObjectRef& input : inputs) { |
103 | if (!input.defined()) { |
104 | // Case 0. nullptr => None |
105 | results.push_back(String("None" )); |
106 | continue; |
107 | } |
108 | auto it = rv_names.find(input); |
109 | if (it != rv_names.end()) { |
110 | // Case 1. BlockRV, LoopRV, VarRV |
111 | results.push_back(it->second); |
112 | } else if (const auto* str_obj = input.as<StringObj>()) { |
113 | // Case 2. string => "content" |
114 | results.push_back(String('"' + std::string(str_obj->data) + '"')); |
115 | } else if (input->IsInstance<IntImmNode>() || input->IsInstance<FloatImmNode>()) { |
116 | // Case 3. integer or floating-point number |
117 | results.push_back(input); |
118 | } else if (input->IsInstance<ArrayNode>()) { |
119 | // Case 4: array |
120 | results.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), rv_names)); |
121 | } else if (input->IsInstance<MapNode>()) { |
122 | // Case 5: dict |
123 | results.push_back(input); |
124 | } else if (input->IsInstance<IndexMapNode>()) { |
125 | // // Case 6: IndexMap |
126 | IndexMap index_map = Downcast<IndexMap>(input); |
127 | index_map = index_map.RenameVariables([&rv_names](const Var& var) -> Optional<String> { |
128 | if (auto it = rv_names.find(var); it != rv_names.end()) { |
129 | return it->second; |
130 | } |
131 | return NullOpt; |
132 | }); |
133 | results.push_back(index_map); |
134 | } else if (input->IsInstance<BlockRVNode>() || inputs->IsInstance<LoopRVNode>() || |
135 | inputs->IsInstance<VarNode>()) { |
136 | LOG(FATAL) << "IndexError: Random variable is not defined " << input; |
137 | throw; |
138 | } else { |
139 | LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << input->GetTypeKey(); |
140 | throw; |
141 | } |
142 | } |
143 | return results; |
144 | } |
145 | |
146 | Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs, |
147 | const std::unordered_map<std::string, ObjectRef>& named_rvs) { |
148 | Array<ObjectRef> results; |
149 | results.reserve(inputs.size()); |
150 | for (const ObjectRef& input : inputs) { |
151 | // Case 3. integer or floating-point number |
152 | if (input->IsInstance<IntImmNode>() || input->IsInstance<FloatImmNode>()) { |
153 | results.push_back(input); |
154 | continue; |
155 | } |
156 | // Case 4. array |
157 | if (input->IsInstance<ArrayNode>()) { |
158 | results.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), named_rvs)); |
159 | continue; |
160 | } |
161 | // Case 5. dict |
162 | if (input->IsInstance<MapNode>()) { |
163 | results.push_back(input); |
164 | continue; |
165 | } |
166 | const auto* str = input.as<StringObj>(); |
167 | CHECK(str) << "TypeError: Expect String, but gets: " << input->GetTypeKey(); |
168 | CHECK_GT(str->size, 0) << "ValueError: Empty string is not allowed in input names" ; |
169 | const char* name = str->data; |
170 | int64_t size = str->size; |
171 | if (name[0] == '{' && name[size - 1] == '}') { |
172 | ObjectRef obj = LoadJSON(name); |
173 | // Case 6. IndexMap |
174 | if (obj->IsInstance<IndexMapNode>()) { |
175 | IndexMap index_map = Downcast<IndexMap>(obj); |
176 | index_map = Substitute(index_map, [&named_rvs](const Var& var) -> Optional<PrimExpr> { |
177 | auto it = named_rvs.find(var->name_hint); |
178 | if (it != named_rvs.end()) { |
179 | return Downcast<Var>(it->second); |
180 | } |
181 | return NullOpt; |
182 | }); |
183 | results.push_back(index_map); |
184 | continue; |
185 | } else { |
186 | LOG(FATAL) << "TypeError: Unexpected object: " << obj->GetTypeKey(); |
187 | throw; |
188 | } |
189 | } |
190 | // Case 2. string |
191 | if (size >= 2 && name[0] == '"' && name[size - 1] == '"') { |
192 | results.push_back(String(std::string(name + 1, size - 2))); |
193 | continue; |
194 | } |
195 | // Case 0 & 1. None, BlockRV, LoopRV, VarRV |
196 | auto it = named_rvs.find(name); |
197 | CHECK(it != named_rvs.end()) << "ValueError: The random variable is not defined: " << name; |
198 | results.push_back(it->second); |
199 | } |
200 | return results; |
201 | } |
202 | |
203 | /**************** TranslateAddOutputRVs ****************/ |
204 | |
205 | void TranslateAddOutputRVs(const Array<ObjectRef>& old_outputs, const Array<ObjectRef>& new_outputs, |
206 | std::unordered_map<const Object*, const Object*>* rv_map) { |
207 | ICHECK_EQ(old_outputs.size(), new_outputs.size()); |
208 | int n = old_outputs.size(); |
209 | const ObjectRef* p_old = old_outputs.GetArrayNode()->begin(); |
210 | const ObjectRef* p_new = new_outputs.GetArrayNode()->begin(); |
211 | for (int i = 0; i < n; ++i) { |
212 | (*rv_map)[p_old[i].get()] = p_new[i].get(); |
213 | } |
214 | } |
215 | |
216 | Array<String> TranslateAddOutputRVs( |
217 | const Array<ObjectRef>& outputs, |
218 | std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual>* rv_names) { |
219 | Array<String> results; |
220 | results.reserve(outputs.size()); |
221 | for (const ObjectRef& output : outputs) { |
222 | int i = rv_names->size(); |
223 | ICHECK(!rv_names->count(output)) |
224 | << "ValueError: The random variable has been produced once: " << rv_names->at(output); |
225 | String result{ObjectPtr<StringObj>{nullptr}}; |
226 | if (output->IsInstance<BlockRVNode>()) { |
227 | result = "b" + std::to_string(i); |
228 | } else if (output->IsInstance<LoopRVNode>()) { |
229 | result = "l" + std::to_string(i); |
230 | } else if (output->IsInstance<VarNode>()) { |
231 | result = "v" + std::to_string(i); |
232 | } else { |
233 | LOG(FATAL) << "TypeError: Cannot recognize the type of the random variable: " |
234 | << output->GetTypeKey(); |
235 | throw; |
236 | } |
237 | results.push_back(result); |
238 | rv_names->emplace(output, std::move(result)); |
239 | } |
240 | return results; |
241 | } |
242 | |
243 | void TranslateAddOutputRVs(const Array<String>& old_outputs, const Array<ObjectRef>& new_outputs, |
244 | std::unordered_map<std::string, ObjectRef>* named_rvs) { |
245 | ICHECK_EQ(old_outputs.size(), new_outputs.size()); |
246 | int n = old_outputs.size(); |
247 | const ObjectRef* p_old = old_outputs.GetArrayNode()->begin(); |
248 | const ObjectRef* p_new = new_outputs.GetArrayNode()->begin(); |
249 | for (int i = 0; i < n; ++i) { |
250 | const auto* name = static_cast<const StringObj*>(p_old[i].get()); |
251 | named_rvs->emplace(std::string(name->data, name->size), p_new[i]); |
252 | } |
253 | } |
254 | |
255 | /**************** Add/Remove/Get ****************/ |
256 | |
257 | Optional<ObjectRef> TraceNode::GetDecision(const Instruction& inst) const { |
258 | auto it = this->decisions.find(inst); |
259 | return it == this->decisions.end() ? Optional<ObjectRef>(NullOpt) : (*it).second; |
260 | } |
261 | |
262 | void TraceNode::Append(Instruction inst) { insts.push_back(std::move(inst)); } |
263 | |
264 | void TraceNode::Append(Instruction inst, ObjectRef decision) { |
265 | decisions.Set(inst, std::move(decision)); |
266 | insts.push_back(std::move(inst)); |
267 | } |
268 | |
269 | Optional<Instruction> TraceNode::Pop() { |
270 | if (insts.empty()) { |
271 | return NullOpt; |
272 | } |
273 | Instruction inst = insts.back(); |
274 | insts.pop_back(); |
275 | if (decisions.count(inst)) { |
276 | decisions.erase(inst); |
277 | } |
278 | return inst; |
279 | } |
280 | |
281 | /**************** Interfacing with InstructionKind ****************/ |
282 | |
283 | void TraceNode::ApplyToSchedule( |
284 | Schedule sch, bool remove_postproc, |
285 | runtime::TypedPackedFunc<ObjectRef(const Instruction& inst, const Array<ObjectRef>& inputs, // |
286 | const Array<ObjectRef>& attrs, // |
287 | const Optional<ObjectRef>& decision)> |
288 | decision_provider) const { |
289 | std::unordered_map<const Object*, const Object*> rv_map; |
290 | for (const Instruction& inst : this->insts) { |
291 | if (remove_postproc && inst->kind->IsPostproc()) { |
292 | break; |
293 | } |
294 | Array<ObjectRef> inputs = TranslateInputRVs(inst->inputs, rv_map); |
295 | Array<ObjectRef> attrs = inst->attrs; |
296 | Optional<ObjectRef> decision = this->GetDecision(inst); |
297 | if (decision_provider != nullptr) { |
298 | decision = decision_provider(inst, inputs, attrs, decision); |
299 | } |
300 | Array<ObjectRef> outputs = inst->kind->f_apply_to_schedule(sch, inputs, attrs, decision); |
301 | TranslateAddOutputRVs(inst->outputs, outputs, &rv_map); |
302 | } |
303 | } |
304 | |
305 | ObjectRef TraceNode::AsJSON(bool remove_postproc) const { |
306 | std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual> rv_names; |
307 | Array<ObjectRef> json_insts; |
308 | Array<ObjectRef> json_decisions; |
309 | json_insts.reserve(this->insts.size()); |
310 | json_decisions.reserve(this->insts.size()); |
311 | |
312 | int i = 0; |
313 | for (const Instruction& inst : this->insts) { |
314 | const InstructionKind& kind = inst->kind; |
315 | if (remove_postproc && kind->IsPostproc()) { |
316 | break; |
317 | } |
318 | json_insts.push_back(Array<ObjectRef>{ |
319 | /* 0: inst name */ kind->name, |
320 | /* 1: inputs */ TranslateInputRVs(inst->inputs, rv_names), |
321 | /* 2: attrs */ kind->f_attrs_as_json != nullptr ? kind->f_attrs_as_json(inst->attrs) |
322 | : ObjectRef(inst->attrs), |
323 | /* 3: outputs */ TranslateAddOutputRVs(inst->outputs, &rv_names), |
324 | }); |
325 | if (Optional<ObjectRef> decision = this->GetDecision(inst)) { |
326 | json_decisions.push_back(Array<ObjectRef>{ |
327 | /* 0: index */ Integer(i), |
328 | /* 1: decision */ decision.value(), |
329 | }); |
330 | } |
331 | ++i; |
332 | } |
333 | return Array<ObjectRef>{ |
334 | /* 0: trace */ std::move(json_insts), |
335 | /* 1: decision */ std::move(json_decisions), |
336 | }; |
337 | } |
338 | |
339 | Array<String> TraceNode::AsPython(bool remove_postproc) const { |
340 | std::unordered_map<ObjectRef, String, ObjectPtrHash, ObjectPtrEqual> rv_names; |
341 | Array<String> py_trace; |
342 | py_trace.reserve(this->insts.size()); |
343 | for (const Instruction& inst : this->insts) { |
344 | if (remove_postproc && inst->kind->IsPostproc()) { |
345 | break; |
346 | } |
347 | Array<ObjectRef> attrs; |
348 | attrs.reserve(inst->attrs.size()); |
349 | for (const ObjectRef& obj : inst->attrs) { |
350 | if (const auto* str = obj.as<StringObj>()) { |
351 | attrs.push_back(String('"' + std::string(str->data) + '"')); |
352 | } else { |
353 | attrs.push_back(obj); |
354 | } |
355 | } |
356 | py_trace.push_back( |
357 | inst->kind->f_as_python(/*inputs=*/TranslateInputRVs(inst->inputs, rv_names), |
358 | /*attrs=*/attrs, |
359 | /*decision=*/this->GetDecision(inst), |
360 | /*outputs=*/TranslateAddOutputRVs(inst->outputs, &rv_names))); |
361 | } |
362 | return py_trace; |
363 | } |
364 | |
365 | void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { |
366 | Array<ObjectRef> json_insts{nullptr}; |
367 | Array<ObjectRef> json_decisions{nullptr}; |
368 | // Parse `json` into `json_insts` and `json_decisions` |
369 | try { |
370 | const ArrayNode* arr = json.as<ArrayNode>(); |
371 | ICHECK(arr && arr->size() == 2); |
372 | const auto* arr0 = arr->at(0).as<ArrayNode>(); |
373 | const auto* arr1 = arr->at(1).as<ArrayNode>(); |
374 | ICHECK(arr0 && arr1); |
375 | json_insts = GetRef<Array<ObjectRef>>(arr0); |
376 | json_decisions = GetRef<Array<ObjectRef>>(arr1); |
377 | } catch (const tvm::Error& e) { |
378 | LOG(FATAL) << "ValueError: The json entry of a trace should contain two arrays, an array of " |
379 | "instructions and an array of decisions, but gets: " |
380 | << json; |
381 | throw; |
382 | } |
383 | // Parse `json_decisions` |
384 | std::vector<Optional<ObjectRef>> decisions(json_insts.size(), NullOpt); |
385 | for (const ObjectRef& decision_entry : json_decisions) { |
386 | int index = -1; |
387 | ObjectRef decision{nullptr}; |
388 | try { |
389 | const ArrayNode* arr = decision_entry.as<ArrayNode>(); |
390 | ICHECK(arr && arr->size() == 2); |
391 | const IntImmNode* arr0 = arr->at(0).as<IntImmNode>(); |
392 | ICHECK(arr0); |
393 | index = arr0->value; |
394 | decision = arr->at(1); |
395 | } catch (const tvm::Error& e) { |
396 | LOG(FATAL) << "ValueError: Each entry of a json decision should be a tuple [index, " |
397 | "decision], but gets: " |
398 | << decision_entry; |
399 | throw; |
400 | } |
401 | decisions[index] = std::move(decision); |
402 | } |
403 | // Parse `json_insts` |
404 | std::unordered_map<std::string, ObjectRef> named_rvs{{"None" , ObjectRef{nullptr}}}; |
405 | int i = 0; |
406 | for (const ObjectRef& inst_entry : json_insts) { |
407 | InstructionKind kind{nullptr}; |
408 | Array<ObjectRef> inputs{nullptr}; |
409 | Array<ObjectRef> attrs{nullptr}; |
410 | Array<String> outputs{ObjectPtr<Object>{nullptr}}; |
411 | // Parse the entry |
412 | try { |
413 | const auto* arr = inst_entry.as<ArrayNode>(); |
414 | ICHECK(arr && arr->size() == 4); |
415 | const auto* arr0 = arr->at(0).as<StringObj>(); |
416 | const auto* arr1 = arr->at(1).as<ArrayNode>(); |
417 | const auto* arr2 = arr->at(2).as<ArrayNode>(); |
418 | const auto* arr3 = arr->at(3).as<ArrayNode>(); |
419 | ICHECK(arr0 && arr1 && arr2 && arr3); |
420 | for (const ObjectRef& str : *arr3) { |
421 | ICHECK(str->IsInstance<StringObj>()); |
422 | } |
423 | kind = InstructionKind::Get(arr0->data); |
424 | inputs = GetRef<Array<ObjectRef>>(arr1); |
425 | attrs = GetRef<Array<ObjectRef>>(arr2); |
426 | outputs = GetRef<Array<String>>(arr3); |
427 | } catch (const tvm::Error& e) { |
428 | LOG(FATAL) << "ValueError: Each entry of a json instruction should be a tuple [inst_name, " |
429 | "inputs, attrs, outputs], but gets: " |
430 | << inst_entry; |
431 | throw; |
432 | } |
433 | // Parse inputs |
434 | inputs = TranslateInputRVs(inputs, named_rvs); |
435 | // Parse attrs |
436 | if (kind->f_attrs_from_json != nullptr) { |
437 | attrs = kind->f_attrs_from_json(attrs); |
438 | } |
439 | // Apply to the schedule |
440 | Array<ObjectRef> new_outputs = kind->f_apply_to_schedule(sch, inputs, attrs, decisions[i]); |
441 | // Parse outputs |
442 | TranslateAddOutputRVs(outputs, new_outputs, &named_rvs); |
443 | ++i; |
444 | } |
445 | } |
446 | |
447 | /**************** Creation ****************/ |
448 | |
449 | Trace TraceNode::WithDecision(Instruction inst, ObjectRef decision, bool remove_postproc) const { |
450 | int n_insts = GetNumValidInstructions(this->insts, remove_postproc); |
451 | Array<Instruction> new_insts = |
452 | Array<Instruction>{this->insts.begin(), this->insts.begin() + n_insts}; |
453 | Map<Instruction, ObjectRef> new_decisions{this->decisions.begin(), this->decisions.end()}; |
454 | new_decisions.Set(std::move(inst), std::move(decision)); |
455 | return Trace(new_insts, new_decisions); |
456 | } |
457 | |
458 | Trace TraceNode::Simplified(bool remove_postproc) const { |
459 | int n_insts = GetNumValidInstructions(this->insts, remove_postproc); |
460 | std::unordered_set<const Object*> used_rvs; |
461 | std::vector<Instruction> new_insts; |
462 | std::unordered_map<Instruction, ObjectRef, ObjectPtrHash, ObjectPtrEqual> new_decisions; |
463 | new_insts.reserve(n_insts); |
464 | new_decisions.reserve(this->decisions.size()); |
465 | for (int inst_idx = n_insts - 1; inst_idx >= 0; --inst_idx) { |
466 | const Instruction& inst = this->insts[inst_idx]; |
467 | // Check if all the variables the instruction defined are dead |
468 | // If so, and the instruction is pure, we can safely remove this instruction |
469 | bool all_defs_dead = inst->kind->is_pure; |
470 | if (all_defs_dead) { |
471 | for (const ObjectRef& obj : inst->outputs) { |
472 | if (used_rvs.count(obj.get())) { |
473 | all_defs_dead = false; |
474 | break; |
475 | } |
476 | } |
477 | } |
478 | // Remove this instruction |
479 | if (all_defs_dead) { |
480 | continue; |
481 | } |
482 | // Otherwise this instruction is not dead |
483 | new_insts.push_back(inst); |
484 | if (Optional<ObjectRef> decision = this->GetDecision(inst)) { |
485 | new_decisions.emplace(inst, std::move(decision)); |
486 | } |
487 | // Add its inputs as "used" ones |
488 | for (const ObjectRef& obj : inst->inputs) { |
489 | if (!obj.defined()) { |
490 | continue; |
491 | } else if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>() || |
492 | obj->IsInstance<VarNode>()) { |
493 | used_rvs.insert(obj.get()); |
494 | continue; |
495 | } else if (obj->IsInstance<PrimExprNode>()) { |
496 | PostOrderVisit(obj, [&used_rvs](const ObjectRef& obj) -> void { |
497 | if (obj->IsInstance<VarNode>()) { |
498 | used_rvs.insert(obj.get()); |
499 | } |
500 | }); |
501 | } |
502 | } |
503 | } |
504 | return Trace(Array<Instruction>(new_insts.rbegin(), new_insts.rend()), |
505 | Map<Instruction, ObjectRef>(new_decisions)); |
506 | } |
507 | |
508 | /**************** Repr ****************/ |
509 | |
510 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
511 | .set_dispatch<TraceNode>([](const ObjectRef& obj, ReprPrinter* p) { |
512 | const auto* self = obj.as<TraceNode>(); |
513 | ICHECK_NOTNULL(self); |
514 | p->stream << "# from tvm import tir\n" ; |
515 | p->stream << "def apply_trace(sch: tir.Schedule) -> None:\n" ; |
516 | Array<String> repr = self->AsPython(/*remove_postproc=*/false); |
517 | bool is_first = true; |
518 | for (const String& line : repr) { |
519 | if (is_first) { |
520 | is_first = false; |
521 | } else { |
522 | p->stream << '\n'; |
523 | } |
524 | p->stream << " " << line; |
525 | } |
526 | if (is_first) { |
527 | p->stream << " pass" ; |
528 | } |
529 | p->stream << std::flush; |
530 | }); |
531 | |
532 | /**************** Instruction Registration ****************/ |
533 | |
534 | struct EnterPostprocTraits : public UnpackedInstTraits<EnterPostprocTraits> { |
535 | static constexpr const char* kName = "EnterPostproc" ; |
536 | static constexpr bool kIsPure = false; |
537 | |
538 | private: |
539 | static constexpr size_t kNumInputs = 0; |
540 | static constexpr size_t kNumAttrs = 0; |
541 | static constexpr size_t kNumDecisions = 0; |
542 | |
543 | static void UnpackedApplyToSchedule(Schedule sch) { return sch->EnterPostproc(); } |
544 | |
545 | static String UnpackedAsPython(Array<String> outputs) { |
546 | PythonAPICall py("enter_postproc" ); |
547 | return py.Str(); |
548 | } |
549 | |
550 | template <typename> |
551 | friend struct ::tvm::tir::UnpackedInstTraits; |
552 | }; |
553 | |
554 | TVM_REGISTER_INST_KIND_TRAITS(EnterPostprocTraits); |
555 | |
556 | /**************** FFI ****************/ |
557 | |
558 | TVM_REGISTER_NODE_TYPE(TraceNode); |
559 | TVM_REGISTER_GLOBAL("tir.schedule.Trace" ) |
560 | .set_body_typed([](Optional<Array<Instruction>> insts, |
561 | Optional<Map<Instruction, ObjectRef>> decisions) { |
562 | return Trace(insts.value_or(Array<Instruction>()), |
563 | decisions.value_or(Map<Instruction, ObjectRef>())); |
564 | }); |
565 | TVM_REGISTER_GLOBAL("tir.schedule.TraceGetDecision" ) |
566 | .set_body_method<Trace>(&TraceNode::GetDecision); |
567 | TVM_REGISTER_GLOBAL("tir.schedule.TraceAppend" ) |
568 | .set_body_typed([](Trace self, Instruction inst, Optional<ObjectRef> decision) { |
569 | if (decision.defined()) { |
570 | return self->Append(inst, decision.value()); |
571 | } else { |
572 | return self->Append(inst); |
573 | } |
574 | }); |
575 | TVM_REGISTER_GLOBAL("tir.schedule.TracePop" ).set_body_method<Trace>(&TraceNode::Pop); |
576 | TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyToSchedule" ) |
577 | .set_body_method<Trace>(&TraceNode::ApplyToSchedule); |
578 | TVM_REGISTER_GLOBAL("tir.schedule.TraceAsJSON" ).set_body_method<Trace>(&TraceNode::AsJSON); |
579 | TVM_REGISTER_GLOBAL("tir.schedule.TraceAsPython" ).set_body_method<Trace>(&TraceNode::AsPython); |
580 | TVM_REGISTER_GLOBAL("tir.schedule.TraceWithDecision" ) |
581 | .set_body_method<Trace>(&TraceNode::WithDecision); |
582 | TVM_REGISTER_GLOBAL("tir.schedule.TraceSimplified" ).set_body_method<Trace>(&TraceNode::Simplified); |
583 | TVM_REGISTER_GLOBAL("tir.schedule.TraceApplyJSONToSchedule" ) |
584 | .set_body_typed(Trace::ApplyJSONToSchedule); |
585 | |
586 | } // namespace tir |
587 | } // namespace tvm |
588 | |