1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_
20#define TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_
21
22#include "./concrete_schedule.h"
23
24namespace tvm {
25namespace tir {
26
27class TracedScheduleNode : public ConcreteScheduleNode {
28 friend class Schedule;
29
30 protected:
31 Trace trace_;
32
33 public:
34 void VisitAttrs(tvm::AttrVisitor* v) {
35 // `state_` is not visited
36 // `error_render_level_` is not visited
37 // `symbol_table_` is not visited
38 // `analyzer_` is not visitied
39 // `trace_` is not visited
40 }
41
42 ~TracedScheduleNode() = default;
43
44 public:
45 Optional<Trace> trace() const final { return trace_; }
46 Schedule Copy() final;
47
48 public:
49 /******** Schedule: Sampling ********/
50 ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs,
51 Optional<Integer> decision = NullOpt) final;
52 Array<ExprRV> SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor,
53 Optional<Array<Integer>> decision = NullOpt) final;
54 LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional<Integer> decision = NullOpt) final;
55 /******** Schedule: Get blocks & loops ********/
56 BlockRV GetBlock(const String& name, const Optional<String>& func_name) final;
57 Array<LoopRV> GetLoops(const BlockRV& block_rv) final;
58 Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) final;
59 Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final;
60 Array<BlockRV> GetProducers(const BlockRV& block_rv) final;
61 Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
62 /******** Schedule: Transform loops ********/
63 LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final;
64 Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs,
65 bool preserve_unit_iters) final;
66 void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;
67 LoopRV AddUnitLoop(const BlockRV& block_rv) final;
68 LoopRV AddUnitLoop(const LoopRV& loop_rv) final;
69 /******** Schedule: Manipulate ForKind ********/
70 void Parallel(const LoopRV& loop_rv) final;
71 void Vectorize(const LoopRV& loop_rv) final;
72 void Bind(const LoopRV& loop_rv, const String& thread_axis) final;
73 void Unroll(const LoopRV& loop_rv) final;
74 /******** Schedule: Insert cache stages ********/
75 BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope,
76 const Array<BlockRV> consumer_blocks = {}) final;
77 BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope,
78 const Array<BlockRV> consumer_blocks = {}) final;
79 Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
80 const String& storage_scope) final;
81 BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
82 BufferIndexType buffer_index_type) final;
83 Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope,
84 int cse_thresh) final;
85 /******** Schedule: Compute location ********/
86 void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
87 int index = -1) final;
88 void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
89 int index = -1) final;
90 void ComputeInline(const BlockRV& block_rv) final;
91 void ReverseComputeInline(const BlockRV& block_rv) final;
92 /******** Schedule: Reduction ********/
93 BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) final;
94 BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final;
95 /******** Schedule: Block annotation ********/
96 void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
97 int offset) final;
98 void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final;
99 /******** Schedule: Blockize & Tensorize ********/
100 BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final;
101 void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final;
102 void Tensorize(const LoopRV& loop_rv, const String& intrin, bool preserve_unit_iters) final;
103 /******** Schedule: Annotation ********/
104 void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
105 void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
106 void Annotate(const BlockRV& block_rv, const String& ann_key, const ObjectRef& ann_val) override;
107 void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
108 /******** Schedule: Layout transformation ********/
109 void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
110 const IndexMap& index_map, const Optional<IndexMap>& pad_value) override;
111 void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
112 void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
113 BufferIndexType buffer_index_type,
114 const Array<IntImm>& axis_separators) final;
115 /******** Schedule: Padding ********/
116 BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final;
117 void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) final;
118 /******** Schedule: Buffer transformation ********/
119 void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) final;
120 /******** Schedule: Misc ********/
121 void EnterPostproc() final;
122};
123
124} // namespace tir
125} // namespace tvm
126
127#endif // TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_
128