1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | Licensed under the Apache License, Version 2.0 (the "License"); |
3 | you may not use this file except in compliance with the License. |
4 | You may obtain a copy of the License at |
5 | http://www.apache.org/licenses/LICENSE-2.0 |
6 | Unless required by applicable law or agreed to in writing, software |
7 | distributed under the License is distributed on an "AS IS" BASIS, |
8 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
9 | See the License for the specific language governing permissions and |
10 | limitations under the License. |
11 | ==============================================================================*/ |
12 | |
13 | #include <sstream> |
14 | #include <string> |
15 | |
16 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
17 | #include "tensorflow/c/kernels.h" |
18 | #include "tensorflow/c/tf_status.h" |
19 | #include "tensorflow/c/tf_tensor.h" |
20 | #include "tensorflow/core/framework/registration/registration.h" |
21 | #include "tensorflow/core/framework/summary.pb.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | #include "tensorflow/core/lib/histogram/histogram.h" |
24 | #include "tensorflow/core/platform/bfloat16.h" |
25 | #include "tensorflow/core/platform/logging.h" |
26 | #include "tensorflow/core/platform/macros.h" |
27 | #include "tensorflow/core/platform/protobuf.h" |
28 | #include "tensorflow/core/platform/tstring.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | |
31 | namespace { |
32 | |
33 | // Operators used to create a std::unique_ptr for TF_Tensor and TF_Status. |
34 | struct TFTensorDeleter { |
35 | void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); } |
36 | }; |
37 | |
38 | struct TFStatusDeleter { |
39 | void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); } |
40 | }; |
41 | |
42 | // Struct that wraps TF_Tensor and TF_Status to delete once out of scope. |
43 | using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>; |
44 | using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>; |
45 | |
46 | // Used to pass the operation node name from kernel construction to |
47 | // kernel computation. |
48 | struct HistogramSummaryOp { |
49 | std::string op_node_name; |
50 | }; |
51 | |
52 | void* HistogramSummaryOp_Create(TF_OpKernelConstruction* ctx) { |
53 | HistogramSummaryOp* kernel = new HistogramSummaryOp; |
54 | TF_StringView string_view_name = TF_OpKernelConstruction_GetName(ctx); |
55 | kernel->op_node_name = |
56 | std::string(string_view_name.data, string_view_name.len); |
57 | return kernel; |
58 | } |
59 | |
60 | void HistogramSummaryOp_Delete(void* kernel) { |
61 | delete static_cast<HistogramSummaryOp*>(kernel); |
62 | } |
63 | |
64 | template <typename T> |
65 | void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { |
66 | HistogramSummaryOp* k = static_cast<HistogramSummaryOp*>(kernel); |
67 | TF_Tensor* tags; |
68 | TF_Tensor* values; |
69 | Safe_TF_StatusPtr status(TF_NewStatus()); |
70 | TF_GetInput(ctx, 0, &tags, status.get()); |
71 | Safe_TF_TensorPtr safe_tags_ptr(tags); |
72 | if (TF_GetCode(status.get()) != TF_OK) { |
73 | TF_OpKernelContext_Failure(ctx, status.get()); |
74 | return; |
75 | } |
76 | TF_GetInput(ctx, 1, &values, status.get()); |
77 | Safe_TF_TensorPtr safe_values_ptr(values); |
78 | if (TF_GetCode(status.get()) != TF_OK) { |
79 | TF_OpKernelContext_Failure(ctx, status.get()); |
80 | return; |
81 | } |
82 | if (TF_NumDims(safe_tags_ptr.get()) != 0) { |
83 | TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, "tags must be scalar" ); |
84 | TF_OpKernelContext_Failure(ctx, status.get()); |
85 | return; |
86 | } |
87 | // Cast values to array to access tensor elements by index |
88 | auto values_array = static_cast<T*>(TF_TensorData(safe_values_ptr.get())); |
89 | tensorflow::histogram::Histogram histo; |
90 | for (int64_t i = 0; i < TF_TensorElementCount(safe_values_ptr.get()); ++i) { |
91 | const double double_val = static_cast<double>(values_array[i]); |
92 | if (Eigen::numext::isnan(double_val)) { |
93 | std::ostringstream err; |
94 | err << "Nan in summary histogram for: " << k->op_node_name; |
95 | TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str()); |
96 | TF_OpKernelContext_Failure(ctx, status.get()); |
97 | return; |
98 | } else if (Eigen::numext::isinf(double_val)) { |
99 | std::ostringstream err; |
100 | err << "Infinity in Histogram for: " << k->op_node_name; |
101 | TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str()); |
102 | TF_OpKernelContext_Failure(ctx, status.get()); |
103 | return; |
104 | } |
105 | histo.Add(double_val); |
106 | } |
107 | tensorflow::Summary s; |
108 | tensorflow::Summary::Value* v = s.add_value(); |
109 | const tensorflow::tstring& tag = |
110 | *(static_cast<tensorflow::tstring*>(TF_TensorData(safe_tags_ptr.get()))); |
111 | v->set_tag(tag.data(), tag.size()); |
112 | histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */); |
113 | |
114 | Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput( |
115 | /*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0), |
116 | /*dims=*/nullptr, /*num_dims=*/0, |
117 | /*len=*/sizeof(tensorflow::tstring), status.get())); |
118 | |
119 | if (TF_GetCode(status.get()) != TF_OK) { |
120 | TF_OpKernelContext_Failure(ctx, status.get()); |
121 | return; |
122 | } |
123 | tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>( |
124 | TF_TensorData(summary_tensor.get())); |
125 | CHECK(SerializeToTString(s, output_tstring)); |
126 | } |
127 | |
128 | template <typename T> |
129 | void RegisterHistogramSummaryOpKernel() { |
130 | TF_Status* status = TF_NewStatus(); |
131 | { |
132 | auto* builder = TF_NewKernelBuilder( |
133 | "HistogramSummary" , tensorflow::DEVICE_CPU, &HistogramSummaryOp_Create, |
134 | &HistogramSummaryOp_Compute<T>, &HistogramSummaryOp_Delete); |
135 | TF_KernelBuilder_TypeConstraint( |
136 | builder, "T" , |
137 | static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status); |
138 | CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint" ; |
139 | TF_RegisterKernelBuilder("HistogramSummary" , builder, status); |
140 | CHECK_EQ(TF_OK, TF_GetCode(status)) |
141 | << "Error while registering Histogram Summmary kernel" ; |
142 | } |
143 | TF_DeleteStatus(status); |
144 | } |
145 | |
146 | // A dummy static variable initialized by a lambda whose side-effect is to |
147 | // register the Histogram Summary kernel. |
148 | TF_ATTRIBUTE_UNUSED static bool IsHistogramSummaryOpKernelRegistered = []() { |
149 | if (SHOULD_REGISTER_OP_KERNEL("HistogramSummary" )) { |
150 | RegisterHistogramSummaryOpKernel<int64_t>(); |
151 | RegisterHistogramSummaryOpKernel<tensorflow::uint64>(); |
152 | RegisterHistogramSummaryOpKernel<tensorflow::int32>(); |
153 | RegisterHistogramSummaryOpKernel<tensorflow::uint32>(); |
154 | RegisterHistogramSummaryOpKernel<tensorflow::uint16>(); |
155 | RegisterHistogramSummaryOpKernel<tensorflow::int16>(); |
156 | RegisterHistogramSummaryOpKernel<tensorflow::int8>(); |
157 | RegisterHistogramSummaryOpKernel<tensorflow::uint8>(); |
158 | RegisterHistogramSummaryOpKernel<Eigen::half>(); |
159 | RegisterHistogramSummaryOpKernel<tensorflow::bfloat16>(); |
160 | RegisterHistogramSummaryOpKernel<float>(); |
161 | RegisterHistogramSummaryOpKernel<double>(); |
162 | } |
163 | return true; |
164 | }(); |
165 | } // namespace |
166 | |