1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_TIR_SCHEDULE_TRACE_H_
20#define TVM_TIR_SCHEDULE_TRACE_H_
21
22#include <tvm/tir/schedule/instruction.h>
23
24namespace tvm {
25namespace tir {
26
27// Forward declaration
28class Trace;
29
30/*!
31 * \brief A callback that allows users to mutate decisions on the fly
32 * when applying instructions. The signature of the callback is:
33 * \param inst The instruction
34 * \param inputs The input random variables
35 * \param attrs The attributes
36 * \param decision The original decision
37 * \return A new decision
38 */
39using FTraceDecisionProvider = runtime::TypedPackedFunc<ObjectRef(
40 const Instruction& inst, const Array<ObjectRef>& inputs, const Array<ObjectRef>& attrs,
41 const Optional<ObjectRef>& decision)>;
42
43/*!
44 * \brief An execution trace of a scheduling program
45 *
46 * A trace has two parts:
47 * 1) The instructions invoked so far in the program execution
48 * 2) The random decisions made upon those instructions, if any
49 *
50 * A trace can be serialized to:
51 * 1) Roundtrippable JSON format: can be saved to file and loaded back
52 * 2) Python syntax: allows users to copy-paste the trace to reproduce the scheduling process
53 *
54 * A trace can be applied to a TensorIR schedule by re-applying all its instructions possibly with
55 * their decisions accordingly. Re-sampling is invoked if a sampling instruction doesn't have its
56 * corresponding decision; Otherwise the existing decision will be reused accordingly.
57 */
58class TraceNode : public runtime::Object {
59 public:
60 /*! \brief The instructions invoked so far in the program execution */
61 Array<Instruction> insts;
62 /*! \brief The random decisions made upon those instructions */
63 Map<Instruction, ObjectRef> decisions;
64
65 void VisitAttrs(tvm::AttrVisitor* v) {
66 v->Visit("insts", &insts);
67 v->Visit("decisions", &decisions);
68 }
69
70 static constexpr const char* _type_key = "tir.Trace";
71 TVM_DECLARE_FINAL_OBJECT_INFO(TraceNode, runtime::Object);
72
73 public:
74 /*!
75 * \brief Retrieve the decision made on a specific instruction
76 * \param inst The instruction whose decision is to be retrieved
77 * \return The corresponding decision; NullOpt if there is no decision made on the instruction
78 */
79 Optional<ObjectRef> GetDecision(const Instruction& inst) const;
80 /*!
81 * \brief Append a new instruction to the trace
82 * \param inst The new instruction to be appended
83 */
84 void Append(Instruction inst);
85 /*!
86 * \brief Append a new instruction with a random decision to the trace
87 * \param inst The new instruction to be appended
88 * \param decision The random decision made on this instruction
89 * The type of `decision` depends on the instruction, e.g.
90 * the decision of `SamplePerfectTile` has type `Array<IntImm>`
91 */
92 void Append(Instruction inst, ObjectRef decision);
93 /*!
94 * \brief Remove the last instruction, along with the decision made on that instruction, if any
95 * \return The instruction removed; NullOpt if the trace is empty
96 */
97 Optional<Instruction> Pop();
98 /*!
99 * \brief Apply the trace to a TensorIR schedule
100 * \param sch The schedule to be applied onto
101 * \param remove_postproc If postprocessing instructions are removed
102 * \param decision_provider A callback that allows users to mutate decisions on the fly
103 * when applying instructions.
104 * \sa FTraceDecisionProvider
105 */
106 void ApplyToSchedule(Schedule sch, bool remove_postproc,
107 FTraceDecisionProvider decision_provider = nullptr) const;
108 /*!
109 * \brief Serialize the trace as a JSON-style object
110 * \param remove_postproc If postprocessing instructions are removed
111 * \return The JSON-style object
112 */
113 ObjectRef AsJSON(bool remove_postproc) const;
114 /*!
115 * \brief Serialize the trace as a sequence of python statements
116 * \param remove_postproc If postprocessing instructions are removed
117 * \return A sequence of python statements
118 */
119 Array<String> AsPython(bool remove_postproc) const;
120 /*!
121 * \brief Create a new trace with an instruction whose decision is changed,
122 * assuming this instruction exists in the resulting trace
123 * \param inst The instruction whose decision is to be changed
124 * \param decision The decision to be changed to
125 * \param remove_postproc If postprocessing instructions are removed
126 * \return The new trace with the decision changed
127 */
128 Trace WithDecision(Instruction inst, ObjectRef decision, bool remove_postproc) const;
129 /*!
130 * \brief Simplify the trace with dead-code elimination
131 * \param remove_postproc If postprocessing instructions are removed
132 * \return A simplified trace
133 */
134 Trace Simplified(bool remove_postproc) const;
135};
136
137/*!
138 * \brief Managed reference to TraceNode
139 * \sa TraceNode
140 */
141class Trace : public runtime::ObjectRef {
142 public:
143 /*! \brief Default constructor. Creating an empty trace. */
144 Trace();
145 /*!
146 * \brief Constructor. Creating a trace from existing instructions and their decisions
147 * \param insts The instructions used
148 * \param decisions The decisions made in sampling
149 */
150 explicit Trace(Array<Instruction> insts, Map<Instruction, ObjectRef> decisions);
151 /*!
152 * \brief Apply a JSON-serialized trace to a TensorIR schedule
153 * \param json The JSON-serialized trace
154 * \param sch The TensorIR schedule
155 */
156 static void ApplyJSONToSchedule(ObjectRef json, Schedule sch);
157
158 TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Trace, runtime::ObjectRef, TraceNode);
159};
160
161} // namespace tir
162} // namespace tvm
163
164#endif // TVM_TIR_SCHEDULE_TRACE_H_
165