1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_ |
18 | |
19 | #include <functional> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/attr_value_util.h" |
23 | #include "tensorflow/core/framework/node_def.pb.h" |
24 | #include "tensorflow/core/framework/node_def_util.h" |
25 | #include "tensorflow/core/framework/op.h" |
26 | #include "tensorflow/core/framework/op_def.pb.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/graph/graph.h" |
29 | #include "tensorflow/core/graph/graph_node_util.h" |
30 | #include "tensorflow/core/lib/core/status.h" |
31 | #include "tensorflow/core/lib/gtl/array_slice.h" |
32 | #include "tensorflow/core/lib/strings/strcat.h" |
33 | |
34 | namespace tensorflow { |
35 | |
36 | class NodeDefBuilder; |
37 | typedef std::function<Status(const OpDef&, int, const NodeDef&, |
38 | NodeDefBuilder*)> |
39 | FakeInputFunctor; |
40 | |
41 | // This is a helper for creating a NodeDef. Automatically sets attrs |
42 | // that can be inferred from the inputs, and uses default values |
43 | // (where they exist) for unspecified attrs. Example usage: |
44 | // |
45 | // NodeDef node_def; |
46 | // Status status = NodeDefBuilder(node_name, op_name) |
47 | // .Input(...) |
48 | // .Attr(...) |
49 | // .Finalize(&node_def); |
50 | // if (!status.ok()) return status; |
51 | // // Use node_def here. |
52 | class NodeDefBuilder { |
53 | public: |
54 | // To specify an output to be consumed by one of the Input() methods below. |
55 | struct NodeOut { |
56 | NodeOut(StringPiece n, int i, DataType dt); |
57 | NodeOut(); // uninitialized, call Reset() before use. |
58 | void Reset(StringPiece n, int i, DataType dt); |
59 | string node; |
60 | int index; |
61 | DataType data_type; |
62 | }; |
63 | |
64 | // Specify the name and the Op (either via an OpDef or the name of |
65 | // the Op plus a registry) for the NodeDef. Other fields are |
66 | // specified by calling the methods below. |
67 | // REQUIRES: The OpDef must satisfy ValidateOpDef(). |
68 | NodeDefBuilder(StringPiece name, StringPiece op_name, |
69 | const OpRegistryInterface* op_registry = OpRegistry::Global(), |
70 | const NodeDebugInfo* debug = nullptr); |
71 | NodeDefBuilder(StringPiece name, StringPiece op_name, |
72 | const NodeDebugInfo& debug); |
73 | // REQUIRES: in addition, *op_def must outlive *this. |
74 | NodeDefBuilder(StringPiece name, const OpDef* op_def); |
75 | |
76 | // You must call one Input() function per input_arg in the Op, |
77 | // *and in the same order as the input_args appear in the OpDef.* |
78 | |
79 | // For inputs that take a single tensor. |
80 | NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt); |
81 | NodeDefBuilder& Input(const NodeOut& src); |
82 | |
83 | // For inputs that take a list of tensors. |
84 | NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list); |
85 | |
86 | // To create inputs in tests, see fake_input.h. |
87 | NodeDefBuilder& Input(FakeInputFunctor fake_input); |
88 | |
89 | // Specify that this node must only run after src_node. |
90 | NodeDefBuilder& ControlInput(StringPiece src_node); |
91 | |
92 | // Constrains what devices this node may be scheduled on. |
93 | NodeDefBuilder& Device(StringPiece device_spec); |
94 | |
95 | // Sets the attr, if not already set. If already set with a different |
96 | // value, an error will be returned from Finalize(). |
97 | NodeDefBuilder& Attr(StringPiece name, const AttrValue& value); |
98 | NodeDefBuilder& Attr(StringPiece name, AttrValue&& value); |
99 | NodeDefBuilder& Attr(StringPiece name, StringPiece value); |
100 | NodeDefBuilder& Attr(StringPiece name, const char* value); |
101 | NodeDefBuilder& Attr(StringPiece name, int32_t value); |
102 | NodeDefBuilder& Attr(StringPiece name, int64_t value); |
103 | NodeDefBuilder& Attr(StringPiece name, float value); |
104 | NodeDefBuilder& Attr(StringPiece name, double value); |
105 | NodeDefBuilder& Attr(StringPiece name, bool value); |
106 | NodeDefBuilder& Attr(StringPiece name, DataType value); |
107 | NodeDefBuilder& Attr(StringPiece name, const PartialTensorShape& value); |
108 | NodeDefBuilder& Attr(StringPiece name, const Tensor& value); |
109 | NodeDefBuilder& Attr(StringPiece name, const TensorProto& value); |
110 | NodeDefBuilder& Attr(StringPiece name, const NameAttrList& value); |
111 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value); |
112 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value); |
113 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value); |
114 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<tstring> value); |
115 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value); |
116 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64_t> value); |
117 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value); |
118 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<bool> value); |
119 | NodeDefBuilder& Attr(StringPiece name, const std::vector<bool>& value); |
120 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<DataType> value); |
121 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<TensorShape> value); |
122 | NodeDefBuilder& Attr(StringPiece name, |
123 | gtl::ArraySlice<PartialTensorShape> value); |
124 | NodeDefBuilder& Attr(StringPiece name, |
125 | gtl::ArraySlice<TensorShapeProto> value); |
126 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<Tensor> value); |
127 | NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<NameAttrList> value); |
128 | |
129 | template <class T> |
130 | NodeDefBuilder& Attr(StringPiece name, std::initializer_list<T> value) { |
131 | return Attr(name, gtl::ArraySlice<T>(value)); |
132 | } |
133 | |
134 | // Finish building the NodeDef, returning any errors or setting |
135 | // *node_def if none. |
136 | // If `consume` is true, the builder state will be moved into `node_def`, |
137 | // and the builder will be left in an undefined state. |
138 | // WARNING: Not all problems are detected! The resulting NodeDef may |
139 | // not be valid! Call ValidateNodeDef() from node_def_utils to be sure. |
140 | Status Finalize(NodeDef* node_def, bool consume = false); |
141 | |
142 | // Accessors for the values set in the constructor. |
143 | const string& node_name() const { return node_def_.name(); } |
144 | const OpDef& op_def() const { return *op_def_; } |
145 | |
146 | private: |
147 | // Called in the constructors. |
148 | void Initialize(); |
149 | |
150 | // Get the current ArgDef and advance to the next one. Returns nullptr |
151 | // if no more inputs are available. |
152 | const OpDef::ArgDef* NextArgDef(); |
153 | |
154 | // Returns true if there is still an input_arg available in *op_def_, |
155 | // otherwise adds to error_ and returns false. |
156 | bool NextArgAvailable(); |
157 | |
158 | // These do the main work of the Input() methods. |
159 | void SingleInput(const OpDef::ArgDef* input_arg, StringPiece src_node, |
160 | int src_index, DataType dt); |
161 | void ListInput(const OpDef::ArgDef* input_arg, |
162 | gtl::ArraySlice<NodeOut> src_list); |
163 | |
164 | // Add "src_node:src_index" to the list of inputs in the node_def_. |
165 | void AddInput(StringPiece src_node, int src_index); |
166 | |
167 | // Generate an error if you can't pass dt when expected is expected. |
168 | void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected, |
169 | DataType dt); |
170 | |
171 | // If input_arg->is_ref() is true, generate an error if dt is not a ref. |
172 | void VerifyInputRef(const OpDef::ArgDef* input_arg, DataType dt); |
173 | |
174 | // Makes dt a ref type if that is what the input_arg specifies. |
175 | DataType MaybeAddRef(const OpDef::ArgDef* input_arg, DataType dt) { |
176 | return input_arg->is_ref() ? MakeRefType(dt) : dt; |
177 | } |
178 | |
179 | // Returns true if an attr named `name` is already present in the node_def_. |
180 | // If such an attr is already present and `value` is not equal to the present |
181 | // value, an error is generated. |
182 | bool AttrValueAlreadyPresent(StringPiece name, const AttrValue& value); |
183 | |
184 | const OpDef* op_def_; |
185 | NodeDef node_def_; |
186 | int inputs_specified_; |
187 | std::vector<string> control_inputs_; |
188 | std::vector<string> errors_; |
189 | }; |
190 | |
191 | } // namespace tensorflow |
192 | |
193 | #endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_ |
194 | |