1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | using tir::Instruction; |
25 | using tir::InstructionKind; |
26 | using tir::Trace; |
27 | |
28 | /*! \brief A mutator that mutates the compute-at location decision of SampleComputeLocation */ |
29 | class MutateComputeLocationNode : public MutatorNode { |
30 | public: |
31 | /*! \brief JSON representation of the workload */ |
32 | std::string json_mod_; |
33 | |
34 | void VisitAttrs(tvm::AttrVisitor* v) {} |
35 | static constexpr const char* _type_key = "meta_schedule.MutateComputeLocation" ; |
36 | TVM_DECLARE_FINAL_OBJECT_INFO(MutateComputeLocationNode, MutatorNode); |
37 | |
38 | public: |
39 | // Inherit from `MutatorNode` |
40 | void InitializeWithTuneContext(const TuneContext& context) final { |
41 | this->json_mod_ = SaveJSON(context->mod.value()); |
42 | } |
43 | // Inherit from `MutatorNode` |
44 | Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final; |
45 | // Inherit from `MutatorNode` |
46 | Mutator Clone() const final { |
47 | ObjectPtr<MutateComputeLocationNode> n = make_object<MutateComputeLocationNode>(*this); |
48 | return Mutator(n); |
49 | } |
50 | |
51 | private: |
52 | struct Candidate { |
53 | /*! \brief The SampleComputeLocation instruction */ |
54 | Instruction inst; |
55 | /*! \brief The candidate compute-at locations */ |
56 | std::vector<int> locs; |
57 | |
58 | explicit Candidate(Instruction inst, std::vector<int> locs) |
59 | : inst(std::move(inst)), locs(std::move(locs)) {} |
60 | }; |
61 | |
62 | std::vector<Candidate> FindCandidates(const Trace& trace, TRandState* rand_state); |
63 | }; |
64 | |
65 | /*! |
66 | * \brief Find all appearances of instruction `SampleComputeLocation` whose decision can be mutated |
67 | * to at lease one other value |
68 | * \param trace The trace from which to find the instructions |
69 | * \return All the candidate instructions together with the candidate compute-at locations |
70 | */ |
71 | std::vector<MutateComputeLocationNode::Candidate> MutateComputeLocationNode::FindCandidates( |
72 | const Trace& trace, TRandState* rand_state) { |
73 | tir::Schedule sch = tir::Schedule::Traced( // |
74 | /*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), // |
75 | /*rand_state=*/ForkSeed(rand_state), // |
76 | /*debug_mode=*/0, // |
77 | /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); |
78 | |
79 | static InstructionKind inst_sample_compute_location = |
80 | InstructionKind::Get("SampleComputeLocation" ); |
81 | std::vector<MutateComputeLocationNode::Candidate> candidates; |
82 | |
83 | auto f_decision_provider = [&](const tir::Instruction& inst, // |
84 | const Array<ObjectRef>& inputs, // |
85 | const Array<ObjectRef>& attrs, // |
86 | const ObjectRef& decision) -> ObjectRef { |
87 | if (inst->kind.same_as(inst_sample_compute_location)) { |
88 | // Step 1. Extract the instruction input and the old decision. |
89 | ICHECK_EQ(inputs.size(), 1); |
90 | tir::StmtSRef block_sref = sch->GetSRef(Downcast<tir::BlockRV>(inputs[0])); |
91 | int old_decision = Downcast<Integer>(decision)->value; |
92 | |
93 | // Step 2. Collect all the compute_at locations. |
94 | auto [location_srefs, location_indices] = CollectComputeLocation(sch->state(), block_sref); |
95 | // Step 3. Remove the old decision. |
96 | auto it = std::find(location_indices.begin(), location_indices.end(), old_decision); |
97 | if (it != location_indices.end()) { |
98 | location_srefs.erase(location_srefs.begin() + (it - location_indices.begin())); |
99 | location_indices.erase(it); |
100 | } |
101 | ICHECK_EQ(location_srefs.size(), location_indices.size()); |
102 | // Step 4. Add a new candidate if there are at least one remaining compute-at position. |
103 | if (!location_srefs.empty()) { |
104 | candidates.emplace_back(inst, std::move(location_indices)); |
105 | } |
106 | } |
107 | return decision; |
108 | }; |
109 | trace->ApplyToSchedule(sch, // |
110 | /*remove_postproc=*/true, // |
111 | /*decision_provider=*/f_decision_provider); |
112 | return candidates; |
113 | } |
114 | |
115 | Optional<Trace> MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) { |
116 | std::vector<Candidate> candidates = FindCandidates(trace, rand_state); |
117 | if (candidates.empty()) { |
118 | return NullOpt; |
119 | } |
120 | const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; |
121 | int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())]; |
122 | return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true); |
123 | } |
124 | |
125 | Mutator Mutator::MutateComputeLocation() { |
126 | return Mutator(make_object<MutateComputeLocationNode>()); |
127 | } |
128 | |
129 | TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode); |
130 | TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation" ) |
131 | .set_body_typed(Mutator::MutateComputeLocation); |
132 | |
133 | } // namespace meta_schedule |
134 | } // namespace tvm |
135 | |