1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/tsl/framework/allocator_registry.h" |
17 | |
18 | #include <string> |
19 | |
20 | #include "tensorflow/tsl/platform/logging.h" |
21 | |
22 | namespace tsl { |
23 | |
24 | // static |
25 | AllocatorFactoryRegistry* AllocatorFactoryRegistry::singleton() { |
26 | static AllocatorFactoryRegistry* singleton = new AllocatorFactoryRegistry; |
27 | return singleton; |
28 | } |
29 | |
30 | const AllocatorFactoryRegistry::FactoryEntry* |
31 | AllocatorFactoryRegistry::FindEntry(const string& name, int priority) const { |
32 | for (auto& entry : factories_) { |
33 | if (!name.compare(entry.name) && priority == entry.priority) { |
34 | return &entry; |
35 | } |
36 | } |
37 | return nullptr; |
38 | } |
39 | |
40 | void AllocatorFactoryRegistry::Register(const char* source_file, |
41 | int source_line, const string& name, |
42 | int priority, |
43 | AllocatorFactory* factory) { |
44 | mutex_lock l(mu_); |
45 | CHECK(!first_alloc_made_) << "Attempt to register an AllocatorFactory " |
46 | << "after call to GetAllocator()" ; |
47 | CHECK(!name.empty()) << "Need a valid name for Allocator" ; |
48 | CHECK_GE(priority, 0) << "Priority needs to be non-negative" ; |
49 | |
50 | const FactoryEntry* existing = FindEntry(name, priority); |
51 | if (existing != nullptr) { |
52 | // Duplicate registration is a hard failure. |
53 | LOG(FATAL) << "New registration for AllocatorFactory with name=" << name |
54 | << " priority=" << priority << " at location " << source_file |
55 | << ":" << source_line |
56 | << " conflicts with previous registration at location " |
57 | << existing->source_file << ":" << existing->source_line; |
58 | } |
59 | |
60 | FactoryEntry entry; |
61 | entry.source_file = source_file; |
62 | entry.source_line = source_line; |
63 | entry.name = name; |
64 | entry.priority = priority; |
65 | entry.factory.reset(factory); |
66 | factories_.push_back(std::move(entry)); |
67 | } |
68 | |
69 | Allocator* AllocatorFactoryRegistry::GetAllocator() { |
70 | mutex_lock l(mu_); |
71 | first_alloc_made_ = true; |
72 | FactoryEntry* best_entry = nullptr; |
73 | for (auto& entry : factories_) { |
74 | if (best_entry == nullptr) { |
75 | best_entry = &entry; |
76 | } else if (entry.priority > best_entry->priority) { |
77 | best_entry = &entry; |
78 | } |
79 | } |
80 | if (best_entry) { |
81 | if (!best_entry->allocator) { |
82 | best_entry->allocator.reset(best_entry->factory->CreateAllocator()); |
83 | } |
84 | return best_entry->allocator.get(); |
85 | } else { |
86 | LOG(FATAL) << "No registered CPU AllocatorFactory" ; |
87 | return nullptr; |
88 | } |
89 | } |
90 | |
91 | SubAllocator* AllocatorFactoryRegistry::GetSubAllocator(int numa_node) { |
92 | mutex_lock l(mu_); |
93 | first_alloc_made_ = true; |
94 | FactoryEntry* best_entry = nullptr; |
95 | for (auto& entry : factories_) { |
96 | if (best_entry == nullptr) { |
97 | best_entry = &entry; |
98 | } else if (best_entry->factory->NumaEnabled()) { |
99 | if (entry.factory->NumaEnabled() && |
100 | (entry.priority > best_entry->priority)) { |
101 | best_entry = &entry; |
102 | } |
103 | } else { |
104 | DCHECK(!best_entry->factory->NumaEnabled()); |
105 | if (entry.factory->NumaEnabled() || |
106 | (entry.priority > best_entry->priority)) { |
107 | best_entry = &entry; |
108 | } |
109 | } |
110 | } |
111 | if (best_entry) { |
112 | int index = 0; |
113 | if (numa_node != port::kNUMANoAffinity) { |
114 | CHECK_LE(numa_node, port::NUMANumNodes()); |
115 | index = 1 + numa_node; |
116 | } |
117 | if (best_entry->sub_allocators.size() < static_cast<size_t>(index + 1)) { |
118 | best_entry->sub_allocators.resize(index + 1); |
119 | } |
120 | if (!best_entry->sub_allocators[index].get()) { |
121 | best_entry->sub_allocators[index].reset( |
122 | best_entry->factory->CreateSubAllocator(numa_node)); |
123 | } |
124 | return best_entry->sub_allocators[index].get(); |
125 | } else { |
126 | LOG(FATAL) << "No registered CPU AllocatorFactory" ; |
127 | return nullptr; |
128 | } |
129 | } |
130 | |
131 | } // namespace tsl |
132 | |