1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/client/session_ref.h"
16
17#include <stdlib.h>
18#include <memory>
19#include <utility>
20
21#include "tensorflow/core/lib/io/path.h"
22#include "tensorflow/core/lib/io/record_writer.h"
23#include "tensorflow/core/lib/strings/stringprintf.h"
24#include "tensorflow/core/protobuf/master.pb.h"
25#include "tensorflow/core/protobuf/named_tensor.pb.h"
26#include "tensorflow/core/protobuf/replay_log.pb.h"
27
28namespace tensorflow {
29
30namespace {
31
32// Scope helper to track active calls and manage session lifetime.
33// SessionRef blocks closing until all active calls complete or are cancelled.
34struct RunCounter {
35 std::shared_ptr<Session> session;
36 uint64* value;
37 mutex* m;
38 condition_variable* cv;
39
40 explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
41 condition_variable* cv)
42 : session(std::move(s)), value(v), m(m), cv(cv) {
43 mutex_lock l(*m);
44 ++*value;
45 }
46
47 ~RunCounter() {
48 mutex_lock l(*m);
49 if (--*value == 0) {
50 cv->notify_all();
51 }
52 }
53};
54
55std::string SessionToHandle(Session* session) {
56 return strings::Printf("%llu", static_cast<unsigned long long>(
57 reinterpret_cast<uintptr_t>(session)));
58}
59
60// The Session interface has many methods of the form:
61//
62// X(a, b);
63// X(RunOptions, a, b);
64//
65// Not all sessions support the second case (with an empty RunOptions()).
66// We use this variable as a sentinel to dispatch to the correct call.
67RunOptions* kEmptyRunOptions() {
68 static RunOptions* options = new RunOptions();
69 return options;
70}
71
72} // namespace
73
74// Run the given session operation, recording start and end timestamps.
75// If the operation returns a bad status, return after flushing the current
76// log request. This should be run _after_ all request information has been
77// added to the current op.
78#define RUN_WITH_TIMESTAMP(OpName, ...) \
79 op.set_start_time_us(Env::Default()->NowMicros()); \
80 Status status = session->OpName(__VA_ARGS__); \
81 op.set_end_time_us(Env::Default()->NowMicros()); \
82 if (!status.ok()) { \
83 Flush(op).IgnoreError(); \
84 return status; \
85 }
86
87// Records requests (and optionally responses) performed against a session.
88// The resulting replay log can be used with the `tf_replay` tool to replicate
89// the operations against a simulated environment, without requiring the
90// original code or cluster setup.
91//
92// Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
93class SessionLogger {
94 public:
95 SessionLogger() {
96 const char* log_file_env = getenv("TF_REPLAY_LOG_FILE");
97 std::string log_name = log_file_env ? std::string(log_file_env) : ".";
98 LOG(INFO) << "Constructing new session logger for " << log_name;
99 TF_CHECK_OK(
100 Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
101 Env::Default()->DeleteFile(log_name).IgnoreError();
102
103 TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
104 log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
105 }
106
107 ~SessionLogger() {
108 log_writer_->Close().IgnoreError();
109 log_writer_.release();
110 log_file_->Close().IgnoreError();
111 }
112
113 Status RecordNewSession(Session* session) {
114 ReplayOp op;
115 NewReplaySession* req = op.mutable_new_replay_session();
116 req->set_session_handle(SessionToHandle(session));
117 return Flush(op);
118 }
119
120 Status RecordRun(Session* session,
121 const std::vector<std::pair<string, Tensor> >& inputs,
122 const std::vector<string>& output_tensor_names,
123 const std::vector<string>& target_node_names,
124 std::vector<Tensor>* outputs) {
125 return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
126 target_node_names, outputs, nullptr);
127 }
128
129 Status RecordRun(Session* session, const RunOptions& run_options,
130 const std::vector<std::pair<string, Tensor> >& inputs,
131 const std::vector<string>& output_tensor_names,
132 const std::vector<string>& target_node_names,
133 std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
134 ReplayOp op;
135 RunStepRequest* req = op.mutable_run_step();
136 RunStepResponse* resp = op.mutable_run_step_response();
137
138 req->set_session_handle(SessionToHandle(session));
139 *req->mutable_options() = run_options;
140
141 for (const auto& it : inputs) {
142 NamedTensorProto* feed = req->add_feed();
143 feed->set_name(it.first);
144 it.second.AsProtoField(feed->mutable_tensor());
145 }
146
147 // Build an index from fetch tensor name to first index in
148 // output_tensor_names.
149 std::unordered_map<string, int> output_name_to_offset;
150 for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
151 const string& name = output_tensor_names[i];
152 if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
153 req->add_fetch(name);
154 }
155 }
156 for (const string& target : target_node_names) {
157 req->add_target(target);
158 }
159
160 if (&run_options == kEmptyRunOptions()) {
161 RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
162 outputs);
163 } else {
164 RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
165 target_node_names, outputs, run_metadata);
166 }
167
168 for (size_t i = 0; i < outputs->size(); ++i) {
169 const Tensor& tensor = (*outputs)[i];
170 NamedTensorProto* tproto = resp->add_tensor();
171 tensor.AsProtoField(tproto->mutable_tensor());
172 tproto->set_name(output_tensor_names[i]);
173 }
174
175 if (run_metadata) {
176 *resp->mutable_metadata() = *run_metadata;
177 }
178
179 return Flush(op);
180 }
181
182 Status RecordCreate(Session* session, const GraphDef& graph) {
183 return RecordCreate(session, *kEmptyRunOptions(), graph);
184 }
185
186 // N.B. RunOptions is not stored (it has no entry in CreateRequest)
187 Status RecordCreate(Session* session, const RunOptions& run_options,
188 const GraphDef& graph) {
189 ReplayOp op;
190 CreateSessionRequest* req = op.mutable_create_session();
191 *req->mutable_graph_def() = graph;
192
193 CreateSessionResponse* resp = op.mutable_create_session_response();
194 if (&run_options == kEmptyRunOptions()) {
195 RUN_WITH_TIMESTAMP(Create, graph);
196 } else {
197 RUN_WITH_TIMESTAMP(Create, run_options, graph);
198 }
199 resp->set_session_handle(SessionToHandle(session));
200 return Flush(op);
201 }
202
203 Status RecordExtend(Session* session, const GraphDef& graph) {
204 return RecordExtend(session, *kEmptyRunOptions(), graph);
205 }
206
207 // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
208 Status RecordExtend(Session* session, const RunOptions& run_options,
209 const GraphDef& graph) {
210 ReplayOp op;
211 ExtendSessionRequest* req = op.mutable_extend_session();
212 op.mutable_extend_session_response();
213 req->set_session_handle(SessionToHandle(session));
214 *req->mutable_graph_def() = graph;
215 if (&run_options == kEmptyRunOptions()) {
216 RUN_WITH_TIMESTAMP(Extend, graph);
217 } else {
218 RUN_WITH_TIMESTAMP(Extend, run_options, graph);
219 }
220
221 return Flush(op);
222 }
223
224 Status RecordClose(Session* session) {
225 return RecordClose(session, *kEmptyRunOptions());
226 }
227
228 // N.B. RunOptions is not stored (it has no entry in CloseRequest)
229 Status RecordClose(Session* session, const RunOptions& run_options) {
230 ReplayOp op;
231 CloseSessionRequest* req = op.mutable_close_session();
232 req->set_session_handle(SessionToHandle(session));
233 op.mutable_close_session_response();
234 if (&run_options == kEmptyRunOptions()) {
235 RUN_WITH_TIMESTAMP(Close);
236 } else {
237 RUN_WITH_TIMESTAMP(Close, run_options);
238 }
239 return Flush(op);
240 }
241
242 Status RecordListDevices(Session* session,
243 std::vector<DeviceAttributes>* response) {
244 ReplayOp op;
245 ListDevicesRequest* req = op.mutable_list_devices();
246 ListDevicesResponse* resp = op.mutable_list_devices_response();
247 req->set_session_handle(SessionToHandle(session));
248 RUN_WITH_TIMESTAMP(ListDevices, response);
249
250 // TODO(power) -- local vs remote device distinction is lost here!
251 *resp->mutable_local_device() = {response->begin(), response->end()};
252 return Flush(op);
253 }
254
255 Status RecordPRunSetup(Session* session,
256 const std::vector<string>& input_names,
257 const std::vector<string>& output_names,
258 const std::vector<string>& target_nodes,
259 string* handle) {
260 ReplayOp op;
261 PartialRunSetupRequest* req = op.mutable_partial_run_setup();
262 req->set_session_handle(SessionToHandle(session));
263 for (auto& input : input_names) {
264 req->add_feed(input);
265 }
266 for (auto& output : output_names) {
267 req->add_fetch(output);
268 }
269 for (auto& target : target_nodes) {
270 req->add_target(target);
271 }
272 RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
273 handle);
274 op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
275 return Flush(op);
276 }
277
278 Status RecordPRun(Session* session, const string& handle,
279 const std::vector<std::pair<string, Tensor> >& inputs,
280 const std::vector<string>& output_names,
281 std::vector<Tensor>* outputs) {
282 ReplayOp op;
283 RunStepRequest* req = op.mutable_run_step();
284 RunStepResponse* resp = op.mutable_run_step_response();
285 req->set_session_handle(SessionToHandle(session));
286
287 // Mark this step as a partial run for replay.
288 req->set_partial_run_handle(handle);
289 for (auto& input : inputs) {
290 auto* feed = req->add_feed();
291 feed->set_name(input.first);
292 input.second.AsProtoField(feed->mutable_tensor());
293 }
294
295 for (auto& output : output_names) {
296 req->add_fetch(output);
297 }
298
299 RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
300
301 for (size_t i = 0; i < outputs->size(); ++i) {
302 const Tensor& tensor = (*outputs)[i];
303 NamedTensorProto* tproto = resp->add_tensor();
304 tensor.AsProtoField(tproto->mutable_tensor());
305 tproto->set_name(output_names[i]);
306 }
307
308 return Flush(op);
309 }
310
311 Status RecordMakeCallable(Session* session,
312 const CallableOptions& callable_options,
313 Session::CallableHandle* handle) {
314 ReplayOp op;
315 MakeCallableRequest* req = op.mutable_make_callable();
316 req->set_session_handle(SessionToHandle(session));
317 *req->mutable_options() = callable_options;
318
319 RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
320
321 MakeCallableResponse* resp = op.mutable_make_callable_response();
322 resp->set_handle(*handle);
323
324 return Flush(op);
325 }
326
327 Status RecordRunCallable(Session* session, Session::CallableHandle handle,
328 const std::vector<Tensor>& feed_tensors,
329 std::vector<Tensor>* fetch_tensors,
330 RunMetadata* run_metadata) {
331 ReplayOp op;
332 RunCallableRequest* req = op.mutable_run_callable();
333 req->set_session_handle(SessionToHandle(session));
334 req->set_handle(handle);
335 for (auto& tensor : feed_tensors) {
336 tensor.AsProtoField(req->add_feed());
337 }
338 RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
339 run_metadata);
340
341 RunCallableResponse* resp = op.mutable_run_callable_response();
342 if (run_metadata) {
343 *resp->mutable_metadata() = *run_metadata;
344 }
345 for (const Tensor& tensor : *fetch_tensors) {
346 tensor.AsProtoTensorContent(resp->add_fetch());
347 }
348 return Flush(op);
349 }
350
351 Status RecordReleaseCallable(Session* session,
352 Session::CallableHandle handle) {
353 ReplayOp op;
354 ReleaseCallableRequest* req = op.mutable_release_callable();
355 req->set_session_handle(SessionToHandle(session));
356 req->set_handle(handle);
357 RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
358 return Flush(op);
359 }
360
361 private:
362 Status Flush(const ReplayOp& op) {
363 mutex_lock l(log_mutex_);
364
365 string buf;
366 op.SerializeToString(&buf);
367 TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
368
369 // TODO(b/116624106): Not all file-systems respect calls to `Sync()`
370 return log_file_->Sync();
371 }
372
373 std::unique_ptr<WritableFile> log_file_;
374 std::unique_ptr<io::RecordWriter> log_writer_;
375 mutex log_mutex_;
376};
377
378static SessionLogger* global_session_logger() {
379 static SessionLogger* logger = new SessionLogger();
380 return logger;
381}
382
383SessionRef::SessionRef(Session* session) : session_(session) {
384 if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
385 logger_ = global_session_logger();
386 logger_->RecordNewSession(this->session_.get()).IgnoreError();
387 } else {
388 logger_ = nullptr;
389 }
390}
391
392SessionRef::~SessionRef() = default;
393
394Status SessionRef::CheckNotClosed() {
395 mutex_lock l(run_lock_);
396 if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
397 return OkStatus();
398}
399
400// If logging is active, log the start and end time of the operation along with
401// the request and response.
402#define LOG_AND_RUN_OPERATION(OpName, ...) \
403 TF_RETURN_IF_ERROR(CheckNotClosed()); \
404 RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
405 if (!logger_) { \
406 return rc.session->OpName(__VA_ARGS__); \
407 } \
408 return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
409
410Status SessionRef::Run(const RunOptions& run_options,
411 const std::vector<std::pair<string, Tensor> >& inputs,
412 const std::vector<string>& output_tensor_names,
413 const std::vector<string>& target_node_names,
414 std::vector<Tensor>* outputs,
415 RunMetadata* run_metadata) {
416 LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
417 target_node_names, outputs, run_metadata);
418}
419
420Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
421 const std::vector<string>& output_tensor_names,
422 const std::vector<string>& target_node_names,
423 std::vector<Tensor>* outputs) {
424 LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
425 outputs);
426}
427
428Status SessionRef::Create(const GraphDef& graph) {
429 LOG_AND_RUN_OPERATION(Create, graph);
430}
431
432Status SessionRef::Create(const RunOptions& run_options,
433 const GraphDef& graph) {
434 LOG_AND_RUN_OPERATION(Create, run_options, graph);
435}
436
437Status SessionRef::Extend(const RunOptions& run_options,
438 const GraphDef& graph) {
439 LOG_AND_RUN_OPERATION(Extend, run_options, graph);
440}
441
442Status SessionRef::Extend(const GraphDef& graph) {
443 LOG_AND_RUN_OPERATION(Extend, graph);
444}
445
446Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
447 LOG_AND_RUN_OPERATION(ListDevices, response);
448}
449
450Status SessionRef::PRunSetup(const std::vector<string>& input_names,
451 const std::vector<string>& output_names,
452 const std::vector<string>& target_nodes,
453 string* handle) {
454 LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
455 handle);
456}
457
458Status SessionRef::PRun(const string& handle,
459 const std::vector<std::pair<string, Tensor> >& inputs,
460 const std::vector<string>& output_names,
461 std::vector<Tensor>* outputs) {
462 LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
463}
464
465Status SessionRef::MakeCallable(const CallableOptions& callable_options,
466 CallableHandle* out_handle) {
467 LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
468}
469
470Status SessionRef::RunCallable(CallableHandle handle,
471 const std::vector<Tensor>& feed_tensors,
472 std::vector<Tensor>* fetch_tensors,
473 RunMetadata* run_metadata) {
474 LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
475 run_metadata);
476}
477
478Status SessionRef::ReleaseCallable(CallableHandle handle) {
479 {
480 mutex_lock l(run_lock_);
481 if (session_ == nullptr) {
482 // Session already closed. Do nothing.
483 return OkStatus();
484 }
485 }
486 LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
487}
488
489Status SessionRef::Close(const RunOptions& run_options) {
490 TF_RETURN_IF_ERROR(CheckNotClosed());
491 mutex_lock l(run_lock_);
492 Status status;
493 if (logger_) {
494 status = logger_->RecordClose(session_.get(), run_options);
495 } else {
496 status = session_->Close(run_options);
497 }
498 session_.reset();
499 while (run_count_ > 0) {
500 run_finished_.wait(l);
501 }
502 return status;
503}
504
505Status SessionRef::Close() {
506 TF_RETURN_IF_ERROR(CheckNotClosed());
507 mutex_lock l(run_lock_);
508 Status status;
509 if (logger_) {
510 status = logger_->RecordClose(session_.get());
511 } else {
512 status = session_->Close();
513 }
514 session_.reset();
515 while (run_count_ > 0) {
516 run_finished_.wait(l);
517 }
518 return status;
519}
520
521} // namespace tensorflow
522