1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ |
18 | |
19 | #include "tensorflow/core/framework/node_def.pb.h" |
20 | #include "tensorflow/core/framework/op_def.pb.h" |
21 | #include "tensorflow/core/framework/op_def_builder.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | #include "tensorflow/core/lib/core/status.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | class OpRegistryInterface; |
28 | |
29 | struct NodeProperties { |
30 | public: |
31 | NodeProperties(const OpDef* op_def, NodeDef node_def, |
32 | const DataTypeSlice inputs, const DataTypeSlice outputs) |
33 | : NodeProperties(op_def, std::move(node_def), |
34 | DataTypeVector(inputs.begin(), inputs.end()), |
35 | DataTypeVector(outputs.begin(), outputs.end())) {} |
36 | |
37 | NodeProperties(const OpDef* _op_def, NodeDef&& _node_def, |
38 | DataTypeVector inputs, DataTypeVector outputs) |
39 | : op_def(_op_def), |
40 | node_def(std::move(_node_def)), |
41 | input_types(std::move(inputs)), |
42 | input_types_slice(input_types), |
43 | output_types(std::move(outputs)), |
44 | output_types_slice(output_types) {} |
45 | |
46 | // Resets the 'props' shared pointer to point to a new NodeProperties created |
47 | // from the given NodeDef. 'op_registry' is used to look up the OpDef |
48 | // corresponding to node_def.op(). Returns an error if OpDef lookup or |
49 | // creation failed. |
50 | static Status CreateFromNodeDef(NodeDef node_def, |
51 | const OpRegistryInterface* op_registry, |
52 | std::shared_ptr<const NodeProperties>* props); |
53 | |
54 | const OpDef* op_def; // not owned. |
55 | NodeDef node_def; |
56 | DataTypeVector input_types; |
57 | DataTypeSlice input_types_slice; |
58 | DataTypeVector output_types; |
59 | DataTypeSlice output_types_slice; |
60 | }; |
61 | |
62 | } // namespace tensorflow |
63 | |
64 | #endif // TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ |
65 | |