1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Graph/PlaceholderBindings.h" |
18 | #include "glow/Base/Tensor.h" |
19 | #include "glow/Graph/Nodes.h" |
20 | #include "glow/Support/TensorPool.h" |
21 | |
22 | #include <glog/logging.h> |
23 | |
24 | using namespace glow; |
25 | |
26 | bool PlaceholderBindings::compare(const PlaceholderBindings *A, |
27 | const PlaceholderBindings *B, |
28 | float allowedError) { |
29 | // Trivial cases. |
30 | if (!A && !B) { |
31 | return true; |
32 | } else if ((!A && B) || (A && !B)) { |
33 | return false; |
34 | } |
35 | |
36 | // Get the map of Placeholder -> Tensor mappings within the two |
37 | // PlaceholderBindingss. |
38 | const PlaceholderBindings::PlaceholderMap &phMapA = A->pairs(); |
39 | const PlaceholderBindings::PlaceholderMap &phMapB = B->pairs(); |
40 | |
41 | // If the maps have different sizes, the PlaceholderBindingss cannot match. |
42 | if (phMapA.size() != phMapB.size()) { |
43 | return false; |
44 | } |
45 | |
46 | // Iterate through all Placeholders in A, look up the corresponding tensors |
47 | // in A and B, and check if they match. If not, return false. |
48 | for (const auto &phTensorPair : phMapA) { |
49 | auto *placeholder = phTensorPair.first; |
50 | const auto &tensorA = phTensorPair.second; |
51 | const auto *tensorB = |
52 | B->get(B->getPlaceholderByNameSlow(placeholder->getName())); |
53 | |
54 | if (!tensorB || !tensorA.isEqual(*tensorB, allowedError, |
55 | /* verbose */ true)) { |
56 | return false; |
57 | } |
58 | } |
59 | |
60 | return true; |
61 | } |
62 | |
63 | const Tensor *PlaceholderBindings::get(Placeholder *P) const { |
64 | auto it = map_.find(P); |
65 | if (it == map_.end()) { |
66 | return nullptr; |
67 | } |
68 | |
69 | return &it->second; |
70 | } |
71 | |
72 | Tensor *PlaceholderBindings::get(Placeholder *P) { |
73 | auto it = map_.find(P); |
74 | if (it == map_.end()) { |
75 | return nullptr; |
76 | } |
77 | |
78 | return &it->second; |
79 | } |
80 | |
81 | Placeholder * |
82 | PlaceholderBindings::getPlaceholderByNameSlow(llvm::StringRef name) const { |
83 | for (auto &kv : map_) { |
84 | if (kv.first->getName() == name) { |
85 | return kv.first; |
86 | } |
87 | } |
88 | return nullptr; |
89 | } |
90 | |
91 | PlaceholderBindings::PlaceholderMapIterator |
92 | PlaceholderBindings::insert(Placeholder *P, Tensor &&T) { |
93 | DCHECK(T.getType().isEqual(*P->getType())) |
94 | << "Placeholder " << P->getName().str() << " has type " |
95 | << P->getType()->toString() << " but Tensor has type " |
96 | << T.getType().toString() << "\n" ; |
97 | auto ret = map_.emplace(P, std::move(T)); |
98 | DCHECK(ret.second) << "Placeholder with name \"" << P->getName().str() |
99 | << "\" already registered" ; |
100 | return ret.first; |
101 | } |
102 | |
103 | void PlaceholderBindings::copyToTarget(llvm::StringRef name, |
104 | PlaceholderBindings &dst) { |
105 | auto *srcPH = this->getPlaceholderByNameSlow(name); |
106 | DCHECK(srcPH) << name.str() << " does not exist in source" ; |
107 | auto *dstPH = dst.getPlaceholderByNameSlow(name); |
108 | DCHECK(dstPH) << name.str() << " does not exist in destination" ; |
109 | dst.erase(dstPH); |
110 | dst.insert(dstPH, this->get(srcPH)->clone()); |
111 | } |
112 | |
113 | void PlaceholderBindings::copyTrainableWeightsTo(PlaceholderBindings &dst) { |
114 | for (auto &PH : pairs()) { |
115 | if (PH.first->isTraining()) { |
116 | copyToTarget(PH.first->getName(), dst); |
117 | } |
118 | } |
119 | } |
120 | |
121 | size_t PlaceholderBindings::count(Placeholder *P) const { |
122 | return map_.count(P); |
123 | } |
124 | |
125 | void PlaceholderBindings::clear() { |
126 | // Delete all of the tensors that are owned by the bindings. |
127 | for (auto &PH : map_) { |
128 | if (auto *tensorPool = PH.second.getOwningPool()) { |
129 | tensorPool->reclaim(std::move(PH.second)); |
130 | } |
131 | } |
132 | |
133 | map_.clear(); |
134 | } |
135 | |
136 | void PlaceholderBindings::erase(Placeholder *P) { |
137 | auto &T = map_[P]; |
138 | if (auto *tensorPool = T.getOwningPool()) { |
139 | tensorPool->reclaim(std::move(T)); |
140 | } |
141 | map_.erase(P); |
142 | } |
143 | |
144 | PlaceholderBindings PlaceholderBindings::clone() const { |
145 | PlaceholderBindings cloned; |
146 | for (auto &PH : map_) { |
147 | Placeholder *P = PH.first; |
148 | cloned.insert(P, PH.second.clone()); |
149 | } |
150 | |
151 | return cloned; |
152 | } |
153 | |
154 | PlaceholderBindings |
155 | PlaceholderBindings::clone(const PlaceholderList &newPHs) const { |
156 | PlaceholderBindings cloned; |
157 | for (const auto &PH : map_) { |
158 | Placeholder *P = PH.first; |
159 | const Tensor &T = PH.second; |
160 | auto newPHIt = std::find_if(newPHs.begin(), newPHs.end(), [=](auto *newPH) { |
161 | return newPH->getName() == P->getName(); |
162 | }); |
163 | DCHECK(newPHIt != newPHs.end()) |
164 | << "Expected to find corresponding PH by name " << P->getName().data(); |
165 | cloned.insert(*newPHIt, T.clone()); |
166 | } |
167 | |
168 | return cloned; |
169 | } |
170 | |
171 | Tensor *PlaceholderBindings::allocate(Placeholder *P) { |
172 | DCHECK(!map_.count(P)) << "Placeholder with name \"" << P->getName().str() |
173 | << "\" already registered" ; |
174 | Tensor T(P->getType()); |
175 | |
176 | // If this Tensor needs to start zeroed, then zero it. |
177 | if (P->allocZero()) { |
178 | T.zero(); |
179 | } |
180 | |
181 | auto ret = map_.emplace(P, std::move(T)); |
182 | return &ret.first->second; |
183 | } |
184 | |
185 | unsigned PlaceholderBindings::allocate(const std::list<Placeholder *> &lst) { |
186 | unsigned allocated = 0; |
187 | // For each placeholder in the list: |
188 | for (Placeholder *P : lst) { |
189 | // Don't allocate tensors for placeholders that are already allocated. |
190 | if (this->count(P)) { |
191 | continue; |
192 | } |
193 | |
194 | // Allocate a tensor to back P. |
195 | allocate(P); |
196 | allocated++; |
197 | } |
198 | return allocated; |
199 | } |
200 | |
201 | Placeholder *PlaceholderBindings::getFirstUnallocated( |
202 | const std::list<Placeholder *> &lst) const { |
203 | // For each placeholder in the list: |
204 | for (Placeholder *P : lst) { |
205 | // If we found an unallocated placeholder then return it. |
206 | if (!count(P)) |
207 | return P; |
208 | } |
209 | |
210 | return nullptr; |
211 | } |
212 | |
213 | uint64_t PlaceholderBindings::getDataSize() const { |
214 | uint64_t size = 0; |
215 | for (const auto &PH : map_) { |
216 | const auto &T = PH.second; |
217 | size += T.getSizeInBytes(); |
218 | } |
219 | return size; |
220 | } |
221 | |
222 | PlaceholderBindings::PlaceholderBindings( |
223 | llvm::ArrayRef<Placeholder *> placeholders, |
224 | llvm::ArrayRef<Tensor *> inputs) { |
225 | DCHECK_EQ(placeholders.size(), inputs.size()) |
226 | << "Invalid number of placeholders" ; |
227 | |
228 | for (size_t i = 0, e = placeholders.size(); i < e; i++) { |
229 | auto *orig = inputs[i]; |
230 | /// Create a reference to the original tensor and hand it to the |
231 | /// PlaceholderBindings. |
232 | Tensor ptrT = orig->getUnowned(); |
233 | insert(placeholders[i], std::move(ptrT)); |
234 | } |
235 | } |
236 | |