1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24/*!
25 * \brief Check if an instruction is annotate with
26 * `meta_schedule_unroll_explicit` or `meta_schedule_unroll_implicit`
27 * \param inst The instruction to be checked
28 * \return Whether the instruction is annotated
29 */
30bool IsAnnotateWithUnroll(const Instruction& inst) {
31 static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
32 if (!inst->kind.same_as(inst_annotate)) {
33 return false;
34 }
35 ICHECK_EQ(inst->attrs.size(), 1);
36 String ann_key = Downcast<String>(inst->attrs[0]);
37 return ann_key == attr::meta_schedule_unroll_explicit ||
38 ann_key == attr::meta_schedule_unroll_implicit;
39}
40
41} // namespace tir
42} // namespace tvm
43
44namespace tvm {
45namespace meta_schedule {
46
47using tir::Instruction;
48using tir::Trace;
49
50/*! \brief Create a Mutator that mutates auto unroll step */
51class MutateUnrollNode : public MutatorNode {
52 public:
53 void VisitAttrs(tvm::AttrVisitor* v) {}
54 static constexpr const char* _type_key = "meta_schedule.MutateUnroll";
55 TVM_DECLARE_FINAL_OBJECT_INFO(MutateUnrollNode, MutatorNode);
56
57 public:
58 struct Candidate;
59 // Inherit from `MutatorNode`
60 void InitializeWithTuneContext(const TuneContext& context) final {}
61 // Inherit from `MutatorNode`
62 Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
63 // Inherit from `MutatorNode`
64 Mutator Clone() const final {
65 ObjectPtr<MutateUnrollNode> n = make_object<MutateUnrollNode>(*this);
66 return Mutator(n);
67 }
68};
69
70/*! \brief A candidate to be mutated */
71struct MutateUnrollNode::Candidate {
72 /*! \brief The sampling instruction to be mutated */
73 Instruction inst;
74 /*! \brief The probability */
75 std::vector<double> probs;
76 /*! \brief The decision made */
77 int decision;
78};
79
80/*!
81 * \brief Find the Sample-Categorical instruction to be mutated that affects the maximal unroll step
82 * \param trace The trace to be mutated
83 * \param rand_state The random state
84 * \param candidates The mutation candidate
85 * \return Whether a decision is found
86 */
87bool FindUnrollDecision(const Trace& trace, TRandState* rand_state,
88 MutateUnrollNode::Candidate* candidate) {
89 using tir::InstructionKind;
90 using tir::InstructionNode;
91 static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
92 std::unordered_map<const PrimExprNode*, const InstructionNode*> sample_insts;
93 std::vector<const InstructionNode*> ann_insts;
94 sample_insts.reserve(trace->insts.size());
95 ann_insts.reserve(trace->insts.size());
96 for (const Instruction& inst : trace->insts) {
97 if (inst->kind.same_as(inst_sample_categorical)) {
98 ICHECK_EQ(inst->outputs.size(), 1);
99 const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode);
100 sample_insts[var_rv] = inst.get();
101 } else if (IsAnnotateWithUnroll(inst)) {
102 ann_insts.push_back(inst.get());
103 }
104 }
105 int n_ann_insts = ann_insts.size();
106 if (n_ann_insts == 0) {
107 return false;
108 }
109 const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)];
110 ICHECK_EQ(ann_inst->inputs.size(), 2);
111 const auto* var_rv = TVM_TYPE_AS(ann_inst->inputs[1], PrimExprNode);
112 ICHECK(sample_insts.count(var_rv));
113 const InstructionNode* sample_inst = sample_insts.at(var_rv);
114 ICHECK_EQ(sample_inst->attrs.size(), 2);
115 candidate->inst = GetRef<Instruction>(sample_inst);
116 candidate->decision =
117 Downcast<Integer>(trace->decisions[GetRef<Instruction>(sample_inst)])->value;
118 candidate->probs =
119 support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(sample_inst->attrs[1]));
120 return true;
121}
122
123Optional<Trace> MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) {
124 Candidate candidate;
125 if (!FindUnrollDecision(trace, rand_state, &candidate)) {
126 return NullOpt;
127 }
128 if (candidate.probs.size() == 0) {
129 return NullOpt;
130 }
131 candidate.probs.erase(candidate.probs.begin() + candidate.decision);
132 int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)();
133 if (result >= candidate.decision) {
134 result += 1;
135 }
136 return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true);
137}
138
139Mutator Mutator::MutateUnroll() { return Mutator(make_object<MutateUnrollNode>()); }
140
141TVM_REGISTER_NODE_TYPE(MutateUnrollNode);
142TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll);
143
144} // namespace meta_schedule
145} // namespace tvm
146