1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/feature.h |
22 | * \brief Feature extraction for the cost model. |
23 | * We extract one feature vector per BufferStoreNode statement in a TIR Stmt, |
24 | * so we call this feature as "per-store" feature. |
25 | * The cost model also does prediction for each BufferStoreNode statement and aggregates |
26 | * the predictions as the whole score for a TVM IR (Stmt). |
27 | * |
28 | * The feature specification is defined by `src/auto_scheduler/feature.cc:: FeatureSet` |
29 | */ |
30 | |
31 | #ifndef TVM_AUTO_SCHEDULER_FEATURE_H_ |
32 | #define TVM_AUTO_SCHEDULER_FEATURE_H_ |
33 | |
34 | #include <tvm/auto_scheduler/compute_dag.h> |
35 | #include <tvm/auto_scheduler/measure.h> |
36 | #include <tvm/tir/function.h> |
37 | |
38 | #include <string> |
39 | #include <vector> |
40 | |
41 | namespace tvm { |
42 | namespace auto_scheduler { |
43 | |
44 | /*! |
45 | * \brief Get per-store features from a TIR PrimFunc |
46 | * \param func The input lowered TIR PrimFunc |
47 | * \param cache_line_size The size of cache line in bytes |
48 | * \param max_n_bufs The maximum number of extracted buffers for one statement |
49 | * \param ret The returned feature vector |
50 | * \param log_scale Should the outputs be scaled by log2(1+x). |
51 | */ |
52 | void GetPerStoreFeature(const PrimFunc& func, int cache_line_size, int max_n_bufs, |
53 | std::vector<float>* ret, bool log_scale = true); |
54 | |
55 | /* |
56 | * \brief Get the names of elements in the feature vector. Use this for debug and inspection. |
57 | * \param max_n_bufs The maximum number of extracted buffers for one statement |
58 | * \param ret The returned names. |
59 | */ |
60 | void GetPerStoreFeatureName(int max_n_bufs, std::vector<std::string>* ret); |
61 | |
62 | /*! |
63 | * \brief Get per-store feature from states of the same task |
64 | * \param states The input states |
65 | * \param task The same search task for all states |
66 | * \param skip_first_n_feature_extraction Skip feature extraction for the first n states |
67 | * \param max_n_bufs The maximum number of extracted buffers for one statement |
68 | * \param features The returned feature vector. The innermost vector contains the |
69 | * feature vectors for all BufferStoreNode statements |
70 | */ |
71 | void GetPerStoreFeaturesFromStates(const Array<State>& states, const SearchTask& task, |
72 | int , int max_n_bufs, |
73 | std::vector<std::vector<float>>* features); |
74 | |
75 | /*! |
76 | * \brief Get per-store feature from states of different tasks |
77 | * \param states The input states |
78 | * \param tasks The search tasks corresponding to the input states |
79 | * \param skip_first_n_feature_extraction Skip feature extraction for the first n states |
80 | * \param max_n_bufs The maximum number of extracted buffers for one statement |
81 | * \param features The returned feature vector. The innermost vector contains the |
82 | * feature vectors for all BufferStoreNode statements |
83 | */ |
84 | void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector<SearchTask>& tasks, |
85 | int , int max_n_bufs, |
86 | std::vector<std::vector<float>>* features); |
87 | |
88 | /*! |
89 | * \brief Get per-store features from a log file |
90 | * \param filename The name of log file |
91 | * \param max_lines Only read the first n lines of the file |
92 | * \param max_n_bufs The maximum number of extracted buffers for one statement |
93 | * \param features The returned feature vector. The innermost vector contains the |
94 | * feature vectors for all BufferStoreNode statements |
95 | * \param normalized_throughputs The normalized throughputs for all states |
96 | * \param task_ids The task ids for all states |
97 | */ |
98 | void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs, |
99 | std::vector<std::vector<float>>* features, |
100 | std::vector<float>* normalized_throughputs, |
101 | std::vector<int>* task_ids); |
102 | |
103 | /*! |
104 | * \brief Get per-store features from measurement input/result pairs |
105 | * \param inputs The measurement inputs |
106 | * \param results The measurement results |
107 | * \param skip_first_n_feature_extraction Skip feature extraction for the first n measurement pairs |
108 | * \param max_n_bufs The maximum number of extracted buffers for one statement |
109 | * \param features The returned feature vector. The innermost vector contains the |
110 | * feature vectors for all BufferStoreNode statements |
111 | * \param normalized_throughputs The normalized throughputs for all states |
112 | * \param task_ids The task ids for all states |
113 | */ |
114 | void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs, |
115 | const Array<MeasureResult>& results, |
116 | int , int max_n_bufs, |
117 | std::vector<std::vector<float>>* features, |
118 | std::vector<float>* normalized_throughputs, |
119 | std::vector<int>* task_ids); |
120 | |
121 | } // namespace auto_scheduler |
122 | } // namespace tvm |
123 | |
124 | #endif // TVM_AUTO_SCHEDULER_FEATURE_H_ |
125 | |