1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_ |
17 | #define TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_ |
18 | |
19 | #include <cstddef> |
20 | #include <functional> |
21 | #include <memory> |
22 | #include <string> |
23 | #include <unordered_map> |
24 | #include <unordered_set> |
25 | #include <vector> |
26 | |
27 | #include "tensorflow/core/debug/debug_node_key.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/graph/graph.h" |
30 | #include "tensorflow/core/lib/core/status.h" |
31 | #include "tensorflow/core/lib/gtl/array_slice.h" |
32 | #include "tensorflow/core/platform/env.h" |
33 | #include "tensorflow/core/util/event.pb.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | Status ReadEventFromFile(const string& dump_file_path, Event* event); |
38 | |
39 | struct DebugWatchAndURLSpec { |
40 | DebugWatchAndURLSpec(const string& watch_key, const string& url, |
41 | const bool gated_grpc) |
42 | : watch_key(watch_key), url(url), gated_grpc(gated_grpc) {} |
43 | |
44 | const string watch_key; |
45 | const string url; |
46 | const bool gated_grpc; |
47 | }; |
48 | |
49 | // TODO(cais): Put static functions and members in a namespace, not a class. |
50 | class DebugIO { |
51 | public: |
52 | static const char* const kDebuggerPluginName; |
53 | |
54 | static const char* const kCoreMetadataTag; |
55 | static const char* const kGraphTag; |
56 | static const char* const kHashTag; |
57 | |
58 | static const char* const kFileURLScheme; |
59 | static const char* const kGrpcURLScheme; |
60 | static const char* const kMemoryURLScheme; |
61 | |
62 | static Status PublishDebugMetadata( |
63 | const int64_t global_step, const int64_t session_run_index, |
64 | const int64_t executor_step_index, const std::vector<string>& input_names, |
65 | const std::vector<string>& output_names, |
66 | const std::vector<string>& target_nodes, |
67 | const std::unordered_set<string>& debug_urls); |
68 | |
69 | // Publishes a tensor to a debug target URL. |
70 | // |
71 | // Args: |
72 | // debug_node_key: A DebugNodeKey identifying the debug node. |
73 | // tensor: The Tensor object being published. |
74 | // wall_time_us: Time stamp for the Tensor. Unit: microseconds (us). |
75 | // debug_urls: An array of debug target URLs, e.g., |
76 | // "file:///foo/tfdbg_dump", "grpc://localhost:11011" |
77 | // gated_grpc: Whether this call is subject to gRPC gating. |
78 | static Status PublishDebugTensor(const DebugNodeKey& debug_node_key, |
79 | const Tensor& tensor, |
80 | const uint64 wall_time_us, |
81 | const gtl::ArraySlice<string> debug_urls, |
82 | const bool gated_grpc); |
83 | |
84 | // Convenience overload of the method above for no gated_grpc by default. |
85 | static Status PublishDebugTensor(const DebugNodeKey& debug_node_key, |
86 | const Tensor& tensor, |
87 | const uint64 wall_time_us, |
88 | const gtl::ArraySlice<string> debug_urls); |
89 | |
90 | // Publishes a graph to a set of debug URLs. |
91 | // |
92 | // Args: |
93 | // graph: The graph to be published. |
94 | // debug_urls: The set of debug URLs to publish the graph to. |
95 | static Status PublishGraph(const Graph& graph, const string& device_name, |
96 | const std::unordered_set<string>& debug_urls); |
97 | |
98 | // Determines whether a copy node needs to perform deep-copy of input tensor. |
99 | // |
100 | // The input arguments contain sufficient information about the attached |
101 | // downstream debug ops for this method to determine whether all the said |
102 | // ops are disabled given the current status of the gRPC gating. |
103 | // |
104 | // Args: |
105 | // specs: A vector of DebugWatchAndURLSpec carrying information about the |
106 | // debug ops attached to the Copy node, their debug URLs and whether |
107 | // they have the attribute value gated_grpc == True. |
108 | // |
109 | // Returns: |
110 | // Whether any of the attached downstream debug ops is enabled given the |
111 | // current status of the gRPC gating. |
112 | static bool IsCopyNodeGateOpen( |
113 | const std::vector<DebugWatchAndURLSpec>& specs); |
114 | |
115 | // Determines whether a debug node needs to proceed given the current gRPC |
116 | // gating status. |
117 | // |
118 | // Args: |
119 | // watch_key: debug tensor watch key, in the format of |
120 | // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity". |
121 | // debug_urls: the debug URLs of the debug node. |
122 | // |
123 | // Returns: |
124 | // Whether this debug op should proceed. |
125 | static bool IsDebugNodeGateOpen(const string& watch_key, |
126 | const std::vector<string>& debug_urls); |
127 | |
128 | // Determines whether debug information should be sent through a grpc:// |
129 | // debug URL given the current gRPC gating status. |
130 | // |
131 | // Args: |
132 | // watch_key: debug tensor watch key, in the format of |
133 | // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity". |
134 | // debug_url: the debug URL, e.g., "grpc://localhost:3333", |
135 | // "file:///tmp/tfdbg_1". |
136 | // |
137 | // Returns: |
138 | // Whether the sending of debug data to the debug_url should |
139 | // proceed. |
140 | static bool IsDebugURLGateOpen(const string& watch_key, |
141 | const string& debug_url); |
142 | |
143 | static Status CloseDebugURL(const string& debug_url); |
144 | }; |
145 | |
146 | // Helper class for debug ops. |
147 | class DebugFileIO { |
148 | public: |
149 | // Encapsulates the Tensor in an Event protobuf and write it to a directory. |
150 | // The actual path of the dump file will be a contactenation of |
151 | // dump_root_dir, tensor_name, along with the wall_time. |
152 | // |
153 | // For example: |
154 | // let dump_root_dir = "/tmp/tfdbg_dump", |
155 | // node_name = "foo/bar", |
156 | // output_slot = 0, |
157 | // debug_op = DebugIdentity, |
158 | // and wall_time_us = 1467891234512345, |
159 | // the dump file will be generated at path: |
160 | // /tmp/tfdbg_dump/foo/bar_0_DebugIdentity_1467891234512345. |
161 | // |
162 | // Args: |
163 | // debug_node_key: A DebugNodeKey identifying the debug node. |
164 | // wall_time_us: Wall time at which the Tensor is generated during graph |
165 | // execution. Unit: microseconds (us). |
166 | // dump_root_dir: Root directory for dumping the tensor. |
167 | // dump_file_path: The actual dump file path (passed as reference). |
168 | static Status DumpTensorToDir(const DebugNodeKey& debug_node_key, |
169 | const Tensor& tensor, const uint64 wall_time_us, |
170 | const string& dump_root_dir, |
171 | string* dump_file_path); |
172 | |
173 | // Get the full path to the dump file. |
174 | // |
175 | // Args: |
176 | // dump_root_dir: The dump root directory, e.g., /tmp/tfdbg_dump |
177 | // node_name: Name of the node from which the dumped tensor is generated, |
178 | // e.g., foo/bar/node_a |
179 | // output_slot: Output slot index of the said node, e.g., 0. |
180 | // debug_op: Name of the debug op, e.g., DebugIdentity. |
181 | // wall_time_us: Time stamp of the dumped tensor, in microseconds (us). |
182 | static string GetDumpFilePath(const string& dump_root_dir, |
183 | const DebugNodeKey& debug_node_key, |
184 | const uint64 wall_time_us); |
185 | |
186 | // Dumps an Event proto to a file. |
187 | // |
188 | // Args: |
189 | // event_prot: The Event proto to be dumped. |
190 | // dir_name: Directory path. |
191 | // file_name: Base file name. |
192 | static Status DumpEventProtoToFile(const Event& event_proto, |
193 | const string& dir_name, |
194 | const string& file_name); |
195 | |
196 | // Request additional bytes to be dumped to the file system. |
197 | // |
198 | // Does not actually dump the bytes, but instead just performs the |
199 | // bookkeeping necessary to prevent the total dumped amount of data from |
200 | // exceeding the limit (default 100 GBytes or set customly through the |
201 | // environment variable TFDBG_DISK_BYTES_LIMIT). |
202 | // |
203 | // Args: |
204 | // bytes: Number of bytes to request. |
205 | // |
206 | // Returns: |
207 | // Whether the request is approved given the total dumping |
208 | // limit. |
209 | static bool requestDiskByteUsage(uint64 bytes); |
210 | |
211 | // Reset the disk byte usage to zero. |
212 | static void resetDiskByteUsage(); |
213 | |
214 | static uint64 global_disk_bytes_limit_; |
215 | |
216 | private: |
217 | // Encapsulates the Tensor in an Event protobuf and write it to file. |
218 | static Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key, |
219 | const Tensor& tensor, |
220 | const uint64 wall_time_us, |
221 | const string& file_path); |
222 | |
223 | // Implemented ad hoc here for now. |
224 | // TODO(cais): Replace with shared implementation once http://b/30497715 is |
225 | // fixed. |
226 | static Status RecursiveCreateDir(Env* env, const string& dir); |
227 | |
228 | // Tracks how much disk has been used so far. |
229 | static uint64 disk_bytes_used_; |
230 | // Mutex for thread-safe access to disk_bytes_used_. |
231 | static mutex bytes_mu_; |
232 | // Default limit for the disk space. |
233 | static const uint64 kDefaultGlobalDiskBytesLimit; |
234 | |
235 | friend class DiskUsageLimitTest; |
236 | }; |
237 | |
238 | } // namespace tensorflow |
239 | |
240 | namespace std { |
241 | |
242 | template <> |
243 | struct hash<::tensorflow::DebugNodeKey> { |
244 | size_t operator()(const ::tensorflow::DebugNodeKey& k) const { |
245 | return ::tensorflow::Hash64( |
246 | ::tensorflow::strings::StrCat(k.device_name, ":" , k.node_name, ":" , |
247 | k.output_slot, ":" , k.debug_op, ":" )); |
248 | } |
249 | }; |
250 | |
251 | } // namespace std |
252 | |
253 | // TODO(cais): Support grpc:// debug URLs in open source once Python grpc |
254 | // genrule becomes available. See b/23796275. |
255 | #ifndef PLATFORM_WINDOWS |
256 | #include "grpcpp/channel.h" |
257 | #include "tensorflow/core/debug/debug_service.grpc.pb.h" |
258 | |
259 | namespace tensorflow { |
260 | |
261 | class DebugGrpcChannel { |
262 | public: |
263 | // Constructor of DebugGrpcChannel. |
264 | // |
265 | // Args: |
266 | // server_stream_addr: Address (host name and port) of the debug stream |
267 | // server implementing the EventListener service (see |
268 | // debug_service.proto). E.g., "127.0.0.1:12345". |
269 | explicit DebugGrpcChannel(const string& server_stream_addr); |
270 | |
271 | virtual ~DebugGrpcChannel() {} |
272 | |
273 | // Attempt to establish connection with server. |
274 | // |
275 | // Args: |
276 | // timeout_micros: Timeout (in microseconds) for the attempt to establish |
277 | // the connection. |
278 | // |
279 | // Returns: |
280 | // OK Status iff connection is successfully established before timeout, |
281 | // otherwise return an error Status. |
282 | Status Connect(const int64_t timeout_micros); |
283 | |
284 | // Write an Event proto to the debug gRPC stream. |
285 | // |
286 | // Thread-safety: Safe with respect to other calls to the same method and |
287 | // calls to ReadEventReply() and Close(). |
288 | // |
289 | // Args: |
290 | // event: The event proto to be written to the stream. |
291 | // |
292 | // Returns: |
293 | // True iff the write is successful. |
294 | bool WriteEvent(const Event& event); |
295 | |
296 | // Read an EventReply proto from the debug gRPC stream. |
297 | // |
298 | // This method blocks and waits for an EventReply from the server. |
299 | // Thread-safety: Safe with respect to other calls to the same method and |
300 | // calls to WriteEvent() and Close(). |
301 | // |
302 | // Args: |
303 | // event_reply: the to-be-modified EventReply proto passed as reference. |
304 | // |
305 | // Returns: |
306 | // True iff the read is successful. |
307 | bool ReadEventReply(EventReply* event_reply); |
308 | |
309 | // Receive and process EventReply protos from the gRPC debug server. |
310 | // |
311 | // The processing includes setting debug watch key states using the |
312 | // DebugOpStateChange fields of the EventReply. |
313 | // |
314 | // Args: |
315 | // max_replies: Maximum number of replies to receive. Will receive all |
316 | // remaining replies iff max_replies == 0. |
317 | void ReceiveAndProcessEventReplies(size_t max_replies); |
318 | |
319 | // Receive EventReplies from server (if any) and close the stream and the |
320 | // channel. |
321 | Status ReceiveServerRepliesAndClose(); |
322 | |
323 | private: |
324 | string server_stream_addr_; |
325 | string url_; |
326 | ::grpc::ClientContext ctx_; |
327 | std::shared_ptr<::grpc::Channel> channel_; |
328 | std::unique_ptr<grpc::EventListener::Stub> stub_; |
329 | std::unique_ptr<::grpc::ClientReaderWriterInterface<Event, EventReply>> |
330 | reader_writer_; |
331 | |
332 | mutex mu_; |
333 | }; |
334 | |
335 | class DebugGrpcIO { |
336 | public: |
337 | static const size_t kGrpcMessageSizeLimitBytes; |
338 | static const size_t kGrpcMaxVarintLengthSize; |
339 | |
340 | // Sends a tensor through a debug gRPC stream. |
341 | static Status SendTensorThroughGrpcStream(const DebugNodeKey& debug_node_key, |
342 | const Tensor& tensor, |
343 | const uint64 wall_time_us, |
344 | const string& grpc_stream_url, |
345 | const bool gated); |
346 | |
347 | // Sends an Event proto through a debug gRPC stream. |
348 | // Thread-safety: Safe with respect to other calls to the same method and |
349 | // calls to CloseGrpcStream(). |
350 | // |
351 | // Args: |
352 | // event_proto: The Event proto to be sent. |
353 | // grpc_stream_url: The grpc:// URL of the stream to use, e.g., |
354 | // "grpc://localhost:11011", "localhost:22022". |
355 | // receive_reply: Whether an EventReply proto will be read after event_proto |
356 | // is sent and before the function returns. |
357 | // |
358 | // Returns: |
359 | // The Status of the operation. |
360 | static Status SendEventProtoThroughGrpcStream( |
361 | const Event& event_proto, const string& grpc_stream_url, |
362 | const bool receive_reply = false); |
363 | |
364 | // Receive an EventReply proto through a debug gRPC stream. |
365 | static Status ReceiveEventReplyProtoThroughGrpcStream( |
366 | EventReply* event_reply, const string& grpc_stream_url); |
367 | |
368 | // Check whether a debug watch key is read-activated at a given gRPC URL. |
369 | static bool IsReadGateOpen(const string& grpc_debug_url, |
370 | const string& watch_key); |
371 | |
372 | // Check whether a debug watch key is write-activated (i.e., read- and |
373 | // write-activated) at a given gRPC URL. |
374 | static bool IsWriteGateOpen(const string& grpc_debug_url, |
375 | const string& watch_key); |
376 | |
377 | // Closes a gRPC stream to the given address, if it exists. |
378 | // Thread-safety: Safe with respect to other calls to the same method and |
379 | // calls to SendTensorThroughGrpcStream(). |
380 | static Status CloseGrpcStream(const string& grpc_stream_url); |
381 | |
382 | // Set the gRPC state of a debug node key. |
383 | // TODO(cais): Include device information in watch_key. |
384 | static void SetDebugNodeKeyGrpcState( |
385 | const string& grpc_debug_url, const string& watch_key, |
386 | const EventReply::DebugOpStateChange::State new_state); |
387 | |
388 | private: |
389 | using DebugNodeName2State = |
390 | std::unordered_map<string, EventReply::DebugOpStateChange::State>; |
391 | |
392 | // Returns a global map from grpc debug URLs to the corresponding |
393 | // DebugGrpcChannels. |
394 | static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* |
395 | GetStreamChannels(); |
396 | |
397 | // Get a DebugGrpcChannel object at a given URL, creating one if necessary. |
398 | // |
399 | // Args: |
400 | // grpc_stream_url: grpc:// URL of the stream, e.g., "grpc://localhost:6064" |
401 | // debug_grpc_channel: A pointer to the DebugGrpcChannel object, passed as a |
402 | // a pointer to the pointer. The DebugGrpcChannel object is owned |
403 | // statically elsewhere, not by the caller of this function. |
404 | // |
405 | // Returns: |
406 | // Status of this operation. |
407 | static Status GetOrCreateDebugGrpcChannel( |
408 | const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel); |
409 | |
410 | // Returns a map from debug URL to a map from debug op name to enabled state. |
411 | static std::unordered_map<string, DebugNodeName2State>* |
412 | GetEnabledDebugOpStates(); |
413 | |
414 | // Returns a map from debug op names to enabled state, for a given debug URL. |
415 | static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl( |
416 | const string& grpc_debug_url); |
417 | |
418 | // Clear enabled debug op state from all debug URLs (if any). |
419 | static void ClearEnabledWatchKeys(); |
420 | |
421 | static mutex streams_mu_; |
422 | static int64_t channel_connection_timeout_micros_; |
423 | |
424 | friend class GrpcDebugTest; |
425 | friend class DebugNumericSummaryOpTest; |
426 | }; |
427 | |
428 | } // namespace tensorflow |
429 | #endif // #ifndef(PLATFORM_WINDOWS) |
430 | |
431 | #endif // TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_ |
432 | |