1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/memory_types.h"
16
17#include <utility>
18
19#include "tensorflow/core/framework/device_factory.h"
20#include "tensorflow/core/framework/memory_types.h"
21#include "tensorflow/core/framework/node_def_builder.h"
22#include "tensorflow/core/graph/node_builder.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/lib/gtl/map_util.h"
25#include "tensorflow/core/lib/hash/hash.h"
26#include "tensorflow/core/platform/types.h"
27#include "tensorflow/core/util/dump_graph.h"
28
29namespace tensorflow {
30
31struct Endpoint {
32 int node_id;
33 int output_index;
34};
35
36struct EndpointHash {
37 uint32 operator()(const Endpoint& x) const {
38 return Hash32(reinterpret_cast<const char*>(&x.node_id), sizeof(int),
39 x.output_index);
40 }
41};
42
43struct EndpointEq {
44 uint32 operator()(const Endpoint& x, const Endpoint& y) const {
45 return (x.node_id == y.node_id) && (x.output_index == y.output_index);
46 }
47};
48
49static Status ProcessMemoryTypes(
50 const DeviceType& device_type, const Graph* g,
51 const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
52 if (device_type != DEVICE_GPU &&
53 !DeviceFactory::IsPluggableDevice(device_type.type_string())) {
54 // On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always compatible.
55 return OkStatus();
56 }
57 // For GPU, HOST_MEMORY and DEVICE_MEMORY is not compatible. I.e., a
58 // conversion/transfer must be done.
59 //
60 // {node id, slot id} -> memory type.
61 typedef std::unordered_map<Endpoint, MemoryType, EndpointHash, EndpointEq>
62 MemTypeMap;
63 MemTypeMap inp;
64 MemTypeMap out;
65 MemoryTypeVector inp_mvec;
66 MemoryTypeVector out_mvec;
67 for (const Node* n : g->nodes()) {
68 TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type,
69 n->def(), &inp_mvec, &out_mvec));
70 for (size_t i = 0; i < inp_mvec.size(); ++i) {
71 VLOG(2) << "inp mvec " << n->id() << " " << i << " " << inp_mvec[i];
72 inp[{n->id(), static_cast<int>(i)}] = inp_mvec[i];
73 }
74 for (size_t i = 0; i < out_mvec.size(); ++i) {
75 VLOG(2) << "out mvec " << n->id() << " " << i << " " << out_mvec[i];
76 out[{n->id(), static_cast<int>(i)}] = out_mvec[i];
77 }
78 }
79 for (const Edge* e : g->edges()) {
80 if (e->IsControlEdge()) {
81 continue;
82 }
83 MemoryType sm = gtl::FindWithDefault(out, {e->src()->id(), e->src_output()},
84 DEVICE_MEMORY);
85 MemoryType dm = gtl::FindWithDefault(inp, {e->dst()->id(), e->dst_input()},
86 DEVICE_MEMORY);
87 VLOG(1) << e->src()->id() << ":" << e->src_output() << " -> "
88 << e->dst()->id() << ":" << e->dst_input() << ": " << sm << " -> "
89 << dm;
90 TF_RETURN_IF_ERROR(fn(e, sm, dm));
91 }
92 return OkStatus();
93}
94
95Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) {
96 return ProcessMemoryTypes(
97 device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) {
98 if (sm == dm) {
99 return OkStatus();
100 }
101 return errors::Internal("Memory type mismatch (", sm, " ", dm,
102 ") between :", e->src()->id(), ":",
103 e->src_output(), " and ", e->dst()->id(), ":",
104 e->dst_input(), " : from ",
105 FormatNodeForError(*e->src()), " to ",
106 FormatNodeForError(*e->dst()));
107 });
108}
109
110// Given an Edge whose two endpoints have different memory types and
111// are gonna to insert a pair of HostSend/Recv or Send/HostRecv nodes,
112// GetTensorName() returns a unique string that we can use as part of
113// the rendezvous key. The return string is guaranteed to be unique
114// within this process. That is sufficient because EnsureMemoryTypes
115// is only used on a TensorFlow graph that is gonna to be executed in
116// a single tf device (hence within a single process).
117static string GetTensorName(const Edge* edge) {
118 static std::atomic<int64_t> counter(0);
119 return strings::StrCat("memtype_", counter.fetch_add(1), "_",
120 edge->src()->name());
121}
122
123static Node* Send(Graph* g, const string& tensor_name,
124 const string& device_name, bool host, const Edge* edge) {
125 Node* ret;
126 TF_CHECK_OK(NodeBuilder(g->NewName("n"), host ? "_HostSend" : "_Send")
127 .Input(edge->src(), edge->src_output())
128 .Attr("tensor_name", tensor_name)
129 .Attr("send_device", device_name)
130 .Attr("send_device_incarnation", 0) // Do not care.
131 .Attr("recv_device", device_name)
132 .Attr("_hostmem_sendrecv", true)
133 .Attr("_src", edge->src()->name())
134 .Attr("_dst", edge->dst()->name())
135 .Finalize(g, &ret));
136 return ret;
137}
138
139static Node* Recv(Graph* g, const string& tensor_name,
140 const string& device_name, bool host, const Edge* edge) {
141 Node* ret;
142 TF_CHECK_OK(
143 NodeBuilder(g->NewName("n"), host ? "_HostRecv" : "_Recv")
144 .Attr("tensor_type", edge->src()->output_type(edge->src_output()))
145 .Attr("tensor_name", tensor_name)
146 .Attr("send_device", device_name)
147 .Attr("send_device_incarnation", 0)
148 .Attr("recv_device", device_name)
149 .Attr("_hostmem_sendrecv", true)
150 .Attr("_src", edge->src()->name())
151 .Attr("_dst", edge->dst()->name())
152 .Finalize(g, &ret));
153 return ret;
154}
155
156Status EnsureMemoryTypes(const DeviceType& device_type,
157 const string& device_name, Graph* g) {
158 struct Item {
159 const Edge* edge;
160 MemoryType sm;
161 MemoryType dm;
162 };
163 std::vector<Item> edges;
164 TF_RETURN_IF_ERROR(ProcessMemoryTypes(
165 device_type, g, [&edges](const Edge* e, MemoryType sm, MemoryType dm) {
166 if (sm == dm) {
167 return OkStatus();
168 }
169 if (((sm == HOST_MEMORY) && (dm == DEVICE_MEMORY)) ||
170 ((sm == DEVICE_MEMORY) && (dm == HOST_MEMORY))) {
171 edges.push_back({e, sm, dm});
172 return OkStatus();
173 }
174 return errors::Internal("Unexpected memory type pair on an edge: ", sm,
175 " vs. ", dm);
176 }));
177
178 // edges contains edges in 'g' that memtype is not
179 // compatible. Therefore, if we found any, we need to insert
180 // HostSend/Recv and Send/HostRecv pairs. recv_nodes records all
181 // nodes we added so that we don't copy the same tensor more than
182 // once.
183 if (!edges.empty()) {
184 std::unordered_map<Endpoint, Node*, EndpointHash, EndpointEq> recv_nodes;
185 for (const auto& item : edges) {
186 const Edge* e = item.edge;
187 const bool has_ref = IsRefType(e->src()->output_type(e->src_output()));
188 Node* recv = nullptr;
189 Endpoint key{e->src()->id(), e->src_output()};
190 auto iter = recv_nodes.find(key);
191 if (iter == recv_nodes.end()) {
192 const string tensor_name = GetTensorName(e);
193 Node* send =
194 Send(g, tensor_name, device_name, (item.sm == HOST_MEMORY), e);
195 recv = Recv(g, tensor_name, device_name, (item.dm == HOST_MEMORY), e);
196 if (!has_ref) {
197 // We only cache if there is no ref is involved.
198 recv_nodes[key] = recv;
199 }
200 g->AddControlEdge(send, recv);
201 } else {
202 recv = iter->second;
203 }
204 g->AddEdge(recv, 0, e->dst(), e->dst_input());
205 g->RemoveEdge(e);
206 }
207 }
208
209 if (VLOG_IS_ON(2)) {
210 VLOG(2) << "Dumped graph after EnsureMemoryTypes to "
211 << DumpGraphToFile("EnsureMemoryTypes", *g);
212 }
213
214 return ValidateMemoryTypes(device_type, g);
215}
216
217Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g,
218 const Node* n, int index, MemoryType* memory_type) {
219 MemoryTypeVector inp_mvec;
220 MemoryTypeVector out_mvec;
221 TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, n->def(),
222 &inp_mvec, &out_mvec));
223 if (out_mvec.size() <= index) {
224 return errors::Internal("Trying to get the memory type for ", index,
225 "'th output of node ", FormatNodeForError(*n),
226 " that has only ", out_mvec.size(), " outputs");
227 }
228 *memory_type = out_mvec[index];
229 return OkStatus();
230}
231
232} // end namespace tensorflow
233