1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#ifndef TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_
21#define TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_
22
23#include <tvm/meta_schedule/measure_candidate.h>
24#include <tvm/node/reflection.h>
25#include <tvm/runtime/container/array.h>
26#include <tvm/runtime/container/string.h>
27#include <tvm/runtime/ndarray.h>
28#include <tvm/runtime/object.h>
29#include <tvm/runtime/packed_func.h>
30
31namespace tvm {
32namespace meta_schedule {
33
34class TuneContext;
35
36/*! \brief Extractor for features from measure candidates for use in cost model. */
37class FeatureExtractorNode : public runtime::Object {
38 public:
39 /*! \brief Virtual destructor. */
40 virtual ~FeatureExtractorNode() = default;
41
42 void VisitAttrs(tvm::AttrVisitor* v) {}
43
44 /*!
45 * \brief Extract features from the given measure candidate.
46 * \param context The tuning context for feature extraction.
47 * \param candidates The measure candidates to extract features from.
48 * \return The feature ndarray extracted.
49 */
50 virtual Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& context,
51 const Array<MeasureCandidate>& candidates) = 0;
52
53 static constexpr const char* _type_key = "meta_schedule.FeatureExtractor";
54 TVM_DECLARE_BASE_OBJECT_INFO(FeatureExtractorNode, Object);
55};
56
57/*! \brief The feature extractor with customized methods on the python-side. */
58class PyFeatureExtractorNode : public FeatureExtractorNode {
59 public:
60 /*!
61 * \brief Extract features from the given measure candidate.
62 * \param context The tuning context for feature extraction.
63 * \param candidates The measure candidates to extract features from.
64 * \return The feature ndarray extracted.
65 */
66 using FExtractFrom = runtime::TypedPackedFunc<Array<tvm::runtime::NDArray>(
67 const TuneContext& context, const Array<MeasureCandidate>& candidates)>;
68 /*!
69 * \brief Get the feature extractor as string with name.
70 * \return The string of the feature extractor.
71 */
72 using FAsString = runtime::TypedPackedFunc<String()>;
73
74 /*! \brief The packed function to the `ExtractFrom` function. */
75 FExtractFrom f_extract_from;
76 /*! \brief The packed function to the `AsString` function. */
77 FAsString f_as_string;
78
79 void VisitAttrs(tvm::AttrVisitor* v) {
80 // `f_extract_from` is not visited
81 // `f_as_string` is not visited
82 }
83
84 Array<tvm::runtime::NDArray> ExtractFrom(const TuneContext& context,
85 const Array<MeasureCandidate>& candidates) final;
86
87 static constexpr const char* _type_key = "meta_schedule.PyFeatureExtractor";
88 TVM_DECLARE_FINAL_OBJECT_INFO(PyFeatureExtractorNode, FeatureExtractorNode);
89};
90
91/*!
92 * \brief Managed reference to FeatureExtractorNode
93 * \sa FeatureExtractorNode
94 */
95class FeatureExtractor : public runtime::ObjectRef {
96 public:
97 /*!
98 * \brief Create a feature extractor that extracts features from each BufferStore
99 * \param buffers_per_store The number of buffers in each BufferStore; Pad or truncate if
100 * necessary.
101 * \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity
102 * curve.
103 * \param cache_line_bytes The number of bytes in a cache line.
104 * \param extract_workload Whether to extract features in the workload in tuning context or not.
105 * \return The feature extractor created.
106 */
107 TVM_DLL static FeatureExtractor PerStoreFeature(int buffers_per_store = 5,
108 int arith_intensity_curve_num_samples = 10,
109 int cache_line_bytes = 64,
110 bool extract_workload = false);
111 /*!
112 * \brief Create a feature extractor with customized methods on the python-side.
113 * \param f_extract_from The packed function of `ExtractFrom`.
114 * \param f_as_string The packed function of `AsString`.
115 * \return The feature extractor created.
116 */
117 TVM_DLL static FeatureExtractor PyFeatureExtractor(
118 PyFeatureExtractorNode::FExtractFrom f_extract_from,
119 PyFeatureExtractorNode::FAsString f_as_string);
120 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FeatureExtractor, ObjectRef, FeatureExtractorNode);
121};
122
123} // namespace meta_schedule
124} // namespace tvm
125
126#endif // TVM_META_SCHEDULE_FEATURE_EXTRACTOR_H_
127