1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/runtime.cc |
22 | * \brief Runtime Registry |
23 | */ |
24 | |
25 | #include <tvm/relay/runtime.h> |
26 | |
27 | #include "../../node/attr_registry.h" |
28 | |
29 | namespace tvm { |
30 | namespace relay { |
31 | |
32 | TVM_REGISTER_NODE_TYPE(RuntimeNode); |
33 | |
34 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
35 | .set_dispatch<RuntimeNode>([](const ObjectRef& obj, ReprPrinter* p) { |
36 | const Runtime& runtime = Downcast<Runtime>(obj); |
37 | p->stream << runtime->name; |
38 | }); |
39 | |
40 | /********** Registry-related code **********/ |
41 | |
42 | using RuntimeRegistry = AttrRegistry<RuntimeRegEntry, Runtime>; |
43 | |
44 | Runtime Runtime::Create(String name, Map<String, ObjectRef> attrs) { |
45 | const RuntimeRegEntry* reg = RuntimeRegistry::Global()->Get(name); |
46 | if (reg == nullptr) { |
47 | throw Error("Runtime \"" + name + "\" is not defined" ); |
48 | } |
49 | |
50 | for (const auto& kv : attrs) { |
51 | if (!reg->key2vtype_.count(kv.first)) { |
52 | throw Error("Attribute \"" + kv.first + "\" is not available on this Runtime" ); |
53 | } |
54 | std::string expected_type = reg->key2vtype_.at(kv.first).type_key; |
55 | std::string actual_type = kv.second->GetTypeKey(); |
56 | if (expected_type != actual_type) { |
57 | throw Error("Attribute \"" + kv.first + "\" should have type \"" + expected_type + |
58 | "\" but instead found \"" + actual_type + "\"" ); |
59 | } |
60 | } |
61 | |
62 | for (const auto& kv : reg->key2default_) { |
63 | if (!attrs.count(kv.first)) { |
64 | attrs.Set(kv.first, kv.second); |
65 | } |
66 | } |
67 | |
68 | return Runtime(name, DictAttrs(attrs)); |
69 | } |
70 | |
71 | Array<String> Runtime::ListRuntimes() { return RuntimeRegistry::Global()->ListAllNames(); } |
72 | |
73 | Map<String, String> Runtime::ListRuntimeOptions(const String& name) { |
74 | Map<String, String> options; |
75 | const RuntimeRegEntry* reg = RuntimeRegistry::Global()->Get(name); |
76 | if (reg == nullptr) { |
77 | throw Error("Runtime \"" + name + "\" is not defined" ); |
78 | } |
79 | for (const auto& kv : reg->key2vtype_) { |
80 | options.Set(kv.first, kv.second.type_key); |
81 | } |
82 | return options; |
83 | } |
84 | |
85 | RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { |
86 | return RuntimeRegistry::Global()->RegisterOrGet(name); |
87 | } |
88 | |
89 | /********** Register Runtimes and options **********/ |
90 | |
91 | TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option<Bool>("system-lib" ); |
92 | |
93 | TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option<Bool>("system-lib" ); |
94 | |
95 | /********** Registry **********/ |
96 | |
97 | TVM_REGISTER_GLOBAL("relay.backend.CreateRuntime" ).set_body_typed(Runtime::Create); |
98 | TVM_REGISTER_GLOBAL("relay.backend.GetRuntimeAttrs" ).set_body_typed([](const Runtime& runtime) { |
99 | return runtime->attrs->dict; |
100 | }); |
101 | |
102 | TVM_REGISTER_GLOBAL("relay.backend.ListRuntimes" ).set_body_typed(Runtime::ListRuntimes); |
103 | TVM_REGISTER_GLOBAL("relay.backend.ListRuntimeOptions" ).set_body_typed(Runtime::ListRuntimeOptions); |
104 | |
105 | } // namespace relay |
106 | } // namespace tvm |
107 | |