1 | |
2 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
3 | |
4 | Licensed under the Apache License, Version 2.0 (the "License"); |
5 | you may not use this file except in compliance with the License. |
6 | You may obtain a copy of the License at |
7 | |
8 | http://www.apache.org/licenses/LICENSE-2.0 |
9 | |
10 | Unless required by applicable law or agreed to in writing, software |
11 | distributed under the License is distributed on an "AS IS" BASIS, |
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | See the License for the specific language governing permissions and |
14 | limitations under the License. |
15 | ==============================================================================*/ |
16 | |
17 | #include <sstream> |
18 | #include <string> |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/c/kernels.h" |
22 | #include "tensorflow/c/kernels/tensor_shape_utils.h" |
23 | #include "tensorflow/c/tf_status.h" |
24 | #include "tensorflow/c/tf_tensor.h" |
25 | #include "tensorflow/core/framework/registration/registration.h" |
26 | #include "tensorflow/core/framework/summary.pb.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/platform/bfloat16.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | #include "tensorflow/core/platform/macros.h" |
31 | #include "tensorflow/core/platform/protobuf.h" |
32 | #include "tensorflow/core/platform/strcat.h" |
33 | #include "tensorflow/core/platform/tstring.h" |
34 | #include "tensorflow/core/platform/types.h" |
35 | |
36 | namespace { |
37 | |
38 | // Struct that stores the status and TF_Tensor inputs to the opkernel. |
39 | // Used to delete tensor and status in its destructor upon kernel return. |
40 | struct Params { |
41 | TF_Tensor* tags; |
42 | TF_Tensor* values; |
43 | TF_Status* status; |
44 | explicit Params(TF_OpKernelContext* ctx) |
45 | : tags(nullptr), values(nullptr), status(nullptr) { |
46 | status = TF_NewStatus(); |
47 | TF_GetInput(ctx, 0, &tags, status); |
48 | if (TF_GetCode(status) == TF_OK) { |
49 | TF_GetInput(ctx, 1, &values, status); |
50 | } |
51 | } |
52 | ~Params() { |
53 | TF_DeleteStatus(status); |
54 | TF_DeleteTensor(tags); |
55 | TF_DeleteTensor(values); |
56 | } |
57 | }; |
58 | |
59 | // dummy functions used for kernel registration |
60 | void* ScalarSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; } |
61 | |
62 | void ScalarSummaryOp_Delete(void* kernel) {} |
63 | |
64 | // Helper functions for compute method |
65 | bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2); |
66 | // Returns a string representation of a single tag or empty string if there |
67 | // are multiple tags |
68 | std::string SingleTag(TF_Tensor* tags); |
69 | |
70 | template <typename T> |
71 | void ScalarSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { |
72 | Params params(ctx); |
73 | if (TF_GetCode(params.status) != TF_OK) { |
74 | TF_OpKernelContext_Failure(ctx, params.status); |
75 | return; |
76 | } |
77 | if (!IsSameSize(params.tags, params.values)) { |
78 | std::ostringstream err; |
79 | err << "tags and values are not the same shape: " |
80 | << tensorflow::ShapeDebugString(params.tags) |
81 | << " != " << tensorflow::ShapeDebugString(params.values) |
82 | << SingleTag(params.tags); |
83 | TF_SetStatus(params.status, TF_INVALID_ARGUMENT, err.str().c_str()); |
84 | TF_OpKernelContext_Failure(ctx, params.status); |
85 | return; |
86 | } |
87 | // Convert tags and values tensor to array to access elements by index |
88 | tensorflow::Summary s; |
89 | auto tags_array = |
90 | static_cast<tensorflow::tstring*>(TF_TensorData(params.tags)); |
91 | auto values_array = static_cast<T*>(TF_TensorData(params.values)); |
92 | // Copy tags and values into summary protobuf |
93 | for (int i = 0; i < TF_TensorElementCount(params.tags); ++i) { |
94 | tensorflow::Summary::Value* v = s.add_value(); |
95 | const tensorflow::tstring& Ttags_i = tags_array[i]; |
96 | v->set_tag(Ttags_i.data(), Ttags_i.size()); |
97 | v->set_simple_value(static_cast<float>(values_array[i])); |
98 | } |
99 | TF_Tensor* summary_tensor = |
100 | TF_AllocateOutput(ctx, 0, TF_ExpectedOutputDataType(ctx, 0), nullptr, 0, |
101 | sizeof(tensorflow::tstring), params.status); |
102 | if (TF_GetCode(params.status) != TF_OK) { |
103 | TF_DeleteTensor(summary_tensor); |
104 | TF_OpKernelContext_Failure(ctx, params.status); |
105 | return; |
106 | } |
107 | tensorflow::tstring* output_tstring = |
108 | reinterpret_cast<tensorflow::tstring*>(TF_TensorData(summary_tensor)); |
109 | CHECK(SerializeToTString(s, output_tstring)); |
110 | TF_DeleteTensor(summary_tensor); |
111 | } |
112 | |
113 | bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2) { |
114 | if (TF_NumDims(tensor1) != TF_NumDims(tensor2)) { |
115 | return false; |
116 | } |
117 | for (int d = 0; d < TF_NumDims(tensor1); d++) { |
118 | if (TF_Dim(tensor1, d) != TF_Dim(tensor2, d)) { |
119 | return false; |
120 | } |
121 | } |
122 | return true; |
123 | } |
124 | |
125 | std::string SingleTag(TF_Tensor* tags) { |
126 | if (TF_TensorElementCount(tags) == 1) { |
127 | const char* single_tag = |
128 | static_cast<tensorflow::tstring*>(TF_TensorData(tags))->c_str(); |
129 | return tensorflow::strings::StrCat(" (tag '" , single_tag, "')" ); |
130 | } else { |
131 | return "" ; |
132 | } |
133 | } |
134 | |
135 | template <typename T> |
136 | void RegisterScalarSummaryOpKernel() { |
137 | TF_Status* status = TF_NewStatus(); |
138 | { |
139 | auto* builder = TF_NewKernelBuilder( |
140 | "ScalarSummary" , tensorflow::DEVICE_CPU, &ScalarSummaryOp_Create, |
141 | &ScalarSummaryOp_Compute<T>, &ScalarSummaryOp_Delete); |
142 | TF_KernelBuilder_TypeConstraint( |
143 | builder, "T" , |
144 | static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status); |
145 | CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint" ; |
146 | TF_RegisterKernelBuilder("ScalarSummary" , builder, status); |
147 | CHECK_EQ(TF_OK, TF_GetCode(status)) |
148 | << "Error while registering Scalar Summmary kernel" ; |
149 | } |
150 | TF_DeleteStatus(status); |
151 | } |
152 | |
153 | // A dummy static variable initialized by a lambda whose side-effect is to |
154 | // register the ScalarSummary kernel. |
155 | TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() { |
156 | if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary" )) { |
157 | RegisterScalarSummaryOpKernel<int64_t>(); |
158 | RegisterScalarSummaryOpKernel<tensorflow::uint64>(); |
159 | RegisterScalarSummaryOpKernel<tensorflow::int32>(); |
160 | RegisterScalarSummaryOpKernel<tensorflow::uint32>(); |
161 | RegisterScalarSummaryOpKernel<tensorflow::uint16>(); |
162 | RegisterScalarSummaryOpKernel<tensorflow::int16>(); |
163 | RegisterScalarSummaryOpKernel<tensorflow::int8>(); |
164 | RegisterScalarSummaryOpKernel<tensorflow::uint8>(); |
165 | RegisterScalarSummaryOpKernel<Eigen::half>(); |
166 | RegisterScalarSummaryOpKernel<tensorflow::bfloat16>(); |
167 | RegisterScalarSummaryOpKernel<float>(); |
168 | RegisterScalarSummaryOpKernel<double>(); |
169 | } |
170 | return true; |
171 | }(); |
172 | } // namespace |
173 | |