1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/memory_types.h" |
16 | |
17 | #include <utility> |
18 | |
19 | #include "tensorflow/core/framework/device_factory.h" |
20 | #include "tensorflow/core/framework/memory_types.h" |
21 | #include "tensorflow/core/framework/node_def_builder.h" |
22 | #include "tensorflow/core/graph/node_builder.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/lib/gtl/map_util.h" |
25 | #include "tensorflow/core/lib/hash/hash.h" |
26 | #include "tensorflow/core/platform/types.h" |
27 | #include "tensorflow/core/util/dump_graph.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | struct Endpoint { |
32 | int node_id; |
33 | int output_index; |
34 | }; |
35 | |
36 | struct EndpointHash { |
37 | uint32 operator()(const Endpoint& x) const { |
38 | return Hash32(reinterpret_cast<const char*>(&x.node_id), sizeof(int), |
39 | x.output_index); |
40 | } |
41 | }; |
42 | |
43 | struct EndpointEq { |
44 | uint32 operator()(const Endpoint& x, const Endpoint& y) const { |
45 | return (x.node_id == y.node_id) && (x.output_index == y.output_index); |
46 | } |
47 | }; |
48 | |
49 | static Status ProcessMemoryTypes( |
50 | const DeviceType& device_type, const Graph* g, |
51 | const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) { |
52 | if (device_type != DEVICE_GPU && |
53 | !DeviceFactory::IsPluggableDevice(device_type.type_string())) { |
54 | // On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always compatible. |
55 | return OkStatus(); |
56 | } |
57 | // For GPU, HOST_MEMORY and DEVICE_MEMORY is not compatible. I.e., a |
58 | // conversion/transfer must be done. |
59 | // |
60 | // {node id, slot id} -> memory type. |
61 | typedef std::unordered_map<Endpoint, MemoryType, EndpointHash, EndpointEq> |
62 | MemTypeMap; |
63 | MemTypeMap inp; |
64 | MemTypeMap out; |
65 | MemoryTypeVector inp_mvec; |
66 | MemoryTypeVector out_mvec; |
67 | for (const Node* n : g->nodes()) { |
68 | TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, |
69 | n->def(), &inp_mvec, &out_mvec)); |
70 | for (size_t i = 0; i < inp_mvec.size(); ++i) { |
71 | VLOG(2) << "inp mvec " << n->id() << " " << i << " " << inp_mvec[i]; |
72 | inp[{n->id(), static_cast<int>(i)}] = inp_mvec[i]; |
73 | } |
74 | for (size_t i = 0; i < out_mvec.size(); ++i) { |
75 | VLOG(2) << "out mvec " << n->id() << " " << i << " " << out_mvec[i]; |
76 | out[{n->id(), static_cast<int>(i)}] = out_mvec[i]; |
77 | } |
78 | } |
79 | for (const Edge* e : g->edges()) { |
80 | if (e->IsControlEdge()) { |
81 | continue; |
82 | } |
83 | MemoryType sm = gtl::FindWithDefault(out, {e->src()->id(), e->src_output()}, |
84 | DEVICE_MEMORY); |
85 | MemoryType dm = gtl::FindWithDefault(inp, {e->dst()->id(), e->dst_input()}, |
86 | DEVICE_MEMORY); |
87 | VLOG(1) << e->src()->id() << ":" << e->src_output() << " -> " |
88 | << e->dst()->id() << ":" << e->dst_input() << ": " << sm << " -> " |
89 | << dm; |
90 | TF_RETURN_IF_ERROR(fn(e, sm, dm)); |
91 | } |
92 | return OkStatus(); |
93 | } |
94 | |
95 | Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) { |
96 | return ProcessMemoryTypes( |
97 | device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) { |
98 | if (sm == dm) { |
99 | return OkStatus(); |
100 | } |
101 | return errors::Internal("Memory type mismatch (" , sm, " " , dm, |
102 | ") between :" , e->src()->id(), ":" , |
103 | e->src_output(), " and " , e->dst()->id(), ":" , |
104 | e->dst_input(), " : from " , |
105 | FormatNodeForError(*e->src()), " to " , |
106 | FormatNodeForError(*e->dst())); |
107 | }); |
108 | } |
109 | |
110 | // Given an Edge whose two endpoints have different memory types and |
111 | // are gonna to insert a pair of HostSend/Recv or Send/HostRecv nodes, |
112 | // GetTensorName() returns a unique string that we can use as part of |
113 | // the rendezvous key. The return string is guaranteed to be unique |
114 | // within this process. That is sufficient because EnsureMemoryTypes |
115 | // is only used on a TensorFlow graph that is gonna to be executed in |
116 | // a single tf device (hence within a single process). |
117 | static string GetTensorName(const Edge* edge) { |
118 | static std::atomic<int64_t> counter(0); |
119 | return strings::StrCat("memtype_" , counter.fetch_add(1), "_" , |
120 | edge->src()->name()); |
121 | } |
122 | |
123 | static Node* Send(Graph* g, const string& tensor_name, |
124 | const string& device_name, bool host, const Edge* edge) { |
125 | Node* ret; |
126 | TF_CHECK_OK(NodeBuilder(g->NewName("n" ), host ? "_HostSend" : "_Send" ) |
127 | .Input(edge->src(), edge->src_output()) |
128 | .Attr("tensor_name" , tensor_name) |
129 | .Attr("send_device" , device_name) |
130 | .Attr("send_device_incarnation" , 0) // Do not care. |
131 | .Attr("recv_device" , device_name) |
132 | .Attr("_hostmem_sendrecv" , true) |
133 | .Attr("_src" , edge->src()->name()) |
134 | .Attr("_dst" , edge->dst()->name()) |
135 | .Finalize(g, &ret)); |
136 | return ret; |
137 | } |
138 | |
139 | static Node* Recv(Graph* g, const string& tensor_name, |
140 | const string& device_name, bool host, const Edge* edge) { |
141 | Node* ret; |
142 | TF_CHECK_OK( |
143 | NodeBuilder(g->NewName("n" ), host ? "_HostRecv" : "_Recv" ) |
144 | .Attr("tensor_type" , edge->src()->output_type(edge->src_output())) |
145 | .Attr("tensor_name" , tensor_name) |
146 | .Attr("send_device" , device_name) |
147 | .Attr("send_device_incarnation" , 0) |
148 | .Attr("recv_device" , device_name) |
149 | .Attr("_hostmem_sendrecv" , true) |
150 | .Attr("_src" , edge->src()->name()) |
151 | .Attr("_dst" , edge->dst()->name()) |
152 | .Finalize(g, &ret)); |
153 | return ret; |
154 | } |
155 | |
156 | Status EnsureMemoryTypes(const DeviceType& device_type, |
157 | const string& device_name, Graph* g) { |
158 | struct Item { |
159 | const Edge* edge; |
160 | MemoryType sm; |
161 | MemoryType dm; |
162 | }; |
163 | std::vector<Item> edges; |
164 | TF_RETURN_IF_ERROR(ProcessMemoryTypes( |
165 | device_type, g, [&edges](const Edge* e, MemoryType sm, MemoryType dm) { |
166 | if (sm == dm) { |
167 | return OkStatus(); |
168 | } |
169 | if (((sm == HOST_MEMORY) && (dm == DEVICE_MEMORY)) || |
170 | ((sm == DEVICE_MEMORY) && (dm == HOST_MEMORY))) { |
171 | edges.push_back({e, sm, dm}); |
172 | return OkStatus(); |
173 | } |
174 | return errors::Internal("Unexpected memory type pair on an edge: " , sm, |
175 | " vs. " , dm); |
176 | })); |
177 | |
178 | // edges contains edges in 'g' that memtype is not |
179 | // compatible. Therefore, if we found any, we need to insert |
180 | // HostSend/Recv and Send/HostRecv pairs. recv_nodes records all |
181 | // nodes we added so that we don't copy the same tensor more than |
182 | // once. |
183 | if (!edges.empty()) { |
184 | std::unordered_map<Endpoint, Node*, EndpointHash, EndpointEq> recv_nodes; |
185 | for (const auto& item : edges) { |
186 | const Edge* e = item.edge; |
187 | const bool has_ref = IsRefType(e->src()->output_type(e->src_output())); |
188 | Node* recv = nullptr; |
189 | Endpoint key{e->src()->id(), e->src_output()}; |
190 | auto iter = recv_nodes.find(key); |
191 | if (iter == recv_nodes.end()) { |
192 | const string tensor_name = GetTensorName(e); |
193 | Node* send = |
194 | Send(g, tensor_name, device_name, (item.sm == HOST_MEMORY), e); |
195 | recv = Recv(g, tensor_name, device_name, (item.dm == HOST_MEMORY), e); |
196 | if (!has_ref) { |
197 | // We only cache if there is no ref is involved. |
198 | recv_nodes[key] = recv; |
199 | } |
200 | g->AddControlEdge(send, recv); |
201 | } else { |
202 | recv = iter->second; |
203 | } |
204 | g->AddEdge(recv, 0, e->dst(), e->dst_input()); |
205 | g->RemoveEdge(e); |
206 | } |
207 | } |
208 | |
209 | if (VLOG_IS_ON(2)) { |
210 | VLOG(2) << "Dumped graph after EnsureMemoryTypes to " |
211 | << DumpGraphToFile("EnsureMemoryTypes" , *g); |
212 | } |
213 | |
214 | return ValidateMemoryTypes(device_type, g); |
215 | } |
216 | |
217 | Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, |
218 | const Node* n, int index, MemoryType* memory_type) { |
219 | MemoryTypeVector inp_mvec; |
220 | MemoryTypeVector out_mvec; |
221 | TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, n->def(), |
222 | &inp_mvec, &out_mvec)); |
223 | if (out_mvec.size() <= index) { |
224 | return errors::Internal("Trying to get the memory type for " , index, |
225 | "'th output of node " , FormatNodeForError(*n), |
226 | " that has only " , out_mvec.size(), " outputs" ); |
227 | } |
228 | *memory_type = out_mvec[index]; |
229 | return OkStatus(); |
230 | } |
231 | |
232 | } // end namespace tensorflow |
233 | |