1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/node_builder.h"
17
18#include <unordered_map>
19#include <vector>
20
21#include "tensorflow/core/framework/node_def_util.h"
22#include "tensorflow/core/framework/types.pb.h"
23#include "tensorflow/core/framework/versions.pb.h"
24#include "tensorflow/core/lib/core/errors.h"
25#include "tensorflow/core/platform/statusor.h"
26#include "tensorflow/core/protobuf/error_codes.pb.h"
27
28namespace tensorflow {
29
30NodeBuilder::NodeOut::NodeOut(Node* n, int32_t i) // NOLINT(runtime/explicit)
31 : node(n),
32 error(false),
33 name(node != nullptr ? node->name() : (error = true, "")),
34 index(i),
35 dt(SafeGetOutput(node, i, &error)) {}
36
37NodeBuilder::NodeOut::NodeOut(OutputTensor t) : NodeOut(t.node, t.index) {}
38
39NodeBuilder::NodeOut::NodeOut(StringPiece n, int32_t i, DataType t)
40 : node(nullptr), error(false), name(n), index(i), dt(t) {}
41
42NodeBuilder::NodeOut::NodeOut()
43 : node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
44
45NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
46 const OpRegistryInterface* op_registry,
47 const NodeDebugInfo* debug)
48 : def_builder_(name, op_name, op_registry, debug) {}
49
50NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
51 : def_builder_(name, op_def) {}
52
53NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder)
54 : def_builder_(def_builder) {}
55
56NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
57 inputs_.emplace_back(src_node, src_index);
58 DataType dt;
59 if (GetOutputType(src_node, src_index, &dt)) {
60 def_builder_.Input(src_node->name(), src_index, dt);
61 }
62 return *this;
63}
64
65NodeBuilder& NodeBuilder::Input(NodeOut src) {
66 if (src.error) {
67 AddIndexError(src.node, src.index);
68 } else {
69 inputs_.emplace_back(src.node, src.index);
70 def_builder_.Input(src.name, src.index, src.dt);
71 }
72 return *this;
73}
74
75NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
76 std::vector<NodeDefBuilder::NodeOut> srcs;
77 srcs.reserve(src_list.size());
78 for (const auto& node_out : src_list) {
79 if (node_out.error) {
80 AddIndexError(node_out.node, node_out.index);
81 } else {
82 srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
83 inputs_.emplace_back(node_out.node, node_out.index);
84 }
85 }
86 def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
87 return *this;
88}
89
90NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
91 control_inputs_.emplace_back(src_node);
92 def_builder_.ControlInput(src_node->name());
93 return *this;
94}
95
96NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
97 control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
98 src_nodes.end());
99 for (const Node* src_node : src_nodes) {
100 def_builder_.ControlInput(src_node->name());
101 }
102 return *this;
103}
104
105NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
106 def_builder_.Device(device_spec);
107 return *this;
108}
109
110NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
111 assigned_device_ = string(device);
112 return *this;
113}
114
115NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) {
116 def_builder_.Attr("_XlaCluster", xla_cluster);
117 return *this;
118}
119
120StatusOr<Node*> NodeBuilder::Finalize(Graph* graph, bool consume) {
121 Node* out;
122 TF_RETURN_IF_ERROR(Finalize(graph, &out, consume));
123 return out;
124}
125
126Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) {
127 // In case of error, set *created_node to nullptr.
128 if (created_node != nullptr) {
129 *created_node = nullptr;
130 }
131 if (!errors_.empty()) {
132 return errors::InvalidArgument(absl::StrJoin(errors_, "\n"));
133 }
134
135 NodeDef node_def;
136 TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume));
137 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
138 TF_RETURN_IF_ERROR(
139 CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer()));
140
141 TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(std::move(node_def)));
142
143 node->set_assigned_device_name(assigned_device_);
144
145 for (size_t i = 0; i < inputs_.size(); ++i) {
146 if (inputs_[i].node != nullptr) { // Skip back edges.
147 graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
148 }
149 }
150 for (Node* control_input : control_inputs_) {
151 graph->AddControlEdge(control_input, node);
152 }
153
154 if (created_node != nullptr) *created_node = node;
155
156 return OkStatus();
157}
158
159void NodeBuilder::AddIndexError(const Node* node, int i) {
160 if (node == nullptr) {
161 errors_.emplace_back(
162 strings::StrCat("Attempt to add nullptr Node to node with type ",
163 def_builder_.op_def().name()));
164 } else {
165 errors_.emplace_back(strings::StrCat(
166 "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
167 node->num_outputs(), ") to node with type ",
168 def_builder_.op_def().name(), ". Node: ", FormatNodeForError(*node)));
169 }
170}
171
172bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) {
173 bool error;
174 *dt = SafeGetOutput(node, i, &error);
175 if (error) AddIndexError(node, i);
176 return !error;
177}
178
179} // namespace tensorflow
180