1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_
17
18#include <atomic>
19#include <deque>
20#include <memory>
21#include <vector>
22
23#include "absl/container/flat_hash_map.h"
24#include "tensorflow/core/common_runtime/graph_view.h"
25#include "tensorflow/core/common_runtime/local_executor_params.h"
26#include "tensorflow/core/common_runtime/pending_counts.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/lib/core/status.h"
29#include "tensorflow/core/lib/gtl/flatmap.h"
30#include "tensorflow/core/lib/gtl/flatset.h"
31#include "tensorflow/core/platform/macros.h"
32#include "tensorflow/core/platform/types.h"
33
34namespace tensorflow {
35
36class Graph;
37
38// Represents the state of an executor (graph and control flow information)
39// that is immutable throughout execution.
40//
41// TODO(b/152651962): Add independent unit tests for this class.
42class ImmutableExecutorState {
43 public:
44 struct FrameInfo {
45 explicit FrameInfo(string name)
46 : name(std::move(name)),
47 input_count(0),
48 total_inputs(0),
49 pending_counts(nullptr),
50 nodes(nullptr),
51 parallel_iterations(-1) {}
52
53 // The name of the frame.
54 string name;
55
56 // The total number of inputs to a frame.
57 int input_count;
58
59 // The total number of input tensors of a frame.
60 // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame.
61 int total_inputs;
62
63 // Used to determine the next place to allocate space in the
64 // pending_counts data structure we'll eventually construct
65 PendingCounts::Layout pending_counts_layout;
66
67 // Each frame has its own PendingCounts only for the nodes in the frame.
68 std::unique_ptr<PendingCounts> pending_counts;
69
70 // The nodes in a frame. Used only for debugging.
71 std::unique_ptr<std::vector<const NodeItem*>> nodes;
72
73 // The number of iterations of this frame that can execute concurrently.
74 int32 parallel_iterations;
75 };
76
77 explicit ImmutableExecutorState(const LocalExecutorParams& p)
78 : params_(p), gview_() {}
79 ~ImmutableExecutorState();
80
81 Status Initialize(const Graph& graph);
82
83 // Process all Nodes in the current graph, attempting to infer the
84 // memory allocation attributes to be used wherever they may allocate
85 // a tensor buffer.
86 Status SetAllocAttrs();
87
88 const LocalExecutorParams& params() const { return params_; }
89 const GraphView& graph_view() const { return gview_; }
90 const std::vector<PendingCounts::Handle>& pending_ids() const {
91 return pending_ids_;
92 }
93 const std::vector<const NodeItem*>& root_nodes() const { return root_nodes_; }
94
95 const FrameInfo& get_root_frame_info() const { return *root_frame_info_; }
96
97 const FrameInfo& get_enter_frame_info(const NodeItem& node_item) const {
98 DCHECK(node_item.is_enter);
99 return *enter_frame_info_[node_item.node_id];
100 }
101
102 bool requires_control_flow_support() const { return requires_control_flow_; }
103
104 // Copies the pending counts for nodes in this graph to the given array.
105 //
106 // This method provides a more efficient way of initializing
107 // `SimplePropagatorState` than individually accessing the pending counts from
108 // `get_root_frame_info().counts`.
109 //
110 // REQUIRES: `!requires_control_flow_support && len(dest) ==
111 // graph_view().num_nodes()`.
112 void copy_pending_counts(std::atomic<int32>* dest) const {
113 DCHECK(!requires_control_flow_);
114 memcpy(dest, atomic_pending_counts_.get(),
115 graph_view().num_nodes() * sizeof(std::atomic<int32>));
116 std::atomic_thread_fence(std::memory_order_release);
117 }
118
119 private:
120 struct ControlFlowInfo {
121 gtl::FlatSet<string> unique_frame_names;
122 std::vector<string> frame_names;
123 };
124
125 static Status BuildControlFlowInfo(const Graph* graph,
126 ControlFlowInfo* cf_info);
127 void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info);
128
129 FrameInfo* EnsureFrameInfo(const string& fname);
130
131 // Owned.
132 LocalExecutorParams params_;
133 GraphView gview_;
134 bool requires_control_flow_;
135 std::vector<PendingCounts::Handle> pending_ids_;
136
137 // Root nodes (with no in edges) that should form the initial ready queue
138 std::vector<const NodeItem*> root_nodes_;
139
140 // Mapping from frame name to static information about the frame.
141 // TODO(yuanbyu): We could cache it along with the graph so to avoid
142 // the overhead of constructing it for each executor instance.
143 absl::flat_hash_map<absl::string_view, std::unique_ptr<FrameInfo>>
144 frame_info_;
145 const FrameInfo* root_frame_info_; // Not owned.
146
147 // If the graph contains any "Enter" or "RefEnter" nodes, this vector maps
148 // dense node IDs to the corresponding FrameInfo.
149 std::vector<FrameInfo*> enter_frame_info_;
150
151 // If `requires_control_flow_` is false, this points to an array of initial
152 // pending counts for the nodes in the graph, indexed by node ID.
153 std::unique_ptr<std::atomic<int32>[]> atomic_pending_counts_;
154
155 // Shallow copies of the constant tensors used in the graph.
156 std::vector<Tensor> const_tensors_;
157
158 TF_DISALLOW_COPY_AND_ASSIGN(ImmutableExecutorState);
159};
160
161} // namespace tensorflow
162#endif // TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_
163