1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/debug/debug_io_utils.h" |
17 | |
18 | #include <stddef.h> |
19 | #include <string.h> |
20 | #include <cmath> |
21 | #include <cstdlib> |
22 | #include <cstring> |
23 | #include <limits> |
24 | #include <utility> |
25 | #include <vector> |
26 | |
27 | #ifndef PLATFORM_WINDOWS |
28 | #include "grpcpp/create_channel.h" |
29 | #else |
30 | #endif // #ifndef PLATFORM_WINDOWS |
31 | |
32 | #include "absl/strings/ascii.h" |
33 | #include "absl/strings/match.h" |
34 | #include "tensorflow/core/debug/debug_callback_registry.h" |
35 | #include "tensorflow/core/debug/debugger_event_metadata.pb.h" |
36 | #include "tensorflow/core/framework/graph.pb.h" |
37 | #include "tensorflow/core/framework/summary.pb.h" |
38 | #include "tensorflow/core/framework/tensor.pb.h" |
39 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
40 | #include "tensorflow/core/lib/core/bits.h" |
41 | #include "tensorflow/core/lib/hash/hash.h" |
42 | #include "tensorflow/core/lib/io/path.h" |
43 | #include "tensorflow/core/lib/strings/str_util.h" |
44 | #include "tensorflow/core/lib/strings/stringprintf.h" |
45 | #include "tensorflow/core/platform/protobuf.h" |
46 | #include "tensorflow/core/util/event.pb.h" |
47 | |
48 | #define GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR \ |
49 | return errors::Unimplemented( \ |
50 | kGrpcURLScheme, " debug URL scheme is not implemented on Windows yet.") |
51 | |
52 | namespace tensorflow { |
53 | |
54 | namespace { |
55 | |
56 | // Creates an Event proto representing a chunk of a Tensor. This method only |
57 | // populates the field of the Event proto that represent the envelope |
58 | // information (e.g., timestamp, device_name, num_chunks, chunk_index, dtype, |
59 | // shape). It does not set the value.tensor field, which should be set by the |
60 | // caller separately. |
61 | Event PrepareChunkEventProto(const DebugNodeKey& debug_node_key, |
62 | const uint64 wall_time_us, const size_t num_chunks, |
63 | const size_t chunk_index, |
64 | const DataType& tensor_dtype, |
65 | const TensorShapeProto& tensor_shape) { |
66 | Event event; |
67 | event.set_wall_time(static_cast<double>(wall_time_us)); |
68 | Summary::Value* value = event.mutable_summary()->add_value(); |
69 | |
70 | // Create the debug node_name in the Summary proto. |
71 | // For example, if tensor_name = "foo/node_a:0", and the debug_op is |
72 | // "DebugIdentity", the debug node_name in the Summary proto will be |
73 | // "foo/node_a:0:DebugIdentity". |
74 | value->set_node_name(debug_node_key.debug_node_name); |
75 | |
76 | // Tag by the node name. This allows TensorBoard to quickly fetch data |
77 | // per op. |
78 | value->set_tag(debug_node_key.node_name); |
79 | |
80 | // Store data within debugger metadata to be stored for each event. |
81 | third_party::tensorflow::core::debug::DebuggerEventMetadata metadata; |
82 | metadata.set_device(debug_node_key.device_name); |
83 | metadata.set_output_slot(debug_node_key.output_slot); |
84 | metadata.set_num_chunks(num_chunks); |
85 | metadata.set_chunk_index(chunk_index); |
86 | |
87 | // Encode the data in JSON. |
88 | string json_output; |
89 | tensorflow::protobuf::util::JsonPrintOptions json_options; |
90 | json_options.always_print_primitive_fields = true; |
91 | auto status = tensorflow::protobuf::util::MessageToJsonString( |
92 | metadata, &json_output, json_options); |
93 | if (status.ok()) { |
94 | // Store summary metadata. Set the plugin to use this data as "debugger". |
95 | SummaryMetadata::PluginData* plugin_data = |
96 | value->mutable_metadata()->mutable_plugin_data(); |
97 | plugin_data->set_plugin_name(DebugIO::kDebuggerPluginName); |
98 | plugin_data->set_content(json_output); |
99 | } else { |
100 | LOG(WARNING) << "Failed to convert DebuggerEventMetadata proto to JSON. " |
101 | << "The debug_node_name is " << debug_node_key.debug_node_name |
102 | << "." ; |
103 | } |
104 | |
105 | value->mutable_tensor()->set_dtype(tensor_dtype); |
106 | *value->mutable_tensor()->mutable_tensor_shape() = tensor_shape; |
107 | |
108 | return event; |
109 | } |
110 | |
111 | // Translates the length of a string to number of bytes when the string is |
112 | // encoded as bytes in protobuf. Note that this makes a conservative estimate |
113 | // (i.e., an estimate that is usually too large, but never too small under the |
114 | // gRPC message size limit) of the Varint-encoded length, to workaround the lack |
115 | // of a portable length function. |
116 | const size_t StringValMaxBytesInProto(const string& str) { |
117 | #if defined(PLATFORM_GOOGLE) |
118 | return str.size() + DebugGrpcIO::kGrpcMaxVarintLengthSize; |
119 | #else |
120 | return str.size(); |
121 | #endif |
122 | } |
123 | |
124 | // Breaks a string Tensor (represented as a TensorProto) as a vector of Event |
125 | // protos. |
126 | Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key, |
127 | const uint64 wall_time_us, |
128 | const size_t chunk_size_limit, |
129 | TensorProto* tensor_proto, |
130 | std::vector<Event>* events) { |
131 | const protobuf::RepeatedPtrField<string>& strs = tensor_proto->string_val(); |
132 | const size_t num_strs = strs.size(); |
133 | const size_t chunk_size_ub = chunk_size_limit > 0 |
134 | ? chunk_size_limit |
135 | : std::numeric_limits<size_t>::max(); |
136 | |
137 | // E.g., if cutoffs is {j, k, l}, the chunks will have index ranges: |
138 | // [0:a), [a:b), [c:<end>]. |
139 | std::vector<size_t> cutoffs; |
140 | size_t chunk_size = 0; |
141 | for (size_t i = 0; i < num_strs; ++i) { |
142 | // Take into account the extra bytes in proto buffer. |
143 | if (StringValMaxBytesInProto(strs[i]) > chunk_size_ub) { |
144 | return errors::FailedPrecondition( |
145 | "string value at index " , i, " from debug node " , |
146 | debug_node_key.debug_node_name, |
147 | " does not fit gRPC message size limit (" , chunk_size_ub, ")" ); |
148 | } |
149 | if (chunk_size + StringValMaxBytesInProto(strs[i]) > chunk_size_ub) { |
150 | cutoffs.push_back(i); |
151 | chunk_size = 0; |
152 | } |
153 | chunk_size += StringValMaxBytesInProto(strs[i]); |
154 | } |
155 | cutoffs.push_back(num_strs); |
156 | const size_t num_chunks = cutoffs.size(); |
157 | |
158 | for (size_t i = 0; i < num_chunks; ++i) { |
159 | Event event = PrepareChunkEventProto(debug_node_key, wall_time_us, |
160 | num_chunks, i, tensor_proto->dtype(), |
161 | tensor_proto->tensor_shape()); |
162 | Summary::Value* value = event.mutable_summary()->mutable_value(0); |
163 | |
164 | if (cutoffs.size() == 1) { |
165 | value->mutable_tensor()->mutable_string_val()->Swap( |
166 | tensor_proto->mutable_string_val()); |
167 | } else { |
168 | const size_t begin = (i == 0) ? 0 : cutoffs[i - 1]; |
169 | const size_t end = cutoffs[i]; |
170 | for (size_t j = begin; j < end; ++j) { |
171 | value->mutable_tensor()->add_string_val(strs[j]); |
172 | } |
173 | } |
174 | |
175 | events->push_back(std::move(event)); |
176 | } |
177 | |
178 | return OkStatus(); |
179 | } |
180 | |
181 | // Encapsulates the tensor value inside a vector of Event protos. Large tensors |
182 | // are broken up to multiple protos to fit the chunk_size_limit. In each Event |
183 | // proto the field summary.tensor carries the content of the tensor. |
184 | // If chunk_size_limit <= 0, the tensor will not be broken into chunks, i.e., a |
185 | // length-1 vector will be returned, regardless of the size of the tensor. |
186 | Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key, |
187 | const Tensor& tensor, const uint64 wall_time_us, |
188 | const size_t chunk_size_limit, |
189 | std::vector<Event>* events) { |
190 | TensorProto tensor_proto; |
191 | if (tensor.dtype() == DT_STRING) { |
192 | // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python can |
193 | // convert the TensorProto to string-type numpy array. MakeNdarray does not |
194 | // work with strings encoded by AsProtoTensorContent() in tensor_content. |
195 | tensor.AsProtoField(&tensor_proto); |
196 | |
197 | TF_RETURN_IF_ERROR(WrapStringTensorAsEvents( |
198 | debug_node_key, wall_time_us, chunk_size_limit, &tensor_proto, events)); |
199 | } else { |
200 | tensor.AsProtoTensorContent(&tensor_proto); |
201 | |
202 | const size_t total_length = tensor_proto.tensor_content().size(); |
203 | const size_t chunk_size_ub = |
204 | chunk_size_limit > 0 ? chunk_size_limit : total_length; |
205 | const size_t num_chunks = |
206 | (total_length == 0) |
207 | ? 1 |
208 | : (total_length + chunk_size_ub - 1) / chunk_size_ub; |
209 | for (size_t i = 0; i < num_chunks; ++i) { |
210 | const size_t pos = i * chunk_size_ub; |
211 | const size_t len = |
212 | (i == num_chunks - 1) ? (total_length - pos) : chunk_size_ub; |
213 | Event event = PrepareChunkEventProto(debug_node_key, wall_time_us, |
214 | num_chunks, i, tensor_proto.dtype(), |
215 | tensor_proto.tensor_shape()); |
216 | event.mutable_summary() |
217 | ->mutable_value(0) |
218 | ->mutable_tensor() |
219 | ->set_tensor_content(tensor_proto.tensor_content().substr(pos, len)); |
220 | events->push_back(std::move(event)); |
221 | } |
222 | } |
223 | |
224 | return OkStatus(); |
225 | } |
226 | |
227 | // Appends an underscore and a timestamp to a file path. If the path already |
228 | // exists on the file system, append a hyphen and a 1-up index. Consecutive |
229 | // values of the index will be tried until the first unused one is found. |
230 | // TOCTOU race condition is not of concern here due to the fact that tfdbg |
231 | // sets parallel_iterations attribute of all while_loops to 1 to prevent |
232 | // the same node from between executed multiple times concurrently. |
233 | string AppendTimestampToFilePath(const string& in, const uint64 timestamp) { |
234 | string out = strings::StrCat(in, "_" , timestamp); |
235 | |
236 | uint64 i = 1; |
237 | while (Env::Default()->FileExists(out).ok()) { |
238 | out = strings::StrCat(in, "_" , timestamp, "-" , i); |
239 | ++i; |
240 | } |
241 | return out; |
242 | } |
243 | |
244 | #ifndef PLATFORM_WINDOWS |
245 | // Publishes encoded GraphDef through a gRPC debugger stream, in chunks, |
246 | // conforming to the gRPC message size limit. |
247 | Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def, |
248 | const string& device_name, |
249 | const int64_t wall_time, |
250 | const string& debug_url) { |
251 | const uint64 hash = ::tensorflow::Hash64(encoded_graph_def); |
252 | const size_t total_length = encoded_graph_def.size(); |
253 | const size_t num_chunks = |
254 | static_cast<size_t>(std::ceil(static_cast<float>(total_length) / |
255 | DebugGrpcIO::kGrpcMessageSizeLimitBytes)); |
256 | for (size_t i = 0; i < num_chunks; ++i) { |
257 | const size_t pos = i * DebugGrpcIO::kGrpcMessageSizeLimitBytes; |
258 | const size_t len = (i == num_chunks - 1) |
259 | ? (total_length - pos) |
260 | : DebugGrpcIO::kGrpcMessageSizeLimitBytes; |
261 | Event event; |
262 | event.set_wall_time(static_cast<double>(wall_time)); |
263 | // Prefix the chunk with |
264 | // <hash64>,<device_name>,<wall_time>|<index>|<num_chunks>|. |
265 | // TODO(cais): Use DebuggerEventMetadata to store device_name, num_chunks |
266 | // and chunk_index, instead. |
267 | event.set_graph_def(strings::StrCat(hash, "," , device_name, "," , wall_time, |
268 | "|" , i, "|" , num_chunks, "|" , |
269 | encoded_graph_def.substr(pos, len))); |
270 | const Status s = DebugGrpcIO::SendEventProtoThroughGrpcStream( |
271 | event, debug_url, num_chunks - 1 == i); |
272 | if (!s.ok()) { |
273 | return errors::FailedPrecondition( |
274 | "Failed to send chunk " , i, " of " , num_chunks, |
275 | " of encoded GraphDef of size " , encoded_graph_def.size(), " bytes, " , |
276 | "due to: " , s.error_message()); |
277 | } |
278 | } |
279 | return OkStatus(); |
280 | } |
281 | #endif // #ifndef PLATFORM_WINDOWS |
282 | |
283 | } // namespace |
284 | |
285 | const char* const DebugIO::kDebuggerPluginName = "debugger" ; |
286 | |
287 | const char* const DebugIO::kCoreMetadataTag = "core_metadata_" ; |
288 | |
289 | const char* const DebugIO::kGraphTag = "graph_" ; |
290 | |
291 | const char* const DebugIO::kHashTag = "hash" ; |
292 | |
293 | Status ReadEventFromFile(const string& dump_file_path, Event* event) { |
294 | Env* env(Env::Default()); |
295 | |
296 | string content; |
297 | uint64 file_size = 0; |
298 | |
299 | Status s = env->GetFileSize(dump_file_path, &file_size); |
300 | if (!s.ok()) { |
301 | return s; |
302 | } |
303 | |
304 | content.resize(file_size); |
305 | |
306 | std::unique_ptr<RandomAccessFile> file; |
307 | s = env->NewRandomAccessFile(dump_file_path, &file); |
308 | if (!s.ok()) { |
309 | return s; |
310 | } |
311 | |
312 | StringPiece result; |
313 | s = file->Read(0, file_size, &result, &(content)[0]); |
314 | if (!s.ok()) { |
315 | return s; |
316 | } |
317 | |
318 | event->ParseFromString(content); |
319 | return OkStatus(); |
320 | } |
321 | |
322 | const char* const DebugIO::kFileURLScheme = "file://" ; |
323 | const char* const DebugIO::kGrpcURLScheme = "grpc://" ; |
324 | const char* const DebugIO::kMemoryURLScheme = "memcbk://" ; |
325 | |
326 | // Publishes debug metadata to a set of debug URLs. |
327 | Status DebugIO::PublishDebugMetadata( |
328 | const int64_t global_step, const int64_t session_run_index, |
329 | const int64_t executor_step_index, const std::vector<string>& input_names, |
330 | const std::vector<string>& output_names, |
331 | const std::vector<string>& target_nodes, |
332 | const std::unordered_set<string>& debug_urls) { |
333 | std::ostringstream oss; |
334 | |
335 | // Construct a JSON string to carry the metadata. |
336 | oss << "{" ; |
337 | oss << "\"global_step\":" << global_step << "," ; |
338 | oss << "\"session_run_index\":" << session_run_index << "," ; |
339 | oss << "\"executor_step_index\":" << executor_step_index << "," ; |
340 | oss << "\"input_names\":[" ; |
341 | for (size_t i = 0; i < input_names.size(); ++i) { |
342 | oss << "\"" << input_names[i] << "\"" ; |
343 | if (i < input_names.size() - 1) { |
344 | oss << "," ; |
345 | } |
346 | } |
347 | oss << "]," ; |
348 | oss << "\"output_names\":[" ; |
349 | for (size_t i = 0; i < output_names.size(); ++i) { |
350 | oss << "\"" << output_names[i] << "\"" ; |
351 | if (i < output_names.size() - 1) { |
352 | oss << "," ; |
353 | } |
354 | } |
355 | oss << "]," ; |
356 | oss << "\"target_nodes\":[" ; |
357 | for (size_t i = 0; i < target_nodes.size(); ++i) { |
358 | oss << "\"" << target_nodes[i] << "\"" ; |
359 | if (i < target_nodes.size() - 1) { |
360 | oss << "," ; |
361 | } |
362 | } |
363 | oss << "]" ; |
364 | oss << "}" ; |
365 | |
366 | const string json_metadata = oss.str(); |
367 | Event event; |
368 | event.set_wall_time(static_cast<double>(Env::Default()->NowMicros())); |
369 | LogMessage* log_message = event.mutable_log_message(); |
370 | log_message->set_message(json_metadata); |
371 | |
372 | Status status; |
373 | for (const string& url : debug_urls) { |
374 | if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) { |
375 | #ifndef PLATFORM_WINDOWS |
376 | Event grpc_event; |
377 | |
378 | // Determine the path (if any) in the grpc:// URL, and add it as a field |
379 | // of the JSON string. |
380 | const string address = url.substr(strlen(DebugIO::kFileURLScheme)); |
381 | const string path = address.find('/') == string::npos |
382 | ? "" |
383 | : address.substr(address.find('/')); |
384 | grpc_event.set_wall_time(event.wall_time()); |
385 | LogMessage* log_message_grpc = grpc_event.mutable_log_message(); |
386 | log_message_grpc->set_message( |
387 | strings::StrCat(json_metadata.substr(0, json_metadata.size() - 1), |
388 | ",\"grpc_path\":\"" , path, "\"}" )); |
389 | |
390 | status.Update( |
391 | DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url, true)); |
392 | #else |
393 | GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; |
394 | #endif |
395 | } else if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) { |
396 | const string dump_root_dir = url.substr(strlen(kFileURLScheme)); |
397 | const string core_metadata_path = AppendTimestampToFilePath( |
398 | io::JoinPath(dump_root_dir, |
399 | strings::StrCat( |
400 | DebugNodeKey::kMetadataFilePrefix, |
401 | DebugIO::kCoreMetadataTag, "sessionrun" , |
402 | strings::Printf("%.14lld" , static_cast<long long>( |
403 | session_run_index)))), |
404 | Env::Default()->NowMicros()); |
405 | status.Update(DebugFileIO::DumpEventProtoToFile( |
406 | event, string(io::Dirname(core_metadata_path)), |
407 | string(io::Basename(core_metadata_path)))); |
408 | } |
409 | } |
410 | |
411 | return status; |
412 | } |
413 | |
414 | Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, |
415 | const Tensor& tensor, |
416 | const uint64 wall_time_us, |
417 | const gtl::ArraySlice<string> debug_urls, |
418 | const bool gated_grpc) { |
419 | int32_t num_failed_urls = 0; |
420 | std::vector<Status> fail_statuses; |
421 | for (const string& url : debug_urls) { |
422 | if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) { |
423 | const string dump_root_dir = url.substr(strlen(kFileURLScheme)); |
424 | |
425 | const int64_t tensorBytes = |
426 | tensor.IsInitialized() ? tensor.TotalBytes() : 0; |
427 | if (!DebugFileIO::requestDiskByteUsage(tensorBytes)) { |
428 | return errors::ResourceExhausted( |
429 | "TensorFlow Debugger has exhausted file-system byte-size " |
430 | "allowance (" , |
431 | DebugFileIO::global_disk_bytes_limit_, "), therefore it cannot " , |
432 | "dump an additional " , tensorBytes, " byte(s) of tensor data " , |
433 | "for the debug tensor " , debug_node_key.node_name, ":" , |
434 | debug_node_key.output_slot, ". You may use the environment " , |
435 | "variable TFDBG_DISK_BYTES_LIMIT to set a higher limit." ); |
436 | } |
437 | |
438 | Status s = DebugFileIO::DumpTensorToDir( |
439 | debug_node_key, tensor, wall_time_us, dump_root_dir, nullptr); |
440 | if (!s.ok()) { |
441 | num_failed_urls++; |
442 | fail_statuses.push_back(s); |
443 | } |
444 | } else if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) { |
445 | #ifndef PLATFORM_WINDOWS |
446 | Status s = DebugGrpcIO::SendTensorThroughGrpcStream( |
447 | debug_node_key, tensor, wall_time_us, url, gated_grpc); |
448 | |
449 | if (!s.ok()) { |
450 | num_failed_urls++; |
451 | fail_statuses.push_back(s); |
452 | } |
453 | #else |
454 | GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; |
455 | #endif |
456 | } else if (absl::StartsWith(absl::AsciiStrToLower(url), kMemoryURLScheme)) { |
457 | const string dump_root_dir = url.substr(strlen(kMemoryURLScheme)); |
458 | auto* callback_registry = DebugCallbackRegistry::singleton(); |
459 | auto* callback = callback_registry->GetCallback(dump_root_dir); |
460 | CHECK(callback) << "No callback registered for: " << dump_root_dir; |
461 | (*callback)(debug_node_key, tensor); |
462 | } else { |
463 | return Status(error::UNAVAILABLE, |
464 | strings::StrCat("Invalid debug target URL: " , url)); |
465 | } |
466 | } |
467 | |
468 | if (num_failed_urls == 0) { |
469 | return OkStatus(); |
470 | } else { |
471 | string error_message = strings::StrCat( |
472 | "Publishing to " , num_failed_urls, " of " , debug_urls.size(), |
473 | " debug target URLs failed, due to the following errors:" ); |
474 | for (Status& status : fail_statuses) { |
475 | error_message = |
476 | strings::StrCat(error_message, " " , status.error_message(), ";" ); |
477 | } |
478 | |
479 | return Status(error::INTERNAL, error_message); |
480 | } |
481 | } |
482 | |
483 | Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, |
484 | const Tensor& tensor, |
485 | const uint64 wall_time_us, |
486 | const gtl::ArraySlice<string> debug_urls) { |
487 | return PublishDebugTensor(debug_node_key, tensor, wall_time_us, debug_urls, |
488 | false); |
489 | } |
490 | |
491 | Status DebugIO::PublishGraph(const Graph& graph, const string& device_name, |
492 | const std::unordered_set<string>& debug_urls) { |
493 | GraphDef graph_def; |
494 | graph.ToGraphDef(&graph_def); |
495 | |
496 | string buf; |
497 | graph_def.SerializeToString(&buf); |
498 | |
499 | const int64_t now_micros = Env::Default()->NowMicros(); |
500 | Event event; |
501 | event.set_wall_time(static_cast<double>(now_micros)); |
502 | event.set_graph_def(buf); |
503 | |
504 | Status status = OkStatus(); |
505 | for (const string& debug_url : debug_urls) { |
506 | if (absl::StartsWith(debug_url, kFileURLScheme)) { |
507 | const string dump_root_dir = |
508 | io::JoinPath(debug_url.substr(strlen(kFileURLScheme)), |
509 | DebugNodeKey::DeviceNameToDevicePath(device_name)); |
510 | const uint64 graph_hash = ::tensorflow::Hash64(buf); |
511 | const string file_name = |
512 | strings::StrCat(DebugNodeKey::kMetadataFilePrefix, DebugIO::kGraphTag, |
513 | DebugIO::kHashTag, graph_hash, "_" , now_micros); |
514 | |
515 | status.Update( |
516 | DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name)); |
517 | } else if (absl::StartsWith(debug_url, kGrpcURLScheme)) { |
518 | #ifndef PLATFORM_WINDOWS |
519 | status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros, |
520 | debug_url)); |
521 | #else |
522 | GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; |
523 | #endif |
524 | } |
525 | } |
526 | |
527 | return status; |
528 | } |
529 | |
530 | bool DebugIO::IsCopyNodeGateOpen( |
531 | const std::vector<DebugWatchAndURLSpec>& specs) { |
532 | #ifndef PLATFORM_WINDOWS |
533 | for (const DebugWatchAndURLSpec& spec : specs) { |
534 | if (!spec.gated_grpc || spec.url.compare(0, strlen(DebugIO::kGrpcURLScheme), |
535 | DebugIO::kGrpcURLScheme)) { |
536 | return true; |
537 | } else { |
538 | if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) { |
539 | return true; |
540 | } |
541 | } |
542 | } |
543 | return false; |
544 | #else |
545 | return true; |
546 | #endif |
547 | } |
548 | |
549 | bool DebugIO::IsDebugNodeGateOpen(const string& watch_key, |
550 | const std::vector<string>& debug_urls) { |
551 | #ifndef PLATFORM_WINDOWS |
552 | for (const string& debug_url : debug_urls) { |
553 | if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme), |
554 | DebugIO::kGrpcURLScheme)) { |
555 | return true; |
556 | } else { |
557 | if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) { |
558 | return true; |
559 | } |
560 | } |
561 | } |
562 | return false; |
563 | #else |
564 | return true; |
565 | #endif |
566 | } |
567 | |
568 | bool DebugIO::IsDebugURLGateOpen(const string& watch_key, |
569 | const string& debug_url) { |
570 | #ifndef PLATFORM_WINDOWS |
571 | if (debug_url != kGrpcURLScheme) { |
572 | return true; |
573 | } else { |
574 | return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key); |
575 | } |
576 | #else |
577 | return true; |
578 | #endif |
579 | } |
580 | |
581 | Status DebugIO::CloseDebugURL(const string& debug_url) { |
582 | if (absl::StartsWith(debug_url, DebugIO::kGrpcURLScheme)) { |
583 | #ifndef PLATFORM_WINDOWS |
584 | return DebugGrpcIO::CloseGrpcStream(debug_url); |
585 | #else |
586 | GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; |
587 | #endif |
588 | } else { |
589 | // No-op for non-gRPC URLs. |
590 | return OkStatus(); |
591 | } |
592 | } |
593 | |
594 | Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key, |
595 | const Tensor& tensor, |
596 | const uint64 wall_time_us, |
597 | const string& dump_root_dir, |
598 | string* dump_file_path) { |
599 | const string file_path = |
600 | GetDumpFilePath(dump_root_dir, debug_node_key, wall_time_us); |
601 | |
602 | if (dump_file_path != nullptr) { |
603 | *dump_file_path = file_path; |
604 | } |
605 | |
606 | return DumpTensorToEventFile(debug_node_key, tensor, wall_time_us, file_path); |
607 | } |
608 | |
609 | string DebugFileIO::GetDumpFilePath(const string& dump_root_dir, |
610 | const DebugNodeKey& debug_node_key, |
611 | const uint64 wall_time_us) { |
612 | return AppendTimestampToFilePath( |
613 | io::JoinPath(dump_root_dir, debug_node_key.device_path, |
614 | strings::StrCat(debug_node_key.node_name, "_" , |
615 | debug_node_key.output_slot, "_" , |
616 | debug_node_key.debug_op)), |
617 | wall_time_us); |
618 | } |
619 | |
620 | Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto, |
621 | const string& dir_name, |
622 | const string& file_name) { |
623 | Env* env(Env::Default()); |
624 | |
625 | Status s = RecursiveCreateDir(env, dir_name); |
626 | if (!s.ok()) { |
627 | return Status(error::FAILED_PRECONDITION, |
628 | strings::StrCat("Failed to create directory " , dir_name, |
629 | ", due to: " , s.error_message())); |
630 | } |
631 | |
632 | const string file_path = io::JoinPath(dir_name, file_name); |
633 | |
634 | string event_str; |
635 | event_proto.SerializeToString(&event_str); |
636 | |
637 | std::unique_ptr<WritableFile> f = nullptr; |
638 | TF_CHECK_OK(env->NewWritableFile(file_path, &f)); |
639 | f->Append(event_str).IgnoreError(); |
640 | TF_CHECK_OK(f->Close()); |
641 | |
642 | return OkStatus(); |
643 | } |
644 | |
645 | Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key, |
646 | const Tensor& tensor, |
647 | const uint64 wall_time_us, |
648 | const string& file_path) { |
649 | std::vector<Event> events; |
650 | TF_RETURN_IF_ERROR( |
651 | WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events)); |
652 | return DumpEventProtoToFile(events[0], string(io::Dirname(file_path)), |
653 | string(io::Basename(file_path))); |
654 | } |
655 | |
656 | Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { |
657 | if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) { |
658 | // The path already exists as a directory. Return OK right away. |
659 | return OkStatus(); |
660 | } |
661 | |
662 | string parent_dir(io::Dirname(dir)); |
663 | if (!env->FileExists(parent_dir).ok()) { |
664 | // The parent path does not exist yet, create it first. |
665 | Status s = RecursiveCreateDir(env, parent_dir); // Recursive call |
666 | if (!s.ok()) { |
667 | return Status( |
668 | error::FAILED_PRECONDITION, |
669 | strings::StrCat("Failed to create directory " , parent_dir)); |
670 | } |
671 | } else if (env->FileExists(parent_dir).ok() && |
672 | !env->IsDirectory(parent_dir).ok()) { |
673 | // The path exists, but it is a file. |
674 | return Status(error::FAILED_PRECONDITION, |
675 | strings::StrCat("Failed to create directory " , parent_dir, |
676 | " because the path exists as a file " )); |
677 | } |
678 | |
679 | env->CreateDir(dir).IgnoreError(); |
680 | // Guard against potential race in creating directories by doing a check |
681 | // after the CreateDir call. |
682 | if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) { |
683 | return OkStatus(); |
684 | } else { |
685 | return Status(error::ABORTED, |
686 | strings::StrCat("Failed to create directory " , parent_dir)); |
687 | } |
688 | } |
689 | |
690 | // Default total disk usage limit: 100 GBytes |
691 | const uint64 DebugFileIO::kDefaultGlobalDiskBytesLimit = 107374182400L; |
692 | uint64 DebugFileIO::global_disk_bytes_limit_ = 0; |
693 | uint64 DebugFileIO::disk_bytes_used_ = 0; |
694 | |
695 | mutex DebugFileIO::bytes_mu_(LINKER_INITIALIZED); |
696 | |
697 | bool DebugFileIO::requestDiskByteUsage(uint64 bytes) { |
698 | mutex_lock l(bytes_mu_); |
699 | if (global_disk_bytes_limit_ == 0) { |
700 | const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT" ); |
701 | if (env_tfdbg_disk_bytes_limit == nullptr || |
702 | strlen(env_tfdbg_disk_bytes_limit) == 0) { |
703 | global_disk_bytes_limit_ = kDefaultGlobalDiskBytesLimit; |
704 | } else { |
705 | strings::safe_strtou64(string(env_tfdbg_disk_bytes_limit), |
706 | &global_disk_bytes_limit_); |
707 | } |
708 | } |
709 | |
710 | if (bytes == 0) { |
711 | return true; |
712 | } |
713 | if (disk_bytes_used_ + bytes < global_disk_bytes_limit_) { |
714 | disk_bytes_used_ += bytes; |
715 | return true; |
716 | } else { |
717 | return false; |
718 | } |
719 | } |
720 | |
721 | void DebugFileIO::resetDiskByteUsage() { |
722 | mutex_lock l(bytes_mu_); |
723 | disk_bytes_used_ = 0; |
724 | } |
725 | |
726 | #ifndef PLATFORM_WINDOWS |
727 | DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr) |
728 | : server_stream_addr_(server_stream_addr), |
729 | url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {} |
730 | |
731 | Status DebugGrpcChannel::Connect(const int64_t timeout_micros) { |
732 | ::grpc::ChannelArguments args; |
733 | args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max()); |
734 | // Avoid problems where default reconnect backoff is too long (e.g., 20 s). |
735 | args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000); |
736 | channel_ = ::grpc::CreateCustomChannel( |
737 | server_stream_addr_, ::grpc::InsecureChannelCredentials(), args); |
738 | if (!channel_->WaitForConnected( |
739 | gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), |
740 | gpr_time_from_micros(timeout_micros, GPR_TIMESPAN)))) { |
741 | return errors::FailedPrecondition( |
742 | "Failed to connect to gRPC channel at " , server_stream_addr_, |
743 | " within a timeout of " , timeout_micros / 1e6, " s." ); |
744 | } |
745 | stub_ = grpc::EventListener::NewStub(channel_); |
746 | reader_writer_ = stub_->SendEvents(&ctx_); |
747 | |
748 | return OkStatus(); |
749 | } |
750 | |
751 | bool DebugGrpcChannel::WriteEvent(const Event& event) { |
752 | mutex_lock l(mu_); |
753 | return reader_writer_->Write(event); |
754 | } |
755 | |
756 | bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) { |
757 | mutex_lock l(mu_); |
758 | return reader_writer_->Read(event_reply); |
759 | } |
760 | |
761 | void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) { |
762 | EventReply event_reply; |
763 | size_t num_replies = 0; |
764 | while ((max_replies == 0 || ++num_replies <= max_replies) && |
765 | ReadEventReply(&event_reply)) { |
766 | for (const EventReply::DebugOpStateChange& debug_op_state_change : |
767 | event_reply.debug_op_state_changes()) { |
768 | string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":" , |
769 | debug_op_state_change.output_slot(), |
770 | ":" , debug_op_state_change.debug_op()); |
771 | DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key, |
772 | debug_op_state_change.state()); |
773 | } |
774 | } |
775 | } |
776 | |
777 | Status DebugGrpcChannel::ReceiveServerRepliesAndClose() { |
778 | reader_writer_->WritesDone(); |
779 | // Read all EventReply messages (if any) from the server. |
780 | ReceiveAndProcessEventReplies(0); |
781 | |
782 | if (reader_writer_->Finish().ok()) { |
783 | return OkStatus(); |
784 | } else { |
785 | return Status(error::FAILED_PRECONDITION, |
786 | "Failed to close debug GRPC stream." ); |
787 | } |
788 | } |
789 | |
790 | mutex DebugGrpcIO::streams_mu_(LINKER_INITIALIZED); |
791 | |
792 | int64_t DebugGrpcIO::channel_connection_timeout_micros_ = 900 * 1000 * 1000; |
793 | // TODO(cais): Make this configurable? |
794 | |
795 | const size_t DebugGrpcIO::kGrpcMessageSizeLimitBytes = 4000 * 1024; |
796 | |
797 | const size_t DebugGrpcIO::kGrpcMaxVarintLengthSize = 6; |
798 | |
799 | std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* |
800 | DebugGrpcIO::GetStreamChannels() { |
801 | static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* |
802 | stream_channels = |
803 | new std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>(); |
804 | return stream_channels; |
805 | } |
806 | |
807 | Status DebugGrpcIO::SendTensorThroughGrpcStream( |
808 | const DebugNodeKey& debug_node_key, const Tensor& tensor, |
809 | const uint64 wall_time_us, const string& grpc_stream_url, |
810 | const bool gated) { |
811 | if (gated && |
812 | !IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) { |
813 | return OkStatus(); |
814 | } else { |
815 | std::vector<Event> events; |
816 | TF_RETURN_IF_ERROR(WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, |
817 | kGrpcMessageSizeLimitBytes, &events)); |
818 | for (const Event& event : events) { |
819 | TF_RETURN_IF_ERROR( |
820 | SendEventProtoThroughGrpcStream(event, grpc_stream_url)); |
821 | } |
822 | if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) { |
823 | DebugGrpcChannel* debug_grpc_channel = nullptr; |
824 | TF_RETURN_IF_ERROR( |
825 | GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel)); |
826 | debug_grpc_channel->ReceiveAndProcessEventReplies(1); |
827 | // TODO(cais): Support new tensor value carried in the EventReply for |
828 | // overriding the value of the tensor being published. |
829 | } |
830 | return OkStatus(); |
831 | } |
832 | } |
833 | |
834 | Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream( |
835 | EventReply* event_reply, const string& grpc_stream_url) { |
836 | DebugGrpcChannel* debug_grpc_channel = nullptr; |
837 | TF_RETURN_IF_ERROR( |
838 | GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel)); |
839 | if (debug_grpc_channel->ReadEventReply(event_reply)) { |
840 | return OkStatus(); |
841 | } else { |
842 | return errors::Cancelled(strings::StrCat( |
843 | "Reading EventReply from stream URL " , grpc_stream_url, " failed." )); |
844 | } |
845 | } |
846 | |
847 | Status DebugGrpcIO::GetOrCreateDebugGrpcChannel( |
848 | const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) { |
849 | const string addr_with_path = |
850 | absl::StartsWith(grpc_stream_url, DebugIO::kGrpcURLScheme) |
851 | ? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme)) |
852 | : grpc_stream_url; |
853 | const string server_stream_addr = |
854 | addr_with_path.substr(0, addr_with_path.find('/')); |
855 | { |
856 | mutex_lock l(streams_mu_); |
857 | std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* |
858 | stream_channels = GetStreamChannels(); |
859 | if (stream_channels->find(grpc_stream_url) == stream_channels->end()) { |
860 | std::unique_ptr<DebugGrpcChannel> channel( |
861 | new DebugGrpcChannel(server_stream_addr)); |
862 | TF_RETURN_IF_ERROR(channel->Connect(channel_connection_timeout_micros_)); |
863 | stream_channels->insert( |
864 | std::make_pair(grpc_stream_url, std::move(channel))); |
865 | } |
866 | *debug_grpc_channel = (*stream_channels)[grpc_stream_url].get(); |
867 | } |
868 | return OkStatus(); |
869 | } |
870 | |
871 | Status DebugGrpcIO::SendEventProtoThroughGrpcStream( |
872 | const Event& event_proto, const string& grpc_stream_url, |
873 | const bool receive_reply) { |
874 | DebugGrpcChannel* debug_grpc_channel; |
875 | TF_RETURN_IF_ERROR( |
876 | GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel)); |
877 | |
878 | bool write_ok = debug_grpc_channel->WriteEvent(event_proto); |
879 | if (!write_ok) { |
880 | return errors::Cancelled(strings::StrCat("Write event to stream URL " , |
881 | grpc_stream_url, " failed." )); |
882 | } |
883 | |
884 | if (receive_reply) { |
885 | debug_grpc_channel->ReceiveAndProcessEventReplies(1); |
886 | } |
887 | |
888 | return OkStatus(); |
889 | } |
890 | |
891 | bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url, |
892 | const string& watch_key) { |
893 | const DebugNodeName2State* enabled_node_to_state = |
894 | GetEnabledDebugOpStatesAtUrl(grpc_debug_url); |
895 | return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end(); |
896 | } |
897 | |
898 | bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url, |
899 | const string& watch_key) { |
900 | const DebugNodeName2State* enabled_node_to_state = |
901 | GetEnabledDebugOpStatesAtUrl(grpc_debug_url); |
902 | auto it = enabled_node_to_state->find(watch_key); |
903 | if (it == enabled_node_to_state->end()) { |
904 | return false; |
905 | } else { |
906 | return it->second == EventReply::DebugOpStateChange::READ_WRITE; |
907 | } |
908 | } |
909 | |
910 | Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) { |
911 | mutex_lock l(streams_mu_); |
912 | |
913 | std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* |
914 | stream_channels = GetStreamChannels(); |
915 | if (stream_channels->find(grpc_stream_url) != stream_channels->end()) { |
916 | // Stream of the specified address exists. Close it and remove it from |
917 | // record. |
918 | Status s = |
919 | (*stream_channels)[grpc_stream_url]->ReceiveServerRepliesAndClose(); |
920 | (*stream_channels).erase(grpc_stream_url); |
921 | return s; |
922 | } else { |
923 | // Stream of the specified address does not exist. No action. |
924 | return OkStatus(); |
925 | } |
926 | } |
927 | |
928 | std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>* |
929 | DebugGrpcIO::GetEnabledDebugOpStates() { |
930 | static std::unordered_map<string, DebugNodeName2State>* |
931 | enabled_debug_op_states = |
932 | new std::unordered_map<string, DebugNodeName2State>(); |
933 | return enabled_debug_op_states; |
934 | } |
935 | |
936 | DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl( |
937 | const string& grpc_debug_url) { |
938 | static mutex* debug_ops_state_mu = new mutex(); |
939 | std::unordered_map<string, DebugNodeName2State>* states = |
940 | GetEnabledDebugOpStates(); |
941 | |
942 | mutex_lock l(*debug_ops_state_mu); |
943 | if (states->find(grpc_debug_url) == states->end()) { |
944 | DebugNodeName2State url_enabled_debug_op_states; |
945 | (*states)[grpc_debug_url] = url_enabled_debug_op_states; |
946 | } |
947 | return &(*states)[grpc_debug_url]; |
948 | } |
949 | |
950 | void DebugGrpcIO::SetDebugNodeKeyGrpcState( |
951 | const string& grpc_debug_url, const string& watch_key, |
952 | const EventReply::DebugOpStateChange::State new_state) { |
953 | DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url); |
954 | if (new_state == EventReply::DebugOpStateChange::DISABLED) { |
955 | if (states->find(watch_key) == states->end()) { |
956 | LOG(ERROR) << "Attempt to disable a watch key that is not currently " |
957 | << "enabled at " << grpc_debug_url << ": " << watch_key; |
958 | } else { |
959 | states->erase(watch_key); |
960 | } |
961 | } else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) { |
962 | (*states)[watch_key] = new_state; |
963 | } |
964 | } |
965 | |
966 | void DebugGrpcIO::ClearEnabledWatchKeys() { |
967 | GetEnabledDebugOpStates()->clear(); |
968 | } |
969 | |
970 | #endif // #ifndef PLATFORM_WINDOWS |
971 | |
972 | } // namespace tensorflow |
973 | |