1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24Array<tvm::runtime::NDArray> PyFeatureExtractorNode::ExtractFrom(
25 const TuneContext& context, const Array<MeasureCandidate>& candidates) {
26 ICHECK(f_extract_from != nullptr) << "PyFeatureExtractor's ExtractFrom method not implemented!";
27 return f_extract_from(context, candidates);
28}
29
30FeatureExtractor FeatureExtractor::PyFeatureExtractor(
31 PyFeatureExtractorNode::FExtractFrom f_extract_from, //
32 PyFeatureExtractorNode::FAsString f_as_string) {
33 ObjectPtr<PyFeatureExtractorNode> n = make_object<PyFeatureExtractorNode>();
34 n->f_extract_from = std::move(f_extract_from);
35 n->f_as_string = std::move(f_as_string);
36 return FeatureExtractor(n);
37}
38
39TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
40 .set_dispatch<PyFeatureExtractorNode>([](const ObjectRef& n, ReprPrinter* p) {
41 const auto* self = n.as<PyFeatureExtractorNode>();
42 ICHECK(self);
43 PyFeatureExtractorNode::FAsString f_as_string = (*self).f_as_string;
44 ICHECK(f_as_string != nullptr) << "PyFeatureExtractor's AsString method not implemented!";
45 p->stream << f_as_string();
46 });
47
48TVM_REGISTER_OBJECT_TYPE(FeatureExtractorNode);
49TVM_REGISTER_NODE_TYPE(PyFeatureExtractorNode);
50
51TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorExtractFrom")
52 .set_body_method<FeatureExtractor>(&FeatureExtractorNode::ExtractFrom);
53TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPyFeatureExtractor")
54 .set_body_typed(FeatureExtractor::PyFeatureExtractor);
55
56} // namespace meta_schedule
57} // namespace tvm
58