1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | using tir::Instruction; |
25 | using tir::InstructionKind; |
26 | using tir::Trace; |
27 | |
28 | /*! \brief A mutator that mutates the thread binding factor decision of SampleCategorical */ |
29 | class MutateThreadBindingNode : public MutatorNode { |
30 | public: |
31 | /*! \brief JSON representation of the workload */ |
32 | std::string json_mod_; |
33 | |
34 | void VisitAttrs(tvm::AttrVisitor* v) {} |
35 | static constexpr const char* _type_key = "meta_schedule.MutateThreadBinding" ; |
36 | TVM_DECLARE_FINAL_OBJECT_INFO(MutateThreadBindingNode, MutatorNode); |
37 | |
38 | public: |
39 | // Inherit from `MutatorNode` |
40 | void InitializeWithTuneContext(const TuneContext& context) final { |
41 | this->json_mod_ = SaveJSON(context->mod.value()); |
42 | } |
43 | // Inherit from `MutatorNode` |
44 | Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final; |
45 | // Inherit from `MutatorNode` |
46 | Mutator Clone() const final { |
47 | ObjectPtr<MutateThreadBindingNode> n = make_object<MutateThreadBindingNode>(*this); |
48 | return Mutator(n); |
49 | } |
50 | |
51 | private: |
52 | struct Candidate { |
53 | /*! \brief The sampling instruction to be mutated */ |
54 | Instruction inst; |
55 | /*! \brief The probability */ |
56 | std::vector<double> probs; |
57 | /*! \brief The decision made */ |
58 | int decision; |
59 | |
60 | explicit Candidate(Instruction inst, std::vector<double> probs, int decision) |
61 | : inst(std::move(inst)), probs(std::move(probs)), decision(std::move(decision)) {} |
62 | }; |
63 | |
64 | std::vector<Candidate> FindCandidates(const Trace& trace, TRandState* rand_state); |
65 | }; |
66 | |
67 | /*! |
68 | * \brief Find Candidate with the following pattern: |
69 | * \code |
70 | * v = sch.sample_categorical(...) |
71 | * l1, l2 = sch.split(loop=l0, factors=[None, v]) |
72 | * sch.bind(loop=l2, thread_axis="threadIdx.x") |
73 | * \endcode |
74 | * |
75 | * \param trace The trace from which to find the instructions |
76 | * \return All the candidate instructions |
77 | */ |
78 | std::vector<MutateThreadBindingNode::Candidate> MutateThreadBindingNode::FindCandidates( |
79 | const Trace& trace, TRandState* rand_state) { |
80 | using tir::InstructionNode; |
81 | |
82 | static InstructionKind inst_sample_categorical = InstructionKind::Get("SampleCategorical" ); |
83 | static InstructionKind inst_split = InstructionKind::Get("Split" ); |
84 | static InstructionKind inst_bind = InstructionKind::Get("Bind" ); |
85 | |
86 | std::vector<MutateThreadBindingNode::Candidate> candidates; |
87 | std::unordered_map<const PrimExprNode*, const tir::InstructionNode*> sample_insts; |
88 | std::unordered_map<const tir::LoopRVNode*, const tir::InstructionNode*> sampled_split_insts; |
89 | std::vector<const InstructionNode*> bind_insts; |
90 | |
91 | auto is_split_by_sample = [&sample_insts](const Instruction& inst) -> bool { |
92 | if (!inst->kind.same_as(inst_split)) { |
93 | return false; |
94 | } |
95 | // Only consider cases with 2 factors and the first one is None |
96 | if (inst->inputs.size() != 3 || inst->inputs[1].defined()) return false; |
97 | ICHECK(inst->inputs[2].defined()); |
98 | |
99 | return sample_insts.find(Downcast<PrimExpr>(inst->inputs[2]).get()) != sample_insts.end(); |
100 | }; |
101 | |
102 | auto is_thread_binding_by_sample = [&sampled_split_insts](const Instruction& inst) -> bool { |
103 | if (!inst->kind.same_as(inst_bind)) { |
104 | return false; |
105 | } |
106 | ICHECK_EQ(inst->inputs.size(), 1); |
107 | ICHECK_EQ(inst->attrs.size(), 1); |
108 | if (Downcast<String>(inst->attrs[0]) != "threadIdx.x" ) return false; |
109 | |
110 | return sampled_split_insts.find(Downcast<tir::LoopRV>(inst->inputs[0]).get()) != |
111 | sampled_split_insts.end(); |
112 | }; |
113 | |
114 | for (const Instruction& inst : trace->insts) { |
115 | if (inst->kind.same_as(inst_sample_categorical)) { |
116 | ICHECK_EQ(inst->outputs.size(), 1); |
117 | const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode); |
118 | sample_insts[var_rv] = inst.get(); |
119 | } else if (is_split_by_sample(inst)) { |
120 | CHECK_EQ(inst->outputs.size(), 2); |
121 | // Only consider the inner loop, which can be bound to threadIdx.x |
122 | const tir::LoopRVNode* var_rv = TVM_TYPE_AS(inst->outputs[1], tir::LoopRVNode); |
123 | sampled_split_insts[var_rv] = inst.get(); |
124 | } else if (is_thread_binding_by_sample(inst)) { |
125 | bind_insts.push_back(inst.get()); |
126 | } |
127 | } |
128 | |
129 | for (const InstructionNode* bind_inst : bind_insts) { |
130 | const auto* loop_rv = TVM_TYPE_AS(bind_inst->inputs[0], tir::LoopRVNode); |
131 | auto split_it = sampled_split_insts.find(loop_rv); |
132 | ICHECK(split_it != sampled_split_insts.end()); |
133 | const InstructionNode* split_inst = split_it->second; |
134 | |
135 | const auto* expr_rv = TVM_TYPE_AS(split_inst->inputs[2], PrimExprNode); |
136 | auto sample_it = sample_insts.find(expr_rv); |
137 | ICHECK(sample_it != sample_insts.end()); |
138 | const InstructionNode* sample_inst = sample_it->second; |
139 | |
140 | int decision = Downcast<Integer>(trace->decisions[GetRef<Instruction>(sample_inst)])->value; |
141 | |
142 | std::vector<double> probs = |
143 | support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(sample_inst->attrs[1])); |
144 | |
145 | candidates.emplace_back(GetRef<Instruction>(sample_inst), probs, decision); |
146 | } |
147 | return candidates; |
148 | } |
149 | |
150 | Optional<Trace> MutateThreadBindingNode::Apply(const Trace& trace, TRandState* rand_state) { |
151 | std::vector<Candidate> candidates = FindCandidates(trace, rand_state); |
152 | if (candidates.empty()) { |
153 | return NullOpt; |
154 | } |
155 | Candidate candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; |
156 | // Remove the current decision |
157 | candidate.probs.erase(candidate.probs.begin() + candidate.decision); |
158 | int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)(); |
159 | if (result >= candidate.decision) { |
160 | result += 1; |
161 | } |
162 | return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); |
163 | } |
164 | |
165 | Mutator Mutator::MutateThreadBinding() { return Mutator(make_object<MutateThreadBindingNode>()); } |
166 | |
167 | TVM_REGISTER_NODE_TYPE(MutateThreadBindingNode); |
168 | TVM_REGISTER_GLOBAL("meta_schedule.MutateThreadBinding" ) |
169 | .set_body_typed(Mutator::MutateThreadBinding); |
170 | |
171 | } // namespace meta_schedule |
172 | } // namespace tvm |
173 | |