1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <memory>
16#include <sstream>
17#include <unordered_set>
18
19#include "tensorflow/c/kernels.h"
20#include "tensorflow/c/tf_status.h"
21#include "tensorflow/c/tf_tensor.h"
22#include "tensorflow/core/framework/registration/registration.h"
23#include "tensorflow/core/framework/summary.pb.h"
24#include "tensorflow/core/framework/types.h"
25#include "tensorflow/core/platform/logging.h"
26#include "tensorflow/core/platform/macros.h"
27#include "tensorflow/core/platform/protobuf.h"
28#include "tensorflow/core/platform/tstring.h"
29
30namespace {
31
32// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status
33struct TFTensorDeleter {
34 void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); }
35};
36
37struct TFStatusDeleter {
38 void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
39};
40
41// Struct that wraps TF_Tensor and TF_Status to delete once out of scope
42using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>;
43using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>;
44
45// dummy functions used for kernel registration
46void* MergeSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
47
48void MergeSummaryOp_Delete(void* kernel) {}
49
50void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
51 tensorflow::Summary s;
52 std::unordered_set<tensorflow::string> tags;
53 Safe_TF_StatusPtr status(TF_NewStatus());
54 for (int input_num = 0; input_num < TF_NumInputs(ctx); ++input_num) {
55 TF_Tensor* input;
56 TF_GetInput(ctx, input_num, &input, status.get());
57 Safe_TF_TensorPtr safe_input_ptr(input);
58 if (TF_GetCode(status.get()) != TF_OK) {
59 TF_OpKernelContext_Failure(ctx, status.get());
60 return;
61 }
62 auto tags_array =
63 static_cast<tensorflow::tstring*>(TF_TensorData(safe_input_ptr.get()));
64 for (int i = 0; i < TF_TensorElementCount(safe_input_ptr.get()); ++i) {
65 const tensorflow::tstring& s_in = tags_array[i];
66 tensorflow::Summary summary_in;
67 if (!tensorflow::ParseProtoUnlimited(&summary_in, s_in)) {
68 TF_SetStatus(status.get(), TF_INVALID_ARGUMENT,
69 "Could not parse one of the summary inputs");
70 TF_OpKernelContext_Failure(ctx, status.get());
71 return;
72 }
73 for (int v = 0; v < summary_in.value_size(); ++v) {
74 // This tag is unused by the TensorSummary op, so no need to check for
75 // duplicates.
76 const tensorflow::string& tag = summary_in.value(v).tag();
77 if ((!tag.empty()) && !tags.insert(tag).second) {
78 std::ostringstream err;
79 err << "Duplicate tag " << tag << " found in summary inputs ";
80 TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
81 TF_OpKernelContext_Failure(ctx, status.get());
82 return;
83 }
84 *s.add_value() = summary_in.value(v);
85 }
86 }
87 }
88 Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput(
89 /*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0),
90 /*dims=*/nullptr, /*num_dims=*/0,
91 /*len=*/sizeof(tensorflow::tstring), status.get()));
92 if (TF_GetCode(status.get()) != TF_OK) {
93 TF_OpKernelContext_Failure(ctx, status.get());
94 return;
95 }
96 tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
97 TF_TensorData(summary_tensor.get()));
98 CHECK(SerializeToTString(s, output_tstring));
99}
100
101void RegisterMergeSummaryOpKernel() {
102 TF_Status* status = TF_NewStatus();
103 {
104 auto* builder = TF_NewKernelBuilder(
105 "MergeSummary", tensorflow::DEVICE_CPU, &MergeSummaryOp_Create,
106 &MergeSummaryOp_Compute, &MergeSummaryOp_Delete);
107 TF_RegisterKernelBuilder("MergeSummary", builder, status);
108 CHECK_EQ(TF_OK, TF_GetCode(status))
109 << "Error while registering Merge Summmary kernel";
110 }
111 TF_DeleteStatus(status);
112}
113
114// A dummy static variable initialized by a lambda whose side-effect is to
115// register the Histogram Summary kernel.
116TF_ATTRIBUTE_UNUSED static bool IsMergeSummaryOpKernelRegistered = []() {
117 if (SHOULD_REGISTER_OP_KERNEL("MergeSummary")) {
118 RegisterMergeSummaryOpKernel();
119 }
120 return true;
121}();
122
123} // namespace
124