1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2Licensed under the Apache License, Version 2.0 (the "License");
3you may not use this file except in compliance with the License.
4You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6Unless required by applicable law or agreed to in writing, software
7distributed under the License is distributed on an "AS IS" BASIS,
8WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9See the License for the specific language governing permissions and
10limitations under the License.
11==============================================================================*/
12
13#include <sstream>
14#include <string>
15
16#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
17#include "tensorflow/c/kernels.h"
18#include "tensorflow/c/tf_status.h"
19#include "tensorflow/c/tf_tensor.h"
20#include "tensorflow/core/framework/registration/registration.h"
21#include "tensorflow/core/framework/summary.pb.h"
22#include "tensorflow/core/framework/types.h"
23#include "tensorflow/core/lib/histogram/histogram.h"
24#include "tensorflow/core/platform/bfloat16.h"
25#include "tensorflow/core/platform/logging.h"
26#include "tensorflow/core/platform/macros.h"
27#include "tensorflow/core/platform/protobuf.h"
28#include "tensorflow/core/platform/tstring.h"
29#include "tensorflow/core/platform/types.h"
30
31namespace {
32
33// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status.
34struct TFTensorDeleter {
35 void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); }
36};
37
38struct TFStatusDeleter {
39 void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
40};
41
42// Struct that wraps TF_Tensor and TF_Status to delete once out of scope.
43using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>;
44using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>;
45
46// Used to pass the operation node name from kernel construction to
47// kernel computation.
48struct HistogramSummaryOp {
49 std::string op_node_name;
50};
51
52void* HistogramSummaryOp_Create(TF_OpKernelConstruction* ctx) {
53 HistogramSummaryOp* kernel = new HistogramSummaryOp;
54 TF_StringView string_view_name = TF_OpKernelConstruction_GetName(ctx);
55 kernel->op_node_name =
56 std::string(string_view_name.data, string_view_name.len);
57 return kernel;
58}
59
60void HistogramSummaryOp_Delete(void* kernel) {
61 delete static_cast<HistogramSummaryOp*>(kernel);
62}
63
64template <typename T>
65void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
66 HistogramSummaryOp* k = static_cast<HistogramSummaryOp*>(kernel);
67 TF_Tensor* tags;
68 TF_Tensor* values;
69 Safe_TF_StatusPtr status(TF_NewStatus());
70 TF_GetInput(ctx, 0, &tags, status.get());
71 Safe_TF_TensorPtr safe_tags_ptr(tags);
72 if (TF_GetCode(status.get()) != TF_OK) {
73 TF_OpKernelContext_Failure(ctx, status.get());
74 return;
75 }
76 TF_GetInput(ctx, 1, &values, status.get());
77 Safe_TF_TensorPtr safe_values_ptr(values);
78 if (TF_GetCode(status.get()) != TF_OK) {
79 TF_OpKernelContext_Failure(ctx, status.get());
80 return;
81 }
82 if (TF_NumDims(safe_tags_ptr.get()) != 0) {
83 TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, "tags must be scalar");
84 TF_OpKernelContext_Failure(ctx, status.get());
85 return;
86 }
87 // Cast values to array to access tensor elements by index
88 auto values_array = static_cast<T*>(TF_TensorData(safe_values_ptr.get()));
89 tensorflow::histogram::Histogram histo;
90 for (int64_t i = 0; i < TF_TensorElementCount(safe_values_ptr.get()); ++i) {
91 const double double_val = static_cast<double>(values_array[i]);
92 if (Eigen::numext::isnan(double_val)) {
93 std::ostringstream err;
94 err << "Nan in summary histogram for: " << k->op_node_name;
95 TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
96 TF_OpKernelContext_Failure(ctx, status.get());
97 return;
98 } else if (Eigen::numext::isinf(double_val)) {
99 std::ostringstream err;
100 err << "Infinity in Histogram for: " << k->op_node_name;
101 TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
102 TF_OpKernelContext_Failure(ctx, status.get());
103 return;
104 }
105 histo.Add(double_val);
106 }
107 tensorflow::Summary s;
108 tensorflow::Summary::Value* v = s.add_value();
109 const tensorflow::tstring& tag =
110 *(static_cast<tensorflow::tstring*>(TF_TensorData(safe_tags_ptr.get())));
111 v->set_tag(tag.data(), tag.size());
112 histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
113
114 Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput(
115 /*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0),
116 /*dims=*/nullptr, /*num_dims=*/0,
117 /*len=*/sizeof(tensorflow::tstring), status.get()));
118
119 if (TF_GetCode(status.get()) != TF_OK) {
120 TF_OpKernelContext_Failure(ctx, status.get());
121 return;
122 }
123 tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
124 TF_TensorData(summary_tensor.get()));
125 CHECK(SerializeToTString(s, output_tstring));
126}
127
128template <typename T>
129void RegisterHistogramSummaryOpKernel() {
130 TF_Status* status = TF_NewStatus();
131 {
132 auto* builder = TF_NewKernelBuilder(
133 "HistogramSummary", tensorflow::DEVICE_CPU, &HistogramSummaryOp_Create,
134 &HistogramSummaryOp_Compute<T>, &HistogramSummaryOp_Delete);
135 TF_KernelBuilder_TypeConstraint(
136 builder, "T",
137 static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
138 CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
139 TF_RegisterKernelBuilder("HistogramSummary", builder, status);
140 CHECK_EQ(TF_OK, TF_GetCode(status))
141 << "Error while registering Histogram Summmary kernel";
142 }
143 TF_DeleteStatus(status);
144}
145
146// A dummy static variable initialized by a lambda whose side-effect is to
147// register the Histogram Summary kernel.
148TF_ATTRIBUTE_UNUSED static bool IsHistogramSummaryOpKernelRegistered = []() {
149 if (SHOULD_REGISTER_OP_KERNEL("HistogramSummary")) {
150 RegisterHistogramSummaryOpKernel<int64_t>();
151 RegisterHistogramSummaryOpKernel<tensorflow::uint64>();
152 RegisterHistogramSummaryOpKernel<tensorflow::int32>();
153 RegisterHistogramSummaryOpKernel<tensorflow::uint32>();
154 RegisterHistogramSummaryOpKernel<tensorflow::uint16>();
155 RegisterHistogramSummaryOpKernel<tensorflow::int16>();
156 RegisterHistogramSummaryOpKernel<tensorflow::int8>();
157 RegisterHistogramSummaryOpKernel<tensorflow::uint8>();
158 RegisterHistogramSummaryOpKernel<Eigen::half>();
159 RegisterHistogramSummaryOpKernel<tensorflow::bfloat16>();
160 RegisterHistogramSummaryOpKernel<float>();
161 RegisterHistogramSummaryOpKernel<double>();
162 }
163 return true;
164}();
165} // namespace
166