1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/debug/debug_io_utils.h"
17
18#include <stddef.h>
19#include <string.h>
20#include <cmath>
21#include <cstdlib>
22#include <cstring>
23#include <limits>
24#include <utility>
25#include <vector>
26
27#ifndef PLATFORM_WINDOWS
28#include "grpcpp/create_channel.h"
29#else
30#endif // #ifndef PLATFORM_WINDOWS
31
32#include "absl/strings/ascii.h"
33#include "absl/strings/match.h"
34#include "tensorflow/core/debug/debug_callback_registry.h"
35#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
36#include "tensorflow/core/framework/graph.pb.h"
37#include "tensorflow/core/framework/summary.pb.h"
38#include "tensorflow/core/framework/tensor.pb.h"
39#include "tensorflow/core/framework/tensor_shape.pb.h"
40#include "tensorflow/core/lib/core/bits.h"
41#include "tensorflow/core/lib/hash/hash.h"
42#include "tensorflow/core/lib/io/path.h"
43#include "tensorflow/core/lib/strings/str_util.h"
44#include "tensorflow/core/lib/strings/stringprintf.h"
45#include "tensorflow/core/platform/protobuf.h"
46#include "tensorflow/core/util/event.pb.h"
47
48#define GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR \
49 return errors::Unimplemented( \
50 kGrpcURLScheme, " debug URL scheme is not implemented on Windows yet.")
51
52namespace tensorflow {
53
54namespace {
55
56// Creates an Event proto representing a chunk of a Tensor. This method only
57// populates the field of the Event proto that represent the envelope
58// information (e.g., timestamp, device_name, num_chunks, chunk_index, dtype,
59// shape). It does not set the value.tensor field, which should be set by the
60// caller separately.
61Event PrepareChunkEventProto(const DebugNodeKey& debug_node_key,
62 const uint64 wall_time_us, const size_t num_chunks,
63 const size_t chunk_index,
64 const DataType& tensor_dtype,
65 const TensorShapeProto& tensor_shape) {
66 Event event;
67 event.set_wall_time(static_cast<double>(wall_time_us));
68 Summary::Value* value = event.mutable_summary()->add_value();
69
70 // Create the debug node_name in the Summary proto.
71 // For example, if tensor_name = "foo/node_a:0", and the debug_op is
72 // "DebugIdentity", the debug node_name in the Summary proto will be
73 // "foo/node_a:0:DebugIdentity".
74 value->set_node_name(debug_node_key.debug_node_name);
75
76 // Tag by the node name. This allows TensorBoard to quickly fetch data
77 // per op.
78 value->set_tag(debug_node_key.node_name);
79
80 // Store data within debugger metadata to be stored for each event.
81 third_party::tensorflow::core::debug::DebuggerEventMetadata metadata;
82 metadata.set_device(debug_node_key.device_name);
83 metadata.set_output_slot(debug_node_key.output_slot);
84 metadata.set_num_chunks(num_chunks);
85 metadata.set_chunk_index(chunk_index);
86
87 // Encode the data in JSON.
88 string json_output;
89 tensorflow::protobuf::util::JsonPrintOptions json_options;
90 json_options.always_print_primitive_fields = true;
91 auto status = tensorflow::protobuf::util::MessageToJsonString(
92 metadata, &json_output, json_options);
93 if (status.ok()) {
94 // Store summary metadata. Set the plugin to use this data as "debugger".
95 SummaryMetadata::PluginData* plugin_data =
96 value->mutable_metadata()->mutable_plugin_data();
97 plugin_data->set_plugin_name(DebugIO::kDebuggerPluginName);
98 plugin_data->set_content(json_output);
99 } else {
100 LOG(WARNING) << "Failed to convert DebuggerEventMetadata proto to JSON. "
101 << "The debug_node_name is " << debug_node_key.debug_node_name
102 << ".";
103 }
104
105 value->mutable_tensor()->set_dtype(tensor_dtype);
106 *value->mutable_tensor()->mutable_tensor_shape() = tensor_shape;
107
108 return event;
109}
110
111// Translates the length of a string to number of bytes when the string is
112// encoded as bytes in protobuf. Note that this makes a conservative estimate
113// (i.e., an estimate that is usually too large, but never too small under the
114// gRPC message size limit) of the Varint-encoded length, to workaround the lack
115// of a portable length function.
116const size_t StringValMaxBytesInProto(const string& str) {
117#if defined(PLATFORM_GOOGLE)
118 return str.size() + DebugGrpcIO::kGrpcMaxVarintLengthSize;
119#else
120 return str.size();
121#endif
122}
123
124// Breaks a string Tensor (represented as a TensorProto) as a vector of Event
125// protos.
126Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key,
127 const uint64 wall_time_us,
128 const size_t chunk_size_limit,
129 TensorProto* tensor_proto,
130 std::vector<Event>* events) {
131 const protobuf::RepeatedPtrField<string>& strs = tensor_proto->string_val();
132 const size_t num_strs = strs.size();
133 const size_t chunk_size_ub = chunk_size_limit > 0
134 ? chunk_size_limit
135 : std::numeric_limits<size_t>::max();
136
137 // E.g., if cutoffs is {j, k, l}, the chunks will have index ranges:
138 // [0:a), [a:b), [c:<end>].
139 std::vector<size_t> cutoffs;
140 size_t chunk_size = 0;
141 for (size_t i = 0; i < num_strs; ++i) {
142 // Take into account the extra bytes in proto buffer.
143 if (StringValMaxBytesInProto(strs[i]) > chunk_size_ub) {
144 return errors::FailedPrecondition(
145 "string value at index ", i, " from debug node ",
146 debug_node_key.debug_node_name,
147 " does not fit gRPC message size limit (", chunk_size_ub, ")");
148 }
149 if (chunk_size + StringValMaxBytesInProto(strs[i]) > chunk_size_ub) {
150 cutoffs.push_back(i);
151 chunk_size = 0;
152 }
153 chunk_size += StringValMaxBytesInProto(strs[i]);
154 }
155 cutoffs.push_back(num_strs);
156 const size_t num_chunks = cutoffs.size();
157
158 for (size_t i = 0; i < num_chunks; ++i) {
159 Event event = PrepareChunkEventProto(debug_node_key, wall_time_us,
160 num_chunks, i, tensor_proto->dtype(),
161 tensor_proto->tensor_shape());
162 Summary::Value* value = event.mutable_summary()->mutable_value(0);
163
164 if (cutoffs.size() == 1) {
165 value->mutable_tensor()->mutable_string_val()->Swap(
166 tensor_proto->mutable_string_val());
167 } else {
168 const size_t begin = (i == 0) ? 0 : cutoffs[i - 1];
169 const size_t end = cutoffs[i];
170 for (size_t j = begin; j < end; ++j) {
171 value->mutable_tensor()->add_string_val(strs[j]);
172 }
173 }
174
175 events->push_back(std::move(event));
176 }
177
178 return OkStatus();
179}
180
181// Encapsulates the tensor value inside a vector of Event protos. Large tensors
182// are broken up to multiple protos to fit the chunk_size_limit. In each Event
183// proto the field summary.tensor carries the content of the tensor.
184// If chunk_size_limit <= 0, the tensor will not be broken into chunks, i.e., a
185// length-1 vector will be returned, regardless of the size of the tensor.
186Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key,
187 const Tensor& tensor, const uint64 wall_time_us,
188 const size_t chunk_size_limit,
189 std::vector<Event>* events) {
190 TensorProto tensor_proto;
191 if (tensor.dtype() == DT_STRING) {
192 // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python can
193 // convert the TensorProto to string-type numpy array. MakeNdarray does not
194 // work with strings encoded by AsProtoTensorContent() in tensor_content.
195 tensor.AsProtoField(&tensor_proto);
196
197 TF_RETURN_IF_ERROR(WrapStringTensorAsEvents(
198 debug_node_key, wall_time_us, chunk_size_limit, &tensor_proto, events));
199 } else {
200 tensor.AsProtoTensorContent(&tensor_proto);
201
202 const size_t total_length = tensor_proto.tensor_content().size();
203 const size_t chunk_size_ub =
204 chunk_size_limit > 0 ? chunk_size_limit : total_length;
205 const size_t num_chunks =
206 (total_length == 0)
207 ? 1
208 : (total_length + chunk_size_ub - 1) / chunk_size_ub;
209 for (size_t i = 0; i < num_chunks; ++i) {
210 const size_t pos = i * chunk_size_ub;
211 const size_t len =
212 (i == num_chunks - 1) ? (total_length - pos) : chunk_size_ub;
213 Event event = PrepareChunkEventProto(debug_node_key, wall_time_us,
214 num_chunks, i, tensor_proto.dtype(),
215 tensor_proto.tensor_shape());
216 event.mutable_summary()
217 ->mutable_value(0)
218 ->mutable_tensor()
219 ->set_tensor_content(tensor_proto.tensor_content().substr(pos, len));
220 events->push_back(std::move(event));
221 }
222 }
223
224 return OkStatus();
225}
226
227// Appends an underscore and a timestamp to a file path. If the path already
228// exists on the file system, append a hyphen and a 1-up index. Consecutive
229// values of the index will be tried until the first unused one is found.
230// TOCTOU race condition is not of concern here due to the fact that tfdbg
231// sets parallel_iterations attribute of all while_loops to 1 to prevent
232// the same node from between executed multiple times concurrently.
233string AppendTimestampToFilePath(const string& in, const uint64 timestamp) {
234 string out = strings::StrCat(in, "_", timestamp);
235
236 uint64 i = 1;
237 while (Env::Default()->FileExists(out).ok()) {
238 out = strings::StrCat(in, "_", timestamp, "-", i);
239 ++i;
240 }
241 return out;
242}
243
244#ifndef PLATFORM_WINDOWS
245// Publishes encoded GraphDef through a gRPC debugger stream, in chunks,
246// conforming to the gRPC message size limit.
247Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
248 const string& device_name,
249 const int64_t wall_time,
250 const string& debug_url) {
251 const uint64 hash = ::tensorflow::Hash64(encoded_graph_def);
252 const size_t total_length = encoded_graph_def.size();
253 const size_t num_chunks =
254 static_cast<size_t>(std::ceil(static_cast<float>(total_length) /
255 DebugGrpcIO::kGrpcMessageSizeLimitBytes));
256 for (size_t i = 0; i < num_chunks; ++i) {
257 const size_t pos = i * DebugGrpcIO::kGrpcMessageSizeLimitBytes;
258 const size_t len = (i == num_chunks - 1)
259 ? (total_length - pos)
260 : DebugGrpcIO::kGrpcMessageSizeLimitBytes;
261 Event event;
262 event.set_wall_time(static_cast<double>(wall_time));
263 // Prefix the chunk with
264 // <hash64>,<device_name>,<wall_time>|<index>|<num_chunks>|.
265 // TODO(cais): Use DebuggerEventMetadata to store device_name, num_chunks
266 // and chunk_index, instead.
267 event.set_graph_def(strings::StrCat(hash, ",", device_name, ",", wall_time,
268 "|", i, "|", num_chunks, "|",
269 encoded_graph_def.substr(pos, len)));
270 const Status s = DebugGrpcIO::SendEventProtoThroughGrpcStream(
271 event, debug_url, num_chunks - 1 == i);
272 if (!s.ok()) {
273 return errors::FailedPrecondition(
274 "Failed to send chunk ", i, " of ", num_chunks,
275 " of encoded GraphDef of size ", encoded_graph_def.size(), " bytes, ",
276 "due to: ", s.error_message());
277 }
278 }
279 return OkStatus();
280}
281#endif // #ifndef PLATFORM_WINDOWS
282
283} // namespace
284
285const char* const DebugIO::kDebuggerPluginName = "debugger";
286
287const char* const DebugIO::kCoreMetadataTag = "core_metadata_";
288
289const char* const DebugIO::kGraphTag = "graph_";
290
291const char* const DebugIO::kHashTag = "hash";
292
293Status ReadEventFromFile(const string& dump_file_path, Event* event) {
294 Env* env(Env::Default());
295
296 string content;
297 uint64 file_size = 0;
298
299 Status s = env->GetFileSize(dump_file_path, &file_size);
300 if (!s.ok()) {
301 return s;
302 }
303
304 content.resize(file_size);
305
306 std::unique_ptr<RandomAccessFile> file;
307 s = env->NewRandomAccessFile(dump_file_path, &file);
308 if (!s.ok()) {
309 return s;
310 }
311
312 StringPiece result;
313 s = file->Read(0, file_size, &result, &(content)[0]);
314 if (!s.ok()) {
315 return s;
316 }
317
318 event->ParseFromString(content);
319 return OkStatus();
320}
321
322const char* const DebugIO::kFileURLScheme = "file://";
323const char* const DebugIO::kGrpcURLScheme = "grpc://";
324const char* const DebugIO::kMemoryURLScheme = "memcbk://";
325
326// Publishes debug metadata to a set of debug URLs.
327Status DebugIO::PublishDebugMetadata(
328 const int64_t global_step, const int64_t session_run_index,
329 const int64_t executor_step_index, const std::vector<string>& input_names,
330 const std::vector<string>& output_names,
331 const std::vector<string>& target_nodes,
332 const std::unordered_set<string>& debug_urls) {
333 std::ostringstream oss;
334
335 // Construct a JSON string to carry the metadata.
336 oss << "{";
337 oss << "\"global_step\":" << global_step << ",";
338 oss << "\"session_run_index\":" << session_run_index << ",";
339 oss << "\"executor_step_index\":" << executor_step_index << ",";
340 oss << "\"input_names\":[";
341 for (size_t i = 0; i < input_names.size(); ++i) {
342 oss << "\"" << input_names[i] << "\"";
343 if (i < input_names.size() - 1) {
344 oss << ",";
345 }
346 }
347 oss << "],";
348 oss << "\"output_names\":[";
349 for (size_t i = 0; i < output_names.size(); ++i) {
350 oss << "\"" << output_names[i] << "\"";
351 if (i < output_names.size() - 1) {
352 oss << ",";
353 }
354 }
355 oss << "],";
356 oss << "\"target_nodes\":[";
357 for (size_t i = 0; i < target_nodes.size(); ++i) {
358 oss << "\"" << target_nodes[i] << "\"";
359 if (i < target_nodes.size() - 1) {
360 oss << ",";
361 }
362 }
363 oss << "]";
364 oss << "}";
365
366 const string json_metadata = oss.str();
367 Event event;
368 event.set_wall_time(static_cast<double>(Env::Default()->NowMicros()));
369 LogMessage* log_message = event.mutable_log_message();
370 log_message->set_message(json_metadata);
371
372 Status status;
373 for (const string& url : debug_urls) {
374 if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) {
375#ifndef PLATFORM_WINDOWS
376 Event grpc_event;
377
378 // Determine the path (if any) in the grpc:// URL, and add it as a field
379 // of the JSON string.
380 const string address = url.substr(strlen(DebugIO::kFileURLScheme));
381 const string path = address.find('/') == string::npos
382 ? ""
383 : address.substr(address.find('/'));
384 grpc_event.set_wall_time(event.wall_time());
385 LogMessage* log_message_grpc = grpc_event.mutable_log_message();
386 log_message_grpc->set_message(
387 strings::StrCat(json_metadata.substr(0, json_metadata.size() - 1),
388 ",\"grpc_path\":\"", path, "\"}"));
389
390 status.Update(
391 DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url, true));
392#else
393 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
394#endif
395 } else if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) {
396 const string dump_root_dir = url.substr(strlen(kFileURLScheme));
397 const string core_metadata_path = AppendTimestampToFilePath(
398 io::JoinPath(dump_root_dir,
399 strings::StrCat(
400 DebugNodeKey::kMetadataFilePrefix,
401 DebugIO::kCoreMetadataTag, "sessionrun",
402 strings::Printf("%.14lld", static_cast<long long>(
403 session_run_index)))),
404 Env::Default()->NowMicros());
405 status.Update(DebugFileIO::DumpEventProtoToFile(
406 event, string(io::Dirname(core_metadata_path)),
407 string(io::Basename(core_metadata_path))));
408 }
409 }
410
411 return status;
412}
413
414Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
415 const Tensor& tensor,
416 const uint64 wall_time_us,
417 const gtl::ArraySlice<string> debug_urls,
418 const bool gated_grpc) {
419 int32_t num_failed_urls = 0;
420 std::vector<Status> fail_statuses;
421 for (const string& url : debug_urls) {
422 if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) {
423 const string dump_root_dir = url.substr(strlen(kFileURLScheme));
424
425 const int64_t tensorBytes =
426 tensor.IsInitialized() ? tensor.TotalBytes() : 0;
427 if (!DebugFileIO::requestDiskByteUsage(tensorBytes)) {
428 return errors::ResourceExhausted(
429 "TensorFlow Debugger has exhausted file-system byte-size "
430 "allowance (",
431 DebugFileIO::global_disk_bytes_limit_, "), therefore it cannot ",
432 "dump an additional ", tensorBytes, " byte(s) of tensor data ",
433 "for the debug tensor ", debug_node_key.node_name, ":",
434 debug_node_key.output_slot, ". You may use the environment ",
435 "variable TFDBG_DISK_BYTES_LIMIT to set a higher limit.");
436 }
437
438 Status s = DebugFileIO::DumpTensorToDir(
439 debug_node_key, tensor, wall_time_us, dump_root_dir, nullptr);
440 if (!s.ok()) {
441 num_failed_urls++;
442 fail_statuses.push_back(s);
443 }
444 } else if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) {
445#ifndef PLATFORM_WINDOWS
446 Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
447 debug_node_key, tensor, wall_time_us, url, gated_grpc);
448
449 if (!s.ok()) {
450 num_failed_urls++;
451 fail_statuses.push_back(s);
452 }
453#else
454 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
455#endif
456 } else if (absl::StartsWith(absl::AsciiStrToLower(url), kMemoryURLScheme)) {
457 const string dump_root_dir = url.substr(strlen(kMemoryURLScheme));
458 auto* callback_registry = DebugCallbackRegistry::singleton();
459 auto* callback = callback_registry->GetCallback(dump_root_dir);
460 CHECK(callback) << "No callback registered for: " << dump_root_dir;
461 (*callback)(debug_node_key, tensor);
462 } else {
463 return Status(error::UNAVAILABLE,
464 strings::StrCat("Invalid debug target URL: ", url));
465 }
466 }
467
468 if (num_failed_urls == 0) {
469 return OkStatus();
470 } else {
471 string error_message = strings::StrCat(
472 "Publishing to ", num_failed_urls, " of ", debug_urls.size(),
473 " debug target URLs failed, due to the following errors:");
474 for (Status& status : fail_statuses) {
475 error_message =
476 strings::StrCat(error_message, " ", status.error_message(), ";");
477 }
478
479 return Status(error::INTERNAL, error_message);
480 }
481}
482
483Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
484 const Tensor& tensor,
485 const uint64 wall_time_us,
486 const gtl::ArraySlice<string> debug_urls) {
487 return PublishDebugTensor(debug_node_key, tensor, wall_time_us, debug_urls,
488 false);
489}
490
491Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
492 const std::unordered_set<string>& debug_urls) {
493 GraphDef graph_def;
494 graph.ToGraphDef(&graph_def);
495
496 string buf;
497 graph_def.SerializeToString(&buf);
498
499 const int64_t now_micros = Env::Default()->NowMicros();
500 Event event;
501 event.set_wall_time(static_cast<double>(now_micros));
502 event.set_graph_def(buf);
503
504 Status status = OkStatus();
505 for (const string& debug_url : debug_urls) {
506 if (absl::StartsWith(debug_url, kFileURLScheme)) {
507 const string dump_root_dir =
508 io::JoinPath(debug_url.substr(strlen(kFileURLScheme)),
509 DebugNodeKey::DeviceNameToDevicePath(device_name));
510 const uint64 graph_hash = ::tensorflow::Hash64(buf);
511 const string file_name =
512 strings::StrCat(DebugNodeKey::kMetadataFilePrefix, DebugIO::kGraphTag,
513 DebugIO::kHashTag, graph_hash, "_", now_micros);
514
515 status.Update(
516 DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name));
517 } else if (absl::StartsWith(debug_url, kGrpcURLScheme)) {
518#ifndef PLATFORM_WINDOWS
519 status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros,
520 debug_url));
521#else
522 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
523#endif
524 }
525 }
526
527 return status;
528}
529
530bool DebugIO::IsCopyNodeGateOpen(
531 const std::vector<DebugWatchAndURLSpec>& specs) {
532#ifndef PLATFORM_WINDOWS
533 for (const DebugWatchAndURLSpec& spec : specs) {
534 if (!spec.gated_grpc || spec.url.compare(0, strlen(DebugIO::kGrpcURLScheme),
535 DebugIO::kGrpcURLScheme)) {
536 return true;
537 } else {
538 if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) {
539 return true;
540 }
541 }
542 }
543 return false;
544#else
545 return true;
546#endif
547}
548
549bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
550 const std::vector<string>& debug_urls) {
551#ifndef PLATFORM_WINDOWS
552 for (const string& debug_url : debug_urls) {
553 if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme),
554 DebugIO::kGrpcURLScheme)) {
555 return true;
556 } else {
557 if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) {
558 return true;
559 }
560 }
561 }
562 return false;
563#else
564 return true;
565#endif
566}
567
568bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
569 const string& debug_url) {
570#ifndef PLATFORM_WINDOWS
571 if (debug_url != kGrpcURLScheme) {
572 return true;
573 } else {
574 return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key);
575 }
576#else
577 return true;
578#endif
579}
580
581Status DebugIO::CloseDebugURL(const string& debug_url) {
582 if (absl::StartsWith(debug_url, DebugIO::kGrpcURLScheme)) {
583#ifndef PLATFORM_WINDOWS
584 return DebugGrpcIO::CloseGrpcStream(debug_url);
585#else
586 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
587#endif
588 } else {
589 // No-op for non-gRPC URLs.
590 return OkStatus();
591 }
592}
593
594Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key,
595 const Tensor& tensor,
596 const uint64 wall_time_us,
597 const string& dump_root_dir,
598 string* dump_file_path) {
599 const string file_path =
600 GetDumpFilePath(dump_root_dir, debug_node_key, wall_time_us);
601
602 if (dump_file_path != nullptr) {
603 *dump_file_path = file_path;
604 }
605
606 return DumpTensorToEventFile(debug_node_key, tensor, wall_time_us, file_path);
607}
608
609string DebugFileIO::GetDumpFilePath(const string& dump_root_dir,
610 const DebugNodeKey& debug_node_key,
611 const uint64 wall_time_us) {
612 return AppendTimestampToFilePath(
613 io::JoinPath(dump_root_dir, debug_node_key.device_path,
614 strings::StrCat(debug_node_key.node_name, "_",
615 debug_node_key.output_slot, "_",
616 debug_node_key.debug_op)),
617 wall_time_us);
618}
619
620Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto,
621 const string& dir_name,
622 const string& file_name) {
623 Env* env(Env::Default());
624
625 Status s = RecursiveCreateDir(env, dir_name);
626 if (!s.ok()) {
627 return Status(error::FAILED_PRECONDITION,
628 strings::StrCat("Failed to create directory ", dir_name,
629 ", due to: ", s.error_message()));
630 }
631
632 const string file_path = io::JoinPath(dir_name, file_name);
633
634 string event_str;
635 event_proto.SerializeToString(&event_str);
636
637 std::unique_ptr<WritableFile> f = nullptr;
638 TF_CHECK_OK(env->NewWritableFile(file_path, &f));
639 f->Append(event_str).IgnoreError();
640 TF_CHECK_OK(f->Close());
641
642 return OkStatus();
643}
644
645Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
646 const Tensor& tensor,
647 const uint64 wall_time_us,
648 const string& file_path) {
649 std::vector<Event> events;
650 TF_RETURN_IF_ERROR(
651 WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events));
652 return DumpEventProtoToFile(events[0], string(io::Dirname(file_path)),
653 string(io::Basename(file_path)));
654}
655
656Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
657 if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
658 // The path already exists as a directory. Return OK right away.
659 return OkStatus();
660 }
661
662 string parent_dir(io::Dirname(dir));
663 if (!env->FileExists(parent_dir).ok()) {
664 // The parent path does not exist yet, create it first.
665 Status s = RecursiveCreateDir(env, parent_dir); // Recursive call
666 if (!s.ok()) {
667 return Status(
668 error::FAILED_PRECONDITION,
669 strings::StrCat("Failed to create directory ", parent_dir));
670 }
671 } else if (env->FileExists(parent_dir).ok() &&
672 !env->IsDirectory(parent_dir).ok()) {
673 // The path exists, but it is a file.
674 return Status(error::FAILED_PRECONDITION,
675 strings::StrCat("Failed to create directory ", parent_dir,
676 " because the path exists as a file "));
677 }
678
679 env->CreateDir(dir).IgnoreError();
680 // Guard against potential race in creating directories by doing a check
681 // after the CreateDir call.
682 if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
683 return OkStatus();
684 } else {
685 return Status(error::ABORTED,
686 strings::StrCat("Failed to create directory ", parent_dir));
687 }
688}
689
690// Default total disk usage limit: 100 GBytes
691const uint64 DebugFileIO::kDefaultGlobalDiskBytesLimit = 107374182400L;
692uint64 DebugFileIO::global_disk_bytes_limit_ = 0;
693uint64 DebugFileIO::disk_bytes_used_ = 0;
694
695mutex DebugFileIO::bytes_mu_(LINKER_INITIALIZED);
696
697bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
698 mutex_lock l(bytes_mu_);
699 if (global_disk_bytes_limit_ == 0) {
700 const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT");
701 if (env_tfdbg_disk_bytes_limit == nullptr ||
702 strlen(env_tfdbg_disk_bytes_limit) == 0) {
703 global_disk_bytes_limit_ = kDefaultGlobalDiskBytesLimit;
704 } else {
705 strings::safe_strtou64(string(env_tfdbg_disk_bytes_limit),
706 &global_disk_bytes_limit_);
707 }
708 }
709
710 if (bytes == 0) {
711 return true;
712 }
713 if (disk_bytes_used_ + bytes < global_disk_bytes_limit_) {
714 disk_bytes_used_ += bytes;
715 return true;
716 } else {
717 return false;
718 }
719}
720
721void DebugFileIO::resetDiskByteUsage() {
722 mutex_lock l(bytes_mu_);
723 disk_bytes_used_ = 0;
724}
725
726#ifndef PLATFORM_WINDOWS
727DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
728 : server_stream_addr_(server_stream_addr),
729 url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {}
730
731Status DebugGrpcChannel::Connect(const int64_t timeout_micros) {
732 ::grpc::ChannelArguments args;
733 args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
734 // Avoid problems where default reconnect backoff is too long (e.g., 20 s).
735 args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
736 channel_ = ::grpc::CreateCustomChannel(
737 server_stream_addr_, ::grpc::InsecureChannelCredentials(), args);
738 if (!channel_->WaitForConnected(
739 gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
740 gpr_time_from_micros(timeout_micros, GPR_TIMESPAN)))) {
741 return errors::FailedPrecondition(
742 "Failed to connect to gRPC channel at ", server_stream_addr_,
743 " within a timeout of ", timeout_micros / 1e6, " s.");
744 }
745 stub_ = grpc::EventListener::NewStub(channel_);
746 reader_writer_ = stub_->SendEvents(&ctx_);
747
748 return OkStatus();
749}
750
751bool DebugGrpcChannel::WriteEvent(const Event& event) {
752 mutex_lock l(mu_);
753 return reader_writer_->Write(event);
754}
755
756bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
757 mutex_lock l(mu_);
758 return reader_writer_->Read(event_reply);
759}
760
761void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) {
762 EventReply event_reply;
763 size_t num_replies = 0;
764 while ((max_replies == 0 || ++num_replies <= max_replies) &&
765 ReadEventReply(&event_reply)) {
766 for (const EventReply::DebugOpStateChange& debug_op_state_change :
767 event_reply.debug_op_state_changes()) {
768 string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
769 debug_op_state_change.output_slot(),
770 ":", debug_op_state_change.debug_op());
771 DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key,
772 debug_op_state_change.state());
773 }
774 }
775}
776
777Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
778 reader_writer_->WritesDone();
779 // Read all EventReply messages (if any) from the server.
780 ReceiveAndProcessEventReplies(0);
781
782 if (reader_writer_->Finish().ok()) {
783 return OkStatus();
784 } else {
785 return Status(error::FAILED_PRECONDITION,
786 "Failed to close debug GRPC stream.");
787 }
788}
789
790mutex DebugGrpcIO::streams_mu_(LINKER_INITIALIZED);
791
792int64_t DebugGrpcIO::channel_connection_timeout_micros_ = 900 * 1000 * 1000;
793// TODO(cais): Make this configurable?
794
795const size_t DebugGrpcIO::kGrpcMessageSizeLimitBytes = 4000 * 1024;
796
797const size_t DebugGrpcIO::kGrpcMaxVarintLengthSize = 6;
798
799std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
800DebugGrpcIO::GetStreamChannels() {
801 static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
802 stream_channels =
803 new std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>();
804 return stream_channels;
805}
806
807Status DebugGrpcIO::SendTensorThroughGrpcStream(
808 const DebugNodeKey& debug_node_key, const Tensor& tensor,
809 const uint64 wall_time_us, const string& grpc_stream_url,
810 const bool gated) {
811 if (gated &&
812 !IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
813 return OkStatus();
814 } else {
815 std::vector<Event> events;
816 TF_RETURN_IF_ERROR(WrapTensorAsEvents(debug_node_key, tensor, wall_time_us,
817 kGrpcMessageSizeLimitBytes, &events));
818 for (const Event& event : events) {
819 TF_RETURN_IF_ERROR(
820 SendEventProtoThroughGrpcStream(event, grpc_stream_url));
821 }
822 if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
823 DebugGrpcChannel* debug_grpc_channel = nullptr;
824 TF_RETURN_IF_ERROR(
825 GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
826 debug_grpc_channel->ReceiveAndProcessEventReplies(1);
827 // TODO(cais): Support new tensor value carried in the EventReply for
828 // overriding the value of the tensor being published.
829 }
830 return OkStatus();
831 }
832}
833
834Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
835 EventReply* event_reply, const string& grpc_stream_url) {
836 DebugGrpcChannel* debug_grpc_channel = nullptr;
837 TF_RETURN_IF_ERROR(
838 GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
839 if (debug_grpc_channel->ReadEventReply(event_reply)) {
840 return OkStatus();
841 } else {
842 return errors::Cancelled(strings::StrCat(
843 "Reading EventReply from stream URL ", grpc_stream_url, " failed."));
844 }
845}
846
847Status DebugGrpcIO::GetOrCreateDebugGrpcChannel(
848 const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) {
849 const string addr_with_path =
850 absl::StartsWith(grpc_stream_url, DebugIO::kGrpcURLScheme)
851 ? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme))
852 : grpc_stream_url;
853 const string server_stream_addr =
854 addr_with_path.substr(0, addr_with_path.find('/'));
855 {
856 mutex_lock l(streams_mu_);
857 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
858 stream_channels = GetStreamChannels();
859 if (stream_channels->find(grpc_stream_url) == stream_channels->end()) {
860 std::unique_ptr<DebugGrpcChannel> channel(
861 new DebugGrpcChannel(server_stream_addr));
862 TF_RETURN_IF_ERROR(channel->Connect(channel_connection_timeout_micros_));
863 stream_channels->insert(
864 std::make_pair(grpc_stream_url, std::move(channel)));
865 }
866 *debug_grpc_channel = (*stream_channels)[grpc_stream_url].get();
867 }
868 return OkStatus();
869}
870
871Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
872 const Event& event_proto, const string& grpc_stream_url,
873 const bool receive_reply) {
874 DebugGrpcChannel* debug_grpc_channel;
875 TF_RETURN_IF_ERROR(
876 GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
877
878 bool write_ok = debug_grpc_channel->WriteEvent(event_proto);
879 if (!write_ok) {
880 return errors::Cancelled(strings::StrCat("Write event to stream URL ",
881 grpc_stream_url, " failed."));
882 }
883
884 if (receive_reply) {
885 debug_grpc_channel->ReceiveAndProcessEventReplies(1);
886 }
887
888 return OkStatus();
889}
890
891bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url,
892 const string& watch_key) {
893 const DebugNodeName2State* enabled_node_to_state =
894 GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
895 return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end();
896}
897
898bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
899 const string& watch_key) {
900 const DebugNodeName2State* enabled_node_to_state =
901 GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
902 auto it = enabled_node_to_state->find(watch_key);
903 if (it == enabled_node_to_state->end()) {
904 return false;
905 } else {
906 return it->second == EventReply::DebugOpStateChange::READ_WRITE;
907 }
908}
909
910Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
911 mutex_lock l(streams_mu_);
912
913 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
914 stream_channels = GetStreamChannels();
915 if (stream_channels->find(grpc_stream_url) != stream_channels->end()) {
916 // Stream of the specified address exists. Close it and remove it from
917 // record.
918 Status s =
919 (*stream_channels)[grpc_stream_url]->ReceiveServerRepliesAndClose();
920 (*stream_channels).erase(grpc_stream_url);
921 return s;
922 } else {
923 // Stream of the specified address does not exist. No action.
924 return OkStatus();
925 }
926}
927
928std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>*
929DebugGrpcIO::GetEnabledDebugOpStates() {
930 static std::unordered_map<string, DebugNodeName2State>*
931 enabled_debug_op_states =
932 new std::unordered_map<string, DebugNodeName2State>();
933 return enabled_debug_op_states;
934}
935
936DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
937 const string& grpc_debug_url) {
938 static mutex* debug_ops_state_mu = new mutex();
939 std::unordered_map<string, DebugNodeName2State>* states =
940 GetEnabledDebugOpStates();
941
942 mutex_lock l(*debug_ops_state_mu);
943 if (states->find(grpc_debug_url) == states->end()) {
944 DebugNodeName2State url_enabled_debug_op_states;
945 (*states)[grpc_debug_url] = url_enabled_debug_op_states;
946 }
947 return &(*states)[grpc_debug_url];
948}
949
950void DebugGrpcIO::SetDebugNodeKeyGrpcState(
951 const string& grpc_debug_url, const string& watch_key,
952 const EventReply::DebugOpStateChange::State new_state) {
953 DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
954 if (new_state == EventReply::DebugOpStateChange::DISABLED) {
955 if (states->find(watch_key) == states->end()) {
956 LOG(ERROR) << "Attempt to disable a watch key that is not currently "
957 << "enabled at " << grpc_debug_url << ": " << watch_key;
958 } else {
959 states->erase(watch_key);
960 }
961 } else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) {
962 (*states)[watch_key] = new_state;
963 }
964}
965
966void DebugGrpcIO::ClearEnabledWatchKeys() {
967 GetEnabledDebugOpStates()->clear();
968}
969
970#endif // #ifndef PLATFORM_WINDOWS
971
972} // namespace tensorflow
973