1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_BASE_PLACEHOLDERBINDINGS_H |
17 | #define GLOW_BASE_PLACEHOLDERBINDINGS_H |
18 | |
19 | #include "glow/ExecutionContext/TraceEvents.h" |
20 | #include "glow/Graph/Graph.h" |
21 | #include "llvm/ADT/ArrayRef.h" |
22 | |
23 | #include <list> |
24 | #include <unordered_map> |
25 | |
26 | namespace glow { |
27 | |
28 | class Tensor; |
29 | class Placeholder; |
30 | |
31 | /// This class provides a mapping between some graph nodes, which are a symbolic |
32 | /// representation of some computation, and concrete tensors that represent the |
33 | /// inputs and outputs to the graph. The PlaceholderBindings owns the tensors |
34 | /// and the graph uses these values as runtime. This is useful for the |
35 | /// multi-threaded execution of code, where each thread has a different |
36 | /// execution context. The difference between this class and a regular map is |
37 | /// that the PlaceholderBindings owns the Tensors (not only the pointers) and |
38 | /// manages their lifetime. |
39 | class PlaceholderBindings final { |
40 | public: |
41 | /// Maps placeholders to the tensors that back them. |
42 | using PlaceholderMap = std::unordered_map<Placeholder *, Tensor>; |
43 | using PlaceholderMapIterator = PlaceholderMap::iterator; |
44 | |
45 | private: |
46 | /// Maps Placeholders to Tensors. |
47 | PlaceholderMap map_; |
48 | |
49 | public: |
50 | /// \returns true if \p A and \p B contain the same Placeholders mapped to |
51 | /// equivalent Tensors. \p allowedError is used when comparing each |
52 | /// Placeholder's backing payload data. |
53 | static bool compare(const PlaceholderBindings *A, |
54 | const PlaceholderBindings *B, |
55 | float allowedError = 0.0001); |
56 | |
57 | /// \returns the tensor that corresponds to Placeholder \p P or Null if the |
58 | /// tensor is not found. |
59 | Tensor *get(Placeholder *P); |
60 | const Tensor *get(Placeholder *P) const; |
61 | |
62 | /// \returns the Placeholder named \name or null of the Placeholder is not |
63 | /// found. Note that this uses a linear search path. If you want to seatch by |
64 | /// name more quickly, consider building a map yourself. |
65 | Placeholder *getPlaceholderByNameSlow(llvm::StringRef name) const; |
66 | |
67 | /// Inserts the Placeholder-Tensor pair. This takes ownership of the Tensor. |
68 | PlaceholderMapIterator insert(Placeholder *P, Tensor &&T); |
69 | |
70 | /// Copy values from this PlaceholderBindings to another, \p dst, by \p name. |
71 | /// This is useful when trained weights need to be transferred between |
72 | /// bindings of two different modules. |
73 | void copyToTarget(llvm::StringRef name, PlaceholderBindings &dst); |
74 | |
75 | /// Transfer all trainable weights to target PlaceholderBindings \p dst. |
76 | void copyTrainableWeightsTo(PlaceholderBindings &dst); |
77 | |
78 | /// Allocates a tensor to back the placeholder \p P. The new tensor has the |
79 | /// type of P. |
80 | Tensor *allocate(Placeholder *P); |
81 | |
82 | /// Allocates zero-initialized backing tensors to all placeholders in \p lst |
83 | /// that are not currently allocated in the bindings. |
84 | /// \returns the number of tensors that were allocated. |
85 | unsigned allocate(const std::list<Placeholder *> &lst); |
86 | |
87 | /// \returns the first placeholder in \p list that is not allocated by this |
88 | /// bindings. This method returns null if all placeholders in the list are |
89 | /// allocated. |
90 | Placeholder *getFirstUnallocated(const std::list<Placeholder *> &lst) const; |
91 | |
92 | /// \returns True if \p P is a registered Placeholder. |
93 | size_t count(Placeholder *P) const; |
94 | |
95 | /// Deletes all tensors and clears the mapping between Placeholders and |
96 | /// tensors. |
97 | void clear(); |
98 | |
99 | /// Removes the Tensor backing Placeholder \p P; |
100 | /// \p P must be a valid Placeholder registered in the bindings. |
101 | void erase(Placeholder *P); |
102 | |
103 | /// Removes the existing Tensor backing Placeholder \p P; Bind \p T to \P. |
104 | /// \p P must be a valid Placeholder registered in the bindings. |
105 | void update(Placeholder *P, Tensor &&T); |
106 | |
107 | /// \returns a copy of the PlaceholderBindings, with each placeholder mapped |
108 | /// to a new Tensor, with their own memory. |
109 | PlaceholderBindings clone() const; |
110 | |
111 | /// \returns a copy of the PlaceholderBindings, with each placeholder mapped |
112 | /// to a new Tensor, with their own memory. However instead of the current |
113 | /// Placeholders in the current mapping, use the Placeholder with the same |
114 | /// name found in \p newPHs. |
115 | PlaceholderBindings clone(const PlaceholderList &newPHs) const; |
116 | |
117 | /// \returns the mapping between placeholder to tensors. |
118 | PlaceholderMap &pairs() { return map_; } |
119 | const PlaceholderMap &pairs() const { return map_; } |
120 | |
121 | /// \returns the size in bytes of allocated Tensors owned by |
122 | /// PlaceholderBindings. |
123 | uint64_t getDataSize() const; |
124 | |
125 | /// Copies all Device Resident Tensors back to the host. |
126 | void ensureOnHost() { |
127 | for (auto &ph : map_) { |
128 | ph.second.ensureOnHost(); |
129 | } |
130 | } |
131 | |
132 | PlaceholderBindings() = default; |
133 | |
134 | /// Construct the PlaceholderBindings with an initial mapping between \p |
135 | /// placeholders and \p inputs; |
136 | PlaceholderBindings(llvm::ArrayRef<Placeholder *> placeholders, |
137 | llvm::ArrayRef<Tensor *> inputs); |
138 | |
139 | PlaceholderBindings(PlaceholderBindings &&other) |
140 | : map_(std::move(other.map_)) {} |
141 | |
142 | ~PlaceholderBindings() { clear(); }; |
143 | |
144 | // Don't copy this class around. |
145 | PlaceholderBindings(const PlaceholderBindings &other) = delete; |
146 | PlaceholderBindings &operator=(const PlaceholderBindings &other) = delete; |
147 | }; |
148 | |
149 | } // namespace glow |
150 | |
151 | #endif // GLOW_BASE_PLACEHOLDERBINDINGS_H |
152 | |