1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <memory> |
16 | #include <sstream> |
17 | #include <unordered_set> |
18 | |
19 | #include "tensorflow/c/kernels.h" |
20 | #include "tensorflow/c/tf_status.h" |
21 | #include "tensorflow/c/tf_tensor.h" |
22 | #include "tensorflow/core/framework/registration/registration.h" |
23 | #include "tensorflow/core/framework/summary.pb.h" |
24 | #include "tensorflow/core/framework/types.h" |
25 | #include "tensorflow/core/platform/logging.h" |
26 | #include "tensorflow/core/platform/macros.h" |
27 | #include "tensorflow/core/platform/protobuf.h" |
28 | #include "tensorflow/core/platform/tstring.h" |
29 | |
30 | namespace { |
31 | |
32 | // Operators used to create a std::unique_ptr for TF_Tensor and TF_Status |
33 | struct TFTensorDeleter { |
34 | void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); } |
35 | }; |
36 | |
37 | struct TFStatusDeleter { |
38 | void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); } |
39 | }; |
40 | |
41 | // Struct that wraps TF_Tensor and TF_Status to delete once out of scope |
42 | using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>; |
43 | using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>; |
44 | |
45 | // dummy functions used for kernel registration |
46 | void* MergeSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; } |
47 | |
48 | void MergeSummaryOp_Delete(void* kernel) {} |
49 | |
50 | void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { |
51 | tensorflow::Summary s; |
52 | std::unordered_set<tensorflow::string> tags; |
53 | Safe_TF_StatusPtr status(TF_NewStatus()); |
54 | for (int input_num = 0; input_num < TF_NumInputs(ctx); ++input_num) { |
55 | TF_Tensor* input; |
56 | TF_GetInput(ctx, input_num, &input, status.get()); |
57 | Safe_TF_TensorPtr safe_input_ptr(input); |
58 | if (TF_GetCode(status.get()) != TF_OK) { |
59 | TF_OpKernelContext_Failure(ctx, status.get()); |
60 | return; |
61 | } |
62 | auto tags_array = |
63 | static_cast<tensorflow::tstring*>(TF_TensorData(safe_input_ptr.get())); |
64 | for (int i = 0; i < TF_TensorElementCount(safe_input_ptr.get()); ++i) { |
65 | const tensorflow::tstring& s_in = tags_array[i]; |
66 | tensorflow::Summary summary_in; |
67 | if (!tensorflow::ParseProtoUnlimited(&summary_in, s_in)) { |
68 | TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, |
69 | "Could not parse one of the summary inputs" ); |
70 | TF_OpKernelContext_Failure(ctx, status.get()); |
71 | return; |
72 | } |
73 | for (int v = 0; v < summary_in.value_size(); ++v) { |
74 | // This tag is unused by the TensorSummary op, so no need to check for |
75 | // duplicates. |
76 | const tensorflow::string& tag = summary_in.value(v).tag(); |
77 | if ((!tag.empty()) && !tags.insert(tag).second) { |
78 | std::ostringstream err; |
79 | err << "Duplicate tag " << tag << " found in summary inputs " ; |
80 | TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str()); |
81 | TF_OpKernelContext_Failure(ctx, status.get()); |
82 | return; |
83 | } |
84 | *s.add_value() = summary_in.value(v); |
85 | } |
86 | } |
87 | } |
88 | Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput( |
89 | /*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0), |
90 | /*dims=*/nullptr, /*num_dims=*/0, |
91 | /*len=*/sizeof(tensorflow::tstring), status.get())); |
92 | if (TF_GetCode(status.get()) != TF_OK) { |
93 | TF_OpKernelContext_Failure(ctx, status.get()); |
94 | return; |
95 | } |
96 | tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>( |
97 | TF_TensorData(summary_tensor.get())); |
98 | CHECK(SerializeToTString(s, output_tstring)); |
99 | } |
100 | |
101 | void RegisterMergeSummaryOpKernel() { |
102 | TF_Status* status = TF_NewStatus(); |
103 | { |
104 | auto* builder = TF_NewKernelBuilder( |
105 | "MergeSummary" , tensorflow::DEVICE_CPU, &MergeSummaryOp_Create, |
106 | &MergeSummaryOp_Compute, &MergeSummaryOp_Delete); |
107 | TF_RegisterKernelBuilder("MergeSummary" , builder, status); |
108 | CHECK_EQ(TF_OK, TF_GetCode(status)) |
109 | << "Error while registering Merge Summmary kernel" ; |
110 | } |
111 | TF_DeleteStatus(status); |
112 | } |
113 | |
114 | // A dummy static variable initialized by a lambda whose side-effect is to |
115 | // register the Histogram Summary kernel. |
116 | TF_ATTRIBUTE_UNUSED static bool IsMergeSummaryOpKernelRegistered = []() { |
117 | if (SHOULD_REGISTER_OP_KERNEL("MergeSummary" )) { |
118 | RegisterMergeSummaryOpKernel(); |
119 | } |
120 | return true; |
121 | }(); |
122 | |
123 | } // namespace |
124 | |