1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/runtime.cc
22 * \brief Runtime Registry
23 */
24
25#include <tvm/relay/runtime.h>
26
27#include "../../node/attr_registry.h"
28
29namespace tvm {
30namespace relay {
31
32TVM_REGISTER_NODE_TYPE(RuntimeNode);
33
34TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
35 .set_dispatch<RuntimeNode>([](const ObjectRef& obj, ReprPrinter* p) {
36 const Runtime& runtime = Downcast<Runtime>(obj);
37 p->stream << runtime->name;
38 });
39
40/********** Registry-related code **********/
41
42using RuntimeRegistry = AttrRegistry<RuntimeRegEntry, Runtime>;
43
44Runtime Runtime::Create(String name, Map<String, ObjectRef> attrs) {
45 const RuntimeRegEntry* reg = RuntimeRegistry::Global()->Get(name);
46 if (reg == nullptr) {
47 throw Error("Runtime \"" + name + "\" is not defined");
48 }
49
50 for (const auto& kv : attrs) {
51 if (!reg->key2vtype_.count(kv.first)) {
52 throw Error("Attribute \"" + kv.first + "\" is not available on this Runtime");
53 }
54 std::string expected_type = reg->key2vtype_.at(kv.first).type_key;
55 std::string actual_type = kv.second->GetTypeKey();
56 if (expected_type != actual_type) {
57 throw Error("Attribute \"" + kv.first + "\" should have type \"" + expected_type +
58 "\" but instead found \"" + actual_type + "\"");
59 }
60 }
61
62 for (const auto& kv : reg->key2default_) {
63 if (!attrs.count(kv.first)) {
64 attrs.Set(kv.first, kv.second);
65 }
66 }
67
68 return Runtime(name, DictAttrs(attrs));
69}
70
71Array<String> Runtime::ListRuntimes() { return RuntimeRegistry::Global()->ListAllNames(); }
72
73Map<String, String> Runtime::ListRuntimeOptions(const String& name) {
74 Map<String, String> options;
75 const RuntimeRegEntry* reg = RuntimeRegistry::Global()->Get(name);
76 if (reg == nullptr) {
77 throw Error("Runtime \"" + name + "\" is not defined");
78 }
79 for (const auto& kv : reg->key2vtype_) {
80 options.Set(kv.first, kv.second.type_key);
81 }
82 return options;
83}
84
85RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) {
86 return RuntimeRegistry::Global()->RegisterOrGet(name);
87}
88
89/********** Register Runtimes and options **********/
90
91TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option<Bool>("system-lib");
92
93TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option<Bool>("system-lib");
94
95/********** Registry **********/
96
97TVM_REGISTER_GLOBAL("relay.backend.CreateRuntime").set_body_typed(Runtime::Create);
98TVM_REGISTER_GLOBAL("relay.backend.GetRuntimeAttrs").set_body_typed([](const Runtime& runtime) {
99 return runtime->attrs->dict;
100});
101
102TVM_REGISTER_GLOBAL("relay.backend.ListRuntimes").set_body_typed(Runtime::ListRuntimes);
103TVM_REGISTER_GLOBAL("relay.backend.ListRuntimeOptions").set_body_typed(Runtime::ListRuntimeOptions);
104
105} // namespace relay
106} // namespace tvm
107