1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/feature.h
22 * \brief Feature extraction for the cost model.
23 * We extract one feature vector per BufferStoreNode statement in a TIR Stmt,
24 * so we call this feature as "per-store" feature.
25 * The cost model also does prediction for each BufferStoreNode statement and aggregates
26 * the predictions as the whole score for a TVM IR (Stmt).
27 *
28 * The feature specification is defined by `src/auto_scheduler/feature.cc:: FeatureSet`
29 */
30
31#ifndef TVM_AUTO_SCHEDULER_FEATURE_H_
32#define TVM_AUTO_SCHEDULER_FEATURE_H_
33
34#include <tvm/auto_scheduler/compute_dag.h>
35#include <tvm/auto_scheduler/measure.h>
36#include <tvm/tir/function.h>
37
38#include <string>
39#include <vector>
40
41namespace tvm {
42namespace auto_scheduler {
43
44/*!
45 * \brief Get per-store features from a TIR PrimFunc
46 * \param func The input lowered TIR PrimFunc
47 * \param cache_line_size The size of cache line in bytes
48 * \param max_n_bufs The maximum number of extracted buffers for one statement
49 * \param ret The returned feature vector
50 * \param log_scale Should the outputs be scaled by log2(1+x).
51 */
52void GetPerStoreFeature(const PrimFunc& func, int cache_line_size, int max_n_bufs,
53 std::vector<float>* ret, bool log_scale = true);
54
55/*
56 * \brief Get the names of elements in the feature vector. Use this for debug and inspection.
57 * \param max_n_bufs The maximum number of extracted buffers for one statement
58 * \param ret The returned names.
59 */
60void GetPerStoreFeatureName(int max_n_bufs, std::vector<std::string>* ret);
61
62/*!
63 * \brief Get per-store feature from states of the same task
64 * \param states The input states
65 * \param task The same search task for all states
66 * \param skip_first_n_feature_extraction Skip feature extraction for the first n states
67 * \param max_n_bufs The maximum number of extracted buffers for one statement
68 * \param features The returned feature vector. The innermost vector contains the
69 * feature vectors for all BufferStoreNode statements
70 */
71void GetPerStoreFeaturesFromStates(const Array<State>& states, const SearchTask& task,
72 int skip_first_n_feature_extraction, int max_n_bufs,
73 std::vector<std::vector<float>>* features);
74
75/*!
76 * \brief Get per-store feature from states of different tasks
77 * \param states The input states
78 * \param tasks The search tasks corresponding to the input states
79 * \param skip_first_n_feature_extraction Skip feature extraction for the first n states
80 * \param max_n_bufs The maximum number of extracted buffers for one statement
81 * \param features The returned feature vector. The innermost vector contains the
82 * feature vectors for all BufferStoreNode statements
83 */
84void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector<SearchTask>& tasks,
85 int skip_first_n_feature_extraction, int max_n_bufs,
86 std::vector<std::vector<float>>* features);
87
88/*!
89 * \brief Get per-store features from a log file
90 * \param filename The name of log file
91 * \param max_lines Only read the first n lines of the file
92 * \param max_n_bufs The maximum number of extracted buffers for one statement
93 * \param features The returned feature vector. The innermost vector contains the
94 * feature vectors for all BufferStoreNode statements
95 * \param normalized_throughputs The normalized throughputs for all states
96 * \param task_ids The task ids for all states
97 */
98void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs,
99 std::vector<std::vector<float>>* features,
100 std::vector<float>* normalized_throughputs,
101 std::vector<int>* task_ids);
102
103/*!
104 * \brief Get per-store features from measurement input/result pairs
105 * \param inputs The measurement inputs
106 * \param results The measurement results
107 * \param skip_first_n_feature_extraction Skip feature extraction for the first n measurement pairs
108 * \param max_n_bufs The maximum number of extracted buffers for one statement
109 * \param features The returned feature vector. The innermost vector contains the
110 * feature vectors for all BufferStoreNode statements
111 * \param normalized_throughputs The normalized throughputs for all states
112 * \param task_ids The task ids for all states
113 */
114void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
115 const Array<MeasureResult>& results,
116 int skip_first_n_feature_extraction, int max_n_bufs,
117 std::vector<std::vector<float>>* features,
118 std::vector<float>* normalized_throughputs,
119 std::vector<int>* task_ids);
120
121} // namespace auto_scheduler
122} // namespace tvm
123
124#endif // TVM_AUTO_SCHEDULER_FEATURE_H_
125