1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24using tir::Instruction;
25using tir::InstructionKind;
26using tir::Trace;
27
28/*! \brief A mutator that mutates the compute-at location decision of SampleComputeLocation */
29class MutateComputeLocationNode : public MutatorNode {
30 public:
31 /*! \brief JSON representation of the workload */
32 std::string json_mod_;
33
34 void VisitAttrs(tvm::AttrVisitor* v) {}
35 static constexpr const char* _type_key = "meta_schedule.MutateComputeLocation";
36 TVM_DECLARE_FINAL_OBJECT_INFO(MutateComputeLocationNode, MutatorNode);
37
38 public:
39 // Inherit from `MutatorNode`
40 void InitializeWithTuneContext(const TuneContext& context) final {
41 this->json_mod_ = SaveJSON(context->mod.value());
42 }
43 // Inherit from `MutatorNode`
44 Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
45 // Inherit from `MutatorNode`
46 Mutator Clone() const final {
47 ObjectPtr<MutateComputeLocationNode> n = make_object<MutateComputeLocationNode>(*this);
48 return Mutator(n);
49 }
50
51 private:
52 struct Candidate {
53 /*! \brief The SampleComputeLocation instruction */
54 Instruction inst;
55 /*! \brief The candidate compute-at locations */
56 std::vector<int> locs;
57
58 explicit Candidate(Instruction inst, std::vector<int> locs)
59 : inst(std::move(inst)), locs(std::move(locs)) {}
60 };
61
62 std::vector<Candidate> FindCandidates(const Trace& trace, TRandState* rand_state);
63};
64
65/*!
66 * \brief Find all appearances of instruction `SampleComputeLocation` whose decision can be mutated
67 * to at lease one other value
68 * \param trace The trace from which to find the instructions
69 * \return All the candidate instructions together with the candidate compute-at locations
70 */
71std::vector<MutateComputeLocationNode::Candidate> MutateComputeLocationNode::FindCandidates(
72 const Trace& trace, TRandState* rand_state) {
73 tir::Schedule sch = tir::Schedule::Traced( //
74 /*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), //
75 /*rand_state=*/ForkSeed(rand_state), //
76 /*debug_mode=*/0, //
77 /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
78
79 static InstructionKind inst_sample_compute_location =
80 InstructionKind::Get("SampleComputeLocation");
81 std::vector<MutateComputeLocationNode::Candidate> candidates;
82
83 auto f_decision_provider = [&](const tir::Instruction& inst, //
84 const Array<ObjectRef>& inputs, //
85 const Array<ObjectRef>& attrs, //
86 const ObjectRef& decision) -> ObjectRef {
87 if (inst->kind.same_as(inst_sample_compute_location)) {
88 // Step 1. Extract the instruction input and the old decision.
89 ICHECK_EQ(inputs.size(), 1);
90 tir::StmtSRef block_sref = sch->GetSRef(Downcast<tir::BlockRV>(inputs[0]));
91 int old_decision = Downcast<Integer>(decision)->value;
92
93 // Step 2. Collect all the compute_at locations.
94 auto [location_srefs, location_indices] = CollectComputeLocation(sch->state(), block_sref);
95 // Step 3. Remove the old decision.
96 auto it = std::find(location_indices.begin(), location_indices.end(), old_decision);
97 if (it != location_indices.end()) {
98 location_srefs.erase(location_srefs.begin() + (it - location_indices.begin()));
99 location_indices.erase(it);
100 }
101 ICHECK_EQ(location_srefs.size(), location_indices.size());
102 // Step 4. Add a new candidate if there are at least one remaining compute-at position.
103 if (!location_srefs.empty()) {
104 candidates.emplace_back(inst, std::move(location_indices));
105 }
106 }
107 return decision;
108 };
109 trace->ApplyToSchedule(sch, //
110 /*remove_postproc=*/true, //
111 /*decision_provider=*/f_decision_provider);
112 return candidates;
113}
114
115Optional<Trace> MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) {
116 std::vector<Candidate> candidates = FindCandidates(trace, rand_state);
117 if (candidates.empty()) {
118 return NullOpt;
119 }
120 const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())];
121 int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())];
122 return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true);
123}
124
125Mutator Mutator::MutateComputeLocation() {
126 return Mutator(make_object<MutateComputeLocationNode>());
127}
128
129TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode);
130TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation")
131 .set_body_typed(Mutator::MutateComputeLocation);
132
133} // namespace meta_schedule
134} // namespace tvm
135