1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24class UnionDatabaseNode : public DatabaseNode {
25 public:
26 Array<Database> databases;
27
28 void VisitAttrs(AttrVisitor* v) { v->Visit("databases", &databases); }
29
30 static constexpr const char* _type_key = "meta_schedule.UnionDatabase";
31 TVM_DECLARE_FINAL_OBJECT_INFO(UnionDatabaseNode, DatabaseNode);
32
33 public:
34 Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
35 const String& task_name) final {
36 std::vector<TuningRecord> results;
37 results.reserve(databases.size());
38 for (const Database& db : databases) {
39 if (Optional<TuningRecord> record = db->QueryTuningRecord(mod, target, task_name)) {
40 results.push_back(record.value());
41 }
42 }
43 std::stable_sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs());
44 return results.empty() ? Optional<TuningRecord>(NullOpt) : results[0];
45 }
46
47 bool HasWorkload(const IRModule& mod) final {
48 LOG(FATAL) << "NotImplementedError: UnionDatabase.HasWorkload";
49 throw;
50 }
51
52 Workload CommitWorkload(const IRModule& mod) final {
53 LOG(FATAL) << "NotImplementedError: UnionDatabase.CommitWorkload";
54 throw;
55 }
56
57 void CommitTuningRecord(const TuningRecord& record) final {
58 LOG(FATAL) << "NotImplementedError: UnionDatabase.CommitTuningRecord";
59 throw;
60 }
61
62 Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
63 LOG(FATAL) << "NotImplementedError: UnionDatabase.GetTopK";
64 throw;
65 }
66
67 Array<TuningRecord> GetAllTuningRecords() final {
68 LOG(FATAL) << "NotImplementedError: UnionDatabase.GetAllTuningRecords";
69 throw;
70 }
71
72 int64_t Size() final {
73 LOG(FATAL) << "NotImplementedError: UnionDatabase.size";
74 throw;
75 }
76};
77
78Database Database::UnionDatabase(Array<Database> databases) {
79 ObjectPtr<UnionDatabaseNode> n = make_object<UnionDatabaseNode>();
80 n->databases = std::move(databases);
81 return Database(n);
82}
83
84TVM_REGISTER_NODE_TYPE(UnionDatabaseNode);
85TVM_REGISTER_GLOBAL("meta_schedule.DatabaseUnionDatabase").set_body_typed(Database::UnionDatabase);
86
87} // namespace meta_schedule
88} // namespace tvm
89