1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_
17#define TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_
18
19#include <numeric>
20
21#include "tensorflow/core/platform/bfloat16.h"
22
23#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
24#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
25#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
26#include "tensorflow/core/util/determinism.h"
27#endif
28
29#if GOOGLE_CUDA
30#include "tensorflow/core/platform/cuda.h"
31#elif TENSORFLOW_USE_ROCM
32#include "tensorflow/core/platform/rocm.h"
33#endif
34
35#include "tensorflow/core/debug/debug_io_utils.h"
36#include "tensorflow/core/framework/device_base.h"
37#include "tensorflow/core/framework/op_kernel.h"
38#include "tensorflow/core/framework/tensor_util.h"
39#include "tensorflow/core/lib/core/notification.h"
40#include "tensorflow/core/lib/strings/stringprintf.h"
41#include "tensorflow/core/util/debug_events_writer.h"
42
43namespace tensorflow {
44
45// Copy op for debugging.
46// Performs CPU-to-CPU or GPU-to-GPU deep-copying of tensor, depending on the
47// device on which the tensor is allocated.
48class CopyOp : public OpKernel {
49 public:
50 explicit CopyOp(OpKernelConstruction* context) : OpKernel(context) {
51 OP_REQUIRES_OK(context, context->GetAttr("tensor_name", &tensor_name_));
52
53 std::vector<string> debug_ops_spec;
54 OP_REQUIRES_OK(context,
55 context->GetAttr("debug_ops_spec", &debug_ops_spec));
56 for (const string& debug_op_spec : debug_ops_spec) {
57 // Assume debug_op_spec has the format
58 // <debug_op>;<debug_url>;<gated_grpc>, e.g.,
59 // DebugIdentity;grpc://localhost:3333;1
60 const std::vector<string> items = str_util::Split(debug_op_spec, ";");
61 OP_REQUIRES(
62 context, items.size() == 3,
63 errors::Internal(
64 "Unexpected number of semicolons in debug_ops_spec element: ",
65 debug_op_spec));
66 debug_op_and_url_specs_.push_back(
67 DebugWatchAndURLSpec(strings::StrCat(tensor_name_, ":", items[0]),
68 items[1], items[2] == "1"));
69 }
70 }
71
72 void Compute(OpKernelContext* context) override {
73 const Tensor& src_tensor = context->input(0);
74
75 if (src_tensor.IsInitialized() &&
76 DataTypeCanUseMemcpy(src_tensor.dtype()) &&
77 DebugIO::IsCopyNodeGateOpen(debug_op_and_url_specs_)) {
78 // Source tensor is initialized and is mem-copyable. Make a copy.
79 Tensor* copied_tensor;
80 OP_REQUIRES_OK(context, context->allocate_output(0, src_tensor.shape(),
81 &copied_tensor));
82
83#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
84 Device* device = static_cast<Device*>(context->device());
85 // Determine if the input tensor is not on CPU (e.g., on GPU).
86 bool off_host_input = device->device_type() == DEVICE_GPU &&
87 !context->input_alloc_attr(0).on_host();
88
89 if (off_host_input) {
90 DeviceContext* device_ctxt = context->op_device_context();
91 // Input is not on host: deep-copy it from GPU to the same GPU.
92 Notification done_copy;
93 GPUUtil::CopyGPUTensorToSameGPU(
94 device, device_ctxt, &src_tensor, copied_tensor,
95 [&done_copy](const Status& s) { done_copy.Notify(); });
96 done_copy.WaitForNotification();
97 } else {
98 // The input tensor is on the host (CPU): deep-copy from CPU to CPU.
99 *copied_tensor = tensor::DeepCopy(src_tensor);
100 }
101#else
102 *copied_tensor = tensor::DeepCopy(src_tensor);
103#endif
104 } else {
105 // Source tensor is NOT initialized and/or is not mem-copyable: Forward
106 // the Tensor object.
107 context->set_output(0, src_tensor);
108 }
109 }
110
111 bool IsExpensive() override { return false; }
112
113 private:
114 string tensor_name_;
115 std::vector<DebugWatchAndURLSpec> debug_op_and_url_specs_;
116};
117
118// Base class of all debug ops.
119class BaseDebugOp : public OpKernel {
120 public:
121 explicit BaseDebugOp(const string& debug_op_name,
122 OpKernelConstruction* context)
123 : OpKernel(context), debug_op_name_(debug_op_name) {
124 OP_REQUIRES_OK(context, context->GetAttr("debug_urls", &debug_urls_));
125 OP_REQUIRES_OK(context, context->GetAttr("gated_grpc", &gated_grpc_));
126
127 string device_name;
128 string tensor_name;
129 OP_REQUIRES_OK(context, context->GetAttr("device_name", &device_name));
130 OP_REQUIRES_OK(context, context->GetAttr("tensor_name", &tensor_name));
131
132 std::vector<string> name_items = str_util::Split(tensor_name, ':');
133 string node_name;
134 int32_t output_slot = 0;
135 OP_REQUIRES(context, name_items.size() == 1 || name_items.size() == 2,
136 errors::InvalidArgument("Failed to parse tensor name: \"",
137 tensor_name, "\""));
138 if (name_items.size() == 2) {
139 node_name = name_items[0];
140 OP_REQUIRES(
141 context, strings::safe_strto32(name_items[1], &output_slot),
142 errors::InvalidArgument("Invalid string value for output_slot: \"",
143 name_items[1], "\""));
144 } else if (name_items.size() == 1) {
145 node_name = name_items[0];
146 }
147
148 debug_watch_key_.reset(
149 new DebugNodeKey(device_name, node_name, output_slot, debug_op_name_));
150 }
151
152 bool IsExpensive() override { return false; }
153
154 protected:
155 // Apply gRPC gating (if gated_grpc_ attribute is true).
156 //
157 // Returns false if and only if all grpc:// debug URLs of the debug op are
158 // disabled currently (i.e., gated off), in which case the debug op will emit
159 // an empty (size {0}) tensor of undefined data type.
160 bool ApplyGrpcGating(OpKernelContext* context) {
161 if (gated_grpc_ && !DebugIO::IsDebugNodeGateOpen(
162 debug_watch_key_->debug_node_name, debug_urls_)) {
163 // The entire node is gated off: Output an empty tensor and avoid
164 // expensive computation.
165 Tensor* output_tensor;
166 TensorShape shape({0});
167 if (!context->allocate_output(0, shape, &output_tensor).ok()) {
168 LOG(ERROR) << "Debug node of watch key "
169 << debug_watch_key_->debug_node_name
170 << " failed to allocate empty tensor under gated-off state.";
171 }
172 return false;
173 } else {
174 return true;
175 }
176 }
177
178 // Publish a tensor to all debug URLs of the debug op.
179 // Log an error if the publishing failed.
180 Status PublishTensor(const Tensor& tensor) {
181 if (debug_urls_.empty()) {
182 return OkStatus();
183 } else {
184 Status status = DebugIO::PublishDebugTensor(*debug_watch_key_, tensor,
185 Env::Default()->NowMicros(),
186 debug_urls_, gated_grpc_);
187 if (!status.ok()) {
188 LOG(ERROR) << "Debug node of watch key "
189 << debug_watch_key_->debug_node_name
190 << " failed to publish debug tensor data to all URLs "
191 << str_util::Join(debug_urls_, ", ")
192 << ", due to: " << status.error_message();
193 }
194 return status;
195 }
196 }
197
198 private:
199 const string debug_op_name_;
200 std::unique_ptr<DebugNodeKey> debug_watch_key_;
201 std::vector<string> debug_urls_;
202 bool gated_grpc_;
203};
204
205// Identity op for debugging.
206// Output slot 0 carries the debug signal and is always allocated on the
207// host (CPU) as a non-Ref tensor. In the case of DebugIdentityOp,
208// the debug signal is equal to the input tensor.
209class DebugIdentityOp : public BaseDebugOp {
210 public:
211 explicit DebugIdentityOp(OpKernelConstruction* context)
212 : BaseDebugOp("DebugIdentity", context) {}
213
214 void Compute(OpKernelContext* context) override {
215 if (!ApplyGrpcGating(context)) {
216 return;
217 }
218
219 OP_REQUIRES_OK(context, PublishTensor(context->input(0)));
220 context->set_output(0, context->input(0));
221 }
222};
223
224// NaN-counter op for debugging.
225template <typename T>
226class DebugNanCountOp : public BaseDebugOp {
227 public:
228 explicit DebugNanCountOp(OpKernelConstruction* context)
229 : BaseDebugOp("DebugNanCount", context) {}
230
231 void Compute(OpKernelContext* context) override {
232 if (!ApplyGrpcGating(context)) {
233 return;
234 }
235
236 Tensor* output_tensor;
237 const Tensor& input = context->input(0);
238
239 // Use DT_INT64/int64 to be consistent with TensorShape::num_elements().
240 int64_t nan_count = 0;
241
242 // If the input is an uninitialized tensor, let nan_count be 0.
243 if (input.IsInitialized()) {
244 // Count NaNs.
245 const TensorShape& input_shape = input.shape();
246 const T* input_flat = input.template flat<T>().data();
247
248 for (int64_t i = 0; i < input_shape.num_elements(); ++i) {
249 if (Eigen::numext::isnan(static_cast<double>(input_flat[i]))) {
250 nan_count++;
251 }
252 }
253 }
254
255 TensorShape shape({1});
256 OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor));
257 output_tensor->vec<int64_t>()(0) = nan_count;
258 OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
259 }
260};
261
262// Numeric summary op for debugging.
263template <typename T>
264class DebugNumericSummaryOp : public BaseDebugOp {
265 public:
266 explicit DebugNumericSummaryOp(OpKernelConstruction* context)
267 : BaseDebugOp("DebugNumericSummary", context) {
268 OP_REQUIRES_OK(context, context->GetAttr("lower_bound", &lower_bound_));
269 OP_REQUIRES_OK(context, context->GetAttr("upper_bound", &upper_bound_));
270 OP_REQUIRES_OK(context,
271 context->GetAttr("mute_if_healthy", &mute_if_healthy_));
272 }
273
274 void Compute(OpKernelContext* context) override {
275 if (!ApplyGrpcGating(context)) {
276 return;
277 }
278
279 Tensor* output_tensor;
280 const Tensor& input = context->input(0);
281
282 int64_t is_initialized = 0;
283 int64_t element_count = 0;
284 int64_t negative_inf_count = 0;
285 int64_t negative_count = 0;
286 int64_t zero_count = 0;
287 int64_t positive_count = 0;
288 int64_t positive_inf_count = 0;
289 int64_t nan_count = 0;
290 double min = std::numeric_limits<double>::infinity();
291 double max = -std::numeric_limits<double>::infinity();
292 double sum = 0.0;
293 double mean = std::numeric_limits<double>::quiet_NaN();
294 double variance = std::numeric_limits<double>::quiet_NaN();
295
296 // Equal to negative_count + zero_count + positive_count.
297 int64_t non_inf_nan_count = 0;
298
299 const TensorShape& input_shape = input.shape();
300 if (input.IsInitialized()) {
301 is_initialized = 1;
302 const T* input_flat = input.template flat<T>().data();
303
304 element_count = input_shape.num_elements();
305 const bool is_lower_bound_custom = !Eigen::numext::isinf(lower_bound_);
306 const bool is_upper_bound_custom = !Eigen::numext::isinf(upper_bound_);
307
308 for (int64_t i = 0; i < element_count; ++i) {
309 const double x = static_cast<double>(input_flat[i]);
310 if (Eigen::numext::isnan(x)) {
311 nan_count++;
312 } else if (Eigen::numext::isinf(x)) {
313 if (x < 0.0) {
314 negative_inf_count++;
315 } else {
316 positive_inf_count++;
317 }
318 } else {
319 if (is_lower_bound_custom && x <= lower_bound_) {
320 negative_inf_count++;
321 } else if (is_upper_bound_custom && x >= upper_bound_) {
322 positive_inf_count++;
323 } else if (x < 0.0) {
324 negative_count++;
325 } else if (x > 0.0) {
326 positive_count++;
327 } else {
328 zero_count++;
329 }
330
331 if (x < min) {
332 min = x;
333 }
334 if (x > max) {
335 max = x;
336 }
337
338 non_inf_nan_count++;
339 sum += x;
340 }
341 }
342
343 if (non_inf_nan_count > 0) {
344 mean = sum / non_inf_nan_count;
345
346 // Do a second pass to compute variance.
347 variance = 0.0;
348 for (int64_t i = 0; i < element_count; ++i) {
349 const double x = static_cast<double>(input_flat[i]);
350 if (!Eigen::numext::isnan(x) && !Eigen::numext::isinf(x)) {
351 variance += (x - mean) * (x - mean);
352 }
353 }
354 variance /= non_inf_nan_count;
355 }
356 }
357
358 TensorShape shape({14 + input_shape.dims()});
359 OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor));
360 output_tensor->vec<double>()(0) = static_cast<double>(is_initialized);
361 output_tensor->vec<double>()(1) = static_cast<double>(element_count);
362 output_tensor->vec<double>()(2) = static_cast<double>(nan_count);
363 output_tensor->vec<double>()(3) = static_cast<double>(negative_inf_count);
364 output_tensor->vec<double>()(4) = static_cast<double>(negative_count);
365 output_tensor->vec<double>()(5) = static_cast<double>(zero_count);
366 output_tensor->vec<double>()(6) = static_cast<double>(positive_count);
367 output_tensor->vec<double>()(7) = static_cast<double>(positive_inf_count);
368 output_tensor->vec<double>()(8) = min;
369 output_tensor->vec<double>()(9) = max;
370 output_tensor->vec<double>()(10) = mean;
371 output_tensor->vec<double>()(11) = variance;
372
373 output_tensor->vec<double>()(12) = static_cast<double>(input.dtype());
374 output_tensor->vec<double>()(13) = static_cast<double>(input_shape.dims());
375 for (size_t d = 0; d < input_shape.dims(); ++d) {
376 output_tensor->vec<double>()(14 + d) =
377 static_cast<double>(input_shape.dim_sizes()[d]);
378 }
379
380 bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 &&
381 positive_inf_count == 0;
382 if (!mute) {
383 OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
384 }
385 }
386
387 private:
388 float lower_bound_;
389 float upper_bound_;
390 bool mute_if_healthy_;
391};
392
393// Identity op for tfdbg v2: Writes debug data using DebugEventsWriter.
394class DebugIdentityV2Op : public OpKernel {
395 public:
396 explicit DebugIdentityV2Op(OpKernelConstruction* context)
397 : OpKernel(context),
398 device_name_(context->device()->name()),
399 output_slot_(-1),
400 tensor_debug_mode_(0),
401 tfdbg_run_id_() {
402 std::vector<string> debug_urls;
403 OP_REQUIRES_OK(context, context->GetAttr("debug_urls", &debug_urls));
404 for (const string& debug_url : debug_urls) {
405 if (absl::StartsWith(debug_url, DebugIO::kFileURLScheme)) {
406 dump_roots_.emplace_back(
407 debug_url.substr(strlen(DebugIO::kFileURLScheme)));
408 } else {
409 context->SetStatus(
410 errors::Internal("Unsupported debug URL schema in: ", debug_url));
411 }
412 }
413 OP_REQUIRES_OK(context,
414 context->GetAttr("tfdbg_context_id", &tfdbg_context_id_));
415 OP_REQUIRES_OK(context, context->GetAttr("op_name", &op_name_));
416 OP_REQUIRES_OK(context, context->GetAttr("output_slot", &output_slot_));
417 OP_REQUIRES_OK(context,
418 context->GetAttr("tensor_debug_mode", &tensor_debug_mode_));
419 if (context->HasAttr("circular_buffer_size")) {
420 OP_REQUIRES_OK(context, context->GetAttr("circular_buffer_size",
421 &circular_buffer_size_));
422 } else {
423 circular_buffer_size_ =
424 tfdbg::DebugEventsWriter::kDefaultCyclicBufferSize;
425 }
426 if (context->HasAttr("tfdbg_run_id")) {
427 OP_REQUIRES_OK(context, context->GetAttr("tfdbg_run_id", &tfdbg_run_id_));
428 }
429 }
430
431 void Compute(OpKernelContext* context) override {
432 const Tensor& tensor = context->input(0);
433 for (const string& dump_root : dump_roots_) {
434 tfdbg::DebugEventsWriter* debug_events_writer =
435 tfdbg::DebugEventsWriter::GetDebugEventsWriter(
436 dump_root, tfdbg_run_id_, circular_buffer_size_);
437 OP_REQUIRES_OK(context, debug_events_writer->WriteGraphExecutionTrace(
438 tfdbg_context_id_, device_name_, op_name_,
439 output_slot_, tensor_debug_mode_, tensor));
440 }
441 context->set_output(0, tensor);
442 }
443
444 private:
445 std::vector<string> dump_roots_;
446 string tfdbg_context_id_;
447 string device_name_;
448 string op_name_;
449 int32 output_slot_;
450 int32 tensor_debug_mode_;
451 int64_t circular_buffer_size_;
452 string tfdbg_run_id_;
453};
454
455typedef Eigen::ThreadPoolDevice CPUDevice;
456typedef Eigen::GpuDevice GPUDevice;
457
458#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
459template <typename Tin, typename Tout>
460struct CurtHealthLaunch {
461 void Run(const GPUDevice& d, const Tin* data, int size, Tout output[1]);
462};
463
464extern template struct CurtHealthLaunch<Eigen::half, float>;
465extern template struct CurtHealthLaunch<float, float>;
466extern template struct CurtHealthLaunch<double, float>;
467extern template struct CurtHealthLaunch<Eigen::half, double>;
468extern template struct CurtHealthLaunch<float, double>;
469extern template struct CurtHealthLaunch<double, double>;
470
471template <typename Tin, typename Tout>
472struct ConciseHealthLaunch {
473 void Run(const GPUDevice& d, const Tin* data, int size, Tout output[3]);
474};
475
476extern template struct ConciseHealthLaunch<Eigen::half, float>;
477extern template struct ConciseHealthLaunch<float, float>;
478extern template struct ConciseHealthLaunch<double, float>;
479extern template struct ConciseHealthLaunch<Eigen::half, double>;
480extern template struct ConciseHealthLaunch<float, double>;
481extern template struct ConciseHealthLaunch<double, double>;
482
483template <typename Tin, typename Tout>
484struct FullHealthLaunch {
485 void Run(const GPUDevice& d, const Tin* data, int size, Tout output[6]);
486};
487
488extern template struct FullHealthLaunch<Eigen::half, float>;
489extern template struct FullHealthLaunch<float, float>;
490extern template struct FullHealthLaunch<double, float>;
491extern template struct FullHealthLaunch<Eigen::half, double>;
492extern template struct FullHealthLaunch<float, double>;
493extern template struct FullHealthLaunch<double, double>;
494
495template <typename Tin, typename Tout>
496struct ReduceInfNanThreeSlotsLaunch {
497 void Run(const GPUDevice& d, const Tin* data, int size, Tout output[3]);
498};
499
500extern template struct ReduceInfNanThreeSlotsLaunch<Eigen::half, float>;
501extern template struct ReduceInfNanThreeSlotsLaunch<float, float>;
502extern template struct ReduceInfNanThreeSlotsLaunch<double, float>;
503extern template struct ReduceInfNanThreeSlotsLaunch<Eigen::half, double>;
504extern template struct ReduceInfNanThreeSlotsLaunch<float, double>;
505extern template struct ReduceInfNanThreeSlotsLaunch<double, double>;
506
507#endif
508
509template <typename Device, typename Tin, typename Tout>
510class DebugNumericSummaryV2Op;
511
512// Numeric summary op for tfdbg v2: CPU Kernel.
513template <typename Tin, typename Tout>
514class DebugNumericSummaryV2Op<CPUDevice, Tin, Tout> : public OpKernel {
515 public:
516 explicit DebugNumericSummaryV2Op(OpKernelConstruction* context)
517 : OpKernel(context) {
518 OP_REQUIRES_OK(context,
519 context->GetAttr("tensor_debug_mode", &tensor_debug_mode_));
520 OP_REQUIRES_OK(context, context->GetAttr("tensor_id", &tensor_id_));
521 }
522
523 void Compute(OpKernelContext* context) override {
524 const Tensor& tensor = context->input(0);
525 auto in = tensor.flat<Tin>();
526 const Tin* data = in.data();
527 const int64_t size = in.size();
528 Tensor* output_tensor;
529 Tout tensor_id = static_cast<Tout>(tensor_id_);
530 const Tout num_elem = static_cast<Tout>(context->input(0).NumElements());
531 // Disregard lossy cast if mode is REDUCE_INF_NAN_THREE_SLOTS because
532 // that mode does not make use of tensor_id.
533 if (tensor_debug_mode_ != 8) {
534 OP_REQUIRES(
535 context, tensor_id_ <= kMaxTensorId,
536 errors::InvalidArgument("DebugNumericSummaryV2Op requires "
537 "tensor_id to be less than or equal to "
538 "(2^",
539 std::numeric_limits<Tout>::digits,
540 "). Given tensor_id:", tensor_id_));
541 }
542
543 if (tensor_debug_mode_ == 2) { // CURT_HEALTH
544 TensorShape shape({2});
545 OP_REQUIRES_OK(context,
546 context->allocate_output(0, shape, &output_tensor));
547 output_tensor->flat<Tout>()(0) = tensor_id; // Slot tensor id
548 output_tensor->flat<Tout>()(1) = 0.0; // Has inf or nan
549 int fp_props =
550 std::accumulate(data, data + size, 0, [](const int x, const Tin& y) {
551 return Eigen::numext::isfinite(y) ? x : 1;
552 });
553 if (fp_props) {
554 output_tensor->flat<Tout>()(1) = 1.0;
555 }
556 } else if (tensor_debug_mode_ == 3) { // CONCISE_HEALTH
557 TensorShape shape({5});
558 OP_REQUIRES_OK(context,
559 context->allocate_output(0, shape, &output_tensor));
560 output_tensor->flat<Tout>()(0) = tensor_id;
561 output_tensor->flat<Tout>()(1) = num_elem;
562
563 // Accumulator value [neg_inf_count, pos_inf_count, nan_count]
564 Tout fp_props[3] = {0.0, 0.0, 0.0};
565 std::for_each(data, data + size, [&fp_props](const Tin& y) {
566 if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) {
567 // Do nothing: common case.
568 } else if (Eigen::numext::isinf(y)) {
569 if (y < static_cast<Tin>(0.f)) {
570 ++fp_props[0];
571 } else {
572 ++fp_props[1];
573 }
574 } else if (Eigen::numext::isnan(y)) {
575 ++fp_props[2];
576 }
577 });
578 output_tensor->flat<Tout>()(2) = fp_props[0]; // Slot for -inf count
579 output_tensor->flat<Tout>()(3) = fp_props[1]; // Slot for inf count
580 output_tensor->flat<Tout>()(4) = fp_props[2]; // Slot for nan count
581 } else if (tensor_debug_mode_ == 4) { // FULL HEALTH
582 TensorShape shape({11});
583 OP_REQUIRES_OK(context,
584 context->allocate_output(0, shape, &output_tensor));
585 int num_dims = tensor.dims();
586 output_tensor->flat<Tout>()(0) = tensor_id;
587 output_tensor->flat<Tout>()(1) = -1.0; // TODO(144919262): Device ID
588 output_tensor->flat<Tout>()(2) = static_cast<Tout>(tensor.dtype());
589 output_tensor->flat<Tout>()(3) = static_cast<Tout>(num_dims);
590 output_tensor->flat<Tout>()(4) = num_elem;
591
592 // Accumulator value [neg_inf_count, pos_inf_count, nan_count, neg_count,
593 // zero_count, pos_count]
594 Tout fp_props[6] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
595 std::for_each(data, data + size, [&fp_props](const Tin& y) {
596 if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) {
597 if (y < static_cast<Tin>(0.f)) {
598 ++fp_props[3];
599 } else if (y == static_cast<Tin>(0.f)) {
600 ++fp_props[4];
601 } else {
602 ++fp_props[5];
603 }
604 } else if (Eigen::numext::isinf(y)) {
605 if (y < static_cast<Tin>(0.f)) {
606 ++fp_props[0];
607 } else {
608 ++fp_props[1];
609 }
610 } else if (Eigen::numext::isnan(y)) {
611 ++fp_props[2];
612 }
613 });
614 output_tensor->flat<Tout>()(5) = fp_props[0]; // Slot for -inf count
615 output_tensor->flat<Tout>()(6) = fp_props[1]; // Slot for inf count
616 output_tensor->flat<Tout>()(7) = fp_props[2]; // Slot for nan count.
617 output_tensor->flat<Tout>()(8) = fp_props[3]; // Slot for neg count.
618 output_tensor->flat<Tout>()(9) = fp_props[4]; // Slot for zero count.
619 output_tensor->flat<Tout>()(10) = fp_props[5]; // Slot for pos count.
620 } else if (tensor_debug_mode_ == 5) { // SHAPE
621 TensorShape shape({10});
622 OP_REQUIRES_OK(context,
623 context->allocate_output(0, shape, &output_tensor));
624
625 int num_dims = tensor.dims();
626 output_tensor->flat<Tout>()(0) = tensor_id;
627 output_tensor->flat<Tout>()(1) = static_cast<Tout>(tensor.dtype());
628 output_tensor->flat<Tout>()(2) = static_cast<Tout>(num_dims);
629 output_tensor->flat<Tout>()(3) = num_elem;
630
631 // Tensor shape - stored as (6 columns)
632 // if num_dim is less than 6, we right pad the shape with zeros
633 // if num_dim is greater than 6, we truncate the head (left most) of the
634 // dimensions as they are more predictable than the last few (e.g. batch
635 // size as first dimension)
636 int dim_idx = 4;
637 for (int i = std::max(0, num_dims - kShapeDims);
638 i < std::max(6, num_dims); ++i) {
639 if (i < num_dims) {
640 output_tensor->flat<Tout>()(dim_idx++) =
641 static_cast<Tout>(tensor.dim_size(i));
642 } else {
643 output_tensor->flat<Tout>()(dim_idx++) = 0.0;
644 }
645 }
646 } else if (tensor_debug_mode_ == 8) { // REDUCE_INF_NAN_THREE_SLOTS.
647 TensorShape shape({3});
648 OP_REQUIRES_OK(context,
649 context->allocate_output(0, shape, &output_tensor));
650 output_tensor->flat<Tout>()(0) = 0.0; // Slot for -inf.
651 output_tensor->flat<Tout>()(1) = 0.0; // Slot for inf.
652 output_tensor->flat<Tout>()(2) = 0.0; // Slot for nan.
653
654 int fp_props =
655 std::accumulate(data, data + size, 0, [](const int x, const Tin& y) {
656 int result = x;
657 if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) {
658 // Do nothing: common case.
659 } else if (Eigen::numext::isinf(y)) {
660 result |= y < static_cast<Tin>(0.f) ? kNegInfBit : kPosInfBit;
661 } else if (Eigen::numext::isnan(y)) {
662 result |= kNaNBit;
663 }
664 return result;
665 });
666
667 if (fp_props & kNegInfBit) {
668 output_tensor->flat<Tout>()(0) = -std::numeric_limits<Tout>::infinity();
669 }
670 if (fp_props & kPosInfBit) {
671 output_tensor->flat<Tout>()(1) = std::numeric_limits<Tout>::infinity();
672 }
673 if (fp_props & kNaNBit) {
674 output_tensor->flat<Tout>()(2) = std::numeric_limits<Tout>::quiet_NaN();
675 }
676 } else {
677 // TODO(cais): Implement other tensor debug modes in debug_event.proto.
678 context->SetStatus(errors::Unimplemented(
679 "Unimplemented tensor debug mode: ", tensor_debug_mode_));
680 }
681 }
682
683 private:
684 int tensor_debug_mode_;
685 int64_t tensor_id_;
686 static constexpr int kShapeDims = 6;
687 static constexpr int kNegInfBit = 0x01;
688 static constexpr int kPosInfBit = 0x02;
689 static constexpr int kNaNBit = 0x04;
690 static constexpr int64_t kMaxTensorId = 1LL
691 << std::numeric_limits<Tout>::digits;
692};
693
694#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
695
696template <typename Tin, typename Tout>
697class DebugNumericSummaryV2Op<GPUDevice, Tin, Tout> : public AsyncOpKernel {
698 public:
699 typedef GPUDevice Device;
700
701 explicit DebugNumericSummaryV2Op(OpKernelConstruction* context)
702 : AsyncOpKernel(context) {
703 OP_REQUIRES_OK(context,
704 context->GetAttr("tensor_debug_mode", &tensor_debug_mode_));
705 OP_REQUIRES_OK(context, context->GetAttr("tensor_id", &tensor_id_));
706 }
707
708 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
709 Tensor* output_tensor;
710 Tout tensor_id = static_cast<Tout>(tensor_id_);
711 const Tensor& tensor = context->input(0);
712 const Tout num_elem = static_cast<Tout>(tensor.NumElements());
713 const Device& d = context->eigen_device<Device>();
714 auto input = tensor.flat<Tin>();
715 auto check_cb = [this, done]() { done(); };
716 // Disregard lossy cast if mode is REDUCE_INF_NAN_THREE_SLOTS because
717 // that mode does not make use of tensor_id.
718 if (tensor_debug_mode_ != 8) {
719 OP_REQUIRES_ASYNC(
720 context, tensor_id_ <= kMaxTensorId,
721 errors::InvalidArgument("DebugNumericSummaryV2Op requires "
722 "tensor_id to be less than or equal to "
723 "(2^",
724 std::numeric_limits<Tout>::digits,
725 "). Given tensor_id:", tensor_id_),
726 done);
727 }
728
729 if (tensor_debug_mode_ == 2) { // CURT_HEALTH.
730 TensorShape shape({2});
731 OP_REQUIRES_OK(context,
732 context->allocate_output(0, shape, &output_tensor));
733
734 auto* stream = context->op_device_context()->stream();
735 OP_REQUIRES_ASYNC(context, stream != nullptr,
736 errors::Internal("No GPU stream available."), done);
737
738 se::DeviceMemoryBase output_tensor_ptr(
739 output_tensor->flat<Tout>().data(),
740 output_tensor->flat<Tout>().size());
741 stream->ThenMemZero(&output_tensor_ptr, 2 * sizeof(Tout));
742 // Copy tensor_id to slot zero
743 stream->ThenMemcpy(&output_tensor_ptr, &tensor_id, sizeof(Tout));
744 if (num_elem == 0) {
745 done();
746 return;
747 }
748
749 // Call the GPU kernels for the numerical (inf/nan) checks.
750 auto input = context->input(0).flat<Tin>();
751 CurtHealthLaunch<Tin, Tout>().Run(d, input.data(), input.size(),
752 output_tensor->flat<Tout>().data() + 1);
753
754 context->device()
755 ->tensorflow_accelerator_device_info()
756 ->event_mgr->ThenExecute(stream, std::move(check_cb));
757 } else if (tensor_debug_mode_ == 3) { // CONCISE_HEALTH.
758 TensorShape shape({5});
759 OP_REQUIRES_OK(context,
760 context->allocate_output(0, shape, &output_tensor));
761 OP_REQUIRES_ASYNC(context, !tensorflow::OpDeterminismRequired(),
762 errors::Unimplemented(
763 "Determinism is not yet supported for "
764 "DebugNumericSummaryV2 when tensor_debug_mode is "
765 "CONCISE_HEALTH."),
766 done);
767
768 auto* stream = context->op_device_context()->stream();
769 OP_REQUIRES_ASYNC(context, stream != nullptr,
770 errors::Internal("No GPU stream available."), done);
771
772 se::DeviceMemoryBase output_tensor_ptr(
773 output_tensor->flat<Tout>().data(),
774 output_tensor->flat<Tout>().size());
775 stream->ThenMemset32(&output_tensor_ptr, 0, 5 * sizeof(Tout));
776 const Tout static_output[] = {tensor_id, num_elem};
777 stream->ThenMemcpy(&output_tensor_ptr, &static_output, 2 * sizeof(Tout));
778 if (num_elem == 0) {
779 done();
780 return;
781 }
782
783 // Call the GPU kernels for the numerical (inf/nan) checks.
784 ConciseHealthLaunch<Tin, Tout>().Run(
785 d, input.data(), input.size(),
786 output_tensor->flat<Tout>().data() + 2);
787
788 context->device()
789 ->tensorflow_accelerator_device_info()
790 ->event_mgr->ThenExecute(stream, std::move(check_cb));
791 } else if (tensor_debug_mode_ == 4) { // FULL HEALTH
792 TensorShape shape({11});
793 OP_REQUIRES_OK(context,
794 context->allocate_output(0, shape, &output_tensor));
795
796 auto* stream = context->op_device_context()->stream();
797 OP_REQUIRES_ASYNC(context, stream != nullptr,
798 errors::Internal("No GPU stream available."), done);
799 OP_REQUIRES_ASYNC(context, !tensorflow::OpDeterminismRequired(),
800 errors::Unimplemented(
801 "Determinism is not yet supported for "
802 "DebugNumericSummaryV2 when tensor_debug_mode is "
803 "FULL_HEALTH."),
804 done);
805
806 se::DeviceMemoryBase output_tensor_ptr(
807 output_tensor->flat<Tout>().data(),
808 output_tensor->flat<Tout>().size());
809 stream->ThenMemset32(&output_tensor_ptr, 0, 11 * sizeof(Tout));
810
811 int num_dims = tensor.dims();
812 const Tout static_output[] = {tensor_id,
813 -1.0, // TODO(144919262): Device ID
814 static_cast<Tout>(tensor.dtype()),
815 static_cast<Tout>(num_dims), num_elem};
816 stream->ThenMemcpy(&output_tensor_ptr, &static_output, 5 * sizeof(Tout));
817 if (num_elem == 0) {
818 done();
819 return;
820 }
821
822 // Call the GPU kernels for the numerical (inf/nan) checks and
823 // pos/neg/zero counts.
824 FullHealthLaunch<Tin, Tout>().Run(d, input.data(), input.size(),
825 output_tensor->flat<Tout>().data() + 5);
826
827 context->device()
828 ->tensorflow_accelerator_device_info()
829 ->event_mgr->ThenExecute(stream, std::move(check_cb));
830 } else if (tensor_debug_mode_ == 5) { // SHAPE
831 TensorShape shape({10});
832 OP_REQUIRES_OK(context,
833 context->allocate_output(0, shape, &output_tensor));
834
835 auto* stream = context->op_device_context()->stream();
836 OP_REQUIRES_ASYNC(context, stream != nullptr,
837 errors::Internal("No GPU stream available."), done);
838
839 se::DeviceMemoryBase output_tensor_ptr(
840 output_tensor->flat<Tout>().data(),
841 output_tensor->flat<Tout>().size());
842
843 int num_dims = tensor.dims();
844 Tout static_output[10] = {tensor_id,
845 static_cast<Tout>(tensor.dtype()),
846 static_cast<Tout>(num_dims),
847 num_elem,
848 0.0,
849 0.0,
850 0.0,
851 0.0,
852 0.0,
853 0.0};
854 // Tensor shape: right pad zeros, truncate head
855 int dim_idx = 4;
856 for (int i = std::max(0, num_dims - 6); i < num_dims; ++i) {
857 static_output[dim_idx++] = static_cast<Tout>(tensor.dim_size(i));
858 }
859 // Write to device stream
860 stream->ThenMemcpy(&output_tensor_ptr, &static_output, sizeof(Tout) * 10);
861 context->device()
862 ->tensorflow_accelerator_device_info()
863 ->event_mgr->ThenExecute(stream, std::move(check_cb));
864 } else if (tensor_debug_mode_ == 8) { // REDUCE_INF_NAN_THREE_SLOTS.
865 TensorShape shape({3});
866 OP_REQUIRES_OK(context,
867 context->allocate_output(0, shape, &output_tensor));
868
869 auto* stream = context->op_device_context()->stream();
870 OP_REQUIRES_ASYNC(context, stream != nullptr,
871 errors::Internal("No GPU stream available."), done);
872
873 se::DeviceMemoryBase output_tensor_ptr(
874 output_tensor->flat<Tout>().data(),
875 output_tensor->flat<Tout>().size());
876 stream->ThenMemset32(&output_tensor_ptr, 0,
877 output_tensor->flat<Tout>().size() * sizeof(Tout));
878 if (num_elem == 0) {
879 done();
880 return;
881 }
882
883 // Call the GPU kernels for the numerical (inf/nan) checks.
884 auto input = context->input(0).flat<Tin>();
885 ReduceInfNanThreeSlotsLaunch<Tin, Tout>().Run(
886 d, input.data(), input.size(), output_tensor->flat<Tout>().data());
887
888 context->device()
889 ->tensorflow_accelerator_device_info()
890 ->event_mgr->ThenExecute(stream, std::move(check_cb));
891 } else {
892 // TODO(cais): Implement other tensor debug modes in debug_event.proto.
893 context->SetStatus(errors::Unimplemented(
894 "Unimplemented tensor debug mode: ", tensor_debug_mode_));
895 done();
896 }
897 }
898
899 private:
900 int tensor_debug_mode_;
901 int64_t tensor_id_;
902 static constexpr int64_t kMaxTensorId = 1L
903 << std::numeric_limits<Tout>::digits;
904};
905
906#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
907
908} // namespace tensorflow
909
910#endif // TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_
911