1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ |
20 | #define TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ |
21 | |
22 | #include "./concrete_schedule.h" |
23 | |
24 | namespace tvm { |
25 | namespace tir { |
26 | |
27 | class TracedScheduleNode : public ConcreteScheduleNode { |
28 | friend class Schedule; |
29 | |
30 | protected: |
31 | Trace trace_; |
32 | |
33 | public: |
34 | void VisitAttrs(tvm::AttrVisitor* v) { |
35 | // `state_` is not visited |
36 | // `error_render_level_` is not visited |
37 | // `symbol_table_` is not visited |
38 | // `analyzer_` is not visitied |
39 | // `trace_` is not visited |
40 | } |
41 | |
42 | ~TracedScheduleNode() = default; |
43 | |
44 | public: |
45 | Optional<Trace> trace() const final { return trace_; } |
46 | Schedule Copy() final; |
47 | |
48 | public: |
49 | /******** Schedule: Sampling ********/ |
50 | ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs, |
51 | Optional<Integer> decision = NullOpt) final; |
52 | Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, |
53 | Optional<Array<Integer>> decision = NullOpt) final; |
54 | LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional<Integer> decision = NullOpt) final; |
55 | /******** Schedule: Get blocks & loops ********/ |
56 | BlockRV GetBlock(const String& name, const Optional<String>& func_name) final; |
57 | Array<LoopRV> GetLoops(const BlockRV& block_rv) final; |
58 | Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) final; |
59 | Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final; |
60 | Array<BlockRV> GetProducers(const BlockRV& block_rv) final; |
61 | Array<BlockRV> GetConsumers(const BlockRV& block_rv) final; |
62 | /******** Schedule: Transform loops ********/ |
63 | LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final; |
64 | Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs, |
65 | bool preserve_unit_iters) final; |
66 | void Reorder(const Array<LoopRV>& ordered_loop_rvs) final; |
67 | LoopRV AddUnitLoop(const BlockRV& block_rv) final; |
68 | LoopRV AddUnitLoop(const LoopRV& loop_rv) final; |
69 | /******** Schedule: Manipulate ForKind ********/ |
70 | void Parallel(const LoopRV& loop_rv) final; |
71 | void Vectorize(const LoopRV& loop_rv) final; |
72 | void Bind(const LoopRV& loop_rv, const String& thread_axis) final; |
73 | void Unroll(const LoopRV& loop_rv) final; |
74 | /******** Schedule: Insert cache stages ********/ |
75 | BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, |
76 | const Array<BlockRV> consumer_blocks = {}) final; |
77 | BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, |
78 | const Array<BlockRV> consumer_blocks = {}) final; |
79 | Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index, |
80 | const String& storage_scope) final; |
81 | BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, |
82 | BufferIndexType buffer_index_type) final; |
83 | Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope, |
84 | int cse_thresh) final; |
85 | /******** Schedule: Compute location ********/ |
86 | void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, |
87 | int index = -1) final; |
88 | void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, |
89 | int index = -1) final; |
90 | void ComputeInline(const BlockRV& block_rv) final; |
91 | void ReverseComputeInline(const BlockRV& block_rv) final; |
92 | /******** Schedule: Reduction ********/ |
93 | BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) final; |
94 | BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; |
95 | /******** Schedule: Block annotation ********/ |
96 | void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, |
97 | int offset) final; |
98 | void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; |
99 | /******** Schedule: Blockize & Tensorize ********/ |
100 | BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final; |
101 | void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final; |
102 | void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) final; |
103 | /******** Schedule: Annotation ********/ |
104 | void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; |
105 | void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; |
106 | void Annotate(const BlockRV& block_rv, const String& ann_key, const ObjectRef& ann_val) override; |
107 | void Unannotate(const BlockRV& block_rv, const String& ann_key) override; |
108 | /******** Schedule: Layout transformation ********/ |
109 | void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, |
110 | const IndexMap& index_map, const Optional<IndexMap>& pad_value) override; |
111 | void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; |
112 | void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, |
113 | BufferIndexType buffer_index_type, |
114 | const Array<IntImm>& axis_separators) final; |
115 | /******** Schedule: Padding ********/ |
116 | BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final; |
117 | void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) final; |
118 | /******** Schedule: Buffer transformation ********/ |
119 | void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) final; |
120 | /******** Schedule: Misc ********/ |
121 | void EnterPostproc() final; |
122 | }; |
123 | |
124 | } // namespace tir |
125 | } // namespace tvm |
126 | |
127 | #endif // TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ |
128 | |