1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24void PyMutatorNode::InitializeWithTuneContext(const TuneContext& context) {
25 ICHECK(f_initialize_with_tune_context != nullptr)
26 << "PyMutator's InitializeWithTuneContext method not implemented!";
27 f_initialize_with_tune_context(context);
28}
29
30Optional<tir::Trace> PyMutatorNode::Apply(
31 const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) {
32 ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!";
33 return f_apply(trace, *rand_state);
34}
35
36Mutator PyMutatorNode::Clone() const {
37 ICHECK(f_clone != nullptr) << "PyMutator's Clone method not implemented!";
38 return f_clone();
39}
40
41Mutator Mutator::PyMutator(
42 PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
43 PyMutatorNode::FApply f_apply, //
44 PyMutatorNode::FClone f_clone, //
45 PyMutatorNode::FAsString f_as_string) {
46 ObjectPtr<PyMutatorNode> n = make_object<PyMutatorNode>();
47 n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context);
48 n->f_apply = std::move(f_apply);
49 n->f_clone = std::move(f_clone);
50 n->f_as_string = std::move(f_as_string);
51 return Mutator(n);
52}
53
54Map<Mutator, FloatImm> Mutator::DefaultLLVM() {
55 return Map<Mutator, FloatImm>{
56 {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)},
57 {Mutator::MutateComputeLocation(), FloatImm(DataType::Float(64), 0.05)},
58 {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.03)},
59 {Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}};
60}
61
62Map<Mutator, FloatImm> Mutator::DefaultCUDA() {
63 return Map<Mutator, FloatImm>{
64 {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)},
65 {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.08)},
66 {Mutator::MutateThreadBinding(), FloatImm(DataType::Float(64), 0.02)}};
67}
68
69Map<Mutator, FloatImm> Mutator::DefaultCUDATensorCore() { return Mutator::DefaultCUDA(); }
70
71Map<Mutator, FloatImm> Mutator::DefaultHexagon() {
72 return Map<Mutator, FloatImm>{
73 {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)},
74 {Mutator::MutateComputeLocation(), FloatImm(DataType::Float(64), 0.05)},
75 {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.03)},
76 {Mutator::MutateParallel(/*max_jobs_per_core=*/16), FloatImm(DataType::Float(64), 0.02)}};
77}
78
79Map<Mutator, FloatImm> Mutator::DefaultMicro() {
80 return Map<Mutator, FloatImm>{
81 {Mutator::MutateTileSize(), FloatImm(DataType::Float(64), 0.9)},
82 {Mutator::MutateComputeLocation(), FloatImm(DataType::Float(64), 0.05)},
83 {Mutator::MutateUnroll(), FloatImm(DataType::Float(64), 0.03)}};
84}
85
86TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
87 .set_dispatch<PyMutatorNode>([](const ObjectRef& n, ReprPrinter* p) {
88 const auto* self = n.as<PyMutatorNode>();
89 ICHECK(self);
90 PyMutatorNode::FAsString f_as_string = (*self).f_as_string;
91 ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!";
92 p->stream << f_as_string();
93 });
94
95TVM_REGISTER_OBJECT_TYPE(MutatorNode);
96TVM_REGISTER_NODE_TYPE(PyMutatorNode);
97
98TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext")
99 .set_body_method<Mutator>(&MutatorNode::InitializeWithTuneContext);
100TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply")
101 .set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) -> Optional<tir::Trace> {
102 TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom();
103 return self->Apply(trace, &seed_);
104 });
105TVM_REGISTER_GLOBAL("meta_schedule.MutatorClone").set_body_method<Mutator>(&MutatorNode::Clone);
106TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator);
107TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultLLVM").set_body_typed(Mutator::DefaultLLVM);
108TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDA").set_body_typed(Mutator::DefaultCUDA);
109TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultCUDATensorCore")
110 .set_body_typed(Mutator::DefaultCUDATensorCore);
111TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultHexagon").set_body_typed(Mutator::DefaultHexagon);
112TVM_REGISTER_GLOBAL("meta_schedule.MutatorDefaultMicro").set_body_typed(Mutator::DefaultMicro);
113
114} // namespace meta_schedule
115} // namespace tvm
116