1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ |
21 | #define |
22 | |
23 | #include <tvm/meta_schedule/measure_candidate.h> |
24 | #include <tvm/node/reflection.h> |
25 | #include <tvm/runtime/container/array.h> |
26 | #include <tvm/runtime/container/string.h> |
27 | #include <tvm/runtime/ndarray.h> |
28 | #include <tvm/runtime/object.h> |
29 | #include <tvm/runtime/packed_func.h> |
30 | |
31 | namespace tvm { |
32 | namespace meta_schedule { |
33 | |
34 | class TuneContext; |
35 | |
36 | /*! \brief Extractor for features from measure candidates for use in cost model. */ |
37 | class : public runtime::Object { |
38 | public: |
39 | /*! \brief Virtual destructor. */ |
40 | virtual () = default; |
41 | |
42 | void (tvm::AttrVisitor* v) {} |
43 | |
44 | /*! |
45 | * \brief Extract features from the given measure candidate. |
46 | * \param context The tuning context for feature extraction. |
47 | * \param candidates The measure candidates to extract features from. |
48 | * \return The feature ndarray extracted. |
49 | */ |
50 | virtual Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& context, |
51 | const Array<MeasureCandidate>& candidates) = 0; |
52 | |
53 | static constexpr const char* = "meta_schedule.FeatureExtractor" ; |
54 | TVM_DECLARE_BASE_OBJECT_INFO(FeatureExtractorNode, Object); |
55 | }; |
56 | |
57 | /*! \brief The feature extractor with customized methods on the python-side. */ |
58 | class : public FeatureExtractorNode { |
59 | public: |
60 | /*! |
61 | * \brief Extract features from the given measure candidate. |
62 | * \param context The tuning context for feature extraction. |
63 | * \param candidates The measure candidates to extract features from. |
64 | * \return The feature ndarray extracted. |
65 | */ |
66 | using = runtime::TypedPackedFunc<Array<tvm::runtime::NDArray>( |
67 | const TuneContext& context, const Array<MeasureCandidate>& candidates)>; |
68 | /*! |
69 | * \brief Get the feature extractor as string with name. |
70 | * \return The string of the feature extractor. |
71 | */ |
72 | using = runtime::TypedPackedFunc<String()>; |
73 | |
74 | /*! \brief The packed function to the `ExtractFrom` function. */ |
75 | FExtractFrom ; |
76 | /*! \brief The packed function to the `AsString` function. */ |
77 | FAsString ; |
78 | |
79 | void (tvm::AttrVisitor* v) { |
80 | // `f_extract_from` is not visited |
81 | // `f_as_string` is not visited |
82 | } |
83 | |
84 | Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& context, |
85 | const Array<MeasureCandidate>& candidates) final; |
86 | |
87 | static constexpr const char* = "meta_schedule.PyFeatureExtractor" ; |
88 | TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode); |
89 | }; |
90 | |
91 | /*! |
92 | * \brief Managed reference to FeatureExtractorNode |
93 | * \sa FeatureExtractorNode |
94 | */ |
95 | class : public runtime::ObjectRef { |
96 | public: |
97 | /*! |
98 | * \brief Create a feature extractor that extracts features from each BufferStore |
99 | * \param buffers_per_store The number of buffers in each BufferStore; Pad or truncate if |
100 | * necessary. |
101 | * \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity |
102 | * curve. |
103 | * \param cache_line_bytes The number of bytes in a cache line. |
104 | * \param extract_workload Whether to extract features in the workload in tuning context or not. |
105 | * \return The feature extractor created. |
106 | */ |
107 | TVM_DLL static FeatureExtractor (int buffers_per_store = 5, |
108 | int arith_intensity_curve_num_samples = 10, |
109 | int cache_line_bytes = 64, |
110 | bool = false); |
111 | /*! |
112 | * \brief Create a feature extractor with customized methods on the python-side. |
113 | * \param f_extract_from The packed function of `ExtractFrom`. |
114 | * \param f_as_string The packed function of `AsString`. |
115 | * \return The feature extractor created. |
116 | */ |
117 | TVM_DLL static FeatureExtractor ( |
118 | PyFeatureExtractorNode::FExtractFrom , |
119 | PyFeatureExtractorNode::FAsString f_as_string); |
120 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(, ObjectRef, FeatureExtractorNode); |
121 | }; |
122 | |
123 | } // namespace meta_schedule |
124 | } // namespace tvm |
125 | |
126 | #endif // TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_ |
127 | |