1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_META_SCHEDULE_TUNE_CONTEXT_H_ |
20 | #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ |
21 | |
22 | #include <tvm/ir/expr.h> |
23 | #include <tvm/ir/module.h> |
24 | #include <tvm/meta_schedule/builder.h> |
25 | #include <tvm/meta_schedule/runner.h> |
26 | #include <tvm/meta_schedule/search_strategy.h> |
27 | #include <tvm/meta_schedule/space_generator.h> |
28 | #include <tvm/node/reflection.h> |
29 | #include <tvm/runtime/container/array.h> |
30 | #include <tvm/runtime/container/map.h> |
31 | #include <tvm/runtime/container/optional.h> |
32 | #include <tvm/runtime/container/string.h> |
33 | #include <tvm/runtime/object.h> |
34 | #include <tvm/runtime/packed_func.h> |
35 | #include <tvm/support/random_engine.h> |
36 | #include <tvm/target/target.h> |
37 | |
38 | namespace tvm { |
39 | namespace meta_schedule { |
40 | |
41 | class TaskSchedulerNode; |
42 | class MeasureCallback; |
43 | class TuneContext; |
44 | |
45 | /*! \brief The auto tuning context. */ |
46 | class TuneContextNode : public runtime::Object { |
47 | public: |
48 | using TRandState = support::LinearCongruentialEngine::TRandState; |
49 | |
50 | /*! \brief The workload to be tuned. */ |
51 | Optional<IRModule> mod; |
52 | /*! \brief The target to be tuned for. */ |
53 | Optional<Target> target; |
54 | /*! \brief The design space generator. */ |
55 | Optional<SpaceGenerator> space_generator; |
56 | /*! \brief The search strategy. */ |
57 | Optional<SearchStrategy> search_strategy; |
58 | /*! \brief The name of the tuning task. */ |
59 | Optional<String> task_name; |
60 | /*! \brief The number of threads to be used. */ |
61 | int num_threads; |
62 | /*! \brief The random state. */ |
63 | TRandState rand_state; |
64 | /*! \brief The tuning task's logging function. t*/ |
65 | PackedFunc logger; |
66 | |
67 | void VisitAttrs(tvm::AttrVisitor* v) { |
68 | v->Visit("mod" , &mod); |
69 | v->Visit("target" , &target); |
70 | v->Visit("space_generator" , &space_generator); |
71 | v->Visit("search_strategy" , &search_strategy); |
72 | v->Visit("task_name" , &task_name); |
73 | v->Visit("num_threads" , &num_threads); |
74 | v->Visit("rand_state" , &rand_state); |
75 | // `logger` is not visited |
76 | } |
77 | /*! |
78 | * \brief Initialize members that needs initialization with tune context. |
79 | */ |
80 | void Initialize(); |
81 | /*! |
82 | * \brief Clone the tune context. |
83 | * \return The cloned tune context. |
84 | */ |
85 | TuneContext Clone() const; |
86 | |
87 | static constexpr const char* _type_key = "meta_schedule.TuneContext" ; |
88 | TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); |
89 | }; |
90 | |
91 | /*! |
92 | * \brief Managed reference to TuneContextNode. |
93 | * \sa TuneContextNode |
94 | */ |
95 | class TuneContext : public runtime::ObjectRef { |
96 | public: |
97 | using TRandState = support::LinearCongruentialEngine::TRandState; |
98 | /*! |
99 | * \brief Constructor. |
100 | * \param mod The workload to be tuned. |
101 | * \param target The target to be tuned for. |
102 | * \param space_generator The design space generator. |
103 | * \param search_strategy The search strategy. |
104 | * \param task_name The name of the tuning task. |
105 | * \param num_threads The number of threads to be used. |
106 | * \param rand_state The random state. |
107 | * \param logger The tuning task's logging function. |
108 | */ |
109 | TVM_DLL explicit TuneContext(Optional<IRModule> mod, Optional<Target> target, |
110 | Optional<SpaceGenerator> space_generator, |
111 | Optional<SearchStrategy> search_strategy, Optional<String> task_name, |
112 | int num_threads, TRandState rand_state, PackedFunc logger); |
113 | TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode); |
114 | }; |
115 | |
116 | } // namespace meta_schedule |
117 | } // namespace tvm |
118 | |
119 | #endif // TVM_META_SCHEDULE_TUNE_CONTEXT_H_ |
120 | |