1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/search_task.h |
22 | * \brief Meta information and hardware parameters for a search task. |
23 | */ |
24 | |
25 | #ifndef TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ |
26 | #define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ |
27 | |
28 | #include <tvm/auto_scheduler/compute_dag.h> |
29 | #include <tvm/runtime/ndarray.h> |
30 | #include <tvm/target/target.h> |
31 | |
32 | namespace tvm { |
33 | namespace auto_scheduler { |
34 | |
35 | class HardwareParams; |
36 | |
37 | /*! \brief The parameters of target hardware used to guide the SearchPolicy. */ |
38 | class HardwareParamsNode : public Object { |
39 | public: |
40 | /*! \brief The number of cores. */ |
41 | int num_cores; |
42 | /*! \brief The width of vector units in bytes. */ |
43 | int vector_unit_bytes; |
44 | /*! \brief The size of cache line in bytes. */ |
45 | int cache_line_bytes; |
46 | |
47 | // GPU related parameters got from device query API |
48 | /*! \brief The max shared memory per block in bytes. */ |
49 | int max_shared_memory_per_block; |
50 | /*! \brief The max local memory per block in bytes. */ |
51 | int max_local_memory_per_block; |
52 | /*! \brief The max number of threads per block. */ |
53 | int max_threads_per_block; |
54 | /*! \brief The max vthread extent. */ |
55 | int max_vthread_extent; |
56 | /*! \brief The thread numbers of a warp. */ |
57 | int warp_size; |
58 | |
59 | void VisitAttrs(tvm::AttrVisitor* v) { |
60 | v->Visit("num_cores" , &num_cores); |
61 | v->Visit("vector_unit_bytes" , &vector_unit_bytes); |
62 | v->Visit("cache_line_bytes" , &cache_line_bytes); |
63 | v->Visit("max_shared_memory_per_block" , &max_shared_memory_per_block); |
64 | v->Visit("max_local_memory_per_block" , &max_local_memory_per_block); |
65 | v->Visit("max_threads_per_block" , &max_threads_per_block); |
66 | v->Visit("max_vthread_extent" , &max_vthread_extent); |
67 | v->Visit("warp_size" , &warp_size); |
68 | } |
69 | |
70 | /*! |
71 | * \brief Get the default hardware params. |
72 | * \param target A `tvm.target`. |
73 | * \param target_host A `tvm.target` for host device. |
74 | * \return A HardwareParams object. |
75 | */ |
76 | static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); |
77 | |
78 | static constexpr const char* _type_key = "auto_scheduler.HardwareParams" ; |
79 | TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); |
80 | }; |
81 | |
82 | /*! |
83 | * \brief Managed reference to HardwareParamsNode. |
84 | * \sa HardwareParamsNode |
85 | */ |
86 | class HardwareParams : public ObjectRef { |
87 | public: |
88 | /*! |
89 | * \brief The constructor. |
90 | * \param num_cores The number of cores. |
91 | * \param vector_unit_bytes The width of vector units in bytes. |
92 | * \param cache_line_bytes The size of cache line in bytes. |
93 | * \param max_shared_memory_per_block The max amount of shared memory per block for GPU. |
94 | * \param max_local_memory_per_block The max amount of local memory per block for GPU. |
95 | * \param max_threads_per_block The max number of threads per block for GPU. |
96 | * \param max_vthread_extent The max extent of vthread for GPU. |
97 | * \param warp_size The warp size for GPU |
98 | */ |
99 | HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, |
100 | int max_shared_memory_per_block, int max_local_memory_per_block, |
101 | int max_threads_per_block, int max_vthread_extent, int warp_size); |
102 | |
103 | TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode); |
104 | TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode); |
105 | }; |
106 | |
107 | /*! |
108 | * \brief The computation information and hardware parameters for a specific schedule search task. |
109 | */ |
110 | class SearchTaskNode : public Object { |
111 | public: |
112 | /*! \brief The ComputeDAG for the compute declaration. */ |
113 | ComputeDAG compute_dag; |
114 | /*! \brief The workload key for the compute declaration. */ |
115 | String workload_key; |
116 | /*! \brief The description string of this task. */ |
117 | String desc; |
118 | /*! \brief The target device of this search task. */ |
119 | Target target; |
120 | /*! \brief The target host device of this search task. */ |
121 | Target target_host; |
122 | /*! \brief Hardware parameters used in this search task. */ |
123 | HardwareParams hardware_params; |
124 | /*! \brief The layout rewrite option used for measuring programs. */ |
125 | LayoutRewriteOption layout_rewrite_option; |
126 | /*! \brief Names of some user defined input data used in program measuring. */ |
127 | Array<String> task_input_names; |
128 | |
129 | void VisitAttrs(tvm::AttrVisitor* v) { |
130 | v->Visit("compute_dag" , &compute_dag); |
131 | v->Visit("workload_key" , &workload_key); |
132 | v->Visit("desc" , &desc); |
133 | v->Visit("target" , &target); |
134 | v->Visit("target_host" , &target_host); |
135 | v->Visit("hardware_params" , &hardware_params); |
136 | v->Visit("layout_rewrite_option" , &layout_rewrite_option); |
137 | v->Visit("task_input_names" , &task_input_names); |
138 | } |
139 | |
140 | static constexpr const char* _type_key = "auto_scheduler.SearchTask" ; |
141 | TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); |
142 | }; |
143 | |
144 | /*! |
145 | * \brief Managed reference to SearchTaskNode. |
146 | * \sa SearchTaskNode |
147 | */ |
148 | class SearchTask : public ObjectRef { |
149 | public: |
150 | /*! |
151 | * \brief The constructor. |
152 | * \param compute_dag The ComputeDAG for the compute declaration. |
153 | * \param workload_key The workload key for the compute declaration. |
154 | * \param target The target device of this search task. |
155 | * \param target_host The target host device of this search task. |
156 | * \param hardware_params Hardware parameters used in this search task. |
157 | * \param layout_rewrite_option The layout rewrite option used for measuring programs. |
158 | * \param task_input_names Names of some user defined input data used in program measuring. |
159 | * \param desc The description string of this task. |
160 | */ |
161 | SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host, |
162 | Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option, |
163 | Array<String> task_input_names, String desc = "" ); |
164 | |
165 | TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); |
166 | }; |
167 | |
168 | } // namespace auto_scheduler |
169 | } // namespace tvm |
170 | |
171 | #endif // TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ |
172 | |