1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/executor.cc |
22 | * \brief Executor Registry |
23 | */ |
24 | |
25 | #include <tvm/relay/executor.h> |
26 | |
27 | #include "../../node/attr_registry.h" |
28 | namespace tvm { |
29 | namespace relay { |
30 | |
31 | TVM_REGISTER_NODE_TYPE(ExecutorNode); |
32 | |
33 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
34 | .set_dispatch<ExecutorNode>([](const ObjectRef& obj, ReprPrinter* p) { |
35 | const Executor& executor = Downcast<Executor>(obj); |
36 | p->stream << executor->name; |
37 | p->stream << executor->attrs; |
38 | }); |
39 | |
40 | /********** Registry-related code **********/ |
41 | |
42 | using ExecutorRegistry = AttrRegistry<ExecutorRegEntry, Executor>; |
43 | |
44 | Executor Executor::Create(String name, Map<String, ObjectRef> attrs) { |
45 | const ExecutorRegEntry* reg = ExecutorRegistry::Global()->Get(name); |
46 | if (reg == nullptr) { |
47 | throw Error("Executor \"" + name + "\" is not defined" ); |
48 | } |
49 | |
50 | for (const auto& kv : attrs) { |
51 | if (!reg->key2vtype_.count(kv.first)) { |
52 | throw Error("Attribute \"" + kv.first + "\" is not available on this Executor" ); |
53 | } |
54 | std::string expected_type = reg->key2vtype_.at(kv.first).type_key; |
55 | std::string actual_type = kv.second->GetTypeKey(); |
56 | if (expected_type != actual_type) { |
57 | throw Error("Attribute \"" + kv.first + "\" should have type \"" + expected_type + |
58 | "\" but instead found \"" + actual_type + "\"" ); |
59 | } |
60 | } |
61 | |
62 | for (const auto& kv : reg->key2default_) { |
63 | if (!attrs.count(kv.first)) { |
64 | attrs.Set(kv.first, kv.second); |
65 | } |
66 | } |
67 | |
68 | return Executor(name, DictAttrs(attrs)); |
69 | } |
70 | |
71 | Array<String> Executor::ListExecutors() { return ExecutorRegistry::Global()->ListAllNames(); } |
72 | |
73 | Map<String, String> Executor::ListExecutorOptions(const String& name) { |
74 | Map<String, String> options; |
75 | const ExecutorRegEntry* reg = ExecutorRegistry::Global()->Get(name); |
76 | if (reg == nullptr) { |
77 | throw Error("Executor \"" + name + "\" is not defined" ); |
78 | } |
79 | for (const auto& kv : reg->key2vtype_) { |
80 | options.Set(kv.first, kv.second.type_key); |
81 | } |
82 | return options; |
83 | } |
84 | |
85 | ExecutorRegEntry& ExecutorRegEntry::RegisterOrGet(const String& name) { |
86 | return ExecutorRegistry::Global()->RegisterOrGet(name); |
87 | } |
88 | |
89 | /********** Register Executors and options **********/ |
90 | |
91 | TVM_REGISTER_EXECUTOR("aot" ) |
92 | .add_attr_option<Bool>("link-params" , Bool(true)) |
93 | .add_attr_option<Bool>("unpacked-api" ) |
94 | .add_attr_option<String>("interface-api" ) |
95 | .add_attr_option<Integer>("workspace-byte-alignment" ) |
96 | .add_attr_option<Integer>("constant-byte-alignment" ); |
97 | |
98 | TVM_REGISTER_EXECUTOR("graph" ).add_attr_option<Bool>("link-params" , Bool(false)); |
99 | |
100 | /********** Registry **********/ |
101 | |
102 | TVM_REGISTER_GLOBAL("relay.backend.CreateExecutor" ).set_body_typed(Executor::Create); |
103 | TVM_REGISTER_GLOBAL("relay.backend.GetExecutorAttrs" ).set_body_typed([](const Executor& executor) { |
104 | return executor->attrs->dict; |
105 | }); |
106 | |
107 | TVM_REGISTER_GLOBAL("relay.backend.ListExecutors" ).set_body_typed(Executor::ListExecutors); |
108 | TVM_REGISTER_GLOBAL("relay.backend.ListExecutorOptions" ) |
109 | .set_body_typed(Executor::ListExecutorOptions); |
110 | |
111 | } // namespace relay |
112 | } // namespace tvm |
113 | |