1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Class and associated machinery for specifying an Op's OpDef and shape |
17 | // inference function for Op registration. |
18 | |
19 | #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_ |
20 | #define TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_ |
21 | |
22 | #include <string> |
23 | #include <utility> |
24 | #include <vector> |
25 | |
26 | #include "tensorflow/core/framework/full_type.pb.h" |
27 | #include "tensorflow/core/framework/op_def.pb.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/lib/core/status.h" |
30 | #include "tensorflow/core/lib/core/stringpiece.h" |
31 | #include "tensorflow/core/platform/macros.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | // TODO(b/62899350): Refactor without proto dependencies. |
36 | typedef std::function<Status(OpDef* c)> OpTypeConstructor; |
37 | |
38 | typedef std::vector<std::reference_wrapper<const FullTypeDef>> TypeRefVector; |
39 | typedef std::map<std::string, std::reference_wrapper<const FullTypeDef>> |
40 | TypeRefMap; |
41 | |
42 | // A type inference function, called for each node during type inference |
43 | // (possibly multiple times). |
44 | // The first argument (input_types) will hold the type of each of the node's |
45 | // inputs. The second argument (type_vars) will hold the return type of |
46 | // each function referred from any type variable (e.g. `FuncVar`) present |
47 | // in the node's corresponding op definition. |
48 | // |
49 | // TODO(mdan): Consider a vector-in, vector-out contract. |
50 | // TODO(mdan): Rename to just TypeInferenceFn (since it's not always "forward"). |
51 | typedef std::function<StatusOr<FullTypeDef>(const TypeRefVector&, |
52 | const TypeRefMap&)> |
53 | ForwardTypeInferenceFn; |
54 | |
55 | class FunctionDefHelper; |
56 | |
57 | namespace shape_inference { |
58 | class InferenceContext; |
59 | } |
60 | typedef std::function<Status(shape_inference::InferenceContext* c)> |
61 | OpShapeInferenceFn; |
62 | |
63 | struct OpRegistrationData { |
64 | public: |
65 | OpRegistrationData() {} |
66 | OpRegistrationData(const OpDef& def) : op_def(def) {} |
67 | OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn, |
68 | bool is_function = false) |
69 | : op_def(def), shape_inference_fn(fn), is_function_op(is_function) {} |
70 | |
71 | OpDef op_def; |
72 | OpShapeInferenceFn shape_inference_fn; |
73 | |
74 | // Type constructor. This callable initializes the type of this op. |
75 | // It is provided as a programmatic mechanism for defining an op's |
76 | // type, as part of its registration. It is to be eventually replaced by a |
77 | // textual language. |
78 | // |
79 | // Important: historically, op registrations only contained partial |
80 | // input/output type information in non-standardized attribute declarations |
81 | // (e.g. typically, input types were held in a `dtype` attribute). The type |
82 | // constructor currently duplicates such attribute information, with the aim |
83 | // of entirely subsuming it, and eventually deprecating all type-related |
84 | // attributes. |
85 | // |
86 | // Since ops are typically parametrized, the type created by this constructor |
87 | // is also parametric. |
88 | // |
89 | // Example: for an op `Foo(x: T) -> Bar[T]`: |
90 | // |
91 | // * typically, its op registration included a single attribute `T: type`; |
92 | // then the respective input was defined as `x: T`; the output type `Bar` |
93 | // was implied by the op name. |
94 | // * the type constructor creates a FullType object containing `Bar[T]`; this |
95 | // still relies on the `T` attribute which it references. |
96 | // * in the future, the type constructor will create a FullType containing |
97 | // `Callable[(x: T), Bar[T]]`, and the attribute `T` will be deprecated. |
98 | OpTypeConstructor type_ctor; |
99 | |
100 | // Forward type inference function. This callable infers the return type of an |
101 | // op based on its input types. |
102 | // |
103 | // Note that the type constructor and forward inference functions need not be |
104 | // mutually exclusive: if there is some static information that can be set |
105 | // based on attributes, then that should be set in the constructor. If more |
106 | // information can be extracted from inputs, that should be done in the |
107 | // forward inference function. |
108 | // |
109 | // This is similar to the shape function, but is more general, and applied |
110 | // directly to NodeDefs, rather than working on the ShapeAndType structures. |
111 | // Note that the op input/output declarations may specify some implicit type |
112 | // constraints through attribute references (i.e. two inputs pointing to the |
113 | // same type attribute). Those constraints may duplicate what this function |
114 | // specifies in its body. That's intended, for a gradual transition to a more |
115 | // formal type system. |
116 | // |
117 | // These type inference functions are intermediate solutions as well: once the |
118 | // op registration has a complete, formal type definition, along with |
119 | // a solver-based type inference, it will replace these functions. |
120 | // |
121 | // TODO(mdan): Merge with shape inference. |
122 | // TODO(mdan): Replace with a union-based type inference algorithm. |
123 | ForwardTypeInferenceFn fwd_type_fn; |
124 | |
125 | // Reverse type inference function. This callable infers some input types |
126 | // based on the return type. |
127 | // |
128 | // TODO(mdan): Replace with a union-based type inference algorithm. |
129 | ForwardTypeInferenceFn rev_type_fn; |
130 | |
131 | // The input number affected by reverse type inference. Only one input may be |
132 | // updated in this manner. |
133 | // TODO(mdan): Encode in a manner more consistent with the forward version. |
134 | int rev_type_input; |
135 | |
136 | bool is_function_op = false; |
137 | }; |
138 | |
139 | // Builder class passed to the REGISTER_OP() macro. |
140 | class OpDefBuilder { |
141 | public: |
142 | // Constructs an OpDef with just the name field set. |
143 | explicit OpDefBuilder(std::string op_name); |
144 | |
145 | // Adds an attr to this OpDefBuilder (and returns *this). The spec has |
146 | // format "<name>:<type>" or "<name>:<type>=<default>" |
147 | // where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]* |
148 | // (by convention only using capital letters for attrs that can be inferred) |
149 | // <type> can be: |
150 | // "string", "int", "float", "bool", "type", "shape", or "tensor" |
151 | // "numbertype", "realnumbertype", "quantizedtype" |
152 | // (meaning "type" with a restriction on valid values) |
153 | // "{int32,int64}" or {realnumbertype,quantizedtype,string}" |
154 | // (meaning "type" with a restriction containing unions of value types) |
155 | // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}" |
156 | // (meaning "string" with a restriction on valid values) |
157 | // "list(string)", ..., "list(tensor)", "list(numbertype)", ... |
158 | // (meaning lists of the above types) |
159 | // "int >= 2" (meaning "int" with a restriction on valid values) |
160 | // "list(string) >= 2", "list(int) >= 2" |
161 | // (meaning "list(string)" / "list(int)" with length at least 2) |
162 | // <default>, if included, should use the Proto text format |
163 | // of <type>. For lists use [a, b, c] format. |
164 | // |
165 | // Note that any attr specifying the length of an input or output will |
166 | // get a default minimum of 1 unless the >= # syntax is used. |
167 | // |
168 | // TODO(josh11b): Perhaps support restrictions and defaults as optional |
169 | // extra arguments to Attr() instead of encoding them in the spec string. |
170 | // TODO(josh11b): Would like to have better dtype handling for tensor attrs: |
171 | // * Ability to say the type of an input/output matches the type of |
172 | // the tensor. |
173 | // * Ability to restrict the type of the tensor like the existing |
174 | // restrictions for type attrs. |
175 | // Perhaps by linking the type of the tensor to a type attr? |
176 | OpDefBuilder& Attr(std::string spec); |
177 | |
178 | // Adds an input or output to this OpDefBuilder (and returns *this). |
179 | // The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)" |
180 | // where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be: |
181 | // * For a single tensor: <type> |
182 | // * For a sequence of tensors with the same type: <number>*<type> |
183 | // * For a sequence of tensors with different types: <type-list> |
184 | // Where: |
185 | // <type> is either one of "float", "int32", "string", ... |
186 | // or the name of an attr (see above) with type "type". |
187 | // <number> is the name of an attr with type "int". |
188 | // <type-list> is the name of an attr with type "list(type)". |
189 | // TODO(josh11b): Indicate Ref() via an optional argument instead of |
190 | // in the spec? |
191 | // TODO(josh11b): SparseInput() and SparseOutput() matching the Python |
192 | // handling? |
193 | OpDefBuilder& Input(std::string spec); |
194 | OpDefBuilder& Output(std::string spec); |
195 | |
196 | // Turns on the indicated boolean flag in this OpDefBuilder (and |
197 | // returns *this). |
198 | OpDefBuilder& SetIsCommutative(); |
199 | OpDefBuilder& SetIsAggregate(); |
200 | OpDefBuilder& SetIsStateful(); |
201 | OpDefBuilder& SetAllowsUninitializedInput(); |
202 | OpDefBuilder& SetIsDistributedCommunication(); |
203 | |
204 | // Deprecate the op at a certain GraphDef version. |
205 | OpDefBuilder& Deprecated(int version, std::string explanation); |
206 | |
207 | // Adds docs to this OpDefBuilder (and returns *this). |
208 | // Docs have the format: |
209 | // <1-line summary> |
210 | // <rest of the description> |
211 | // <name>: <description of name> |
212 | // <name>: <description of name> |
213 | // <if long, indent the description on subsequent lines> |
214 | // Where <name> is the name of an attr, input, or output. Please |
215 | // wrap docs at 72 columns so that it may be indented in the |
216 | // generated output. For tensor inputs or outputs (not attrs), you |
217 | // may start the description with an "=" (like name:= <description>) |
218 | // to suppress the automatically-generated type documentation in |
219 | // generated output. |
220 | OpDefBuilder& Doc(std::string text); |
221 | |
222 | // Sets the function to be used as type constructor. |
223 | // See OpRegistrationData::type_ctor. |
224 | OpDefBuilder& SetTypeConstructor(OpTypeConstructor c); |
225 | |
226 | // Sets the function to be used for forward type inference. |
227 | // See OpRegistrationData::fwd_type_fn. |
228 | OpDefBuilder& SetForwardTypeFn(ForwardTypeInferenceFn f); |
229 | |
230 | // Sets the function to be used for reverse type inference. |
231 | // See OpRegistrationData::rew_type_fn. |
232 | OpDefBuilder& SetReverseTypeFn(int input_number, ForwardTypeInferenceFn f); |
233 | |
234 | // Sets the shape function to be used for shape inference. |
235 | // |
236 | // Note that currently (October 2016), python code still requires a |
237 | // RegisterShape call to invoke this; see call_cpp_shape_fn in |
238 | // python/framework/common_shapes.py |
239 | OpDefBuilder& SetShapeFn(OpShapeInferenceFn fn); |
240 | |
241 | // Allows the `<type>` in calls to `Attr()` to be "any". |
242 | // This is used by PythonAPIWrapper for pass-through parameters. |
243 | OpDefBuilder& AllowAttrTypeAny(); |
244 | |
245 | // Sets op_reg_data->op_def to the requested OpDef and |
246 | // op_reg_data->shape_inference_fn to the requested shape inference function, |
247 | // or returns an error. |
248 | // Must be called after all of the above methods. |
249 | // |
250 | // Note that OpDefBuilder only reports parsing errors. You should also |
251 | // call ValidateOpDef() to detect other problems. |
252 | Status Finalize(OpRegistrationData* op_reg_data) const; |
253 | |
254 | private: |
255 | friend class FunctionDefHelper; |
256 | |
257 | // Adds control output to this OpDefBuilder (and returns *this). |
258 | // The <name> must be a valid node name (matches regexp |
259 | // [a-zA-Z][a-zA-Z0-9_]*). Named control output can only exist for functions. |
260 | OpDefBuilder& ControlOutput(std::string name); |
261 | |
262 | OpDef* op_def() { return &op_reg_data_.op_def; } |
263 | |
264 | OpRegistrationData op_reg_data_; |
265 | std::vector<string> attrs_; |
266 | std::vector<string> inputs_; |
267 | std::vector<string> outputs_; |
268 | std::vector<string> control_outputs_; |
269 | std::string doc_; |
270 | std::vector<string> errors_; |
271 | bool allow_attr_type_any_ = false; |
272 | }; |
273 | |
274 | } // namespace tensorflow |
275 | |
276 | #endif // TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_ |
277 | |