1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_BASE_PLACEHOLDERBINDINGS_H
17#define GLOW_BASE_PLACEHOLDERBINDINGS_H
18
19#include "glow/ExecutionContext/TraceEvents.h"
20#include "glow/Graph/Graph.h"
21#include "llvm/ADT/ArrayRef.h"
22
23#include <list>
24#include <unordered_map>
25
26namespace glow {
27
28class Tensor;
29class Placeholder;
30
31/// This class provides a mapping between some graph nodes, which are a symbolic
32/// representation of some computation, and concrete tensors that represent the
33/// inputs and outputs to the graph. The PlaceholderBindings owns the tensors
34/// and the graph uses these values as runtime. This is useful for the
35/// multi-threaded execution of code, where each thread has a different
36/// execution context. The difference between this class and a regular map is
37/// that the PlaceholderBindings owns the Tensors (not only the pointers) and
38/// manages their lifetime.
39class PlaceholderBindings final {
40public:
41 /// Maps placeholders to the tensors that back them.
42 using PlaceholderMap = std::unordered_map<Placeholder *, Tensor>;
43 using PlaceholderMapIterator = PlaceholderMap::iterator;
44
45private:
46 /// Maps Placeholders to Tensors.
47 PlaceholderMap map_;
48
49public:
50 /// \returns true if \p A and \p B contain the same Placeholders mapped to
51 /// equivalent Tensors. \p allowedError is used when comparing each
52 /// Placeholder's backing payload data.
53 static bool compare(const PlaceholderBindings *A,
54 const PlaceholderBindings *B,
55 float allowedError = 0.0001);
56
57 /// \returns the tensor that corresponds to Placeholder \p P or Null if the
58 /// tensor is not found.
59 Tensor *get(Placeholder *P);
60 const Tensor *get(Placeholder *P) const;
61
62 /// \returns the Placeholder named \name or null of the Placeholder is not
63 /// found. Note that this uses a linear search path. If you want to seatch by
64 /// name more quickly, consider building a map yourself.
65 Placeholder *getPlaceholderByNameSlow(llvm::StringRef name) const;
66
67 /// Inserts the Placeholder-Tensor pair. This takes ownership of the Tensor.
68 PlaceholderMapIterator insert(Placeholder *P, Tensor &&T);
69
70 /// Copy values from this PlaceholderBindings to another, \p dst, by \p name.
71 /// This is useful when trained weights need to be transferred between
72 /// bindings of two different modules.
73 void copyToTarget(llvm::StringRef name, PlaceholderBindings &dst);
74
75 /// Transfer all trainable weights to target PlaceholderBindings \p dst.
76 void copyTrainableWeightsTo(PlaceholderBindings &dst);
77
78 /// Allocates a tensor to back the placeholder \p P. The new tensor has the
79 /// type of P.
80 Tensor *allocate(Placeholder *P);
81
82 /// Allocates zero-initialized backing tensors to all placeholders in \p lst
83 /// that are not currently allocated in the bindings.
84 /// \returns the number of tensors that were allocated.
85 unsigned allocate(const std::list<Placeholder *> &lst);
86
87 /// \returns the first placeholder in \p list that is not allocated by this
88 /// bindings. This method returns null if all placeholders in the list are
89 /// allocated.
90 Placeholder *getFirstUnallocated(const std::list<Placeholder *> &lst) const;
91
92 /// \returns True if \p P is a registered Placeholder.
93 size_t count(Placeholder *P) const;
94
95 /// Deletes all tensors and clears the mapping between Placeholders and
96 /// tensors.
97 void clear();
98
99 /// Removes the Tensor backing Placeholder \p P;
100 /// \p P must be a valid Placeholder registered in the bindings.
101 void erase(Placeholder *P);
102
103 /// Removes the existing Tensor backing Placeholder \p P; Bind \p T to \P.
104 /// \p P must be a valid Placeholder registered in the bindings.
105 void update(Placeholder *P, Tensor &&T);
106
107 /// \returns a copy of the PlaceholderBindings, with each placeholder mapped
108 /// to a new Tensor, with their own memory.
109 PlaceholderBindings clone() const;
110
111 /// \returns a copy of the PlaceholderBindings, with each placeholder mapped
112 /// to a new Tensor, with their own memory. However instead of the current
113 /// Placeholders in the current mapping, use the Placeholder with the same
114 /// name found in \p newPHs.
115 PlaceholderBindings clone(const PlaceholderList &newPHs) const;
116
117 /// \returns the mapping between placeholder to tensors.
118 PlaceholderMap &pairs() { return map_; }
119 const PlaceholderMap &pairs() const { return map_; }
120
121 /// \returns the size in bytes of allocated Tensors owned by
122 /// PlaceholderBindings.
123 uint64_t getDataSize() const;
124
125 /// Copies all Device Resident Tensors back to the host.
126 void ensureOnHost() {
127 for (auto &ph : map_) {
128 ph.second.ensureOnHost();
129 }
130 }
131
132 PlaceholderBindings() = default;
133
134 /// Construct the PlaceholderBindings with an initial mapping between \p
135 /// placeholders and \p inputs;
136 PlaceholderBindings(llvm::ArrayRef<Placeholder *> placeholders,
137 llvm::ArrayRef<Tensor *> inputs);
138
139 PlaceholderBindings(PlaceholderBindings &&other)
140 : map_(std::move(other.map_)) {}
141
142 ~PlaceholderBindings() { clear(); };
143
144 // Don't copy this class around.
145 PlaceholderBindings(const PlaceholderBindings &other) = delete;
146 PlaceholderBindings &operator=(const PlaceholderBindings &other) = delete;
147};
148
149} // namespace glow
150
151#endif // GLOW_BASE_PLACEHOLDERBINDINGS_H
152