1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <memory> |
17 | #include <unordered_set> |
18 | |
19 | #include "tensorflow/core/framework/op.h" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/lib/core/errors.h" |
22 | #include "tensorflow/core/platform/env.h" |
23 | #include "tensorflow/core/platform/mem.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | namespace { |
28 | |
29 | struct Library { |
30 | void* handle = nullptr; |
31 | OpList op_list; |
32 | }; |
33 | |
34 | } // namespace |
35 | |
36 | // Load a dynamic library. |
37 | // On success, returns the handle to library in result, copies the serialized |
38 | // OpList of OpDefs registered in the library to *buf and the length to *len, |
39 | // and returns OK from the function. Otherwise return nullptr in result |
40 | // and an error status from the function, leaving buf and len untouched. |
41 | // |
42 | // If `library_filename` has already been loaded, we return a cached handle |
43 | // and OpList. Ops and kernels are registered as globals when a library is |
44 | // loaded for the first time. Without caching, every subsequent load would not |
45 | // perform initialization again, so the OpList would be empty. |
46 | Status LoadDynamicLibrary(const char* library_filename, void** result, |
47 | const void** buf, size_t* len) { |
48 | static mutex mu(LINKER_INITIALIZED); |
49 | static std::unordered_map<string, Library> loaded_libs; |
50 | Env* env = Env::Default(); |
51 | Library library; |
52 | std::unordered_set<string> seen_op_names; |
53 | { |
54 | mutex_lock lock(mu); |
55 | if (loaded_libs.find(library_filename) != loaded_libs.end()) { |
56 | library = loaded_libs[library_filename]; |
57 | } else { |
58 | Status s = OpRegistry::Global()->ProcessRegistrations(); |
59 | if (!s.ok()) { |
60 | return s; |
61 | } |
62 | TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher( |
63 | [&library, &seen_op_names](const Status& s, |
64 | const OpDef& opdef) -> Status { |
65 | if (errors::IsAlreadyExists(s)) { |
66 | if (seen_op_names.find(opdef.name()) == seen_op_names.end()) { |
67 | // Over writing a registration of an op not in this custom op |
68 | // library. Treat this as not an error. |
69 | return OkStatus(); |
70 | } |
71 | } |
72 | if (s.ok()) { |
73 | *library.op_list.add_op() = opdef; |
74 | seen_op_names.insert(opdef.name()); |
75 | } |
76 | return s; |
77 | })); |
78 | OpRegistry::Global()->DeferRegistrations(); |
79 | s = env->LoadDynamicLibrary(library_filename, &library.handle); |
80 | if (s.ok()) { |
81 | s = OpRegistry::Global()->ProcessRegistrations(); |
82 | } |
83 | if (!s.ok()) { |
84 | OpRegistry::Global()->ClearDeferredRegistrations(); |
85 | TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr)); |
86 | return s; |
87 | } |
88 | TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr)); |
89 | |
90 | loaded_libs[library_filename] = library; |
91 | } |
92 | } |
93 | string str; |
94 | library.op_list.SerializeToString(&str); |
95 | char* str_buf = reinterpret_cast<char*>(port::Malloc(str.length())); |
96 | memcpy(str_buf, str.data(), str.length()); |
97 | *buf = str_buf; |
98 | *len = str.length(); |
99 | |
100 | *result = library.handle; |
101 | return OkStatus(); |
102 | } |
103 | |
104 | } // namespace tensorflow |
105 | |