1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/search_task.h
22 * \brief Meta information and hardware parameters for a search task.
23 */
24
25#ifndef TVM_AUTO_SCHEDULER_SEARCH_TASK_H_
26#define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_
27
28#include <tvm/auto_scheduler/compute_dag.h>
29#include <tvm/runtime/ndarray.h>
30#include <tvm/target/target.h>
31
32namespace tvm {
33namespace auto_scheduler {
34
35class HardwareParams;
36
37/*! \brief The parameters of target hardware used to guide the SearchPolicy. */
38class HardwareParamsNode : public Object {
39 public:
40 /*! \brief The number of cores. */
41 int num_cores;
42 /*! \brief The width of vector units in bytes. */
43 int vector_unit_bytes;
44 /*! \brief The size of cache line in bytes. */
45 int cache_line_bytes;
46
47 // GPU related parameters got from device query API
48 /*! \brief The max shared memory per block in bytes. */
49 int max_shared_memory_per_block;
50 /*! \brief The max local memory per block in bytes. */
51 int max_local_memory_per_block;
52 /*! \brief The max number of threads per block. */
53 int max_threads_per_block;
54 /*! \brief The max vthread extent. */
55 int max_vthread_extent;
56 /*! \brief The thread numbers of a warp. */
57 int warp_size;
58
59 void VisitAttrs(tvm::AttrVisitor* v) {
60 v->Visit("num_cores", &num_cores);
61 v->Visit("vector_unit_bytes", &vector_unit_bytes);
62 v->Visit("cache_line_bytes", &cache_line_bytes);
63 v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block);
64 v->Visit("max_local_memory_per_block", &max_local_memory_per_block);
65 v->Visit("max_threads_per_block", &max_threads_per_block);
66 v->Visit("max_vthread_extent", &max_vthread_extent);
67 v->Visit("warp_size", &warp_size);
68 }
69
70 /*!
71 * \brief Get the default hardware params.
72 * \param target A `tvm.target`.
73 * \param target_host A `tvm.target` for host device.
74 * \return A HardwareParams object.
75 */
76 static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host);
77
78 static constexpr const char* _type_key = "auto_scheduler.HardwareParams";
79 TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object);
80};
81
82/*!
83 * \brief Managed reference to HardwareParamsNode.
84 * \sa HardwareParamsNode
85 */
86class HardwareParams : public ObjectRef {
87 public:
88 /*!
89 * \brief The constructor.
90 * \param num_cores The number of cores.
91 * \param vector_unit_bytes The width of vector units in bytes.
92 * \param cache_line_bytes The size of cache line in bytes.
93 * \param max_shared_memory_per_block The max amount of shared memory per block for GPU.
94 * \param max_local_memory_per_block The max amount of local memory per block for GPU.
95 * \param max_threads_per_block The max number of threads per block for GPU.
96 * \param max_vthread_extent The max extent of vthread for GPU.
97 * \param warp_size The warp size for GPU
98 */
99 HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes,
100 int max_shared_memory_per_block, int max_local_memory_per_block,
101 int max_threads_per_block, int max_vthread_extent, int warp_size);
102
103 TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode);
104 TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode);
105};
106
107/*!
108 * \brief The computation information and hardware parameters for a specific schedule search task.
109 */
110class SearchTaskNode : public Object {
111 public:
112 /*! \brief The ComputeDAG for the compute declaration. */
113 ComputeDAG compute_dag;
114 /*! \brief The workload key for the compute declaration. */
115 String workload_key;
116 /*! \brief The description string of this task. */
117 String desc;
118 /*! \brief The target device of this search task. */
119 Target target;
120 /*! \brief The target host device of this search task. */
121 Target target_host;
122 /*! \brief Hardware parameters used in this search task. */
123 HardwareParams hardware_params;
124 /*! \brief The layout rewrite option used for measuring programs. */
125 LayoutRewriteOption layout_rewrite_option;
126 /*! \brief Names of some user defined input data used in program measuring. */
127 Array<String> task_input_names;
128
129 void VisitAttrs(tvm::AttrVisitor* v) {
130 v->Visit("compute_dag", &compute_dag);
131 v->Visit("workload_key", &workload_key);
132 v->Visit("desc", &desc);
133 v->Visit("target", &target);
134 v->Visit("target_host", &target_host);
135 v->Visit("hardware_params", &hardware_params);
136 v->Visit("layout_rewrite_option", &layout_rewrite_option);
137 v->Visit("task_input_names", &task_input_names);
138 }
139
140 static constexpr const char* _type_key = "auto_scheduler.SearchTask";
141 TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object);
142};
143
144/*!
145 * \brief Managed reference to SearchTaskNode.
146 * \sa SearchTaskNode
147 */
148class SearchTask : public ObjectRef {
149 public:
150 /*!
151 * \brief The constructor.
152 * \param compute_dag The ComputeDAG for the compute declaration.
153 * \param workload_key The workload key for the compute declaration.
154 * \param target The target device of this search task.
155 * \param target_host The target host device of this search task.
156 * \param hardware_params Hardware parameters used in this search task.
157 * \param layout_rewrite_option The layout rewrite option used for measuring programs.
158 * \param task_input_names Names of some user defined input data used in program measuring.
159 * \param desc The description string of this task.
160 */
161 SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host,
162 Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option,
163 Array<String> task_input_names, String desc = "");
164
165 TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode);
166};
167
168} // namespace auto_scheduler
169} // namespace tvm
170
171#endif // TVM_AUTO_SCHEDULER_SEARCH_TASK_H_
172