1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_TIR_SCHEDULE_TRACE_H_ |
20 | #define TVM_TIR_SCHEDULE_TRACE_H_ |
21 | |
22 | #include <tvm/tir/schedule/instruction.h> |
23 | |
24 | namespace tvm { |
25 | namespace tir { |
26 | |
27 | // Forward declaration |
28 | class Trace; |
29 | |
30 | /*! |
31 | * \brief A callback that allows users to mutate decisions on the fly |
32 | * when applying instructions. The signature of the callback is: |
33 | * \param inst The instruction |
34 | * \param inputs The input random variables |
35 | * \param attrs The attributes |
36 | * \param decision The original decision |
37 | * \return A new decision |
38 | */ |
39 | using FTraceDecisionProvider = runtime::TypedPackedFunc<ObjectRef( |
40 | const Instruction& inst, const Array<ObjectRef>& inputs, const Array<ObjectRef>& attrs, |
41 | const Optional<ObjectRef>& decision)>; |
42 | |
43 | /*! |
44 | * \brief An execution trace of a scheduling program |
45 | * |
46 | * A trace has two parts: |
47 | * 1) The instructions invoked so far in the program execution |
48 | * 2) The random decisions made upon those instructions, if any |
49 | * |
50 | * A trace can be serialized to: |
51 | * 1) Roundtrippable JSON format: can be saved to file and loaded back |
52 | * 2) Python syntax: allows users to copy-paste the trace to reproduce the scheduling process |
53 | * |
54 | * A trace can be applied to a TensorIR schedule by re-applying all its instructions possibly with |
55 | * their decisions accordingly. Re-sampling is invoked if a sampling instruction doesn't have its |
56 | * corresponding decision; Otherwise the existing decision will be reused accordingly. |
57 | */ |
58 | class TraceNode : public runtime::Object { |
59 | public: |
60 | /*! \brief The instructions invoked so far in the program execution */ |
61 | Array<Instruction> insts; |
62 | /*! \brief The random decisions made upon those instructions */ |
63 | Map<Instruction, ObjectRef> decisions; |
64 | |
65 | void VisitAttrs(tvm::AttrVisitor* v) { |
66 | v->Visit("insts" , &insts); |
67 | v->Visit("decisions" , &decisions); |
68 | } |
69 | |
70 | static constexpr const char* _type_key = "tir.Trace" ; |
71 | TVM_DECLARE_FINAL_OBJECT_INFO(TraceNode, runtime::Object); |
72 | |
73 | public: |
74 | /*! |
75 | * \brief Retrieve the decision made on a specific instruction |
76 | * \param inst The instruction whose decision is to be retrieved |
77 | * \return The corresponding decision; NullOpt if there is no decision made on the instruction |
78 | */ |
79 | Optional<ObjectRef> GetDecision(const Instruction& inst) const; |
80 | /*! |
81 | * \brief Append a new instruction to the trace |
82 | * \param inst The new instruction to be appended |
83 | */ |
84 | void Append(Instruction inst); |
85 | /*! |
86 | * \brief Append a new instruction with a random decision to the trace |
87 | * \param inst The new instruction to be appended |
88 | * \param decision The random decision made on this instruction |
89 | * The type of `decision` depends on the instruction, e.g. |
90 | * the decision of `SamplePerfectTile` has type `Array<IntImm>` |
91 | */ |
92 | void Append(Instruction inst, ObjectRef decision); |
93 | /*! |
94 | * \brief Remove the last instruction, along with the decision made on that instruction, if any |
95 | * \return The instruction removed; NullOpt if the trace is empty |
96 | */ |
97 | Optional<Instruction> Pop(); |
98 | /*! |
99 | * \brief Apply the trace to a TensorIR schedule |
100 | * \param sch The schedule to be applied onto |
101 | * \param remove_postproc If postprocessing instructions are removed |
102 | * \param decision_provider A callback that allows users to mutate decisions on the fly |
103 | * when applying instructions. |
104 | * \sa FTraceDecisionProvider |
105 | */ |
106 | void ApplyToSchedule(Schedule sch, bool remove_postproc, |
107 | FTraceDecisionProvider decision_provider = nullptr) const; |
108 | /*! |
109 | * \brief Serialize the trace as a JSON-style object |
110 | * \param remove_postproc If postprocessing instructions are removed |
111 | * \return The JSON-style object |
112 | */ |
113 | ObjectRef AsJSON(bool remove_postproc) const; |
114 | /*! |
115 | * \brief Serialize the trace as a sequence of python statements |
116 | * \param remove_postproc If postprocessing instructions are removed |
117 | * \return A sequence of python statements |
118 | */ |
119 | Array<String> AsPython(bool remove_postproc) const; |
120 | /*! |
121 | * \brief Create a new trace with an instruction whose decision is changed, |
122 | * assuming this instruction exists in the resulting trace |
123 | * \param inst The instruction whose decision is to be changed |
124 | * \param decision The decision to be changed to |
125 | * \param remove_postproc If postprocessing instructions are removed |
126 | * \return The new trace with the decision changed |
127 | */ |
128 | Trace WithDecision(Instruction inst, ObjectRef decision, bool remove_postproc) const; |
129 | /*! |
130 | * \brief Simplify the trace with dead-code elimination |
131 | * \param remove_postproc If postprocessing instructions are removed |
132 | * \return A simplified trace |
133 | */ |
134 | Trace Simplified(bool remove_postproc) const; |
135 | }; |
136 | |
137 | /*! |
138 | * \brief Managed reference to TraceNode |
139 | * \sa TraceNode |
140 | */ |
141 | class Trace : public runtime::ObjectRef { |
142 | public: |
143 | /*! \brief Default constructor. Creating an empty trace. */ |
144 | Trace(); |
145 | /*! |
146 | * \brief Constructor. Creating a trace from existing instructions and their decisions |
147 | * \param insts The instructions used |
148 | * \param decisions The decisions made in sampling |
149 | */ |
150 | explicit Trace(Array<Instruction> insts, Map<Instruction, ObjectRef> decisions); |
151 | /*! |
152 | * \brief Apply a JSON-serialized trace to a TensorIR schedule |
153 | * \param json The JSON-serialized trace |
154 | * \param sch The TensorIR schedule |
155 | */ |
156 | static void ApplyJSONToSchedule(ObjectRef json, Schedule sch); |
157 | |
158 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, runtime::ObjectRef, TraceNode); |
159 | }; |
160 | |
161 | } // namespace tir |
162 | } // namespace tvm |
163 | |
164 | #endif // TVM_TIR_SCHEDULE_TRACE_H_ |
165 | |