1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/transform_step.cc |
22 | * \brief Transformation steps. These steps are used to manipulate the LoopState. |
23 | * They are similar to the schedule primitives in te::Stage. |
24 | */ |
25 | |
26 | #include <tvm/auto_scheduler/compute_dag.h> |
27 | #include <tvm/auto_scheduler/loop_state.h> |
28 | #include <tvm/auto_scheduler/transform_step.h> |
29 | #include <tvm/runtime/logging.h> |
30 | #include <tvm/runtime/registry.h> |
31 | #include <tvm/te/operation.h> |
32 | |
33 | #include <string> |
34 | #include <utility> |
35 | #include <vector> |
36 | |
37 | #include "utils.h" |
38 | |
39 | namespace dmlc { |
40 | namespace json { |
41 | |
42 | template <> |
43 | struct Handler<::tvm::Array<::tvm::Integer>> { |
44 | inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::Integer>& array) { |
45 | writer->BeginArray(false); |
46 | for (const auto& i : array) { |
47 | ICHECK(i.defined()); |
48 | writer->WriteArrayItem(i->value); |
49 | } |
50 | writer->EndArray(); |
51 | } |
52 | inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::Integer>* array) { |
53 | array->clear(); |
54 | reader->BeginArray(); |
55 | while (reader->NextArrayItem()) { |
56 | int value; |
57 | Handler<int>::Read(reader, &value); |
58 | array->push_back(value); |
59 | } |
60 | } |
61 | }; |
62 | |
63 | template <> |
64 | struct Handler<::tvm::Array<::tvm::Optional<::tvm::Integer>>> { |
65 | inline static void Write(dmlc::JSONWriter* writer, |
66 | const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& array) { |
67 | writer->BeginArray(false); |
68 | for (const auto& i : array) { |
69 | ICHECK(i); |
70 | writer->WriteArrayItem(i.value()->value); |
71 | } |
72 | writer->EndArray(); |
73 | } |
74 | inline static void Read(dmlc::JSONReader* reader, |
75 | ::tvm::Array<::tvm::Optional<::tvm::Integer>>* array) { |
76 | array->clear(); |
77 | reader->BeginArray(); |
78 | while (reader->NextArrayItem()) { |
79 | int value; |
80 | Handler<int>::Read(reader, &value); |
81 | array->push_back(::tvm::Integer(value)); |
82 | } |
83 | } |
84 | }; |
85 | |
86 | } // namespace json |
87 | } // namespace dmlc |
88 | |
89 | namespace tvm { |
90 | namespace auto_scheduler { |
91 | |
92 | // Update the te::stage to tir::IterVar axis mapping |
93 | void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) { |
94 | if (auto pop = stage->op.as<te::ComputeOpNode>()) { |
95 | Array<IterVar> axes; |
96 | for (const auto& axis : pop->axis) { |
97 | axes.push_back(axis); |
98 | } |
99 | for (const auto& axis : pop->reduce_axis) { |
100 | axes.push_back(axis); |
101 | } |
102 | stage_to_axes->Set(stage, std::move(axes)); |
103 | } else if (stage->op->IsInstance<te::PlaceholderOpNode>()) { |
104 | {} // do nothing on Placeholder |
105 | } else { |
106 | LOG(FATAL) << "Invalid op " << stage->op; |
107 | } |
108 | } |
109 | |
110 | const char* IteratorAnnotationString[] = { |
111 | "for" , // kNone = 0 |
112 | "unroll" , // kUnroll = 1 |
113 | "vectorize" , // kVectorize = 2 |
114 | "parallel" , // kParallel = 3 |
115 | "vthread" , // kVThread = 4 |
116 | "blockIdx.x" , // kBlockX = 5 |
117 | "threadIdx.x" , // kThreadX = 6 |
118 | "blockIdx.y" , // kBlockY = 7 |
119 | "threadIdx.y" , // kThreadY = 8 |
120 | "blockIdx.z" , // kBlockZ = 9 |
121 | "threadIdx.z" , // kThreadZ = 10 |
122 | "tensorize" // kTensorized = 11 |
123 | }; |
124 | |
125 | StepNode* Step::CopyOnWrite() { |
126 | CHECK(data_ != nullptr); |
127 | if (!data_.unique()) { |
128 | if (const auto& ps = as<AnnotationStepNode>()) { |
129 | auto n = make_object<AnnotationStepNode>(*ps); |
130 | ObjectPtr<Object>(std::move(n)).swap(data_); |
131 | } else if (const auto& ps = as<FuseStepNode>()) { |
132 | auto n = make_object<FuseStepNode>(*ps); |
133 | ObjectPtr<Object>(std::move(n)).swap(data_); |
134 | } else if (const auto& ps = as<PragmaStepNode>()) { |
135 | auto n = make_object<PragmaStepNode>(*ps); |
136 | ObjectPtr<Object>(std::move(n)).swap(data_); |
137 | } else if (const auto& ps = as<ReorderStepNode>()) { |
138 | auto n = make_object<ReorderStepNode>(*ps); |
139 | ObjectPtr<Object>(std::move(n)).swap(data_); |
140 | } else if (const auto& ps = as<SplitStepNode>()) { |
141 | auto n = make_object<SplitStepNode>(*ps); |
142 | ObjectPtr<Object>(std::move(n)).swap(data_); |
143 | } else if (const auto& ps = as<FollowSplitStepNode>()) { |
144 | auto n = make_object<FollowSplitStepNode>(*ps); |
145 | ObjectPtr<Object>(std::move(n)).swap(data_); |
146 | } else if (const auto& ps = as<FollowFusedSplitStepNode>()) { |
147 | auto n = make_object<FollowFusedSplitStepNode>(*ps); |
148 | ObjectPtr<Object>(std::move(n)).swap(data_); |
149 | } else if (const auto& ps = as<StorageAlignStepNode>()) { |
150 | auto n = make_object<StorageAlignStepNode>(*ps); |
151 | ObjectPtr<Object>(std::move(n)).swap(data_); |
152 | } else if (const auto& ps = as<ComputeAtStepNode>()) { |
153 | auto n = make_object<ComputeAtStepNode>(*ps); |
154 | ObjectPtr<Object>(std::move(n)).swap(data_); |
155 | } else if (const auto& ps = as<ComputeInlineStepNode>()) { |
156 | auto n = make_object<ComputeInlineStepNode>(*ps); |
157 | ObjectPtr<Object>(std::move(n)).swap(data_); |
158 | } else if (const auto& ps = as<ComputeRootStepNode>()) { |
159 | auto n = make_object<ComputeRootStepNode>(*ps); |
160 | ObjectPtr<Object>(std::move(n)).swap(data_); |
161 | } else if (const auto& ps = as<CacheReadStepNode>()) { |
162 | auto n = make_object<CacheReadStepNode>(*ps); |
163 | ObjectPtr<Object>(std::move(n)).swap(data_); |
164 | } else if (const auto& ps = as<CacheWriteStepNode>()) { |
165 | auto n = make_object<CacheWriteStepNode>(*ps); |
166 | ObjectPtr<Object>(std::move(n)).swap(data_); |
167 | } else if (const auto& ps = as<RfactorStepNode>()) { |
168 | auto n = make_object<RfactorStepNode>(*ps); |
169 | ObjectPtr<Object>(std::move(n)).swap(data_); |
170 | } else { |
171 | LOG(FATAL) << "Invalid step: " << (*this); |
172 | } |
173 | } |
174 | return static_cast<StepNode*>(data_.get()); |
175 | } |
176 | |
177 | Step StepReadFromRecord(dmlc::JSONReader* reader) { |
178 | std::string name; |
179 | bool s; |
180 | s = reader->NextArrayItem(); |
181 | ICHECK(s); |
182 | reader->Read(&name); |
183 | if (name == AnnotationStepNode::record_prefix_str) { |
184 | return AnnotationStep(reader); |
185 | } else if (name == FuseStepNode::record_prefix_str) { |
186 | return FuseStep(reader); |
187 | } else if (name == PragmaStepNode::record_prefix_str) { |
188 | return PragmaStep(reader); |
189 | } else if (name == ReorderStepNode::record_prefix_str) { |
190 | return ReorderStep(reader); |
191 | } else if (name == SplitStepNode::record_prefix_str) { |
192 | return SplitStep(reader); |
193 | } else if (name == FollowSplitStepNode::record_prefix_str) { |
194 | return FollowSplitStep(reader); |
195 | } else if (name == FollowFusedSplitStepNode::record_prefix_str) { |
196 | return FollowFusedSplitStep(reader); |
197 | } else if (name == StorageAlignStepNode::record_prefix_str) { |
198 | return StorageAlignStep(reader); |
199 | } else if (name == ComputeAtStepNode::record_prefix_str) { |
200 | return ComputeAtStep(reader); |
201 | } else if (name == ComputeInlineStepNode::record_prefix_str) { |
202 | return ComputeInlineStep(reader); |
203 | } else if (name == ComputeRootStepNode::record_prefix_str) { |
204 | return ComputeRootStep(reader); |
205 | } else if (name == CacheReadStepNode::record_prefix_str) { |
206 | return CacheReadStep(reader); |
207 | } else if (name == CacheWriteStepNode::record_prefix_str) { |
208 | return CacheWriteStep(reader); |
209 | } else if (name == RfactorStepNode::record_prefix_str) { |
210 | return RfactorStep(reader); |
211 | } else { |
212 | LOG(FATAL) << "Invalid step format: " << name; |
213 | } |
214 | return Step(); |
215 | } |
216 | |
217 | void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { |
218 | // We need this runtime dispatcher because different steps have different function signatures |
219 | if (auto ps = step.as<AnnotationStepNode>()) { |
220 | ps->ApplyToState(state); |
221 | } else if (auto ps = step.as<FuseStepNode>()) { |
222 | ps->ApplyToState(state); |
223 | } else if (auto ps = step.as<PragmaStepNode>()) { |
224 | ps->ApplyToState(state); |
225 | } else if (auto ps = step.as<ReorderStepNode>()) { |
226 | ps->ApplyToState(state); |
227 | } else if (auto ps = step.as<SplitStepNode>()) { |
228 | ps->ApplyToState(state); |
229 | } else if (auto ps = step.as<FollowSplitStepNode>()) { |
230 | ps->ApplyToState(state); |
231 | } else if (auto ps = step.as<FollowFusedSplitStepNode>()) { |
232 | ps->ApplyToState(state); |
233 | } else if (auto ps = step.as<StorageAlignStepNode>()) { |
234 | ps->ApplyToState(state); |
235 | } else if (auto ps = step.as<ComputeAtStepNode>()) { |
236 | ps->ApplyToState(state); |
237 | } else if (auto ps = step.as<ComputeInlineStepNode>()) { |
238 | ps->ApplyToState(state); |
239 | } else if (auto ps = step.as<ComputeRootStepNode>()) { |
240 | ps->ApplyToState(state); |
241 | } else if (auto ps = step.as<CacheReadStepNode>()) { |
242 | ps->ApplyToState(state, dag); |
243 | } else if (auto ps = step.as<CacheWriteStepNode>()) { |
244 | ps->ApplyToState(state, dag); |
245 | } else if (auto ps = step.as<RfactorStepNode>()) { |
246 | ps->ApplyToState(state, dag); |
247 | } else { |
248 | LOG(FATAL) << "Invalid step: " << step; |
249 | } |
250 | } |
251 | |
252 | void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
253 | te::Schedule* schedule, const Array<Step>& transform_steps) { |
254 | if (auto ps = step.as<AnnotationStepNode>()) { |
255 | ps->ApplyToSchedule(stages, stage_to_axes); |
256 | } else if (auto ps = step.as<FuseStepNode>()) { |
257 | ps->ApplyToSchedule(stages, stage_to_axes); |
258 | } else if (auto ps = step.as<PragmaStepNode>()) { |
259 | ps->ApplyToSchedule(stages, stage_to_axes); |
260 | } else if (auto ps = step.as<ReorderStepNode>()) { |
261 | ps->ApplyToSchedule(stages, stage_to_axes); |
262 | } else if (auto ps = step.as<SplitStepNode>()) { |
263 | ps->ApplyToSchedule(stages, stage_to_axes); |
264 | } else if (auto ps = step.as<FollowSplitStepNode>()) { |
265 | ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); |
266 | } else if (auto ps = step.as<FollowFusedSplitStepNode>()) { |
267 | ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); |
268 | } else if (auto ps = step.as<StorageAlignStepNode>()) { |
269 | ps->ApplyToSchedule(stages, stage_to_axes); |
270 | } else if (auto ps = step.as<ComputeAtStepNode>()) { |
271 | ps->ApplyToSchedule(stages, stage_to_axes); |
272 | } else if (auto ps = step.as<ComputeInlineStepNode>()) { |
273 | ps->ApplyToSchedule(stages, stage_to_axes); |
274 | } else if (auto ps = step.as<ComputeRootStepNode>()) { |
275 | ps->ApplyToSchedule(stages, stage_to_axes); |
276 | } else if (auto ps = step.as<CacheReadStepNode>()) { |
277 | ps->ApplyToSchedule(stages, stage_to_axes, schedule); |
278 | } else if (auto ps = step.as<CacheWriteStepNode>()) { |
279 | ps->ApplyToSchedule(stages, stage_to_axes, schedule); |
280 | } else if (auto ps = step.as<RfactorStepNode>()) { |
281 | ps->ApplyToSchedule(stages, stage_to_axes, schedule); |
282 | } else { |
283 | LOG(FATAL) << "Invalid Step: " << step; |
284 | } |
285 | } |
286 | |
287 | String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages, |
288 | StageToAxesMap* stage_to_axes, te::Schedule* schedule, |
289 | const Array<Step>& transform_steps) { |
290 | if (auto ps = step.as<AnnotationStepNode>()) { |
291 | return ps->PrintAsPythonAPI(stages, stage_to_axes); |
292 | } else if (auto ps = step.as<FuseStepNode>()) { |
293 | return ps->PrintAsPythonAPI(stages, stage_to_axes); |
294 | } else if (auto ps = step.as<PragmaStepNode>()) { |
295 | return ps->PrintAsPythonAPI(stages, stage_to_axes); |
296 | } else if (auto ps = step.as<ReorderStepNode>()) { |
297 | return ps->PrintAsPythonAPI(stages, stage_to_axes); |
298 | } else if (auto ps = step.as<SplitStepNode>()) { |
299 | return ps->PrintAsPythonAPI(stages, stage_to_axes); |
300 | } else if (auto ps = step.as<FollowSplitStepNode>()) { |
301 | return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); |
302 | } else if (auto ps = step.as<FollowFusedSplitStepNode>()) { |
303 | return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps); |
304 | } else if (auto ps = step.as<StorageAlignStepNode>()) { |
305 | return ps->PrintAsPythonAPI(stages, stage_to_axes); |
306 | } else if (auto ps = step.as<ComputeAtStepNode>()) { |
307 | return ps->PrintAsPythonAPI(stages, stage_to_axes); |
308 | } else if (auto ps = step.as<ComputeInlineStepNode>()) { |
309 | return ps->PrintAsPythonAPI(stages, stage_to_axes); |
310 | } else if (auto ps = step.as<ComputeRootStepNode>()) { |
311 | return ps->PrintAsPythonAPI(stages, stage_to_axes); |
312 | } else if (auto ps = step.as<CacheReadStepNode>()) { |
313 | return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); |
314 | } else if (auto ps = step.as<CacheWriteStepNode>()) { |
315 | return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); |
316 | } else if (auto ps = step.as<RfactorStepNode>()) { |
317 | return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule); |
318 | } else { |
319 | LOG(FATAL) << "Invalid Step: " << step; |
320 | } |
321 | return "" ; |
322 | } |
323 | |
324 | /********** Steps working on single stage **********/ |
325 | |
326 | /********** Annotation **********/ |
327 | AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) { |
328 | auto node = make_object<AnnotationStepNode>(); |
329 | node->stage_id = stage_id; |
330 | node->iter_id = iter_id; |
331 | node->annotation = ann; |
332 | data_ = std::move(node); |
333 | } |
334 | |
335 | AnnotationStep::AnnotationStep(dmlc::JSONReader* reader) { |
336 | auto node = make_object<AnnotationStepNode>(); |
337 | bool s; |
338 | s = reader->NextArrayItem(); |
339 | ICHECK(s); |
340 | reader->Read(&node->stage_id); |
341 | s = reader->NextArrayItem(); |
342 | ICHECK(s); |
343 | reader->Read(&node->iter_id); |
344 | s = reader->NextArrayItem(); |
345 | ICHECK(s); |
346 | int int_val; |
347 | reader->Read(&int_val); |
348 | node->annotation = IteratorAnnotation(int_val); |
349 | data_ = std::move(node); |
350 | } |
351 | |
352 | void AnnotationStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
353 | writer->WriteArraySeperator(); |
354 | writer->WriteString(record_prefix_str); |
355 | writer->WriteArrayItem(stage_id); |
356 | writer->WriteArrayItem(iter_id); |
357 | writer->WriteArrayItem(static_cast<int>(annotation)); |
358 | } |
359 | |
360 | Iterator AnnotationStepNode::ApplyToState(State* state) const { |
361 | const Stage& stage = (*state)->stages[stage_id]; |
362 | Iterator it = stage->iters[iter_id]; |
363 | |
364 | ICHECK(it->annotation == IteratorAnnotation::kNone); |
365 | Iterator new_it = Iterator(it->name, it->range, it->iter_kind, annotation, &it->orig_iters); |
366 | Stage new_stage = stage; |
367 | new_stage.CopyOnWrite()->iters.Set(iter_id, new_it); |
368 | state->CopyOnWrite()->stages.Set(stage_id, std::move(new_stage)); |
369 | return new_it; |
370 | } |
371 | |
372 | void AnnotationStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
373 | StageToAxesMap* stage_to_axes) const { |
374 | te::Stage stage = (*stages)[stage_id]; |
375 | const Array<IterVar>& axes = (*stage_to_axes)[stage]; |
376 | |
377 | switch (annotation) { |
378 | case IteratorAnnotation::kUnroll: |
379 | stage.unroll(axes[iter_id]); |
380 | break; |
381 | case IteratorAnnotation::kVectorize: |
382 | stage.vectorize(axes[iter_id]); |
383 | break; |
384 | case IteratorAnnotation::kParallel: |
385 | stage.parallel(axes[iter_id]); |
386 | break; |
387 | case IteratorAnnotation::kVThread: |
388 | case IteratorAnnotation::kBlockX: |
389 | case IteratorAnnotation::kBlockY: |
390 | case IteratorAnnotation::kBlockZ: |
391 | case IteratorAnnotation::kThreadX: |
392 | case IteratorAnnotation::kThreadY: |
393 | case IteratorAnnotation::kThreadZ: |
394 | stage.bind(axes[iter_id], |
395 | te::thread_axis(Range(), IteratorAnnotationString[static_cast<int>(annotation)])); |
396 | break; |
397 | case IteratorAnnotation::kNone: |
398 | break; |
399 | default: |
400 | LOG(FATAL) << "Invalid Annotation " << static_cast<int>(annotation); |
401 | break; |
402 | } |
403 | |
404 | stages->Set(stage_id, std::move(stage)); |
405 | } |
406 | |
407 | String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
408 | StageToAxesMap* stage_to_axes) const { |
409 | std::stringstream ss; |
410 | const auto& stage = (*stages)[stage_id]; |
411 | const auto& iter = (*stage_to_axes)[stage][iter_id]; |
412 | const auto& op_name = CleanName(stage->op->name); |
413 | |
414 | ss << "s[" << op_name << "]." ; |
415 | switch (annotation) { |
416 | case IteratorAnnotation::kUnroll: |
417 | ss << "unroll(" ; |
418 | break; |
419 | case IteratorAnnotation::kVectorize: |
420 | ss << "vectorize(" ; |
421 | break; |
422 | case IteratorAnnotation::kParallel: |
423 | ss << "parallel(" ; |
424 | break; |
425 | case IteratorAnnotation::kVThread: |
426 | case IteratorAnnotation::kBlockX: |
427 | case IteratorAnnotation::kBlockY: |
428 | case IteratorAnnotation::kBlockZ: |
429 | case IteratorAnnotation::kThreadX: |
430 | case IteratorAnnotation::kThreadY: |
431 | case IteratorAnnotation::kThreadZ: |
432 | ss << "bind(" ; |
433 | break; |
434 | case IteratorAnnotation::kNone: |
435 | break; |
436 | default: |
437 | LOG(FATAL) << "Invalid annotation " << static_cast<int>(annotation); |
438 | break; |
439 | } |
440 | ss << CleanName(iter->var->name_hint, op_name); |
441 | switch (annotation) { |
442 | case IteratorAnnotation::kVThread: |
443 | case IteratorAnnotation::kBlockX: |
444 | case IteratorAnnotation::kBlockY: |
445 | case IteratorAnnotation::kBlockZ: |
446 | case IteratorAnnotation::kThreadX: |
447 | case IteratorAnnotation::kThreadY: |
448 | case IteratorAnnotation::kThreadZ: |
449 | ss << ", te.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)] |
450 | << "\")" ; |
451 | break; |
452 | default: |
453 | break; |
454 | } |
455 | ss << ")\n" ; |
456 | |
457 | ApplyToSchedule(stages, stage_to_axes); |
458 | return ss.str(); |
459 | } |
460 | |
461 | /********** Fuse **********/ |
462 | FuseStep::FuseStep(int stage_id, const Array<Integer>& fused_ids) { |
463 | auto node = make_object<FuseStepNode>(); |
464 | node->stage_id = stage_id; |
465 | for (const auto& x : fused_ids) { |
466 | ICHECK(x->IsInstance<IntImmNode>()); |
467 | } |
468 | node->fused_ids = fused_ids; |
469 | data_ = std::move(node); |
470 | } |
471 | |
472 | FuseStep::FuseStep(dmlc::JSONReader* reader) { |
473 | auto node = make_object<FuseStepNode>(); |
474 | bool s; |
475 | s = reader->NextArrayItem(); |
476 | ICHECK(s); |
477 | reader->Read(&node->stage_id); |
478 | s = reader->NextArrayItem(); |
479 | ICHECK(s); |
480 | reader->Read(&node->fused_ids); |
481 | data_ = std::move(node); |
482 | } |
483 | |
484 | void FuseStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
485 | writer->WriteArraySeperator(); |
486 | writer->WriteString(record_prefix_str); |
487 | writer->WriteArrayItem(stage_id); |
488 | writer->WriteArrayItem(fused_ids); |
489 | } |
490 | |
491 | Iterator FuseStepNode::ApplyToState(State* state) const { |
492 | const Stage& stage = (*state)->stages[stage_id]; |
493 | size_t old_iter_size = static_cast<int>(stage->iters.size()); |
494 | |
495 | String new_name; |
496 | PrimExpr new_extent = 1; |
497 | IteratorKind new_iter_kind = IteratorKind::kSpecial; |
498 | std::vector<Iterator> orig_iters; |
499 | |
500 | for (size_t i = 0; i < fused_ids.size(); ++i) { |
501 | if (i > 0) { |
502 | ICHECK_EQ(fused_ids[i]->value, fused_ids[i - 1]->value + 1); |
503 | } |
504 | if (i != fused_ids.size() - 1) { |
505 | const auto& iter_to_attached_stage = (*state)->attach_map->iter_to_attached_stages; |
506 | if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i].IntValue())) != |
507 | iter_to_attached_stage.end()) { |
508 | LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some " |
509 | << "stages. State before fusion:\n" |
510 | << (*state); |
511 | } |
512 | } |
513 | |
514 | const Iterator& it = stage->iters[fused_ids[i].IntValue()]; |
515 | orig_iters.push_back(it); |
516 | new_name = new_name + it->name + "@" ; |
517 | |
518 | if (it->range.defined() && new_extent.defined()) { |
519 | new_extent = new_extent * it->range->extent; |
520 | } else { |
521 | new_extent = PrimExpr(); |
522 | } |
523 | |
524 | if (i == 0) { |
525 | new_iter_kind = it->iter_kind; |
526 | } else { |
527 | if (new_iter_kind != it->iter_kind) { |
528 | new_iter_kind = IteratorKind::kMixed; |
529 | } |
530 | } |
531 | } |
532 | |
533 | Range range; |
534 | if (new_extent.defined()) { |
535 | range = Range::FromMinExtent(0, new_extent); |
536 | } |
537 | Iterator new_it = |
538 | Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone, &orig_iters); |
539 | Array<Iterator> new_iters; |
540 | |
541 | if (fused_ids.empty()) { |
542 | new_iters.push_back(new_it); |
543 | } else { |
544 | new_iters.insert(new_iters.end(), stage->iters.begin(), |
545 | stage->iters.begin() + fused_ids.front().IntValue()); |
546 | new_iters.push_back(new_it); |
547 | new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back().IntValue() + 1, |
548 | stage->iters.end()); |
549 | } |
550 | |
551 | StateNode* pstate = state->CopyOnWrite(); |
552 | pstate->stages.Set(stage_id, |
553 | Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); |
554 | |
555 | if (fused_ids.empty()) { |
556 | return new_it; |
557 | } |
558 | |
559 | // Two vectors are used to represent the iterator relation before and after fuse |
560 | // The original iterators in AttachMap will be updated with the new iterators |
561 | std::vector<IterKey> from_iters; |
562 | std::vector<IterKey> to_iters; |
563 | const size_t begin_id = fused_ids.front().IntValue(), end_id = fused_ids.back().IntValue(); |
564 | for (size_t i = 0; i < old_iter_size; ++i) { |
565 | if (i <= begin_id) { |
566 | continue; |
567 | } else if (i > end_id) { |
568 | // move forward |
569 | from_iters.emplace_back(stage_id, i); |
570 | to_iters.emplace_back(stage_id, i - end_id + begin_id); |
571 | } else { |
572 | // move to the fused id |
573 | from_iters.emplace_back(stage_id, i); |
574 | to_iters.emplace_back(stage_id, begin_id); |
575 | } |
576 | } |
577 | pstate->attach_map.UpdateIters(from_iters, to_iters); |
578 | |
579 | return new_it; |
580 | } |
581 | |
582 | IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
583 | StageToAxesMap* stage_to_axes) const { |
584 | auto stage = (*stages)[stage_id]; |
585 | const Array<IterVar>& axes = stage_to_axes->at(stage); |
586 | |
587 | Array<IterVar> to_fuse; |
588 | for (const auto& i : fused_ids) { |
589 | to_fuse.push_back(axes[i.IntValue()]); |
590 | } |
591 | IterVar fused_axis; |
592 | stage.fuse(to_fuse, &fused_axis); |
593 | |
594 | Array<IterVar> new_axes; |
595 | if (fused_ids.empty()) { |
596 | new_axes.push_back(fused_axis); |
597 | } else { |
598 | new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front().IntValue()); |
599 | new_axes.push_back(fused_axis); |
600 | new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back().IntValue() + 1, axes.end()); |
601 | } |
602 | |
603 | stage_to_axes->Set(stage, std::move(new_axes)); |
604 | stages->Set(stage_id, std::move(stage)); |
605 | return fused_axis; |
606 | } |
607 | |
608 | String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
609 | StageToAxesMap* stage_to_axes) const { |
610 | const auto& stage = (*stages)[stage_id]; |
611 | const auto& op_name = CleanName(stage->op->name); |
612 | std::stringstream to_fuse; |
613 | |
614 | for (size_t i = 0; i < fused_ids.size(); ++i) { |
615 | to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i].IntValue()]->var->name_hint, |
616 | op_name); |
617 | if (i != fused_ids.size() - 1) { |
618 | to_fuse << ", " ; |
619 | } |
620 | } |
621 | |
622 | std::stringstream ss; |
623 | const auto& fused = ApplyToSchedule(stages, stage_to_axes); |
624 | |
625 | ss << CleanName(fused->var->name_hint, op_name) << " = s[" << op_name << "].fuse(" |
626 | << to_fuse.str() << ")\n" ; |
627 | |
628 | return ss.str(); |
629 | } |
630 | |
631 | /********** Pragma **********/ |
632 | PragmaStep::PragmaStep(int stage_id, int iter_id, String pragma_type) { |
633 | auto node = make_object<PragmaStepNode>(); |
634 | node->stage_id = stage_id; |
635 | node->iter_id = iter_id; |
636 | node->pragma_type = std::move(pragma_type); |
637 | data_ = std::move(node); |
638 | } |
639 | |
640 | PragmaStep::PragmaStep(dmlc::JSONReader* reader) { |
641 | auto node = make_object<PragmaStepNode>(); |
642 | bool s; |
643 | s = reader->NextArrayItem(); |
644 | ICHECK(s); |
645 | reader->Read(&node->stage_id); |
646 | s = reader->NextArrayItem(); |
647 | ICHECK(s); |
648 | reader->Read(&node->iter_id); |
649 | s = reader->NextArrayItem(); |
650 | ICHECK(s); |
651 | std::string string_value; |
652 | reader->Read(&string_value); |
653 | node->pragma_type = std::move(string_value); |
654 | data_ = std::move(node); |
655 | } |
656 | |
657 | void PragmaStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
658 | writer->WriteArraySeperator(); |
659 | writer->WriteString(record_prefix_str); |
660 | writer->WriteArrayItem(stage_id); |
661 | writer->WriteArrayItem(iter_id); |
662 | writer->WriteArraySeperator(); |
663 | writer->WriteString(pragma_type); |
664 | } |
665 | |
666 | void PragmaStepNode::ApplyToState(State* state) const { |
667 | if (pragma_type == "debug_skip_region" ) { |
668 | StateNode* pstate = state->CopyOnWrite(); |
669 | pstate->attach_map.DeleteStage(stage_id); |
670 | } else if (StrStartsWith(pragma_type, "auto_unroll_max_step" )) { |
671 | StateNode* pstate = state->CopyOnWrite(); |
672 | Stage stage = pstate->stages[stage_id]; |
673 | size_t pos = 0; |
674 | for (; pos < pragma_type.size(); ++pos) { |
675 | if ((*(pragma_type.c_str() + pos)) == '$') { |
676 | break; |
677 | } |
678 | } |
679 | ICHECK_LT(pos, pragma_type.size()) << "max step value not found." ; |
680 | stage.CopyOnWrite()->attrs.auto_unroll_max_step = atoi(pragma_type.c_str() + pos + 1); |
681 | pstate->stages.Set(stage_id, std::move(stage)); |
682 | } else { |
683 | LOG(FATAL) << "Unsupported pragma: " << pragma_type; |
684 | } |
685 | } |
686 | |
687 | void PragmaStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
688 | StageToAxesMap* stage_to_axes) const { |
689 | te::Stage stage = (*stages)[stage_id]; |
690 | const Array<IterVar>& axes = (*stage_to_axes)[stage]; |
691 | if (StrStartsWith(pragma_type, "auto_unroll_max_step" )) { |
692 | size_t pos = 0; |
693 | for (; pos < pragma_type.size(); ++pos) { |
694 | if ((*(pragma_type.c_str() + pos)) == '$') { |
695 | break; |
696 | } |
697 | } |
698 | ICHECK_LT(pos, pragma_type.size()) << "max step value not found." ; |
699 | int value = atoi(pragma_type.c_str() + pos + 1); |
700 | if (iter_id < static_cast<int>(axes.size())) { |
701 | stage.pragma(axes[iter_id], "auto_unroll_max_step" , value); |
702 | stage.pragma(axes[iter_id], "unroll_explicit" , true); |
703 | } |
704 | } else { |
705 | ICHECK_LT(iter_id, axes.size()); |
706 | stage.pragma(axes[iter_id], pragma_type); |
707 | } |
708 | stages->Set(stage_id, std::move(stage)); |
709 | } |
710 | |
711 | String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
712 | StageToAxesMap* stage_to_axes) const { |
713 | std::stringstream ss; |
714 | const auto& stage = (*stages)[stage_id]; |
715 | const auto& op_name = CleanName(stage->op->name); |
716 | |
717 | if (StrStartsWith(pragma_type, "auto_unroll_max_step" )) { |
718 | size_t pos = 0; |
719 | for (; pos < pragma_type.size(); ++pos) { |
720 | if ((*(pragma_type.c_str() + pos)) == '$') { |
721 | break; |
722 | } |
723 | } |
724 | ICHECK_LT(pos, pragma_type.size()) << "max step value not found." ; |
725 | int value = atoi(pragma_type.c_str() + pos + 1); |
726 | ss << "s[" << op_name << "].pragma(" |
727 | << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) |
728 | << ", \"auto_unroll_max_step\", " << value << ")\n" ; |
729 | ss << "s[" << op_name << "].pragma(" |
730 | << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) |
731 | << ", \"unroll_explicit\", True)\n" ; |
732 | } else { |
733 | ss << "s[" << op_name << "].pragma(" |
734 | << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", \"" |
735 | << pragma_type << "\")\n" ; |
736 | } |
737 | |
738 | ApplyToSchedule(stages, stage_to_axes); |
739 | return ss.str(); |
740 | } |
741 | |
742 | /********** Reorder **********/ |
743 | ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) { |
744 | auto node = make_object<ReorderStepNode>(); |
745 | node->stage_id = stage_id; |
746 | for (const auto& x : after_ids) { |
747 | ICHECK(x->IsInstance<IntImmNode>()); |
748 | } |
749 | node->after_ids = after_ids; |
750 | data_ = std::move(node); |
751 | } |
752 | |
753 | ReorderStep::ReorderStep(dmlc::JSONReader* reader) { |
754 | auto node = make_object<ReorderStepNode>(); |
755 | bool s; |
756 | s = reader->NextArrayItem(); |
757 | ICHECK(s); |
758 | reader->Read(&node->stage_id); |
759 | s = reader->NextArrayItem(); |
760 | ICHECK(s); |
761 | reader->Read(&node->after_ids); |
762 | data_ = std::move(node); |
763 | } |
764 | |
765 | void ReorderStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
766 | writer->WriteArraySeperator(); |
767 | writer->WriteString(record_prefix_str); |
768 | writer->WriteArrayItem(stage_id); |
769 | writer->WriteArrayItem(after_ids); |
770 | } |
771 | |
772 | void ReorderStepNode::ApplyToState(State* state) const { |
773 | const Stage& stage = (*state)->stages[stage_id]; |
774 | Array<Iterator> iters; |
775 | for (auto x : after_ids) { |
776 | iters.push_back(stage->iters[x.IntValue()]); |
777 | } |
778 | state->CopyOnWrite()->stages.Set( |
779 | stage_id, Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs)); |
780 | } |
781 | |
782 | void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
783 | StageToAxesMap* stage_to_axes) const { |
784 | auto stage = (*stages)[stage_id]; |
785 | const Array<IterVar>& axes = stage_to_axes->at(stage); |
786 | ICHECK_EQ(after_ids.size(), axes.size()); |
787 | |
788 | Array<IterVar> new_axes; |
789 | new_axes.reserve(axes.size()); |
790 | for (auto i : after_ids) { |
791 | new_axes.push_back(axes[i.IntValue()]); |
792 | } |
793 | stage.reorder(new_axes); |
794 | |
795 | stage_to_axes->Set(stage, std::move(new_axes)); |
796 | stages->Set(stage_id, std::move(stage)); |
797 | } |
798 | |
799 | String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
800 | StageToAxesMap* stage_to_axes) const { |
801 | const auto& stage = (*stages)[stage_id]; |
802 | const auto& op_name = CleanName(stage->op->name); |
803 | std::stringstream ss; |
804 | |
805 | ss << "s[" << op_name << "].reorder(" ; |
806 | for (size_t i = 0; i < after_ids.size(); ++i) { |
807 | ss << CleanName((*stage_to_axes)[stage][after_ids[i].IntValue()]->var->name_hint, op_name); |
808 | if (i != after_ids.size() - 1) { |
809 | ss << ", " ; |
810 | } |
811 | } |
812 | ss << ")\n" ; |
813 | |
814 | ApplyToSchedule(stages, stage_to_axes); |
815 | return ss.str(); |
816 | } |
817 | |
818 | /********** Split **********/ |
819 | // common part for SplitStep, FollowSplitStep, and FollowFusedSplitStep |
820 | Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id, |
821 | const Array<Optional<Integer>>& lengths, bool inner_to_outer) { |
822 | const Stage& stage = (*state)->stages[stage_id]; |
823 | const Iterator& it = stage->iters[iter_id]; |
824 | size_t old_iter_size = stage->iters.size(); |
825 | bool concrete = true; |
826 | |
827 | Optional<PrimExpr> tosplit_min, tosplit_extent; |
828 | if (it->range.defined()) { |
829 | tosplit_min = it->range->min; |
830 | tosplit_extent = it->range->extent; |
831 | } else { |
832 | tosplit_min = NullOpt; |
833 | tosplit_extent = NullOpt; |
834 | } |
835 | |
836 | Array<Iterator> outs; |
837 | for (size_t i = 0; i < lengths.size(); ++i) { |
838 | Optional<Integer> l; |
839 | String name; |
840 | if (inner_to_outer) { |
841 | l = lengths[lengths.size() - i - 1]; |
842 | name = it->name + "." + std::to_string(lengths.size() - i); |
843 | } else { |
844 | l = lengths[i]; |
845 | name = it->name + "." + std::to_string(i); |
846 | } |
847 | Iterator res; |
848 | if (l && tosplit_min && tosplit_extent) { |
849 | res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind, |
850 | IteratorAnnotation::kNone); |
851 | tosplit_min = Integer(0); |
852 | tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value()); |
853 | } else { |
854 | res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone); |
855 | tosplit_min = NullOpt; |
856 | tosplit_extent = NullOpt; |
857 | if (!l.defined()) { |
858 | concrete = false; |
859 | } |
860 | } |
861 | outs.push_back(std::move(res)); |
862 | } |
863 | |
864 | Range range; |
865 | if (tosplit_min && tosplit_extent) { |
866 | range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value()); |
867 | } |
868 | if (inner_to_outer) { |
869 | outs.push_back(Iterator(it->name + ".0" , range, it->iter_kind, IteratorAnnotation::kNone)); |
870 | // Reverse the Iterator array |
871 | Array<Iterator> temp(outs.rbegin(), outs.rend()); |
872 | outs = std::move(temp); |
873 | } else { |
874 | outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind, |
875 | IteratorAnnotation::kNone)); |
876 | } |
877 | |
878 | Array<Iterator> new_iters; |
879 | new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); |
880 | new_iters.insert(new_iters.end(), outs.begin(), outs.end()); |
881 | new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); |
882 | |
883 | StateNode* pstate = state->CopyOnWrite(); |
884 | pstate->stages.Set(stage_id, |
885 | Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); |
886 | pstate->concrete &= concrete; |
887 | |
888 | // Two vectors are used to represent the iterator relation before and after split |
889 | // The original iterators in AttachMap will be updated with the new iterators |
890 | std::vector<IterKey> from_iters; |
891 | std::vector<IterKey> to_iters; |
892 | for (size_t i = iter_id; i < old_iter_size; ++i) { |
893 | from_iters.emplace_back(stage_id, i); |
894 | to_iters.emplace_back(stage_id, i + lengths.size()); |
895 | } |
896 | pstate->attach_map.UpdateIters(from_iters, to_iters); |
897 | |
898 | return outs; |
899 | } |
900 | |
901 | Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
902 | int stage_id, int iter_id, |
903 | const Array<Optional<Integer>>& lengths, bool inner_to_outer) { |
904 | auto stage = (*stages)[stage_id]; |
905 | const Array<IterVar>& axes = stage_to_axes->at(stage); |
906 | |
907 | Array<IterVar> outs; |
908 | if (inner_to_outer) { |
909 | IterVar outer = axes[iter_id], inner; |
910 | for (int i = static_cast<int>(lengths.size()) - 1; i >= 0; i--) { |
911 | IterVar to_split = outer; |
912 | stage.split(to_split, lengths[i].value(), &outer, &inner); |
913 | outs.push_back(inner); |
914 | } |
915 | outs.push_back(outer); |
916 | } else { |
917 | IterVar outer, inner = axes[iter_id]; |
918 | for (size_t i = 0; i < lengths.size(); i++) { |
919 | IterVar to_split = inner; |
920 | stage.split_by_nparts(to_split, lengths[i].value(), &outer, &inner); |
921 | outs.push_back(outer); |
922 | } |
923 | outs.push_back(inner); |
924 | } |
925 | |
926 | Array<IterVar> new_axes; |
927 | new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); |
928 | if (inner_to_outer) { |
929 | for (auto x = outs.rbegin(); x != outs.rend(); ++x) { |
930 | new_axes.push_back((*x)); |
931 | } |
932 | } else { |
933 | for (const auto& x : outs) { |
934 | new_axes.push_back(x); |
935 | } |
936 | } |
937 | new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); |
938 | |
939 | stage_to_axes->Set(stage, std::move(new_axes)); |
940 | stages->Set(stage_id, std::move(stage)); |
941 | return outs; |
942 | } |
943 | |
944 | String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, int stage_id, |
945 | int iter_id, const Array<Optional<Integer>>& lengths, |
946 | bool inner_to_outer) { |
947 | const auto& stage = (*stages)[stage_id]; |
948 | auto to_split = stage_to_axes->at(stage)[iter_id]; |
949 | const auto& func_name = CleanName(stage->op->name); |
950 | const auto& outs = |
951 | ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); |
952 | ICHECK_EQ(outs.size(), lengths.size() + 1); |
953 | |
954 | std::stringstream ss; |
955 | int size = static_cast<int>(lengths.size()); |
956 | if (inner_to_outer) { |
957 | for (int i = size - 1; i >= 0; i--) { |
958 | ss << CleanName(outs[size - i]->var->name_hint, func_name) << ", " |
959 | << CleanName(outs[size - i - 1]->var->name_hint, func_name) << " = s[" << func_name |
960 | << "].split(" << CleanName(to_split->var->name_hint, func_name) |
961 | << ", factor=" << lengths[i] << ")\n" ; |
962 | to_split = outs[size - i]; |
963 | } |
964 | } else { |
965 | for (int i = 0; i < size; i++) { |
966 | ss << CleanName(outs[i]->var->name_hint, func_name) << ", " |
967 | << CleanName(outs[i + 1]->var->name_hint, func_name) << " = s[" << func_name << "].split(" |
968 | << CleanName(to_split->var->name_hint, func_name) << ", nparts=" << lengths[i] << ")\n" ; |
969 | to_split = outs[i + 1]; |
970 | } |
971 | } |
972 | |
973 | return ss.str(); |
974 | } |
975 | |
976 | SplitStep::SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent, |
977 | const Array<Optional<Integer>>& lengths, bool inner_to_outer) { |
978 | auto node = make_object<SplitStepNode>(); |
979 | node->stage_id = stage_id; |
980 | // Extent can be a irreducible expression in some special cases |
981 | if (extent && extent.value()->IsInstance<IntImmNode>()) { |
982 | node->extent = tvm::Downcast<Integer>(extent.value()); |
983 | } |
984 | node->iter_id = iter_id; |
985 | node->lengths = lengths; |
986 | node->inner_to_outer = inner_to_outer; |
987 | data_ = std::move(node); |
988 | } |
989 | |
990 | SplitStep::SplitStep(dmlc::JSONReader* reader) { |
991 | auto node = make_object<SplitStepNode>(); |
992 | bool s; |
993 | s = reader->NextArrayItem(); |
994 | ICHECK(s); |
995 | reader->Read(&node->stage_id); |
996 | s = reader->NextArrayItem(); |
997 | ICHECK(s); |
998 | reader->Read(&node->iter_id); |
999 | int int_val; |
1000 | s = reader->NextArrayItem(); |
1001 | ICHECK(s); |
1002 | reader->Read(&int_val); |
1003 | if (int_val) { |
1004 | node->extent = Integer(int_val); |
1005 | } |
1006 | s = reader->NextArrayItem(); |
1007 | ICHECK(s); |
1008 | reader->Read(&node->lengths); |
1009 | s = reader->NextArrayItem(); |
1010 | ICHECK(s); |
1011 | reader->Read(&node->inner_to_outer); |
1012 | data_ = std::move(node); |
1013 | } |
1014 | |
1015 | void SplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
1016 | writer->WriteArraySeperator(); |
1017 | writer->WriteString(record_prefix_str); |
1018 | writer->WriteArrayItem(stage_id); |
1019 | writer->WriteArrayItem(iter_id); |
1020 | writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0); |
1021 | writer->WriteArrayItem(lengths); |
1022 | writer->WriteArrayItem(static_cast<int>(inner_to_outer)); |
1023 | } |
1024 | |
1025 | Array<Iterator> SplitStepNode::ApplyToState(State* state) const { |
1026 | return ApplySplitToState(state, stage_id, iter_id, lengths, inner_to_outer); |
1027 | } |
1028 | |
1029 | Array<IterVar> SplitStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
1030 | StageToAxesMap* stage_to_axes) const { |
1031 | return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); |
1032 | } |
1033 | |
1034 | String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
1035 | StageToAxesMap* stage_to_axes) const { |
1036 | return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); |
1037 | } |
1038 | |
1039 | /********** Follow Split **********/ |
1040 | FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) { |
1041 | auto node = make_object<FollowSplitStepNode>(); |
1042 | node->stage_id = stage_id; |
1043 | node->iter_id = iter_id; |
1044 | node->src_step_id = src_step_id; |
1045 | node->n_split = n_split; |
1046 | data_ = std::move(node); |
1047 | } |
1048 | |
1049 | void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
1050 | writer->WriteArraySeperator(); |
1051 | writer->WriteString(record_prefix_str); |
1052 | writer->WriteArrayItem(stage_id); |
1053 | writer->WriteArrayItem(iter_id); |
1054 | writer->WriteArrayItem(src_step_id); |
1055 | writer->WriteArrayItem(n_split); |
1056 | } |
1057 | |
1058 | Array<Optional<Integer>> FollowSplitStepNode::( |
1059 | const Array<Step>& transform_steps) const { |
1060 | // Make sure src_step_id is within the range of transform_steps. |
1061 | ICHECK_LT(src_step_id, transform_steps.size()); |
1062 | auto ps = transform_steps[src_step_id].as<SplitStepNode>(); |
1063 | ICHECK(ps != nullptr); |
1064 | |
1065 | // Make sure the size of ps->lengths is not smaller than n_split-1. |
1066 | // Note that the number of actual splitting factors of src_step is ps->lengths.size()+1. |
1067 | ICHECK_LE(n_split, ps->lengths.size() + 1); |
1068 | ICHECK(ps != nullptr); |
1069 | |
1070 | Array<Optional<Integer>> lengths; |
1071 | lengths.reserve(n_split); |
1072 | int j = 0; |
1073 | // Get the first (n_split-1) split factors of followed src_step. |
1074 | for (; j < n_split - 1; ++j) { |
1075 | lengths.push_back(ps->lengths[j]); |
1076 | } |
1077 | |
1078 | // Get the last split factor of src_step for splitting level if n_split is smaller than |
1079 | // ps->lengths.size()+1. |
1080 | PrimExpr last_factor = 1; |
1081 | for (; j < static_cast<int>(ps->lengths.size()); ++j) { |
1082 | if (ps->lengths[j]) { |
1083 | last_factor *= ps->lengths[j].value(); |
1084 | } else { |
1085 | last_factor = PrimExpr(); |
1086 | break; |
1087 | } |
1088 | } |
1089 | if (last_factor.defined()) { |
1090 | lengths.push_back(Downcast<Integer>(last_factor)); |
1091 | } else { |
1092 | lengths.push_back(NullOpt); |
1093 | } |
1094 | |
1095 | return lengths; |
1096 | } |
1097 | |
1098 | FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) { |
1099 | auto node = make_object<FollowSplitStepNode>(); |
1100 | bool s; |
1101 | s = reader->NextArrayItem(); |
1102 | ICHECK(s); |
1103 | reader->Read(&node->stage_id); |
1104 | s = reader->NextArrayItem(); |
1105 | ICHECK(s); |
1106 | reader->Read(&node->iter_id); |
1107 | s = reader->NextArrayItem(); |
1108 | ICHECK(s); |
1109 | reader->Read(&node->src_step_id); |
1110 | s = reader->NextArrayItem(); |
1111 | ICHECK(s); |
1112 | reader->Read(&node->n_split); |
1113 | data_ = std::move(node); |
1114 | } |
1115 | |
1116 | Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const { |
1117 | return ApplySplitToState(state, stage_id, iter_id, ExtractSplitLengths((*state)->transform_steps), |
1118 | true); |
1119 | } |
1120 | |
1121 | Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
1122 | StageToAxesMap* stage_to_axes, |
1123 | const Array<Step>& transform_steps) const { |
1124 | return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, |
1125 | ExtractSplitLengths(transform_steps), true); |
1126 | } |
1127 | |
1128 | String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
1129 | StageToAxesMap* stage_to_axes, |
1130 | const Array<Step>& transform_steps) const { |
1131 | return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, |
1132 | ExtractSplitLengths(transform_steps), true); |
1133 | } |
1134 | |
1135 | /********** Follow Fused Split **********/ |
1136 | FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, |
1137 | const Array<Integer>& src_step_ids, int level, |
1138 | bool factor_or_nparts) { |
1139 | auto node = make_object<FollowFusedSplitStepNode>(); |
1140 | node->stage_id = stage_id; |
1141 | node->iter_id = iter_id; |
1142 | node->src_step_ids = src_step_ids; |
1143 | node->level = level; |
1144 | node->factor_or_nparts = factor_or_nparts; |
1145 | data_ = std::move(node); |
1146 | } |
1147 | |
1148 | FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) { |
1149 | auto node = make_object<FollowFusedSplitStepNode>(); |
1150 | bool s; |
1151 | s = reader->NextArrayItem(); |
1152 | ICHECK(s); |
1153 | reader->Read(&node->stage_id); |
1154 | s = reader->NextArrayItem(); |
1155 | ICHECK(s); |
1156 | reader->Read(&node->iter_id); |
1157 | s = reader->NextArrayItem(); |
1158 | ICHECK(s); |
1159 | reader->Read(&node->src_step_ids); |
1160 | s = reader->NextArrayItem(); |
1161 | ICHECK(s); |
1162 | reader->Read(&node->level); |
1163 | s = reader->NextArrayItem(); |
1164 | ICHECK(s); |
1165 | reader->Read(&node->factor_or_nparts); |
1166 | data_ = std::move(node); |
1167 | } |
1168 | |
1169 | void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
1170 | writer->WriteArraySeperator(); |
1171 | writer->WriteString(record_prefix_str); |
1172 | writer->WriteArrayItem(stage_id); |
1173 | writer->WriteArrayItem(iter_id); |
1174 | writer->WriteArrayItem(src_step_ids); |
1175 | writer->WriteArrayItem(level); |
1176 | writer->WriteArrayItem(static_cast<int>(factor_or_nparts)); |
1177 | } |
1178 | |
1179 | Optional<Integer> FollowFusedSplitStepNode::( |
1180 | const Array<Step>& transform_steps) const { |
1181 | PrimExpr ret(1); |
1182 | |
1183 | for (auto src_step_id : src_step_ids) { |
1184 | // Make sure the src_step_id is within the range of transform_steps. |
1185 | ICHECK_LT(src_step_id.IntValue(), transform_steps.size()); |
1186 | auto ps = transform_steps[src_step_id.IntValue()].as<SplitStepNode>(); |
1187 | ICHECK(ps != nullptr); |
1188 | // Multiple the splitting factor on corresponding splitting level of src_steps. |
1189 | if (ps->lengths[level] && ret.defined()) { |
1190 | ret *= ps->lengths[level].value(); |
1191 | } else { |
1192 | return NullOpt; |
1193 | } |
1194 | } |
1195 | return Downcast<Integer>(ret); |
1196 | } |
1197 | |
1198 | Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const { |
1199 | return ApplySplitToState(state, stage_id, iter_id, |
1200 | {ExtractSplitLength((*state)->transform_steps)}, factor_or_nparts); |
1201 | } |
1202 | |
1203 | Array<IterVar> FollowFusedSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
1204 | StageToAxesMap* stage_to_axes, |
1205 | const Array<Step>& transform_steps) const { |
1206 | return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, |
1207 | {ExtractSplitLength(transform_steps)}, factor_or_nparts); |
1208 | } |
1209 | |
1210 | String FollowFusedSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
1211 | StageToAxesMap* stage_to_axes, |
1212 | const Array<Step>& transform_steps) const { |
1213 | return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, |
1214 | {ExtractSplitLength(transform_steps)}, factor_or_nparts); |
1215 | } |
1216 | |
1217 | /********** Storage Align **********/ |
1218 | StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, int factor, int offset) { |
1219 | auto node = make_object<StorageAlignStepNode>(); |
1220 | node->stage_id = stage_id; |
1221 | node->iter_id = iter_id; |
1222 | node->factor = factor; |
1223 | node->offset = offset; |
1224 | data_ = std::move(node); |
1225 | } |
1226 | |
1227 | StorageAlignStep::StorageAlignStep(dmlc::JSONReader* reader) { |
1228 | auto node = make_object<StorageAlignStepNode>(); |
1229 | bool s; |
1230 | s = reader->NextArrayItem(); |
1231 | ICHECK(s); |
1232 | reader->Read(&node->stage_id); |
1233 | s = reader->NextArrayItem(); |
1234 | ICHECK(s); |
1235 | reader->Read(&node->iter_id); |
1236 | s = reader->NextArrayItem(); |
1237 | ICHECK(s); |
1238 | reader->Read(&node->factor); |
1239 | s = reader->NextArrayItem(); |
1240 | ICHECK(s); |
1241 | reader->Read(&node->offset); |
1242 | data_ = std::move(node); |
1243 | } |
1244 | |
1245 | void StorageAlignStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
1246 | writer->WriteArraySeperator(); |
1247 | writer->WriteString(record_prefix_str); |
1248 | writer->WriteArrayItem(stage_id); |
1249 | writer->WriteArrayItem(iter_id); |
1250 | writer->WriteArrayItem(factor); |
1251 | writer->WriteArrayItem(offset); |
1252 | } |
1253 | |
1254 | void StorageAlignStepNode::ApplyToState(State* state) const { |
1255 | StateNode* pstate = state->CopyOnWrite(); |
1256 | Stage stage = pstate->stages[stage_id]; |
1257 | stage.CopyOnWrite()->attrs.storage_offset = offset; |
1258 | pstate->stages.Set(stage_id, std::move(stage)); |
1259 | } |
1260 | |
1261 | void StorageAlignStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
1262 | StageToAxesMap* stage_to_axes) const { |
1263 | te::Stage stage = (*stages)[stage_id]; |
1264 | const Array<IterVar>& axes = (*stage_to_axes)[stage]; |
1265 | stage.storage_align(axes[iter_id], factor, offset); |
1266 | stages->Set(stage_id, std::move(stage)); |
1267 | } |
1268 | |
1269 | String StorageAlignStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
1270 | StageToAxesMap* stage_to_axes) const { |
1271 | std::stringstream ss; |
1272 | const auto& stage = (*stages)[stage_id]; |
1273 | const auto& op_name = CleanName(stage->op->name); |
1274 | ss << "s[" << op_name << "].storage_align(" |
1275 | << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", " << factor |
1276 | << ", " << offset << ")\n" ; |
1277 | |
1278 | ApplyToSchedule(stages, stage_to_axes); |
1279 | return ss.str(); |
1280 | } |
1281 | |
1282 | /********** Steps working on multiple stages **********/ |
1283 | |
1284 | /********** Compute At **********/ |
1285 | ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { |
1286 | auto node = make_object<ComputeAtStepNode>(); |
1287 | node->stage_id = stage_id; |
1288 | node->target_stage_id = target_stage_id; |
1289 | node->target_iter_id = target_iter_id; |
1290 | data_ = std::move(node); |
1291 | } |
1292 | |
1293 | ComputeAtStep::ComputeAtStep(dmlc::JSONReader* reader) { |
1294 | auto node = make_object<ComputeAtStepNode>(); |
1295 | bool s; |
1296 | s = reader->NextArrayItem(); |
1297 | ICHECK(s); |
1298 | reader->Read(&node->stage_id); |
1299 | s = reader->NextArrayItem(); |
1300 | ICHECK(s); |
1301 | reader->Read(&node->target_stage_id); |
1302 | s = reader->NextArrayItem(); |
1303 | ICHECK(s); |
1304 | reader->Read(&node->target_iter_id); |
1305 | data_ = std::move(node); |
1306 | } |
1307 | |
1308 | void ComputeAtStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
1309 | writer->WriteArraySeperator(); |
1310 | writer->WriteString(record_prefix_str); |
1311 | writer->WriteArrayItem(stage_id); |
1312 | writer->WriteArrayItem(target_stage_id); |
1313 | writer->WriteArrayItem(target_iter_id); |
1314 | } |
1315 | void ComputeAtStepNode::ApplyToState(State* state) const { |
1316 | const Stage& stage = (*state)->stages[stage_id]; |
1317 | |
1318 | // Remove the bound information of each iterator since they may not be accurate after |
1319 | // compute at |
1320 | Array<Iterator> new_iters; |
1321 | for (const Iterator& it : stage->iters) { |
1322 | new_iters.push_back( |
1323 | Iterator(it->name, Range(), it->iter_kind, it->annotation, &it->orig_iters)); |
1324 | } |
1325 | |
1326 | StateNode* pstate = state->CopyOnWrite(); |
1327 | pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), |
1328 | ComputeAtKind::kIter, stage->attrs)); |
1329 | // Update attach map |
1330 | pstate->attach_map.SetComputeAtIter(stage_id, target_stage_id, target_iter_id); |
1331 | } |
1332 | |
1333 | void ComputeAtStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
1334 | StageToAxesMap* stage_to_axes) const { |
1335 | te::Stage stage = (*stages)[stage_id]; |
1336 | const auto& target_stage = (*stages)[target_stage_id]; |
1337 | const auto& target_axis = (*stage_to_axes)[target_stage][target_iter_id]; |
1338 | stage.compute_at(target_stage, target_axis); |
1339 | |
1340 | stages->Set(stage_id, std::move(stage)); |
1341 | } |
1342 | |
1343 | String ComputeAtStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, |
1344 | StageToAxesMap* stage_to_axes) const { |
1345 | std::stringstream ss; |
1346 | const auto& stage = (*stages)[stage_id]; |
1347 | const auto& target_stage = (*stages)[target_stage_id]; |
1348 | const auto& op_name = CleanName(stage->op->name); |
1349 | const auto& target_op_name = CleanName(target_stage->op->name); |
1350 | ss << "s[" << op_name << "].compute_at(s[" << target_op_name << "], " |
1351 | << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint, target_op_name) |
1352 | << ")\n" ; |
1353 | ApplyToSchedule(stages, stage_to_axes); |
1354 | return ss.str(); |
1355 | } |
1356 | |
1357 | /********** Compute Inline **********/ |
1358 | ComputeInlineStep::ComputeInlineStep(int stage_id) { |
1359 | auto node = make_object<ComputeInlineStepNode>(); |
1360 | node->stage_id = stage_id; |
1361 | data_ = std::move(node); |
1362 | } |
1363 | |
1364 | ComputeInlineStep::ComputeInlineStep(dmlc::JSONReader* reader) { |
1365 | auto node = make_object<ComputeInlineStepNode>(); |
1366 | bool s; |
1367 | s = reader->NextArrayItem(); |
1368 | ICHECK(s); |
1369 | reader->Read(&node->stage_id); |
1370 | data_ = std::move(node); |
1371 | } |
1372 | |
1373 | void ComputeInlineStepNode::WriteToRecord(dmlc::JSONWriter* writer) const { |
1374 | writer->WriteArraySeperator(); |
1375 | writer->WriteString(record_prefix_str); |
1376 | writer->WriteArrayItem(stage_id); |
1377 | } |
1378 | |
1379 | void ComputeInlineStepNode::ApplyToState(State* state) const { |
1380 | const Stage& stage = (*state)->stages[stage_id]; |
1381 | |
1382 | // Check the validity of compute_inline |
1383 | for (size_t i = 0; i < stage->iters.size(); ++i) { |
1384 | ICHECK_EQ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, i)), 0) |
1385 | << "Invalid compute_inline: There are some other stages that are attached to the " |
1386 | << "target stage" ; |
1387 | } |
1388 | |
1389 | StateNode* pstate = state->CopyOnWrite(); |
1390 | auto new_stage = pstate->stages[stage_id]; |
1391 | new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined; |
1392 | pstate->stages.Set(stage_id, std::move(new_stage)); |
1393 | // Update attach map |
1394 | pstate->attach_map.DeleteStage(stage_id); |
1395 | } |
1396 | |
1397 | void ComputeInlineStepNode::ApplyToSchedule(Array<te::Stage>* stages, |
1398 | StageToAxesMap* stage_to_axes) const { |
1399 | auto stage = (*stages)[stage_id]; |
1400 | stage.compute_inline(); |
1401 | stages->Set(stage_id, std::move(stage)); |
1402 | } |
1403 | |
1404 | String ComputeInlineStepNode::PrintAsPythonAPI( |
---|