1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/auto_schedule.cc
22 * \brief The user interface and tuning options of the TVM auto-scheduler.
23 */
24
25#include <tvm/auto_scheduler/auto_schedule.h>
26#include <tvm/runtime/registry.h>
27
28#include "utils.h"
29
30namespace tvm {
31namespace auto_scheduler {
32
33TVM_REGISTER_NODE_TYPE(TuningOptionsNode);
34
35TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round,
36 int verbose, ProgramBuilder builder, ProgramRunner runner,
37 Optional<Array<MeasureCallback>> measure_callbacks) {
38 auto node = make_object<TuningOptionsNode>();
39 node->num_measure_trials = num_measure_trials;
40 node->early_stopping = early_stopping;
41 node->num_measures_per_round = num_measures_per_round;
42 node->verbose = verbose;
43 node->builder = std::move(builder);
44 node->runner = std::move(runner);
45 node->measure_callbacks = std::move(measure_callbacks);
46 data_ = std::move(node);
47}
48
49std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchPolicy search_policy,
50 TuningOptions tuning_options) {
51 // Create a ProgramMeasurer to handle the schedule build and performance measure
52 ProgramMeasurer measurer =
53 ProgramMeasurer(tuning_options->builder, tuning_options->runner,
54 tuning_options->measure_callbacks, tuning_options->verbose);
55 // Search for the best schedule
56 State state =
57 search_policy->Search(tuning_options->num_measure_trials, tuning_options->early_stopping,
58 tuning_options->num_measures_per_round, measurer);
59 if (state.defined()) {
60 return search_policy->search_task->compute_dag.ApplySteps(state->transform_steps);
61 } else {
62 StdCout(tuning_options->verbose)
63 << "No valid state found in this search round. Check if it has traversed all of the "
64 << "search space." << std::endl;
65 // Return the default schedule
66 return {te::Schedule(search_policy->search_task->compute_dag->ops),
67 search_policy->search_task->compute_dag->tensors};
68 }
69}
70
71TVM_REGISTER_GLOBAL("auto_scheduler.TuningOptions")
72 .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round,
73 int verbose, ProgramBuilder builder, ProgramRunner runner,
74 Optional<Array<MeasureCallback>> measure_callbacks) {
75 return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose,
76 builder, runner, measure_callbacks);
77 });
78
79TVM_REGISTER_GLOBAL("auto_scheduler.AutoSchedule")
80 .set_body_typed([](SearchPolicy search_policy, TuningOptions tuning_options) {
81 auto [sch, return_tensors] = AutoSchedule(search_policy, tuning_options);
82 return Array<ObjectRef>{sch, return_tensors};
83 });
84} // namespace auto_scheduler
85} // namespace tvm
86