1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/executor.cc
22 * \brief Executor Registry
23 */
24
25#include <tvm/relay/executor.h>
26
27#include "../../node/attr_registry.h"
28namespace tvm {
29namespace relay {
30
31TVM_REGISTER_NODE_TYPE(ExecutorNode);
32
33TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
34 .set_dispatch<ExecutorNode>([](const ObjectRef& obj, ReprPrinter* p) {
35 const Executor& executor = Downcast<Executor>(obj);
36 p->stream << executor->name;
37 p->stream << executor->attrs;
38 });
39
40/********** Registry-related code **********/
41
42using ExecutorRegistry = AttrRegistry<ExecutorRegEntry, Executor>;
43
44Executor Executor::Create(String name, Map<String, ObjectRef> attrs) {
45 const ExecutorRegEntry* reg = ExecutorRegistry::Global()->Get(name);
46 if (reg == nullptr) {
47 throw Error("Executor \"" + name + "\" is not defined");
48 }
49
50 for (const auto& kv : attrs) {
51 if (!reg->key2vtype_.count(kv.first)) {
52 throw Error("Attribute \"" + kv.first + "\" is not available on this Executor");
53 }
54 std::string expected_type = reg->key2vtype_.at(kv.first).type_key;
55 std::string actual_type = kv.second->GetTypeKey();
56 if (expected_type != actual_type) {
57 throw Error("Attribute \"" + kv.first + "\" should have type \"" + expected_type +
58 "\" but instead found \"" + actual_type + "\"");
59 }
60 }
61
62 for (const auto& kv : reg->key2default_) {
63 if (!attrs.count(kv.first)) {
64 attrs.Set(kv.first, kv.second);
65 }
66 }
67
68 return Executor(name, DictAttrs(attrs));
69}
70
71Array<String> Executor::ListExecutors() { return ExecutorRegistry::Global()->ListAllNames(); }
72
73Map<String, String> Executor::ListExecutorOptions(const String& name) {
74 Map<String, String> options;
75 const ExecutorRegEntry* reg = ExecutorRegistry::Global()->Get(name);
76 if (reg == nullptr) {
77 throw Error("Executor \"" + name + "\" is not defined");
78 }
79 for (const auto& kv : reg->key2vtype_) {
80 options.Set(kv.first, kv.second.type_key);
81 }
82 return options;
83}
84
85ExecutorRegEntry& ExecutorRegEntry::RegisterOrGet(const String& name) {
86 return ExecutorRegistry::Global()->RegisterOrGet(name);
87}
88
89/********** Register Executors and options **********/
90
91TVM_REGISTER_EXECUTOR("aot")
92 .add_attr_option<Bool>("link-params", Bool(true))
93 .add_attr_option<Bool>("unpacked-api")
94 .add_attr_option<String>("interface-api")
95 .add_attr_option<Integer>("workspace-byte-alignment")
96 .add_attr_option<Integer>("constant-byte-alignment");
97
98TVM_REGISTER_EXECUTOR("graph").add_attr_option<Bool>("link-params", Bool(false));
99
100/********** Registry **********/
101
102TVM_REGISTER_GLOBAL("relay.backend.CreateExecutor").set_body_typed(Executor::Create);
103TVM_REGISTER_GLOBAL("relay.backend.GetExecutorAttrs").set_body_typed([](const Executor& executor) {
104 return executor->attrs->dict;
105});
106
107TVM_REGISTER_GLOBAL("relay.backend.ListExecutors").set_body_typed(Executor::ListExecutors);
108TVM_REGISTER_GLOBAL("relay.backend.ListExecutorOptions")
109 .set_body_typed(Executor::ListExecutorOptions);
110
111} // namespace relay
112} // namespace tvm
113