1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/simple_propagator_state.h" |
16 | |
17 | #include <atomic> |
18 | |
19 | #include "tensorflow/core/common_runtime/propagator_debug_utils.h" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/profiler/lib/traceme.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | SimplePropagatorState::SimplePropagatorState( |
26 | const ImmutableExecutorState& immutable_state, int64_t step_id, bool vlog) |
27 | : SimplePropagatorState(immutable_state, step_id, |
28 | immutable_state.get_root_frame_info(), vlog) {} |
29 | |
30 | SimplePropagatorState::SimplePropagatorState( |
31 | const ImmutableExecutorState& immutable_state, int64_t step_id, |
32 | const ImmutableExecutorState::FrameInfo& finfo, bool vlog) |
33 | : immutable_state_(immutable_state), |
34 | step_id_(step_id), |
35 | vlog_(vlog || VLOG_IS_ON(1)), |
36 | input_tensors_(finfo.total_inputs), |
37 | pending_( |
38 | new std::atomic<int32>[immutable_state.graph_view().num_nodes()]), |
39 | active_(vlog_ ? new std::vector<bool>( |
40 | immutable_state.graph_view().num_nodes()) |
41 | : nullptr), |
42 | nodes_(finfo.nodes.get()) { |
43 | immutable_state_.copy_pending_counts(pending_.get()); |
44 | } |
45 | |
46 | SimplePropagatorState::~SimplePropagatorState() {} |
47 | |
48 | void SimplePropagatorState::ActivateRoots( |
49 | gtl::ArraySlice<const NodeItem*> roots, TaggedNodeSeq* ready) { |
50 | for (const NodeItem* item : roots) { |
51 | DCHECK_EQ(item->num_inputs, 0); |
52 | ready->push_back(TaggedNode{item}); |
53 | } |
54 | } |
55 | |
56 | void SimplePropagatorState::PropagateOutputs(const TaggedNode& tagged_node, |
57 | EntryVector* outputs, |
58 | TaggedNodeSeq* ready) { |
59 | profiler::TraceMe activity( |
60 | [&]() { |
61 | return strings::StrCat( |
62 | "ExecutorPropagateOutputs#" , "id=" , step_id_, |
63 | ",kernel_name=" , tagged_node.node_item->kernel->name_view(), |
64 | ",num_output_edges=" , tagged_node.node_item->num_output_edges, |
65 | ",num_output_control_edges=" , |
66 | tagged_node.node_item->num_output_control_edges, "#" ); |
67 | }, |
68 | profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); |
69 | |
70 | // Propagates outputs along out edges, and puts newly ready nodes |
71 | // into the ready queue. |
72 | DCHECK(ready->empty()); |
73 | |
74 | const GraphView& gview = immutable_state_.graph_view(); |
75 | const NodeItem* item = tagged_node.node_item; |
76 | |
77 | for (const EdgeInfo& e : item->output_edges()) { |
78 | const int dst_id = e.dst_id; |
79 | const int src_slot = e.output_slot; |
80 | const int dst_loc = e.input_slot; |
81 | |
82 | // NOTE(mrry): The write to `input_tensors_[dst_loc]` must happen before |
83 | // the pending count update, or else one thread might conclude that the |
84 | // count has dropped to zero before another thread finishes updating the |
85 | // input. |
86 | if (e.is_last) { |
87 | input_tensors_[dst_loc] = std::move((*outputs)[src_slot]); |
88 | } else { |
89 | input_tensors_[dst_loc] = (*outputs)[src_slot]; |
90 | } |
91 | |
92 | int32_t previous_num_pending = |
93 | pending_[dst_id].fetch_sub(1, std::memory_order_release); |
94 | if (previous_num_pending == 1) ready->emplace_back(&gview.node_ref(dst_id)); |
95 | } |
96 | |
97 | for (const ControlEdgeInfo& e : item->output_control_edges()) { |
98 | const int dst_id = e.dst_id; |
99 | |
100 | int32_t previous_num_pending = |
101 | pending_[dst_id].fetch_sub(1, std::memory_order_release); |
102 | if (previous_num_pending == 1) ready->emplace_back(&gview.node_ref(dst_id)); |
103 | } |
104 | } |
105 | |
106 | void SimplePropagatorState::DumpState() { |
107 | mutex_lock l(mu_); |
108 | // Dump any waiting nodes that are holding on to tensors. |
109 | for (const NodeItem* node : *nodes_) { |
110 | if (pending_[node->node_id]) { |
111 | DumpPendingNodeState(*node, input_tensors_.data(), false); |
112 | } |
113 | } |
114 | // Then the active nodes. |
115 | for (const NodeItem* node : *nodes_) { |
116 | if ((*active_)[node->node_id]) { |
117 | DumpActiveNodeState(*node, input_tensors_.data()); |
118 | } |
119 | } |
120 | // Show all input tensors in use. |
121 | size_t total_bytes = 0; |
122 | for (size_t i = 0; i < input_tensors_.size(); ++i) { |
123 | const Entry& input = input_tensors_[i]; |
124 | const Tensor* tensor = GetTensorValueForDump(input); |
125 | if (tensor && tensor->IsInitialized()) { |
126 | LOG(WARNING) << " Input " << i << ": " |
127 | << strings::StrCat( |
128 | "Tensor<type: " , DataTypeString(tensor->dtype()), |
129 | " shape: " , tensor->shape().DebugString(), |
130 | ", bytes: " , tensor->TotalBytes(), ">" ); |
131 | total_bytes += tensor->TotalBytes(); |
132 | } |
133 | } |
134 | LOG(WARNING) << " Total bytes " << total_bytes; |
135 | } |
136 | |
137 | } // namespace tensorflow |
138 | |