1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/python/client/session_ref.h" |
16 | |
17 | #include <stdlib.h> |
18 | #include <memory> |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/core/lib/io/path.h" |
22 | #include "tensorflow/core/lib/io/record_writer.h" |
23 | #include "tensorflow/core/lib/strings/stringprintf.h" |
24 | #include "tensorflow/core/protobuf/master.pb.h" |
25 | #include "tensorflow/core/protobuf/named_tensor.pb.h" |
26 | #include "tensorflow/core/protobuf/replay_log.pb.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | namespace { |
31 | |
32 | // Scope helper to track active calls and manage session lifetime. |
33 | // SessionRef blocks closing until all active calls complete or are cancelled. |
34 | struct RunCounter { |
35 | std::shared_ptr<Session> session; |
36 | uint64* value; |
37 | mutex* m; |
38 | condition_variable* cv; |
39 | |
40 | explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m, |
41 | condition_variable* cv) |
42 | : session(std::move(s)), value(v), m(m), cv(cv) { |
43 | mutex_lock l(*m); |
44 | ++*value; |
45 | } |
46 | |
47 | ~RunCounter() { |
48 | mutex_lock l(*m); |
49 | if (--*value == 0) { |
50 | cv->notify_all(); |
51 | } |
52 | } |
53 | }; |
54 | |
55 | std::string SessionToHandle(Session* session) { |
56 | return strings::Printf("%llu" , static_cast<unsigned long long>( |
57 | reinterpret_cast<uintptr_t>(session))); |
58 | } |
59 | |
60 | // The Session interface has many methods of the form: |
61 | // |
62 | // X(a, b); |
63 | // X(RunOptions, a, b); |
64 | // |
65 | // Not all sessions support the second case (with an empty RunOptions()). |
66 | // We use this variable as a sentinel to dispatch to the correct call. |
67 | RunOptions* kEmptyRunOptions() { |
68 | static RunOptions* options = new RunOptions(); |
69 | return options; |
70 | } |
71 | |
72 | } // namespace |
73 | |
74 | // Run the given session operation, recording start and end timestamps. |
75 | // If the operation returns a bad status, return after flushing the current |
76 | // log request. This should be run _after_ all request information has been |
77 | // added to the current op. |
78 | #define RUN_WITH_TIMESTAMP(OpName, ...) \ |
79 | op.set_start_time_us(Env::Default()->NowMicros()); \ |
80 | Status status = session->OpName(__VA_ARGS__); \ |
81 | op.set_end_time_us(Env::Default()->NowMicros()); \ |
82 | if (!status.ok()) { \ |
83 | Flush(op).IgnoreError(); \ |
84 | return status; \ |
85 | } |
86 | |
87 | // Records requests (and optionally responses) performed against a session. |
88 | // The resulting replay log can be used with the `tf_replay` tool to replicate |
89 | // the operations against a simulated environment, without requiring the |
90 | // original code or cluster setup. |
91 | // |
92 | // Session logging by setting the TF_REPLAY_LOG_FILE environment variable. |
93 | class SessionLogger { |
94 | public: |
95 | SessionLogger() { |
96 | const char* log_file_env = getenv("TF_REPLAY_LOG_FILE" ); |
97 | std::string log_name = log_file_env ? std::string(log_file_env) : "." ; |
98 | LOG(INFO) << "Constructing new session logger for " << log_name; |
99 | TF_CHECK_OK( |
100 | Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name)))); |
101 | Env::Default()->DeleteFile(log_name).IgnoreError(); |
102 | |
103 | TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_)); |
104 | log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get()); |
105 | } |
106 | |
107 | ~SessionLogger() { |
108 | log_writer_->Close().IgnoreError(); |
109 | log_writer_.release(); |
110 | log_file_->Close().IgnoreError(); |
111 | } |
112 | |
113 | Status RecordNewSession(Session* session) { |
114 | ReplayOp op; |
115 | NewReplaySession* req = op.mutable_new_replay_session(); |
116 | req->set_session_handle(SessionToHandle(session)); |
117 | return Flush(op); |
118 | } |
119 | |
120 | Status RecordRun(Session* session, |
121 | const std::vector<std::pair<string, Tensor> >& inputs, |
122 | const std::vector<string>& output_tensor_names, |
123 | const std::vector<string>& target_node_names, |
124 | std::vector<Tensor>* outputs) { |
125 | return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names, |
126 | target_node_names, outputs, nullptr); |
127 | } |
128 | |
129 | Status RecordRun(Session* session, const RunOptions& run_options, |
130 | const std::vector<std::pair<string, Tensor> >& inputs, |
131 | const std::vector<string>& output_tensor_names, |
132 | const std::vector<string>& target_node_names, |
133 | std::vector<Tensor>* outputs, RunMetadata* run_metadata) { |
134 | ReplayOp op; |
135 | RunStepRequest* req = op.mutable_run_step(); |
136 | RunStepResponse* resp = op.mutable_run_step_response(); |
137 | |
138 | req->set_session_handle(SessionToHandle(session)); |
139 | *req->mutable_options() = run_options; |
140 | |
141 | for (const auto& it : inputs) { |
142 | NamedTensorProto* feed = req->add_feed(); |
143 | feed->set_name(it.first); |
144 | it.second.AsProtoField(feed->mutable_tensor()); |
145 | } |
146 | |
147 | // Build an index from fetch tensor name to first index in |
148 | // output_tensor_names. |
149 | std::unordered_map<string, int> output_name_to_offset; |
150 | for (int i = 0, end = output_tensor_names.size(); i < end; ++i) { |
151 | const string& name = output_tensor_names[i]; |
152 | if (output_name_to_offset.insert(std::make_pair(name, i)).second) { |
153 | req->add_fetch(name); |
154 | } |
155 | } |
156 | for (const string& target : target_node_names) { |
157 | req->add_target(target); |
158 | } |
159 | |
160 | if (&run_options == kEmptyRunOptions()) { |
161 | RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names, |
162 | outputs); |
163 | } else { |
164 | RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names, |
165 | target_node_names, outputs, run_metadata); |
166 | } |
167 | |
168 | for (size_t i = 0; i < outputs->size(); ++i) { |
169 | const Tensor& tensor = (*outputs)[i]; |
170 | NamedTensorProto* tproto = resp->add_tensor(); |
171 | tensor.AsProtoField(tproto->mutable_tensor()); |
172 | tproto->set_name(output_tensor_names[i]); |
173 | } |
174 | |
175 | if (run_metadata) { |
176 | *resp->mutable_metadata() = *run_metadata; |
177 | } |
178 | |
179 | return Flush(op); |
180 | } |
181 | |
182 | Status RecordCreate(Session* session, const GraphDef& graph) { |
183 | return RecordCreate(session, *kEmptyRunOptions(), graph); |
184 | } |
185 | |
186 | // N.B. RunOptions is not stored (it has no entry in CreateRequest) |
187 | Status RecordCreate(Session* session, const RunOptions& run_options, |
188 | const GraphDef& graph) { |
189 | ReplayOp op; |
190 | CreateSessionRequest* req = op.mutable_create_session(); |
191 | *req->mutable_graph_def() = graph; |
192 | |
193 | CreateSessionResponse* resp = op.mutable_create_session_response(); |
194 | if (&run_options == kEmptyRunOptions()) { |
195 | RUN_WITH_TIMESTAMP(Create, graph); |
196 | } else { |
197 | RUN_WITH_TIMESTAMP(Create, run_options, graph); |
198 | } |
199 | resp->set_session_handle(SessionToHandle(session)); |
200 | return Flush(op); |
201 | } |
202 | |
203 | Status RecordExtend(Session* session, const GraphDef& graph) { |
204 | return RecordExtend(session, *kEmptyRunOptions(), graph); |
205 | } |
206 | |
207 | // N.B. RunOptions is not stored (it has no entry in ExtendRequest) |
208 | Status RecordExtend(Session* session, const RunOptions& run_options, |
209 | const GraphDef& graph) { |
210 | ReplayOp op; |
211 | ExtendSessionRequest* req = op.mutable_extend_session(); |
212 | op.mutable_extend_session_response(); |
213 | req->set_session_handle(SessionToHandle(session)); |
214 | *req->mutable_graph_def() = graph; |
215 | if (&run_options == kEmptyRunOptions()) { |
216 | RUN_WITH_TIMESTAMP(Extend, graph); |
217 | } else { |
218 | RUN_WITH_TIMESTAMP(Extend, run_options, graph); |
219 | } |
220 | |
221 | return Flush(op); |
222 | } |
223 | |
224 | Status RecordClose(Session* session) { |
225 | return RecordClose(session, *kEmptyRunOptions()); |
226 | } |
227 | |
228 | // N.B. RunOptions is not stored (it has no entry in CloseRequest) |
229 | Status RecordClose(Session* session, const RunOptions& run_options) { |
230 | ReplayOp op; |
231 | CloseSessionRequest* req = op.mutable_close_session(); |
232 | req->set_session_handle(SessionToHandle(session)); |
233 | op.mutable_close_session_response(); |
234 | if (&run_options == kEmptyRunOptions()) { |
235 | RUN_WITH_TIMESTAMP(Close); |
236 | } else { |
237 | RUN_WITH_TIMESTAMP(Close, run_options); |
238 | } |
239 | return Flush(op); |
240 | } |
241 | |
242 | Status RecordListDevices(Session* session, |
243 | std::vector<DeviceAttributes>* response) { |
244 | ReplayOp op; |
245 | ListDevicesRequest* req = op.mutable_list_devices(); |
246 | ListDevicesResponse* resp = op.mutable_list_devices_response(); |
247 | req->set_session_handle(SessionToHandle(session)); |
248 | RUN_WITH_TIMESTAMP(ListDevices, response); |
249 | |
250 | // TODO(power) -- local vs remote device distinction is lost here! |
251 | *resp->mutable_local_device() = {response->begin(), response->end()}; |
252 | return Flush(op); |
253 | } |
254 | |
255 | Status RecordPRunSetup(Session* session, |
256 | const std::vector<string>& input_names, |
257 | const std::vector<string>& output_names, |
258 | const std::vector<string>& target_nodes, |
259 | string* handle) { |
260 | ReplayOp op; |
261 | PartialRunSetupRequest* req = op.mutable_partial_run_setup(); |
262 | req->set_session_handle(SessionToHandle(session)); |
263 | for (auto& input : input_names) { |
264 | req->add_feed(input); |
265 | } |
266 | for (auto& output : output_names) { |
267 | req->add_fetch(output); |
268 | } |
269 | for (auto& target : target_nodes) { |
270 | req->add_target(target); |
271 | } |
272 | RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes, |
273 | handle); |
274 | op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle); |
275 | return Flush(op); |
276 | } |
277 | |
278 | Status RecordPRun(Session* session, const string& handle, |
279 | const std::vector<std::pair<string, Tensor> >& inputs, |
280 | const std::vector<string>& output_names, |
281 | std::vector<Tensor>* outputs) { |
282 | ReplayOp op; |
283 | RunStepRequest* req = op.mutable_run_step(); |
284 | RunStepResponse* resp = op.mutable_run_step_response(); |
285 | req->set_session_handle(SessionToHandle(session)); |
286 | |
287 | // Mark this step as a partial run for replay. |
288 | req->set_partial_run_handle(handle); |
289 | for (auto& input : inputs) { |
290 | auto* feed = req->add_feed(); |
291 | feed->set_name(input.first); |
292 | input.second.AsProtoField(feed->mutable_tensor()); |
293 | } |
294 | |
295 | for (auto& output : output_names) { |
296 | req->add_fetch(output); |
297 | } |
298 | |
299 | RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs); |
300 | |
301 | for (size_t i = 0; i < outputs->size(); ++i) { |
302 | const Tensor& tensor = (*outputs)[i]; |
303 | NamedTensorProto* tproto = resp->add_tensor(); |
304 | tensor.AsProtoField(tproto->mutable_tensor()); |
305 | tproto->set_name(output_names[i]); |
306 | } |
307 | |
308 | return Flush(op); |
309 | } |
310 | |
311 | Status RecordMakeCallable(Session* session, |
312 | const CallableOptions& callable_options, |
313 | Session::CallableHandle* handle) { |
314 | ReplayOp op; |
315 | MakeCallableRequest* req = op.mutable_make_callable(); |
316 | req->set_session_handle(SessionToHandle(session)); |
317 | *req->mutable_options() = callable_options; |
318 | |
319 | RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle); |
320 | |
321 | MakeCallableResponse* resp = op.mutable_make_callable_response(); |
322 | resp->set_handle(*handle); |
323 | |
324 | return Flush(op); |
325 | } |
326 | |
327 | Status RecordRunCallable(Session* session, Session::CallableHandle handle, |
328 | const std::vector<Tensor>& feed_tensors, |
329 | std::vector<Tensor>* fetch_tensors, |
330 | RunMetadata* run_metadata) { |
331 | ReplayOp op; |
332 | RunCallableRequest* req = op.mutable_run_callable(); |
333 | req->set_session_handle(SessionToHandle(session)); |
334 | req->set_handle(handle); |
335 | for (auto& tensor : feed_tensors) { |
336 | tensor.AsProtoField(req->add_feed()); |
337 | } |
338 | RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors, |
339 | run_metadata); |
340 | |
341 | RunCallableResponse* resp = op.mutable_run_callable_response(); |
342 | if (run_metadata) { |
343 | *resp->mutable_metadata() = *run_metadata; |
344 | } |
345 | for (const Tensor& tensor : *fetch_tensors) { |
346 | tensor.AsProtoTensorContent(resp->add_fetch()); |
347 | } |
348 | return Flush(op); |
349 | } |
350 | |
351 | Status RecordReleaseCallable(Session* session, |
352 | Session::CallableHandle handle) { |
353 | ReplayOp op; |
354 | ReleaseCallableRequest* req = op.mutable_release_callable(); |
355 | req->set_session_handle(SessionToHandle(session)); |
356 | req->set_handle(handle); |
357 | RUN_WITH_TIMESTAMP(ReleaseCallable, handle); |
358 | return Flush(op); |
359 | } |
360 | |
361 | private: |
362 | Status Flush(const ReplayOp& op) { |
363 | mutex_lock l(log_mutex_); |
364 | |
365 | string buf; |
366 | op.SerializeToString(&buf); |
367 | TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf)); |
368 | |
369 | // TODO(b/116624106): Not all file-systems respect calls to `Sync()` |
370 | return log_file_->Sync(); |
371 | } |
372 | |
373 | std::unique_ptr<WritableFile> log_file_; |
374 | std::unique_ptr<io::RecordWriter> log_writer_; |
375 | mutex log_mutex_; |
376 | }; |
377 | |
378 | static SessionLogger* global_session_logger() { |
379 | static SessionLogger* logger = new SessionLogger(); |
380 | return logger; |
381 | } |
382 | |
383 | SessionRef::SessionRef(Session* session) : session_(session) { |
384 | if (getenv("TF_REPLAY_LOG_FILE" ) != nullptr) { |
385 | logger_ = global_session_logger(); |
386 | logger_->RecordNewSession(this->session_.get()).IgnoreError(); |
387 | } else { |
388 | logger_ = nullptr; |
389 | } |
390 | } |
391 | |
392 | SessionRef::~SessionRef() = default; |
393 | |
394 | Status SessionRef::CheckNotClosed() { |
395 | mutex_lock l(run_lock_); |
396 | if (session_ == nullptr) return errors::Cancelled("Session has been closed." ); |
397 | return OkStatus(); |
398 | } |
399 | |
400 | // If logging is active, log the start and end time of the operation along with |
401 | // the request and response. |
402 | #define LOG_AND_RUN_OPERATION(OpName, ...) \ |
403 | TF_RETURN_IF_ERROR(CheckNotClosed()); \ |
404 | RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \ |
405 | if (!logger_) { \ |
406 | return rc.session->OpName(__VA_ARGS__); \ |
407 | } \ |
408 | return logger_->Record##OpName(rc.session.get(), __VA_ARGS__); |
409 | |
410 | Status SessionRef::Run(const RunOptions& run_options, |
411 | const std::vector<std::pair<string, Tensor> >& inputs, |
412 | const std::vector<string>& output_tensor_names, |
413 | const std::vector<string>& target_node_names, |
414 | std::vector<Tensor>* outputs, |
415 | RunMetadata* run_metadata) { |
416 | LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names, |
417 | target_node_names, outputs, run_metadata); |
418 | } |
419 | |
420 | Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs, |
421 | const std::vector<string>& output_tensor_names, |
422 | const std::vector<string>& target_node_names, |
423 | std::vector<Tensor>* outputs) { |
424 | LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names, |
425 | outputs); |
426 | } |
427 | |
428 | Status SessionRef::Create(const GraphDef& graph) { |
429 | LOG_AND_RUN_OPERATION(Create, graph); |
430 | } |
431 | |
432 | Status SessionRef::Create(const RunOptions& run_options, |
433 | const GraphDef& graph) { |
434 | LOG_AND_RUN_OPERATION(Create, run_options, graph); |
435 | } |
436 | |
437 | Status SessionRef::Extend(const RunOptions& run_options, |
438 | const GraphDef& graph) { |
439 | LOG_AND_RUN_OPERATION(Extend, run_options, graph); |
440 | } |
441 | |
442 | Status SessionRef::Extend(const GraphDef& graph) { |
443 | LOG_AND_RUN_OPERATION(Extend, graph); |
444 | } |
445 | |
446 | Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) { |
447 | LOG_AND_RUN_OPERATION(ListDevices, response); |
448 | } |
449 | |
450 | Status SessionRef::PRunSetup(const std::vector<string>& input_names, |
451 | const std::vector<string>& output_names, |
452 | const std::vector<string>& target_nodes, |
453 | string* handle) { |
454 | LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes, |
455 | handle); |
456 | } |
457 | |
458 | Status SessionRef::PRun(const string& handle, |
459 | const std::vector<std::pair<string, Tensor> >& inputs, |
460 | const std::vector<string>& output_names, |
461 | std::vector<Tensor>* outputs) { |
462 | LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs); |
463 | } |
464 | |
465 | Status SessionRef::MakeCallable(const CallableOptions& callable_options, |
466 | CallableHandle* out_handle) { |
467 | LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle); |
468 | } |
469 | |
470 | Status SessionRef::RunCallable(CallableHandle handle, |
471 | const std::vector<Tensor>& feed_tensors, |
472 | std::vector<Tensor>* fetch_tensors, |
473 | RunMetadata* run_metadata) { |
474 | LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors, |
475 | run_metadata); |
476 | } |
477 | |
478 | Status SessionRef::ReleaseCallable(CallableHandle handle) { |
479 | { |
480 | mutex_lock l(run_lock_); |
481 | if (session_ == nullptr) { |
482 | // Session already closed. Do nothing. |
483 | return OkStatus(); |
484 | } |
485 | } |
486 | LOG_AND_RUN_OPERATION(ReleaseCallable, handle); |
487 | } |
488 | |
489 | Status SessionRef::Close(const RunOptions& run_options) { |
490 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
491 | mutex_lock l(run_lock_); |
492 | Status status; |
493 | if (logger_) { |
494 | status = logger_->RecordClose(session_.get(), run_options); |
495 | } else { |
496 | status = session_->Close(run_options); |
497 | } |
498 | session_.reset(); |
499 | while (run_count_ > 0) { |
500 | run_finished_.wait(l); |
501 | } |
502 | return status; |
503 | } |
504 | |
505 | Status SessionRef::Close() { |
506 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
507 | mutex_lock l(run_lock_); |
508 | Status status; |
509 | if (logger_) { |
510 | status = logger_->RecordClose(session_.get()); |
511 | } else { |
512 | status = session_->Close(); |
513 | } |
514 | session_.reset(); |
515 | while (run_count_ > 0) { |
516 | run_finished_.wait(l); |
517 | } |
518 | return status; |
519 | } |
520 | |
521 | } // namespace tensorflow |
522 | |