1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24class OrderedUnionDatabaseNode : public DatabaseNode {
25 public:
26 Array<Database> databases;
27
28 void VisitAttrs(AttrVisitor* v) { v->Visit("databases", &databases); }
29
30 static constexpr const char* _type_key = "meta_schedule.OrderedUnionDatabase";
31 TVM_DECLARE_FINAL_OBJECT_INFO(OrderedUnionDatabaseNode, DatabaseNode);
32
33 public:
34 Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
35 const String& task_name) final {
36 for (const Database& db : databases) {
37 if (Optional<TuningRecord> record = db->QueryTuningRecord(mod, target, task_name)) {
38 return record;
39 }
40 }
41 return NullOpt;
42 }
43
44 bool HasWorkload(const IRModule& mod) final {
45 LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.HasWorkload";
46 throw;
47 }
48
49 Workload CommitWorkload(const IRModule& mod) final {
50 LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.CommitWorkload";
51 throw;
52 }
53
54 void CommitTuningRecord(const TuningRecord& record) final {
55 LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.CommitTuningRecord";
56 throw;
57 }
58
59 Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
60 LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.GetTopK";
61 throw;
62 }
63
64 Array<TuningRecord> GetAllTuningRecords() final {
65 LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.GetAllTuningRecords";
66 throw;
67 }
68
69 int64_t Size() final {
70 LOG(FATAL) << "NotImplementedError: OrderedUnionDatabase.size";
71 throw;
72 }
73};
74
75Database Database::OrderedUnionDatabase(Array<Database> databases) {
76 ObjectPtr<OrderedUnionDatabaseNode> n = make_object<OrderedUnionDatabaseNode>();
77 n->databases = std::move(databases);
78 return Database(n);
79}
80
81TVM_REGISTER_NODE_TYPE(OrderedUnionDatabaseNode);
82TVM_REGISTER_GLOBAL("meta_schedule.DatabaseOrderedUnionDatabase")
83 .set_body_typed(Database::OrderedUnionDatabase);
84
85} // namespace meta_schedule
86} // namespace tvm
87