1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file src/target/generic_func.cc |
21 | */ |
22 | #include <dmlc/thread_local.h> |
23 | #include <tvm/node/node.h> |
24 | #include <tvm/node/repr_printer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/target/generic_func.h> |
27 | #include <tvm/target/target.h> |
28 | #include <tvm/tir/expr.h> |
29 | |
30 | #include <algorithm> |
31 | #include <mutex> |
32 | #include <stack> |
33 | |
34 | #include "../runtime/object_internal.h" |
35 | |
36 | namespace tvm { |
37 | |
38 | TVM_REGISTER_NODE_TYPE(GenericFuncNode); |
39 | |
40 | struct GenericFunc::Manager { |
41 | std::unordered_map<std::string, GenericFunc> fmap; |
42 | // mutex |
43 | std::mutex mutex; |
44 | |
45 | Manager() {} |
46 | |
47 | static Manager* Global() { |
48 | static Manager inst; |
49 | return &inst; |
50 | } |
51 | }; |
52 | |
53 | GenericFunc GenericFunc::Get(const std::string& name) { |
54 | Manager* m = Manager::Global(); |
55 | std::lock_guard<std::mutex> lock(m->mutex); |
56 | auto it = m->fmap.find(name); |
57 | if (it == m->fmap.end()) { |
58 | auto f = make_object<GenericFuncNode>(); |
59 | f->name_ = name; |
60 | auto gf = GenericFunc(f); |
61 | m->fmap[name] = gf; |
62 | return gf; |
63 | } else { |
64 | return it->second; |
65 | } |
66 | } |
67 | |
68 | void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) { |
69 | Manager* m = Manager::Global(); |
70 | std::lock_guard<std::mutex> lock(m->mutex); |
71 | auto it = m->fmap.find(name); |
72 | ICHECK(it == m->fmap.end()) << "GenericFunc already registered " << name; |
73 | func->name_ = name; |
74 | m->fmap[name] = func; |
75 | } |
76 | |
77 | GenericFunc& GenericFunc::set_default(const PackedFunc value, bool allow_override) { |
78 | auto node = static_cast<GenericFuncNode*>(operator->()); |
79 | if (!allow_override) { |
80 | ICHECK(node->generic_func_ == nullptr) |
81 | << "Generic function already registered for " << node->name_; |
82 | } |
83 | node->generic_func_ = value; |
84 | return *this; |
85 | } |
86 | |
87 | GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags, |
88 | const PackedFunc value, bool allow_override) { |
89 | for (auto& t : tags) { |
90 | if (!allow_override) { |
91 | auto iter = (*this)->dispatch_dict_.find(t); |
92 | ICHECK(iter == (*this)->dispatch_dict_.end()) |
93 | << "Tag " << t << " already registered for schedule factory " << (*this)->name_; |
94 | } |
95 | (*this)->dispatch_dict_[t] = value; |
96 | } |
97 | return *this; |
98 | } |
99 | |
100 | void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { |
101 | auto node = static_cast<const GenericFuncNode*>(get()); |
102 | auto target = Target::Current(true); |
103 | PackedFunc func; |
104 | |
105 | if (target.defined()) { |
106 | for (auto& k : target->GetKeys()) { |
107 | auto iter = node->dispatch_dict_.find(k); |
108 | if (iter != node->dispatch_dict_.end()) { |
109 | func = iter->second; |
110 | break; |
111 | } |
112 | } |
113 | } |
114 | |
115 | if (func == nullptr) { |
116 | ICHECK(node->generic_func_ != nullptr) << "No generic function registered for " << node->name_; |
117 | func = node->generic_func_; |
118 | } |
119 | |
120 | func.CallPacked(args, ret); |
121 | } |
122 | |
123 | PackedFunc GenericFunc::GetPacked() const { |
124 | auto node = static_cast<const GenericFuncNode*>(get()); |
125 | auto target = Target::Current(true); |
126 | if (target.defined()) { |
127 | for (auto& k : target->GetKeys()) { |
128 | auto iter = node->dispatch_dict_.find(k); |
129 | if (iter != node->dispatch_dict_.end()) { |
130 | return iter->second; |
131 | } |
132 | } |
133 | } |
134 | return node->generic_func_; |
135 | } |
136 | |
137 | TVM_REGISTER_GLOBAL("target.GenericFuncCreate" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
138 | *ret = GenericFunc(make_object<GenericFuncNode>()); |
139 | }); |
140 | |
141 | TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
142 | std::string func_name = args[0]; |
143 | *ret = GenericFunc::Get(func_name); |
144 | }); |
145 | |
146 | TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
147 | GenericFunc generic_func = args[0]; |
148 | PackedFunc func = args[1]; |
149 | bool allow_override = args[2]; |
150 | // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown |
151 | runtime::ObjectInternal::ObjectRetain((TVMObjectHandle)(func.get())); |
152 | generic_func.set_default(func, allow_override); |
153 | }); |
154 | |
155 | TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
156 | GenericFunc generic_func = args[0]; |
157 | PackedFunc func = args[1]; |
158 | Array<runtime::String> tags = args[2]; |
159 | bool allow_override = args[3]; |
160 | // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown |
161 | runtime::ObjectInternal::ObjectRetain((TVMObjectHandle)(func.get())); |
162 | std::vector<std::string> tags_vector; |
163 | for (auto& tag : tags) { |
164 | tags_vector.push_back(tag); |
165 | } |
166 | |
167 | generic_func.register_func(tags_vector, func, allow_override); |
168 | }); |
169 | |
170 | TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
171 | GenericFunc generic_func = args[0]; |
172 | TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1); |
173 | |
174 | generic_func.CallPacked(func_args, ret); |
175 | }); |
176 | |
177 | TVM_REGISTER_GLOBAL("target.GenericFuncGetPackedFunc" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
178 | GenericFunc generic_func = args[0]; |
179 | *ret = generic_func.GetPacked(); |
180 | }); |
181 | |
182 | } // namespace tvm |
183 | |