1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
17#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
18
19#include <unordered_map>
20
21#include "tensorflow/core/distributed_runtime/worker_interface.h"
22#include "tensorflow/core/framework/cancellation.h"
23#include "tensorflow/core/lib/core/status.h"
24#include "tensorflow/core/platform/macros.h"
25#include "tensorflow/core/platform/mutex.h"
26#include "tensorflow/core/platform/types.h"
27
28namespace tensorflow {
29
30// PartialRunMgr keeps track of pending partial run requests, and ensures that
31// the partial run is only marked complete when the corresponding executor is
32// run to completion.
33//
34// In tensorflow workers, the executor runs operations asynchronously until
35// specified fetches (operations that return tensors) or targets (operations
36// that don't return tensors) are reached. A PartialRun has two components: a
37// setup which specifies all desired fetches and targets, and run calls that
38// specify fetch values (from the setup calls) to retrieve.
39// On the last partial run call, it is possible to satisfy the
40// required fetches before the executor has completed running the graph to all
41// the desired targets.
42// PartialRunMgr is used to ensure that we don't complete and return the final
43// partial run call to the user until both the partial run and executor have
44// completed.
45//
46// PartialRunMgr is thread-safe.
47class PartialRunMgr {
48 public:
49 // Find or create the CancellationManager associated with step_id.
50 // The PartialRunMgr owns the cancellation_manager.
51 // Returns true if a new CancellationManager was created
52 // (i.e this is a new partial run).
53 bool FindOrCreate(int step_id, CancellationManager** cancellation_manager);
54
55 // Calls the final callback if the PartialRunRequest has already completed.
56 // Otherwise stores the executor_status to be propagated when the
57 // PartialRunRequest completes (PartialRunDone has been called).
58 void ExecutorDone(int step_id, const Status& executor_status);
59
60 // Calls done if the executor has already completed (ExecutorDone has been
61 // called). Otherwise, stores the status and done callback, calling them when
62 // ExecutorDone is called. The callback will either be called by the calling
63 // thread of either PartialRunDone or ExecutorDone.
64 // If executor_status in ExecutorDone is not OK, it takes precedence over
65 // status and is passed to the done callback.
66 void PartialRunDone(int step_id, StatusCallback done, const Status& status);
67
68 private:
69 // PartialRunState stores state associated with a pending partial run request.
70 // This is protected by the mutex in PartialRunMgr.
71 struct PartialRunState {
72 std::unique_ptr<CancellationManager> cancellation_manager;
73
74 bool executor_done = false;
75 StatusCallback final_callback = nullptr;
76 Status final_status;
77 };
78
79 mutex mu_;
80
81 std::unordered_map<int, std::unique_ptr<PartialRunState>>
82 step_id_to_partial_run_ TF_GUARDED_BY(mu_);
83};
84
85} // namespace tensorflow
86
87#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_
88