1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/graph/node_builder.h" |
17 | |
18 | #include <unordered_map> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/framework/node_def_util.h" |
22 | #include "tensorflow/core/framework/types.pb.h" |
23 | #include "tensorflow/core/framework/versions.pb.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | #include "tensorflow/core/platform/statusor.h" |
26 | #include "tensorflow/core/protobuf/error_codes.pb.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | NodeBuilder::NodeOut::NodeOut(Node* n, int32_t i) // NOLINT(runtime/explicit) |
31 | : node(n), |
32 | error(false), |
33 | name(node != nullptr ? node->name() : (error = true, "" )), |
34 | index(i), |
35 | dt(SafeGetOutput(node, i, &error)) {} |
36 | |
37 | NodeBuilder::NodeOut::NodeOut(OutputTensor t) : NodeOut(t.node, t.index) {} |
38 | |
39 | NodeBuilder::NodeOut::NodeOut(StringPiece n, int32_t i, DataType t) |
40 | : node(nullptr), error(false), name(n), index(i), dt(t) {} |
41 | |
42 | NodeBuilder::NodeOut::NodeOut() |
43 | : node(nullptr), error(true), index(0), dt(DT_FLOAT) {} |
44 | |
45 | NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name, |
46 | const OpRegistryInterface* op_registry, |
47 | const NodeDebugInfo* debug) |
48 | : def_builder_(name, op_name, op_registry, debug) {} |
49 | |
50 | NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def) |
51 | : def_builder_(name, op_def) {} |
52 | |
53 | NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder) |
54 | : def_builder_(def_builder) {} |
55 | |
56 | NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) { |
57 | inputs_.emplace_back(src_node, src_index); |
58 | DataType dt; |
59 | if (GetOutputType(src_node, src_index, &dt)) { |
60 | def_builder_.Input(src_node->name(), src_index, dt); |
61 | } |
62 | return *this; |
63 | } |
64 | |
65 | NodeBuilder& NodeBuilder::Input(NodeOut src) { |
66 | if (src.error) { |
67 | AddIndexError(src.node, src.index); |
68 | } else { |
69 | inputs_.emplace_back(src.node, src.index); |
70 | def_builder_.Input(src.name, src.index, src.dt); |
71 | } |
72 | return *this; |
73 | } |
74 | |
75 | NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) { |
76 | std::vector<NodeDefBuilder::NodeOut> srcs; |
77 | srcs.reserve(src_list.size()); |
78 | for (const auto& node_out : src_list) { |
79 | if (node_out.error) { |
80 | AddIndexError(node_out.node, node_out.index); |
81 | } else { |
82 | srcs.emplace_back(node_out.name, node_out.index, node_out.dt); |
83 | inputs_.emplace_back(node_out.node, node_out.index); |
84 | } |
85 | } |
86 | def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs)); |
87 | return *this; |
88 | } |
89 | |
90 | NodeBuilder& NodeBuilder::ControlInput(Node* src_node) { |
91 | control_inputs_.emplace_back(src_node); |
92 | def_builder_.ControlInput(src_node->name()); |
93 | return *this; |
94 | } |
95 | |
96 | NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) { |
97 | control_inputs_.insert(control_inputs_.end(), src_nodes.begin(), |
98 | src_nodes.end()); |
99 | for (const Node* src_node : src_nodes) { |
100 | def_builder_.ControlInput(src_node->name()); |
101 | } |
102 | return *this; |
103 | } |
104 | |
105 | NodeBuilder& NodeBuilder::Device(StringPiece device_spec) { |
106 | def_builder_.Device(device_spec); |
107 | return *this; |
108 | } |
109 | |
110 | NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) { |
111 | assigned_device_ = string(device); |
112 | return *this; |
113 | } |
114 | |
115 | NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) { |
116 | def_builder_.Attr("_XlaCluster" , xla_cluster); |
117 | return *this; |
118 | } |
119 | |
120 | StatusOr<Node*> NodeBuilder::Finalize(Graph* graph, bool consume) { |
121 | Node* out; |
122 | TF_RETURN_IF_ERROR(Finalize(graph, &out, consume)); |
123 | return out; |
124 | } |
125 | |
126 | Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) { |
127 | // In case of error, set *created_node to nullptr. |
128 | if (created_node != nullptr) { |
129 | *created_node = nullptr; |
130 | } |
131 | if (!errors_.empty()) { |
132 | return errors::InvalidArgument(absl::StrJoin(errors_, "\n" )); |
133 | } |
134 | |
135 | NodeDef node_def; |
136 | TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume)); |
137 | TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); |
138 | TF_RETURN_IF_ERROR( |
139 | CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer())); |
140 | |
141 | TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(std::move(node_def))); |
142 | |
143 | node->set_assigned_device_name(assigned_device_); |
144 | |
145 | for (size_t i = 0; i < inputs_.size(); ++i) { |
146 | if (inputs_[i].node != nullptr) { // Skip back edges. |
147 | graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i); |
148 | } |
149 | } |
150 | for (Node* control_input : control_inputs_) { |
151 | graph->AddControlEdge(control_input, node); |
152 | } |
153 | |
154 | if (created_node != nullptr) *created_node = node; |
155 | |
156 | return OkStatus(); |
157 | } |
158 | |
159 | void NodeBuilder::AddIndexError(const Node* node, int i) { |
160 | if (node == nullptr) { |
161 | errors_.emplace_back( |
162 | strings::StrCat("Attempt to add nullptr Node to node with type " , |
163 | def_builder_.op_def().name())); |
164 | } else { |
165 | errors_.emplace_back(strings::StrCat( |
166 | "Attempt to add output " , i, " of " , node->name(), " not in range [0, " , |
167 | node->num_outputs(), ") to node with type " , |
168 | def_builder_.op_def().name(), ". Node: " , FormatNodeForError(*node))); |
169 | } |
170 | } |
171 | |
172 | bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) { |
173 | bool error; |
174 | *dt = SafeGetOutput(node, i, &error); |
175 | if (error) AddIndexError(node, i); |
176 | return !error; |
177 | } |
178 | |
179 | } // namespace tensorflow |
180 | |