1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /*! |
25 | * \brief Check if an instruction is annotate with |
26 | * `meta_schedule_unroll_explicit` or `meta_schedule_unroll_implicit` |
27 | * \param inst The instruction to be checked |
28 | * \return Whether the instruction is annotated |
29 | */ |
30 | bool IsAnnotateWithUnroll(const Instruction& inst) { |
31 | static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate" ); |
32 | if (!inst->kind.same_as(inst_annotate)) { |
33 | return false; |
34 | } |
35 | ICHECK_EQ(inst->attrs.size(), 1); |
36 | String ann_key = Downcast<String>(inst->attrs[0]); |
37 | return ann_key == attr::meta_schedule_unroll_explicit || |
38 | ann_key == attr::meta_schedule_unroll_implicit; |
39 | } |
40 | |
41 | } // namespace tir |
42 | } // namespace tvm |
43 | |
44 | namespace tvm { |
45 | namespace meta_schedule { |
46 | |
47 | using tir::Instruction; |
48 | using tir::Trace; |
49 | |
50 | /*! \brief Create a Mutator that mutates auto unroll step */ |
51 | class MutateUnrollNode : public MutatorNode { |
52 | public: |
53 | void VisitAttrs(tvm::AttrVisitor* v) {} |
54 | static constexpr const char* _type_key = "meta_schedule.MutateUnroll" ; |
55 | TVM_DECLARE_FINAL_OBJECT_INFO(MutateUnrollNode, MutatorNode); |
56 | |
57 | public: |
58 | struct Candidate; |
59 | // Inherit from `MutatorNode` |
60 | void InitializeWithTuneContext(const TuneContext& context) final {} |
61 | // Inherit from `MutatorNode` |
62 | Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final; |
63 | // Inherit from `MutatorNode` |
64 | Mutator Clone() const final { |
65 | ObjectPtr<MutateUnrollNode> n = make_object<MutateUnrollNode>(*this); |
66 | return Mutator(n); |
67 | } |
68 | }; |
69 | |
70 | /*! \brief A candidate to be mutated */ |
71 | struct MutateUnrollNode::Candidate { |
72 | /*! \brief The sampling instruction to be mutated */ |
73 | Instruction inst; |
74 | /*! \brief The probability */ |
75 | std::vector<double> probs; |
76 | /*! \brief The decision made */ |
77 | int decision; |
78 | }; |
79 | |
80 | /*! |
81 | * \brief Find the Sample-Categorical instruction to be mutated that affects the maximal unroll step |
82 | * \param trace The trace to be mutated |
83 | * \param rand_state The random state |
84 | * \param candidates The mutation candidate |
85 | * \return Whether a decision is found |
86 | */ |
87 | bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, |
88 | MutateUnrollNode::Candidate* candidate) { |
89 | using tir::InstructionKind; |
90 | using tir::InstructionNode; |
91 | static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical" ); |
92 | std::unordered_map<const PrimExprNode*, const InstructionNode*> sample_insts; |
93 | std::vector<const InstructionNode*> ann_insts; |
94 | sample_insts.reserve(trace->insts.size()); |
95 | ann_insts.reserve(trace->insts.size()); |
96 | for (const Instruction& inst : trace->insts) { |
97 | if (inst->kind.same_as(inst_sample_categorical)) { |
98 | ICHECK_EQ(inst->outputs.size(), 1); |
99 | const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode); |
100 | sample_insts[var_rv] = inst.get(); |
101 | } else if (IsAnnotateWithUnroll(inst)) { |
102 | ann_insts.push_back(inst.get()); |
103 | } |
104 | } |
105 | int n_ann_insts = ann_insts.size(); |
106 | if (n_ann_insts == 0) { |
107 | return false; |
108 | } |
109 | const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; |
110 | ICHECK_EQ(ann_inst->inputs.size(), 2); |
111 | const auto* var_rv = TVM_TYPE_AS(ann_inst->inputs[1], PrimExprNode); |
112 | ICHECK(sample_insts.count(var_rv)); |
113 | const InstructionNode* sample_inst = sample_insts.at(var_rv); |
114 | ICHECK_EQ(sample_inst->attrs.size(), 2); |
115 | candidate->inst = GetRef<Instruction>(sample_inst); |
116 | candidate->decision = |
117 | Downcast<Integer>(trace->decisions[GetRef<Instruction>(sample_inst)])->value; |
118 | candidate->probs = |
119 | support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(sample_inst->attrs[1])); |
120 | return true; |
121 | } |
122 | |
123 | Optional<Trace> MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) { |
124 | Candidate candidate; |
125 | if (!FindUnrollDecision(trace, rand_state, &candidate)) { |
126 | return NullOpt; |
127 | } |
128 | if (candidate.probs.size() == 0) { |
129 | return NullOpt; |
130 | } |
131 | candidate.probs.erase(candidate.probs.begin() + candidate.decision); |
132 | int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)(); |
133 | if (result >= candidate.decision) { |
134 | result += 1; |
135 | } |
136 | return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); |
137 | } |
138 | |
139 | Mutator Mutator::MutateUnroll() { return Mutator(make_object<MutateUnrollNode>()); } |
140 | |
141 | TVM_REGISTER_NODE_TYPE(MutateUnrollNode); |
142 | TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll" ).set_body_typed(Mutator::MutateUnroll); |
143 | |
144 | } // namespace meta_schedule |
145 | } // namespace tvm |
146 | |