1
2/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15==============================================================================*/
16
17#include <sstream>
18#include <string>
19
20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21#include "tensorflow/c/kernels.h"
22#include "tensorflow/c/kernels/tensor_shape_utils.h"
23#include "tensorflow/c/tf_status.h"
24#include "tensorflow/c/tf_tensor.h"
25#include "tensorflow/core/framework/registration/registration.h"
26#include "tensorflow/core/framework/summary.pb.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/platform/bfloat16.h"
29#include "tensorflow/core/platform/logging.h"
30#include "tensorflow/core/platform/macros.h"
31#include "tensorflow/core/platform/protobuf.h"
32#include "tensorflow/core/platform/strcat.h"
33#include "tensorflow/core/platform/tstring.h"
34#include "tensorflow/core/platform/types.h"
35
36namespace {
37
38// Struct that stores the status and TF_Tensor inputs to the opkernel.
39// Used to delete tensor and status in its destructor upon kernel return.
40struct Params {
41 TF_Tensor* tags;
42 TF_Tensor* values;
43 TF_Status* status;
44 explicit Params(TF_OpKernelContext* ctx)
45 : tags(nullptr), values(nullptr), status(nullptr) {
46 status = TF_NewStatus();
47 TF_GetInput(ctx, 0, &tags, status);
48 if (TF_GetCode(status) == TF_OK) {
49 TF_GetInput(ctx, 1, &values, status);
50 }
51 }
52 ~Params() {
53 TF_DeleteStatus(status);
54 TF_DeleteTensor(tags);
55 TF_DeleteTensor(values);
56 }
57};
58
59// dummy functions used for kernel registration
60void* ScalarSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
61
62void ScalarSummaryOp_Delete(void* kernel) {}
63
64// Helper functions for compute method
65bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2);
66// Returns a string representation of a single tag or empty string if there
67// are multiple tags
68std::string SingleTag(TF_Tensor* tags);
69
70template <typename T>
71void ScalarSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
72 Params params(ctx);
73 if (TF_GetCode(params.status) != TF_OK) {
74 TF_OpKernelContext_Failure(ctx, params.status);
75 return;
76 }
77 if (!IsSameSize(params.tags, params.values)) {
78 std::ostringstream err;
79 err << "tags and values are not the same shape: "
80 << tensorflow::ShapeDebugString(params.tags)
81 << " != " << tensorflow::ShapeDebugString(params.values)
82 << SingleTag(params.tags);
83 TF_SetStatus(params.status, TF_INVALID_ARGUMENT, err.str().c_str());
84 TF_OpKernelContext_Failure(ctx, params.status);
85 return;
86 }
87 // Convert tags and values tensor to array to access elements by index
88 tensorflow::Summary s;
89 auto tags_array =
90 static_cast<tensorflow::tstring*>(TF_TensorData(params.tags));
91 auto values_array = static_cast<T*>(TF_TensorData(params.values));
92 // Copy tags and values into summary protobuf
93 for (int i = 0; i < TF_TensorElementCount(params.tags); ++i) {
94 tensorflow::Summary::Value* v = s.add_value();
95 const tensorflow::tstring& Ttags_i = tags_array[i];
96 v->set_tag(Ttags_i.data(), Ttags_i.size());
97 v->set_simple_value(static_cast<float>(values_array[i]));
98 }
99 TF_Tensor* summary_tensor =
100 TF_AllocateOutput(ctx, 0, TF_ExpectedOutputDataType(ctx, 0), nullptr, 0,
101 sizeof(tensorflow::tstring), params.status);
102 if (TF_GetCode(params.status) != TF_OK) {
103 TF_DeleteTensor(summary_tensor);
104 TF_OpKernelContext_Failure(ctx, params.status);
105 return;
106 }
107 tensorflow::tstring* output_tstring =
108 reinterpret_cast<tensorflow::tstring*>(TF_TensorData(summary_tensor));
109 CHECK(SerializeToTString(s, output_tstring));
110 TF_DeleteTensor(summary_tensor);
111}
112
113bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2) {
114 if (TF_NumDims(tensor1) != TF_NumDims(tensor2)) {
115 return false;
116 }
117 for (int d = 0; d < TF_NumDims(tensor1); d++) {
118 if (TF_Dim(tensor1, d) != TF_Dim(tensor2, d)) {
119 return false;
120 }
121 }
122 return true;
123}
124
125std::string SingleTag(TF_Tensor* tags) {
126 if (TF_TensorElementCount(tags) == 1) {
127 const char* single_tag =
128 static_cast<tensorflow::tstring*>(TF_TensorData(tags))->c_str();
129 return tensorflow::strings::StrCat(" (tag '", single_tag, "')");
130 } else {
131 return "";
132 }
133}
134
135template <typename T>
136void RegisterScalarSummaryOpKernel() {
137 TF_Status* status = TF_NewStatus();
138 {
139 auto* builder = TF_NewKernelBuilder(
140 "ScalarSummary", tensorflow::DEVICE_CPU, &ScalarSummaryOp_Create,
141 &ScalarSummaryOp_Compute<T>, &ScalarSummaryOp_Delete);
142 TF_KernelBuilder_TypeConstraint(
143 builder, "T",
144 static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
145 CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
146 TF_RegisterKernelBuilder("ScalarSummary", builder, status);
147 CHECK_EQ(TF_OK, TF_GetCode(status))
148 << "Error while registering Scalar Summmary kernel";
149 }
150 TF_DeleteStatus(status);
151}
152
153// A dummy static variable initialized by a lambda whose side-effect is to
154// register the ScalarSummary kernel.
155TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() {
156 if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary")) {
157 RegisterScalarSummaryOpKernel<int64_t>();
158 RegisterScalarSummaryOpKernel<tensorflow::uint64>();
159 RegisterScalarSummaryOpKernel<tensorflow::int32>();
160 RegisterScalarSummaryOpKernel<tensorflow::uint32>();
161 RegisterScalarSummaryOpKernel<tensorflow::uint16>();
162 RegisterScalarSummaryOpKernel<tensorflow::int16>();
163 RegisterScalarSummaryOpKernel<tensorflow::int8>();
164 RegisterScalarSummaryOpKernel<tensorflow::uint8>();
165 RegisterScalarSummaryOpKernel<Eigen::half>();
166 RegisterScalarSummaryOpKernel<tensorflow::bfloat16>();
167 RegisterScalarSummaryOpKernel<float>();
168 RegisterScalarSummaryOpKernel<double>();
169 }
170 return true;
171}();
172} // namespace
173