1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/auto_scheduler/auto_schedule.h
22 * \brief The user interface of the auto scheduler.
23 */
24
25#ifndef TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_
26#define TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_
27
28#include <tvm/auto_scheduler/measure.h>
29#include <tvm/auto_scheduler/search_policy.h>
30
31#include <utility>
32
33namespace tvm {
34namespace auto_scheduler {
35
36/*! \brief Tuning and measurement options. */
37class TuningOptionsNode : public Object {
38 public:
39 /*! \brief The number of total measurement trials. */
40 int num_measure_trials;
41 /*! \brief Stops the tuning early if no improvement after n measurements. */
42 int early_stopping;
43 /*! \brief The number of programs to be measured at each search round. */
44 int num_measures_per_round;
45 /*! \brief Verbosity level. 0 for silent, 1 to output information during schedule searching. */
46 int verbose;
47 /*! \brief ProgramBuilder which builds the program */
48 ProgramBuilder builder;
49 /*! \brief ProgramRunner which runs the program and measures time costs */
50 ProgramRunner runner;
51 /*! \brief MeasureCallback functions to be called after each measure batch */
52 Optional<Array<MeasureCallback>> measure_callbacks;
53
54 void VisitAttrs(tvm::AttrVisitor* v) {
55 v->Visit("num_measure_trials", &num_measure_trials);
56 v->Visit("early_stopping", &early_stopping);
57 v->Visit("num_measures_per_round", &num_measures_per_round);
58 v->Visit("verbose", &verbose);
59 v->Visit("builder", &builder);
60 v->Visit("runner", &runner);
61 v->Visit("measure_callbacks", &measure_callbacks);
62 }
63
64 static constexpr const char* _type_key = "auto_scheduler.TuningOptions";
65 TVM_DECLARE_FINAL_OBJECT_INFO(TuningOptionsNode, Object);
66};
67
68/*!
69 * \brief Managed reference to TuningOptionsNode.
70 * \sa TuningOptionsNode
71 */
72class TuningOptions : public ObjectRef {
73 public:
74 /*!
75 * \brief The constructor
76 * \param num_measure_trials The number of total measurement trials.
77 * \param early_stopping Stops the tuning early if no improvement after n measurements.
78 * \param num_measures_per_round The number of programs to be measured at each search round.
79 * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule
80 * search.
81 * \param builder ProgramBuilder which builds the program.
82 * \param runner ProgramRunner which runs the program and measure time costs.
83 * \param measure_callbacks MeasureCallback functions to be called after each measure batch.
84 */
85 TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose,
86 ProgramBuilder builder, ProgramRunner runner,
87 Optional<Array<MeasureCallback>> measure_callbacks);
88
89 TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode);
90};
91
92/*!
93 * \brief Run schedule search for a given compute declaration.
94 * \param search_policy The search policy.
95 * \param tuning_options Tuning and measurement options.
96 * \return A `te::schedule` and an Array of `te::Tensor` to be used in `tvm.lower` or
97 * `tvm.build`.
98 */
99TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchPolicy search_policy,
100 TuningOptions tuning_options);
101} // namespace auto_scheduler
102} // namespace tvm
103
104#endif // TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_
105