1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/auto_scheduler/auto_schedule.h |
22 | * \brief The user interface of the auto scheduler. |
23 | */ |
24 | |
25 | #ifndef TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_ |
26 | #define TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_ |
27 | |
28 | #include <tvm/auto_scheduler/measure.h> |
29 | #include <tvm/auto_scheduler/search_policy.h> |
30 | |
31 | #include <utility> |
32 | |
33 | namespace tvm { |
34 | namespace auto_scheduler { |
35 | |
36 | /*! \brief Tuning and measurement options. */ |
37 | class TuningOptionsNode : public Object { |
38 | public: |
39 | /*! \brief The number of total measurement trials. */ |
40 | int num_measure_trials; |
41 | /*! \brief Stops the tuning early if no improvement after n measurements. */ |
42 | int early_stopping; |
43 | /*! \brief The number of programs to be measured at each search round. */ |
44 | int num_measures_per_round; |
45 | /*! \brief Verbosity level. 0 for silent, 1 to output information during schedule searching. */ |
46 | int verbose; |
47 | /*! \brief ProgramBuilder which builds the program */ |
48 | ProgramBuilder builder; |
49 | /*! \brief ProgramRunner which runs the program and measures time costs */ |
50 | ProgramRunner runner; |
51 | /*! \brief MeasureCallback functions to be called after each measure batch */ |
52 | Optional<Array<MeasureCallback>> measure_callbacks; |
53 | |
54 | void VisitAttrs(tvm::AttrVisitor* v) { |
55 | v->Visit("num_measure_trials" , &num_measure_trials); |
56 | v->Visit("early_stopping" , &early_stopping); |
57 | v->Visit("num_measures_per_round" , &num_measures_per_round); |
58 | v->Visit("verbose" , &verbose); |
59 | v->Visit("builder" , &builder); |
60 | v->Visit("runner" , &runner); |
61 | v->Visit("measure_callbacks" , &measure_callbacks); |
62 | } |
63 | |
64 | static constexpr const char* _type_key = "auto_scheduler.TuningOptions" ; |
65 | TVM_DECLARE_FINAL_OBJECT_INFO(TuningOptionsNode, Object); |
66 | }; |
67 | |
68 | /*! |
69 | * \brief Managed reference to TuningOptionsNode. |
70 | * \sa TuningOptionsNode |
71 | */ |
72 | class TuningOptions : public ObjectRef { |
73 | public: |
74 | /*! |
75 | * \brief The constructor |
76 | * \param num_measure_trials The number of total measurement trials. |
77 | * \param early_stopping Stops the tuning early if no improvement after n measurements. |
78 | * \param num_measures_per_round The number of programs to be measured at each search round. |
79 | * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule |
80 | * search. |
81 | * \param builder ProgramBuilder which builds the program. |
82 | * \param runner ProgramRunner which runs the program and measure time costs. |
83 | * \param measure_callbacks MeasureCallback functions to be called after each measure batch. |
84 | */ |
85 | TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, |
86 | ProgramBuilder builder, ProgramRunner runner, |
87 | Optional<Array<MeasureCallback>> measure_callbacks); |
88 | |
89 | TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode); |
90 | }; |
91 | |
92 | /*! |
93 | * \brief Run schedule search for a given compute declaration. |
94 | * \param search_policy The search policy. |
95 | * \param tuning_options Tuning and measurement options. |
96 | * \return A `te::schedule` and an Array of `te::Tensor` to be used in `tvm.lower` or |
97 | * `tvm.build`. |
98 | */ |
99 | TVM_DLL std::pair<te::Schedule, Array<te::Tensor>> AutoSchedule(SearchPolicy search_policy, |
100 | TuningOptions tuning_options); |
101 | } // namespace auto_scheduler |
102 | } // namespace tvm |
103 | |
104 | #endif // TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_ |
105 | |