1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/logging_ops.h"
17
18#include <iostream>
19
20#include "absl/strings/str_cat.h"
21#include "tensorflow/core/framework/logging.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/lib/core/status.h"
24#include "tensorflow/core/lib/strings/str_util.h"
25#include "tensorflow/core/util/determinism.h"
26
27namespace tensorflow {
28
29namespace {
30
31// If the following string is found at the beginning of an output stream, it
32// will be interpreted as a file path.
33const char kOutputStreamEscapeStr[] = "file://";
34
35// A mutex that guards appending strings to files.
36static mutex* file_mutex = new mutex();
37
38// Appends the given data to the specified file. It will create the file if it
39// doesn't already exist.
40Status AppendStringToFile(const std::string& fname, StringPiece data,
41 Env* env) {
42 // TODO(ckluk): If opening and closing on every log causes performance issues,
43 // we can reimplement using reference counters.
44 mutex_lock l(*file_mutex);
45 std::unique_ptr<WritableFile> file;
46 TF_RETURN_IF_ERROR(env->NewAppendableFile(fname, &file));
47 Status a = file->Append(data);
48 Status c = file->Close();
49 return a.ok() ? c : a;
50}
51
52} // namespace
53
54AssertOp::AssertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
55 OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
56}
57
58void AssertOp::Compute(OpKernelContext* ctx) {
59 const Tensor& cond = ctx->input(0);
60 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cond.shape()),
61 errors::InvalidArgument("In[0] should be a scalar: ",
62 cond.shape().DebugString()));
63
64 if (cond.scalar<bool>()()) {
65 return;
66 }
67 string msg = "assertion failed: ";
68 for (int i = 1; i < ctx->num_inputs(); ++i) {
69 strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_),
70 "]");
71 if (i < ctx->num_inputs() - 1) strings::StrAppend(&msg, " ");
72 }
73 ctx->SetStatus(errors::InvalidArgument(msg));
74}
75
76REGISTER_KERNEL_BUILDER(Name("Assert")
77 .Device(DEVICE_DEFAULT)
78 .HostMemory("condition")
79 .HostMemory("data"),
80 AssertOp);
81
82class PrintOp : public OpKernel {
83 public:
84 explicit PrintOp(OpKernelConstruction* ctx)
85 : OpKernel(ctx), call_counter_(0) {
86 OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &message_));
87 OP_REQUIRES_OK(ctx, ctx->GetAttr("first_n", &first_n_));
88 OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
89 }
90
91 void Compute(OpKernelContext* ctx) override {
92 if (IsRefType(ctx->input_dtype(0))) {
93 ctx->forward_ref_input_to_ref_output(0, 0);
94 } else {
95 ctx->set_output(0, ctx->input(0));
96 }
97 if (first_n_ >= 0) {
98 mutex_lock l(mu_);
99 if (call_counter_ >= first_n_) return;
100 call_counter_++;
101 }
102 string msg;
103 strings::StrAppend(&msg, message_);
104 for (int i = 1; i < ctx->num_inputs(); ++i) {
105 strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_),
106 "]");
107 }
108 std::cerr << msg << std::endl;
109 }
110
111 private:
112 mutex mu_;
113 int64_t call_counter_ TF_GUARDED_BY(mu_) = 0;
114 int64_t first_n_ = 0;
115 int32 summarize_ = 0;
116 string message_;
117};
118
119REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
120
121class PrintV2Op : public OpKernel {
122 public:
123 explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
124 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
125 OP_REQUIRES_OK(ctx, ctx->GetAttr("end", &end_));
126
127 SetFilePathIfAny();
128 if (!file_path_.empty()) return;
129
130 auto output_stream_index =
131 std::find(std::begin(valid_output_streams_),
132 std::end(valid_output_streams_), output_stream_);
133
134 if (output_stream_index == std::end(valid_output_streams_)) {
135 string error_msg = strings::StrCat(
136 "Unknown output stream: ", output_stream_, ", Valid streams are:");
137 for (auto valid_stream : valid_output_streams_) {
138 strings::StrAppend(&error_msg, " ", valid_stream);
139 }
140 OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
141 }
142 }
143
144 void Compute(OpKernelContext* ctx) override {
145 const Tensor* input_;
146 OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
147 OP_REQUIRES(
148 ctx, TensorShapeUtils::IsScalar(input_->shape()),
149 errors::InvalidArgument("Input is expected to be scalar, but got ",
150 input_->shape()));
151 const string& msg = input_->scalar<tstring>()();
152
153 string ended_msg = strings::StrCat(msg, end_);
154
155 if (!file_path_.empty()) {
156 // Outputs to a file at the specified path.
157 OP_REQUIRES_OK(ctx,
158 AppendStringToFile(file_path_, ended_msg, ctx->env()));
159 return;
160 }
161
162 if (logging::LogToListeners(ended_msg, "")) {
163 return;
164 }
165
166 if (output_stream_ == "stdout") {
167 std::cout << ended_msg << std::flush;
168 } else if (output_stream_ == "stderr") {
169 std::cerr << ended_msg << std::flush;
170 } else if (output_stream_ == "log(info)") {
171 LOG(INFO) << ended_msg << std::flush;
172 } else if (output_stream_ == "log(warning)") {
173 LOG(WARNING) << ended_msg << std::flush;
174 } else if (output_stream_ == "log(error)") {
175 LOG(ERROR) << ended_msg << std::flush;
176 } else {
177 string error_msg = strings::StrCat(
178 "Unknown output stream: ", output_stream_, ", Valid streams are:");
179 for (auto valid_stream : valid_output_streams_) {
180 strings::StrAppend(&error_msg, " ", valid_stream);
181 }
182 strings::StrAppend(&error_msg, ", or file://<filename>");
183 OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
184 }
185 }
186
187 const char* valid_output_streams_[5] = {"stdout", "stderr", "log(info)",
188 "log(warning)", "log(error)"};
189
190 private:
191 string end_;
192 // Either output_stream_ or file_path_ (but not both) will be non-empty.
193 string output_stream_;
194 string file_path_;
195
196 // If output_stream_ is a file path, extracts it to file_path_ and clears
197 // output_stream_; otherwise sets file_paths_ to "".
198 void SetFilePathIfAny() {
199 if (absl::StartsWith(output_stream_, kOutputStreamEscapeStr)) {
200 file_path_ = output_stream_.substr(strlen(kOutputStreamEscapeStr));
201 output_stream_ = "";
202 } else {
203 file_path_ = "";
204 }
205 }
206};
207
208REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
209
210class TimestampOp : public OpKernel {
211 public:
212 explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
213
214 void Compute(OpKernelContext* context) override {
215 OP_REQUIRES(context, !OpDeterminismRequired(),
216 errors::FailedPrecondition(
217 "Timestamp cannot be called when determinism is enabled"));
218 TensorShape output_shape; // Default shape is 0 dim, 1 element
219 Tensor* output_tensor = nullptr;
220 OP_REQUIRES_OK(context,
221 context->allocate_output(0, output_shape, &output_tensor));
222
223 auto output_scalar = output_tensor->scalar<double>();
224 double now_us = static_cast<double>(Env::Default()->NowMicros());
225 double now_s = now_us / 1000000;
226 output_scalar() = now_s;
227 }
228};
229
230REGISTER_KERNEL_BUILDER(Name("Timestamp").Device(DEVICE_CPU), TimestampOp);
231
232} // end namespace tensorflow
233