1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24using tir::Instruction;
25using tir::InstructionKind;
26using tir::Trace;
27
28/*! \brief A mutator that mutates the thread binding factor decision of SampleCategorical */
29class MutateThreadBindingNode : public MutatorNode {
30 public:
31 /*! \brief JSON representation of the workload */
32 std::string json_mod_;
33
34 void VisitAttrs(tvm::AttrVisitor* v) {}
35 static constexpr const char* _type_key = "meta_schedule.MutateThreadBinding";
36 TVM_DECLARE_FINAL_OBJECT_INFO(MutateThreadBindingNode, MutatorNode);
37
38 public:
39 // Inherit from `MutatorNode`
40 void InitializeWithTuneContext(const TuneContext& context) final {
41 this->json_mod_ = SaveJSON(context->mod.value());
42 }
43 // Inherit from `MutatorNode`
44 Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
45 // Inherit from `MutatorNode`
46 Mutator Clone() const final {
47 ObjectPtr<MutateThreadBindingNode> n = make_object<MutateThreadBindingNode>(*this);
48 return Mutator(n);
49 }
50
51 private:
52 struct Candidate {
53 /*! \brief The sampling instruction to be mutated */
54 Instruction inst;
55 /*! \brief The probability */
56 std::vector<double> probs;
57 /*! \brief The decision made */
58 int decision;
59
60 explicit Candidate(Instruction inst, std::vector<double> probs, int decision)
61 : inst(std::move(inst)), probs(std::move(probs)), decision(std::move(decision)) {}
62 };
63
64 std::vector<Candidate> FindCandidates(const Trace& trace, TRandState* rand_state);
65};
66
67/*!
68 * \brief Find Candidate with the following pattern:
69 * \code
70 * v = sch.sample_categorical(...)
71 * l1, l2 = sch.split(loop=l0, factors=[None, v])
72 * sch.bind(loop=l2, thread_axis="threadIdx.x")
73 * \endcode
74 *
75 * \param trace The trace from which to find the instructions
76 * \return All the candidate instructions
77 */
78std::vector<MutateThreadBindingNode::Candidate> MutateThreadBindingNode::FindCandidates(
79 const Trace& trace, TRandState* rand_state) {
80 using tir::InstructionNode;
81
82 static InstructionKind inst_sample_categorical = InstructionKind::Get("SampleCategorical");
83 static InstructionKind inst_split = InstructionKind::Get("Split");
84 static InstructionKind inst_bind = InstructionKind::Get("Bind");
85
86 std::vector<MutateThreadBindingNode::Candidate> candidates;
87 std::unordered_map<const PrimExprNode*, const tir::InstructionNode*> sample_insts;
88 std::unordered_map<const tir::LoopRVNode*, const tir::InstructionNode*> sampled_split_insts;
89 std::vector<const InstructionNode*> bind_insts;
90
91 auto is_split_by_sample = [&sample_insts](const Instruction& inst) -> bool {
92 if (!inst->kind.same_as(inst_split)) {
93 return false;
94 }
95 // Only consider cases with 2 factors and the first one is None
96 if (inst->inputs.size() != 3 || inst->inputs[1].defined()) return false;
97 ICHECK(inst->inputs[2].defined());
98
99 return sample_insts.find(Downcast<PrimExpr>(inst->inputs[2]).get()) != sample_insts.end();
100 };
101
102 auto is_thread_binding_by_sample = [&sampled_split_insts](const Instruction& inst) -> bool {
103 if (!inst->kind.same_as(inst_bind)) {
104 return false;
105 }
106 ICHECK_EQ(inst->inputs.size(), 1);
107 ICHECK_EQ(inst->attrs.size(), 1);
108 if (Downcast<String>(inst->attrs[0]) != "threadIdx.x") return false;
109
110 return sampled_split_insts.find(Downcast<tir::LoopRV>(inst->inputs[0]).get()) !=
111 sampled_split_insts.end();
112 };
113
114 for (const Instruction& inst : trace->insts) {
115 if (inst->kind.same_as(inst_sample_categorical)) {
116 ICHECK_EQ(inst->outputs.size(), 1);
117 const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode);
118 sample_insts[var_rv] = inst.get();
119 } else if (is_split_by_sample(inst)) {
120 CHECK_EQ(inst->outputs.size(), 2);
121 // Only consider the inner loop, which can be bound to threadIdx.x
122 const tir::LoopRVNode* var_rv = TVM_TYPE_AS(inst->outputs[1], tir::LoopRVNode);
123 sampled_split_insts[var_rv] = inst.get();
124 } else if (is_thread_binding_by_sample(inst)) {
125 bind_insts.push_back(inst.get());
126 }
127 }
128
129 for (const InstructionNode* bind_inst : bind_insts) {
130 const auto* loop_rv = TVM_TYPE_AS(bind_inst->inputs[0], tir::LoopRVNode);
131 auto split_it = sampled_split_insts.find(loop_rv);
132 ICHECK(split_it != sampled_split_insts.end());
133 const InstructionNode* split_inst = split_it->second;
134
135 const auto* expr_rv = TVM_TYPE_AS(split_inst->inputs[2], PrimExprNode);
136 auto sample_it = sample_insts.find(expr_rv);
137 ICHECK(sample_it != sample_insts.end());
138 const InstructionNode* sample_inst = sample_it->second;
139
140 int decision = Downcast<Integer>(trace->decisions[GetRef<Instruction>(sample_inst)])->value;
141
142 std::vector<double> probs =
143 support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(sample_inst->attrs[1]));
144
145 candidates.emplace_back(GetRef<Instruction>(sample_inst), probs, decision);
146 }
147 return candidates;
148}
149
150Optional<Trace> MutateThreadBindingNode::Apply(const Trace& trace, TRandState* rand_state) {
151 std::vector<Candidate> candidates = FindCandidates(trace, rand_state);
152 if (candidates.empty()) {
153 return NullOpt;
154 }
155 Candidate candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())];
156 // Remove the current decision
157 candidate.probs.erase(candidate.probs.begin() + candidate.decision);
158 int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)();
159 if (result >= candidate.decision) {
160 result += 1;
161 }
162 return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true);
163}
164
165Mutator Mutator::MutateThreadBinding() { return Mutator(make_object<MutateThreadBindingNode>()); }
166
167TVM_REGISTER_NODE_TYPE(MutateThreadBindingNode);
168TVM_REGISTER_GLOBAL("meta_schedule.MutateThreadBinding")
169 .set_body_typed(Mutator::MutateThreadBinding);
170
171} // namespace meta_schedule
172} // namespace tvm
173