1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/partial_run_mgr.h" |
17 | |
18 | #include "tensorflow/core/util/ptr_util.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | bool PartialRunMgr::FindOrCreate(int step_id, |
23 | CancellationManager** cancellation_manager) { |
24 | mutex_lock l(mu_); |
25 | auto it = step_id_to_partial_run_.find(step_id); |
26 | if (it != step_id_to_partial_run_.end()) { |
27 | *cancellation_manager = it->second->cancellation_manager.get(); |
28 | return false; |
29 | } |
30 | |
31 | std::unique_ptr<PartialRunState> partial_run = |
32 | tensorflow::MakeUnique<PartialRunState>(); |
33 | partial_run->cancellation_manager = |
34 | tensorflow::MakeUnique<CancellationManager>(); |
35 | *cancellation_manager = partial_run->cancellation_manager.get(); |
36 | step_id_to_partial_run_[step_id] = std::move(partial_run); |
37 | return true; |
38 | } |
39 | |
40 | void PartialRunMgr::ExecutorDone(int step_id, const Status& executor_status) { |
41 | StatusCallback done; |
42 | Status callback_status; |
43 | { |
44 | mutex_lock l(mu_); |
45 | auto run_it = step_id_to_partial_run_.find(step_id); |
46 | if (run_it == step_id_to_partial_run_.end()) { |
47 | return; |
48 | } |
49 | // If we found the partial_run, we call the final callback, if it |
50 | // exists. |
51 | // It is guaranteed that run_it->second->final_callback is left empty |
52 | // after the std::move call. |
53 | done = std::move(run_it->second->final_callback); |
54 | if (!executor_status.ok()) { |
55 | run_it->second->final_status = executor_status; |
56 | } |
57 | callback_status = run_it->second->final_status; |
58 | run_it->second->executor_done = true; |
59 | } |
60 | if (done != nullptr) { |
61 | done(callback_status); |
62 | mutex_lock l(mu_); |
63 | step_id_to_partial_run_.erase(step_id); |
64 | } |
65 | } |
66 | |
67 | void PartialRunMgr::PartialRunDone(int step_id, StatusCallback done, |
68 | const Status& status) { |
69 | Status callback_status; |
70 | { |
71 | mutex_lock l(mu_); |
72 | auto run_it = step_id_to_partial_run_.find(step_id); |
73 | if (run_it == step_id_to_partial_run_.end()) { |
74 | return; |
75 | } |
76 | run_it->second->final_status.Update(status); |
77 | if (!run_it->second->executor_done) { |
78 | // If we found the partial_run, we set the final callback to call only |
79 | // when the executor is completely done. |
80 | run_it->second->final_callback = std::move(done); |
81 | return; |
82 | } |
83 | callback_status = run_it->second->final_status; |
84 | } |
85 | // Otherwise we call the callback immediately. |
86 | done(callback_status); |
87 | mutex_lock l(mu_); |
88 | step_id_to_partial_run_.erase(step_id); |
89 | } |
90 | |
91 | } // namespace tensorflow |
92 | |