1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/tsl/framework/allocator_registry.h"
17
18#include <string>
19
20#include "tensorflow/tsl/platform/logging.h"
21
22namespace tsl {
23
24// static
25AllocatorFactoryRegistry* AllocatorFactoryRegistry::singleton() {
26 static AllocatorFactoryRegistry* singleton = new AllocatorFactoryRegistry;
27 return singleton;
28}
29
30const AllocatorFactoryRegistry::FactoryEntry*
31AllocatorFactoryRegistry::FindEntry(const string& name, int priority) const {
32 for (auto& entry : factories_) {
33 if (!name.compare(entry.name) && priority == entry.priority) {
34 return &entry;
35 }
36 }
37 return nullptr;
38}
39
40void AllocatorFactoryRegistry::Register(const char* source_file,
41 int source_line, const string& name,
42 int priority,
43 AllocatorFactory* factory) {
44 mutex_lock l(mu_);
45 CHECK(!first_alloc_made_) << "Attempt to register an AllocatorFactory "
46 << "after call to GetAllocator()";
47 CHECK(!name.empty()) << "Need a valid name for Allocator";
48 CHECK_GE(priority, 0) << "Priority needs to be non-negative";
49
50 const FactoryEntry* existing = FindEntry(name, priority);
51 if (existing != nullptr) {
52 // Duplicate registration is a hard failure.
53 LOG(FATAL) << "New registration for AllocatorFactory with name=" << name
54 << " priority=" << priority << " at location " << source_file
55 << ":" << source_line
56 << " conflicts with previous registration at location "
57 << existing->source_file << ":" << existing->source_line;
58 }
59
60 FactoryEntry entry;
61 entry.source_file = source_file;
62 entry.source_line = source_line;
63 entry.name = name;
64 entry.priority = priority;
65 entry.factory.reset(factory);
66 factories_.push_back(std::move(entry));
67}
68
69Allocator* AllocatorFactoryRegistry::GetAllocator() {
70 mutex_lock l(mu_);
71 first_alloc_made_ = true;
72 FactoryEntry* best_entry = nullptr;
73 for (auto& entry : factories_) {
74 if (best_entry == nullptr) {
75 best_entry = &entry;
76 } else if (entry.priority > best_entry->priority) {
77 best_entry = &entry;
78 }
79 }
80 if (best_entry) {
81 if (!best_entry->allocator) {
82 best_entry->allocator.reset(best_entry->factory->CreateAllocator());
83 }
84 return best_entry->allocator.get();
85 } else {
86 LOG(FATAL) << "No registered CPU AllocatorFactory";
87 return nullptr;
88 }
89}
90
91SubAllocator* AllocatorFactoryRegistry::GetSubAllocator(int numa_node) {
92 mutex_lock l(mu_);
93 first_alloc_made_ = true;
94 FactoryEntry* best_entry = nullptr;
95 for (auto& entry : factories_) {
96 if (best_entry == nullptr) {
97 best_entry = &entry;
98 } else if (best_entry->factory->NumaEnabled()) {
99 if (entry.factory->NumaEnabled() &&
100 (entry.priority > best_entry->priority)) {
101 best_entry = &entry;
102 }
103 } else {
104 DCHECK(!best_entry->factory->NumaEnabled());
105 if (entry.factory->NumaEnabled() ||
106 (entry.priority > best_entry->priority)) {
107 best_entry = &entry;
108 }
109 }
110 }
111 if (best_entry) {
112 int index = 0;
113 if (numa_node != port::kNUMANoAffinity) {
114 CHECK_LE(numa_node, port::NUMANumNodes());
115 index = 1 + numa_node;
116 }
117 if (best_entry->sub_allocators.size() < static_cast<size_t>(index + 1)) {
118 best_entry->sub_allocators.resize(index + 1);
119 }
120 if (!best_entry->sub_allocators[index].get()) {
121 best_entry->sub_allocators[index].reset(
122 best_entry->factory->CreateSubAllocator(numa_node));
123 }
124 return best_entry->sub_allocators[index].get();
125 } else {
126 LOG(FATAL) << "No registered CPU AllocatorFactory";
127 return nullptr;
128 }
129}
130
131} // namespace tsl
132