1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file src/target/generic_func.cc
21 */
22#include <dmlc/thread_local.h>
23#include <tvm/node/node.h>
24#include <tvm/node/repr_printer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/target/generic_func.h>
27#include <tvm/target/target.h>
28#include <tvm/tir/expr.h>
29
30#include <algorithm>
31#include <mutex>
32#include <stack>
33
34#include "../runtime/object_internal.h"
35
36namespace tvm {
37
38TVM_REGISTER_NODE_TYPE(GenericFuncNode);
39
40struct GenericFunc::Manager {
41 std::unordered_map<std::string, GenericFunc> fmap;
42 // mutex
43 std::mutex mutex;
44
45 Manager() {}
46
47 static Manager* Global() {
48 static Manager inst;
49 return &inst;
50 }
51};
52
53GenericFunc GenericFunc::Get(const std::string& name) {
54 Manager* m = Manager::Global();
55 std::lock_guard<std::mutex> lock(m->mutex);
56 auto it = m->fmap.find(name);
57 if (it == m->fmap.end()) {
58 auto f = make_object<GenericFuncNode>();
59 f->name_ = name;
60 auto gf = GenericFunc(f);
61 m->fmap[name] = gf;
62 return gf;
63 } else {
64 return it->second;
65 }
66}
67
68void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) {
69 Manager* m = Manager::Global();
70 std::lock_guard<std::mutex> lock(m->mutex);
71 auto it = m->fmap.find(name);
72 ICHECK(it == m->fmap.end()) << "GenericFunc already registered " << name;
73 func->name_ = name;
74 m->fmap[name] = func;
75}
76
77GenericFunc& GenericFunc::set_default(const PackedFunc value, bool allow_override) {
78 auto node = static_cast<GenericFuncNode*>(operator->());
79 if (!allow_override) {
80 ICHECK(node->generic_func_ == nullptr)
81 << "Generic function already registered for " << node->name_;
82 }
83 node->generic_func_ = value;
84 return *this;
85}
86
87GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
88 const PackedFunc value, bool allow_override) {
89 for (auto& t : tags) {
90 if (!allow_override) {
91 auto iter = (*this)->dispatch_dict_.find(t);
92 ICHECK(iter == (*this)->dispatch_dict_.end())
93 << "Tag " << t << " already registered for schedule factory " << (*this)->name_;
94 }
95 (*this)->dispatch_dict_[t] = value;
96 }
97 return *this;
98}
99
100void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
101 auto node = static_cast<const GenericFuncNode*>(get());
102 auto target = Target::Current(true);
103 PackedFunc func;
104
105 if (target.defined()) {
106 for (auto& k : target->GetKeys()) {
107 auto iter = node->dispatch_dict_.find(k);
108 if (iter != node->dispatch_dict_.end()) {
109 func = iter->second;
110 break;
111 }
112 }
113 }
114
115 if (func == nullptr) {
116 ICHECK(node->generic_func_ != nullptr) << "No generic function registered for " << node->name_;
117 func = node->generic_func_;
118 }
119
120 func.CallPacked(args, ret);
121}
122
123PackedFunc GenericFunc::GetPacked() const {
124 auto node = static_cast<const GenericFuncNode*>(get());
125 auto target = Target::Current(true);
126 if (target.defined()) {
127 for (auto& k : target->GetKeys()) {
128 auto iter = node->dispatch_dict_.find(k);
129 if (iter != node->dispatch_dict_.end()) {
130 return iter->second;
131 }
132 }
133 }
134 return node->generic_func_;
135}
136
137TVM_REGISTER_GLOBAL("target.GenericFuncCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
138 *ret = GenericFunc(make_object<GenericFuncNode>());
139});
140
141TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal").set_body([](TVMArgs args, TVMRetValue* ret) {
142 std::string func_name = args[0];
143 *ret = GenericFunc::Get(func_name);
144});
145
146TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault").set_body([](TVMArgs args, TVMRetValue* ret) {
147 GenericFunc generic_func = args[0];
148 PackedFunc func = args[1];
149 bool allow_override = args[2];
150 // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
151 runtime::ObjectInternal::ObjectRetain((TVMObjectHandle)(func.get()));
152 generic_func.set_default(func, allow_override);
153});
154
155TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc").set_body([](TVMArgs args, TVMRetValue* ret) {
156 GenericFunc generic_func = args[0];
157 PackedFunc func = args[1];
158 Array<runtime::String> tags = args[2];
159 bool allow_override = args[3];
160 // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
161 runtime::ObjectInternal::ObjectRetain((TVMObjectHandle)(func.get()));
162 std::vector<std::string> tags_vector;
163 for (auto& tag : tags) {
164 tags_vector.push_back(tag);
165 }
166
167 generic_func.register_func(tags_vector, func, allow_override);
168});
169
170TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc").set_body([](TVMArgs args, TVMRetValue* ret) {
171 GenericFunc generic_func = args[0];
172 TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1);
173
174 generic_func.CallPacked(func_args, ret);
175});
176
177TVM_REGISTER_GLOBAL("target.GenericFuncGetPackedFunc").set_body([](TVMArgs args, TVMRetValue* ret) {
178 GenericFunc generic_func = args[0];
179 *ret = generic_func.GetPacked();
180});
181
182} // namespace tvm
183