1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
17#define TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
18
19#include <cstddef>
20#include <functional>
21#include <memory>
22#include <string>
23#include <unordered_map>
24#include <unordered_set>
25#include <vector>
26
27#include "tensorflow/core/debug/debug_node_key.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/graph/graph.h"
30#include "tensorflow/core/lib/core/status.h"
31#include "tensorflow/core/lib/gtl/array_slice.h"
32#include "tensorflow/core/platform/env.h"
33#include "tensorflow/core/util/event.pb.h"
34
35namespace tensorflow {
36
37Status ReadEventFromFile(const string& dump_file_path, Event* event);
38
39struct DebugWatchAndURLSpec {
40 DebugWatchAndURLSpec(const string& watch_key, const string& url,
41 const bool gated_grpc)
42 : watch_key(watch_key), url(url), gated_grpc(gated_grpc) {}
43
44 const string watch_key;
45 const string url;
46 const bool gated_grpc;
47};
48
49// TODO(cais): Put static functions and members in a namespace, not a class.
50class DebugIO {
51 public:
52 static const char* const kDebuggerPluginName;
53
54 static const char* const kCoreMetadataTag;
55 static const char* const kGraphTag;
56 static const char* const kHashTag;
57
58 static const char* const kFileURLScheme;
59 static const char* const kGrpcURLScheme;
60 static const char* const kMemoryURLScheme;
61
62 static Status PublishDebugMetadata(
63 const int64_t global_step, const int64_t session_run_index,
64 const int64_t executor_step_index, const std::vector<string>& input_names,
65 const std::vector<string>& output_names,
66 const std::vector<string>& target_nodes,
67 const std::unordered_set<string>& debug_urls);
68
69 // Publishes a tensor to a debug target URL.
70 //
71 // Args:
72 // debug_node_key: A DebugNodeKey identifying the debug node.
73 // tensor: The Tensor object being published.
74 // wall_time_us: Time stamp for the Tensor. Unit: microseconds (us).
75 // debug_urls: An array of debug target URLs, e.g.,
76 // "file:///foo/tfdbg_dump", "grpc://localhost:11011"
77 // gated_grpc: Whether this call is subject to gRPC gating.
78 static Status PublishDebugTensor(const DebugNodeKey& debug_node_key,
79 const Tensor& tensor,
80 const uint64 wall_time_us,
81 const gtl::ArraySlice<string> debug_urls,
82 const bool gated_grpc);
83
84 // Convenience overload of the method above for no gated_grpc by default.
85 static Status PublishDebugTensor(const DebugNodeKey& debug_node_key,
86 const Tensor& tensor,
87 const uint64 wall_time_us,
88 const gtl::ArraySlice<string> debug_urls);
89
90 // Publishes a graph to a set of debug URLs.
91 //
92 // Args:
93 // graph: The graph to be published.
94 // debug_urls: The set of debug URLs to publish the graph to.
95 static Status PublishGraph(const Graph& graph, const string& device_name,
96 const std::unordered_set<string>& debug_urls);
97
98 // Determines whether a copy node needs to perform deep-copy of input tensor.
99 //
100 // The input arguments contain sufficient information about the attached
101 // downstream debug ops for this method to determine whether all the said
102 // ops are disabled given the current status of the gRPC gating.
103 //
104 // Args:
105 // specs: A vector of DebugWatchAndURLSpec carrying information about the
106 // debug ops attached to the Copy node, their debug URLs and whether
107 // they have the attribute value gated_grpc == True.
108 //
109 // Returns:
110 // Whether any of the attached downstream debug ops is enabled given the
111 // current status of the gRPC gating.
112 static bool IsCopyNodeGateOpen(
113 const std::vector<DebugWatchAndURLSpec>& specs);
114
115 // Determines whether a debug node needs to proceed given the current gRPC
116 // gating status.
117 //
118 // Args:
119 // watch_key: debug tensor watch key, in the format of
120 // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity".
121 // debug_urls: the debug URLs of the debug node.
122 //
123 // Returns:
124 // Whether this debug op should proceed.
125 static bool IsDebugNodeGateOpen(const string& watch_key,
126 const std::vector<string>& debug_urls);
127
128 // Determines whether debug information should be sent through a grpc://
129 // debug URL given the current gRPC gating status.
130 //
131 // Args:
132 // watch_key: debug tensor watch key, in the format of
133 // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity".
134 // debug_url: the debug URL, e.g., "grpc://localhost:3333",
135 // "file:///tmp/tfdbg_1".
136 //
137 // Returns:
138 // Whether the sending of debug data to the debug_url should
139 // proceed.
140 static bool IsDebugURLGateOpen(const string& watch_key,
141 const string& debug_url);
142
143 static Status CloseDebugURL(const string& debug_url);
144};
145
146// Helper class for debug ops.
147class DebugFileIO {
148 public:
149 // Encapsulates the Tensor in an Event protobuf and write it to a directory.
150 // The actual path of the dump file will be a contactenation of
151 // dump_root_dir, tensor_name, along with the wall_time.
152 //
153 // For example:
154 // let dump_root_dir = "/tmp/tfdbg_dump",
155 // node_name = "foo/bar",
156 // output_slot = 0,
157 // debug_op = DebugIdentity,
158 // and wall_time_us = 1467891234512345,
159 // the dump file will be generated at path:
160 // /tmp/tfdbg_dump/foo/bar_0_DebugIdentity_1467891234512345.
161 //
162 // Args:
163 // debug_node_key: A DebugNodeKey identifying the debug node.
164 // wall_time_us: Wall time at which the Tensor is generated during graph
165 // execution. Unit: microseconds (us).
166 // dump_root_dir: Root directory for dumping the tensor.
167 // dump_file_path: The actual dump file path (passed as reference).
168 static Status DumpTensorToDir(const DebugNodeKey& debug_node_key,
169 const Tensor& tensor, const uint64 wall_time_us,
170 const string& dump_root_dir,
171 string* dump_file_path);
172
173 // Get the full path to the dump file.
174 //
175 // Args:
176 // dump_root_dir: The dump root directory, e.g., /tmp/tfdbg_dump
177 // node_name: Name of the node from which the dumped tensor is generated,
178 // e.g., foo/bar/node_a
179 // output_slot: Output slot index of the said node, e.g., 0.
180 // debug_op: Name of the debug op, e.g., DebugIdentity.
181 // wall_time_us: Time stamp of the dumped tensor, in microseconds (us).
182 static string GetDumpFilePath(const string& dump_root_dir,
183 const DebugNodeKey& debug_node_key,
184 const uint64 wall_time_us);
185
186 // Dumps an Event proto to a file.
187 //
188 // Args:
189 // event_prot: The Event proto to be dumped.
190 // dir_name: Directory path.
191 // file_name: Base file name.
192 static Status DumpEventProtoToFile(const Event& event_proto,
193 const string& dir_name,
194 const string& file_name);
195
196 // Request additional bytes to be dumped to the file system.
197 //
198 // Does not actually dump the bytes, but instead just performs the
199 // bookkeeping necessary to prevent the total dumped amount of data from
200 // exceeding the limit (default 100 GBytes or set customly through the
201 // environment variable TFDBG_DISK_BYTES_LIMIT).
202 //
203 // Args:
204 // bytes: Number of bytes to request.
205 //
206 // Returns:
207 // Whether the request is approved given the total dumping
208 // limit.
209 static bool requestDiskByteUsage(uint64 bytes);
210
211 // Reset the disk byte usage to zero.
212 static void resetDiskByteUsage();
213
214 static uint64 global_disk_bytes_limit_;
215
216 private:
217 // Encapsulates the Tensor in an Event protobuf and write it to file.
218 static Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
219 const Tensor& tensor,
220 const uint64 wall_time_us,
221 const string& file_path);
222
223 // Implemented ad hoc here for now.
224 // TODO(cais): Replace with shared implementation once http://b/30497715 is
225 // fixed.
226 static Status RecursiveCreateDir(Env* env, const string& dir);
227
228 // Tracks how much disk has been used so far.
229 static uint64 disk_bytes_used_;
230 // Mutex for thread-safe access to disk_bytes_used_.
231 static mutex bytes_mu_;
232 // Default limit for the disk space.
233 static const uint64 kDefaultGlobalDiskBytesLimit;
234
235 friend class DiskUsageLimitTest;
236};
237
238} // namespace tensorflow
239
240namespace std {
241
242template <>
243struct hash<::tensorflow::DebugNodeKey> {
244 size_t operator()(const ::tensorflow::DebugNodeKey& k) const {
245 return ::tensorflow::Hash64(
246 ::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":",
247 k.output_slot, ":", k.debug_op, ":"));
248 }
249};
250
251} // namespace std
252
253// TODO(cais): Support grpc:// debug URLs in open source once Python grpc
254// genrule becomes available. See b/23796275.
255#ifndef PLATFORM_WINDOWS
256#include "grpcpp/channel.h"
257#include "tensorflow/core/debug/debug_service.grpc.pb.h"
258
259namespace tensorflow {
260
261class DebugGrpcChannel {
262 public:
263 // Constructor of DebugGrpcChannel.
264 //
265 // Args:
266 // server_stream_addr: Address (host name and port) of the debug stream
267 // server implementing the EventListener service (see
268 // debug_service.proto). E.g., "127.0.0.1:12345".
269 explicit DebugGrpcChannel(const string& server_stream_addr);
270
271 virtual ~DebugGrpcChannel() {}
272
273 // Attempt to establish connection with server.
274 //
275 // Args:
276 // timeout_micros: Timeout (in microseconds) for the attempt to establish
277 // the connection.
278 //
279 // Returns:
280 // OK Status iff connection is successfully established before timeout,
281 // otherwise return an error Status.
282 Status Connect(const int64_t timeout_micros);
283
284 // Write an Event proto to the debug gRPC stream.
285 //
286 // Thread-safety: Safe with respect to other calls to the same method and
287 // calls to ReadEventReply() and Close().
288 //
289 // Args:
290 // event: The event proto to be written to the stream.
291 //
292 // Returns:
293 // True iff the write is successful.
294 bool WriteEvent(const Event& event);
295
296 // Read an EventReply proto from the debug gRPC stream.
297 //
298 // This method blocks and waits for an EventReply from the server.
299 // Thread-safety: Safe with respect to other calls to the same method and
300 // calls to WriteEvent() and Close().
301 //
302 // Args:
303 // event_reply: the to-be-modified EventReply proto passed as reference.
304 //
305 // Returns:
306 // True iff the read is successful.
307 bool ReadEventReply(EventReply* event_reply);
308
309 // Receive and process EventReply protos from the gRPC debug server.
310 //
311 // The processing includes setting debug watch key states using the
312 // DebugOpStateChange fields of the EventReply.
313 //
314 // Args:
315 // max_replies: Maximum number of replies to receive. Will receive all
316 // remaining replies iff max_replies == 0.
317 void ReceiveAndProcessEventReplies(size_t max_replies);
318
319 // Receive EventReplies from server (if any) and close the stream and the
320 // channel.
321 Status ReceiveServerRepliesAndClose();
322
323 private:
324 string server_stream_addr_;
325 string url_;
326 ::grpc::ClientContext ctx_;
327 std::shared_ptr<::grpc::Channel> channel_;
328 std::unique_ptr<grpc::EventListener::Stub> stub_;
329 std::unique_ptr<::grpc::ClientReaderWriterInterface<Event, EventReply>>
330 reader_writer_;
331
332 mutex mu_;
333};
334
335class DebugGrpcIO {
336 public:
337 static const size_t kGrpcMessageSizeLimitBytes;
338 static const size_t kGrpcMaxVarintLengthSize;
339
340 // Sends a tensor through a debug gRPC stream.
341 static Status SendTensorThroughGrpcStream(const DebugNodeKey& debug_node_key,
342 const Tensor& tensor,
343 const uint64 wall_time_us,
344 const string& grpc_stream_url,
345 const bool gated);
346
347 // Sends an Event proto through a debug gRPC stream.
348 // Thread-safety: Safe with respect to other calls to the same method and
349 // calls to CloseGrpcStream().
350 //
351 // Args:
352 // event_proto: The Event proto to be sent.
353 // grpc_stream_url: The grpc:// URL of the stream to use, e.g.,
354 // "grpc://localhost:11011", "localhost:22022".
355 // receive_reply: Whether an EventReply proto will be read after event_proto
356 // is sent and before the function returns.
357 //
358 // Returns:
359 // The Status of the operation.
360 static Status SendEventProtoThroughGrpcStream(
361 const Event& event_proto, const string& grpc_stream_url,
362 const bool receive_reply = false);
363
364 // Receive an EventReply proto through a debug gRPC stream.
365 static Status ReceiveEventReplyProtoThroughGrpcStream(
366 EventReply* event_reply, const string& grpc_stream_url);
367
368 // Check whether a debug watch key is read-activated at a given gRPC URL.
369 static bool IsReadGateOpen(const string& grpc_debug_url,
370 const string& watch_key);
371
372 // Check whether a debug watch key is write-activated (i.e., read- and
373 // write-activated) at a given gRPC URL.
374 static bool IsWriteGateOpen(const string& grpc_debug_url,
375 const string& watch_key);
376
377 // Closes a gRPC stream to the given address, if it exists.
378 // Thread-safety: Safe with respect to other calls to the same method and
379 // calls to SendTensorThroughGrpcStream().
380 static Status CloseGrpcStream(const string& grpc_stream_url);
381
382 // Set the gRPC state of a debug node key.
383 // TODO(cais): Include device information in watch_key.
384 static void SetDebugNodeKeyGrpcState(
385 const string& grpc_debug_url, const string& watch_key,
386 const EventReply::DebugOpStateChange::State new_state);
387
388 private:
389 using DebugNodeName2State =
390 std::unordered_map<string, EventReply::DebugOpStateChange::State>;
391
392 // Returns a global map from grpc debug URLs to the corresponding
393 // DebugGrpcChannels.
394 static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
395 GetStreamChannels();
396
397 // Get a DebugGrpcChannel object at a given URL, creating one if necessary.
398 //
399 // Args:
400 // grpc_stream_url: grpc:// URL of the stream, e.g., "grpc://localhost:6064"
401 // debug_grpc_channel: A pointer to the DebugGrpcChannel object, passed as a
402 // a pointer to the pointer. The DebugGrpcChannel object is owned
403 // statically elsewhere, not by the caller of this function.
404 //
405 // Returns:
406 // Status of this operation.
407 static Status GetOrCreateDebugGrpcChannel(
408 const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel);
409
410 // Returns a map from debug URL to a map from debug op name to enabled state.
411 static std::unordered_map<string, DebugNodeName2State>*
412 GetEnabledDebugOpStates();
413
414 // Returns a map from debug op names to enabled state, for a given debug URL.
415 static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl(
416 const string& grpc_debug_url);
417
418 // Clear enabled debug op state from all debug URLs (if any).
419 static void ClearEnabledWatchKeys();
420
421 static mutex streams_mu_;
422 static int64_t channel_connection_timeout_micros_;
423
424 friend class GrpcDebugTest;
425 friend class DebugNumericSummaryOpTest;
426};
427
428} // namespace tensorflow
429#endif // #ifndef(PLATFORM_WINDOWS)
430
431#endif // TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
432