1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/auto_schedule.cc |
22 | * \brief The user interface and tuning options of the TVM auto-scheduler. |
23 | */ |
24 | |
25 | #include <tvm/auto_scheduler/auto_schedule.h> |
26 | #include <tvm/runtime/registry.h> |
27 | |
28 | #include "utils.h" |
29 | |
30 | namespace tvm { |
31 | namespace auto_scheduler { |
32 | |
33 | TVM_REGISTER_NODE_TYPE(TuningOptionsNode); |
34 | |
35 | TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, |
36 | int verbose, ProgramBuilder builder, ProgramRunner runner, |
37 | Optional<Array<MeasureCallback>> measure_callbacks) { |
38 | auto node = make_object<TuningOptionsNode>(); |
39 | node->num_measure_trials = num_measure_trials; |
40 | node->early_stopping = early_stopping; |
41 | node->num_measures_per_round = num_measures_per_round; |
42 | node->verbose = verbose; |
43 | node->builder = std::move(builder); |
44 | node->runner = std::move(runner); |
45 | node->measure_callbacks = std::move(measure_callbacks); |
46 | data_ = std::move(node); |
47 | } |
48 | |
49 | std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchPolicy search_policy, |
50 | TuningOptions tuning_options) { |
51 | // Create a ProgramMeasurer to handle the schedule build and performance measure |
52 | ProgramMeasurer measurer = |
53 | ProgramMeasurer(tuning_options->builder, tuning_options->runner, |
54 | tuning_options->measure_callbacks, tuning_options->verbose); |
55 | // Search for the best schedule |
56 | State state = |
57 | search_policy->Search(tuning_options->num_measure_trials, tuning_options->early_stopping, |
58 | tuning_options->num_measures_per_round, measurer); |
59 | if (state.defined()) { |
60 | return search_policy->search_task->compute_dag.ApplySteps(state->transform_steps); |
61 | } else { |
62 | StdCout(tuning_options->verbose) |
63 | << "No valid state found in this search round. Check if it has traversed all of the " |
64 | << "search space." << std::endl; |
65 | // Return the default schedule |
66 | return {te::Schedule(search_policy->search_task->compute_dag->ops), |
67 | search_policy->search_task->compute_dag->tensors}; |
68 | } |
69 | } |
70 | |
71 | TVM_REGISTER_GLOBAL("auto_scheduler.TuningOptions" ) |
72 | .set_body_typed([](int num_measure_trials, int early_stopping, int num_measures_per_round, |
73 | int verbose, ProgramBuilder builder, ProgramRunner runner, |
74 | Optional<Array<MeasureCallback>> measure_callbacks) { |
75 | return TuningOptions(num_measure_trials, early_stopping, num_measures_per_round, verbose, |
76 | builder, runner, measure_callbacks); |
77 | }); |
78 | |
79 | TVM_REGISTER_GLOBAL("auto_scheduler.AutoSchedule" ) |
80 | .set_body_typed([](SearchPolicy search_policy, TuningOptions tuning_options) { |
81 | auto [sch, return_tensors] = AutoSchedule(search_policy, tuning_options); |
82 | return Array<ObjectRef>{sch, return_tensors}; |
83 | }); |
84 | } // namespace auto_scheduler |
85 | } // namespace tvm |
86 | |