1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_C_C_API_INTERNAL_H_ |
17 | #define TENSORFLOW_C_C_API_INTERNAL_H_ |
18 | |
19 | #include "tensorflow/c/c_api.h" |
20 | |
21 | #include <list> |
22 | #include <set> |
23 | #include <string> |
24 | #include <unordered_map> |
25 | #include <vector> |
26 | |
27 | // clang-format off |
28 | // Required for IS_MOBILE_PLATFORM |
29 | #include "tensorflow/core/platform/platform.h" |
30 | // clang-format on |
31 | |
32 | #include "tensorflow/c/tf_status_internal.h" |
33 | #include "tensorflow/c/tf_tensor_internal.h" |
34 | #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) |
35 | #include "tensorflow/core/framework/op_gen_lib.h" |
36 | #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) |
37 | #include "tensorflow/core/common_runtime/shape_refiner.h" |
38 | #include "tensorflow/core/framework/tensor.h" |
39 | #include "tensorflow/core/framework/tensor_shape.h" |
40 | #include "tensorflow/core/graph/graph.h" |
41 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
42 | #include "tensorflow/core/graph/node_builder.h" |
43 | #include "tensorflow/core/platform/mutex.h" |
44 | #include "tensorflow/core/platform/status.h" |
45 | #include "tensorflow/core/platform/types.h" |
46 | #include "tensorflow/core/public/session.h" |
47 | |
48 | namespace tensorflow { |
49 | class Device; |
50 | class DeviceMgr; |
51 | class ServerInterface; |
52 | } // namespace tensorflow |
53 | |
54 | // Internal structures used by the C API. These are likely to change and should |
55 | // not be depended on. |
56 | |
57 | struct TF_SessionOptions { |
58 | tensorflow::SessionOptions options; |
59 | }; |
60 | |
61 | struct TF_DeprecatedSession { |
62 | tensorflow::Session* session; |
63 | }; |
64 | |
65 | struct TF_Library { |
66 | void* lib_handle; |
67 | TF_Buffer op_list; |
68 | }; |
69 | |
70 | struct TF_Graph { |
71 | TF_Graph(); |
72 | |
73 | mutable tensorflow::mutex mu; |
74 | tensorflow::Graph graph TF_GUARDED_BY(mu); |
75 | |
76 | // Runs shape inference. |
77 | tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu); |
78 | |
79 | // Maps from name of an operation to the Node* in 'graph'. |
80 | std::unordered_map<tensorflow::string, tensorflow::Node*> name_map |
81 | TF_GUARDED_BY(mu); |
82 | |
83 | // The keys of this map are all the active sessions using this graph. Each |
84 | // value records whether the graph has been mutated since the corresponding |
85 | // session has been run (this is detected in RecordMutation function). If the |
86 | // string is empty, no mutation has occurred. Otherwise the string is a |
87 | // description of the mutation suitable for returning to the user. |
88 | // |
89 | // Sessions are added to this map in TF_NewSession, and removed in |
90 | // TF_DeleteSession. |
91 | // TF_Graph may only / must be deleted when |
92 | // sessions.size() == 0 && delete_requested == true |
93 | // |
94 | // TODO(b/74949947): mutations currently trigger a warning instead of a bad |
95 | // status, this should be reverted when possible. |
96 | tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions |
97 | TF_GUARDED_BY(mu); |
98 | bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph |
99 | |
100 | // Used to link graphs contained in TF_WhileParams to the parent graph that |
101 | // will eventually contain the full while loop. |
102 | TF_Graph* parent; |
103 | TF_Output* parent_inputs; |
104 | }; |
105 | |
106 | struct TF_OperationDescription { |
107 | TF_OperationDescription(TF_Graph* g, const char* op_type, |
108 | const char* node_name) |
109 | : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} |
110 | |
111 | tensorflow::NodeBuilder node_builder; |
112 | TF_Graph* graph; |
113 | std::set<tensorflow::string> colocation_constraints; |
114 | }; |
115 | |
116 | struct TF_Operation { |
117 | tensorflow::Node node; |
118 | }; |
119 | |
120 | struct TF_Session { |
121 | TF_Session(tensorflow::Session* s, TF_Graph* g); |
122 | |
123 | tensorflow::Session* session; |
124 | TF_Graph* const graph; |
125 | |
126 | tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu); |
127 | int last_num_graph_nodes; |
128 | |
129 | // If true, TF_SessionRun and similar methods will call |
130 | // ExtendSessionGraphHelper before running the graph (this is the default |
131 | // public behavior). Can be set to false if the caller needs to call |
132 | // ExtendSessionGraphHelper manually. |
133 | std::atomic<bool> extend_before_run; |
134 | }; |
135 | |
136 | struct TF_ImportGraphDefOptions { |
137 | tensorflow::ImportGraphDefOptions opts; |
138 | |
139 | // Backing memory for TensorId fields in opts. |
140 | // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this. |
141 | std::list<tensorflow::string> tensor_id_data; |
142 | }; |
143 | |
144 | struct TF_ImportGraphDefResults { |
145 | std::vector<TF_Output> return_tensors; |
146 | std::vector<TF_Operation*> return_nodes; |
147 | std::vector<const char*> missing_unused_key_names; |
148 | std::vector<int> missing_unused_key_indexes; |
149 | |
150 | // Backing memory for missing_unused_key_names values. |
151 | std::list<tensorflow::string> missing_unused_key_names_data; |
152 | }; |
153 | |
154 | struct TF_DeviceList { |
155 | std::vector<tensorflow::DeviceAttributes> response; |
156 | }; |
157 | |
158 | struct TF_Function { |
159 | tensorflow::FunctionDef fdef; |
160 | tensorflow::StackTracesMap stack_traces; |
161 | }; |
162 | |
163 | struct TF_ApiDefMap { |
164 | explicit TF_ApiDefMap(const tensorflow::OpList& op_list) |
165 | : |
166 | #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) |
167 | api_def_map(op_list), |
168 | #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) |
169 | update_docs_called(false) { |
170 | } |
171 | |
172 | #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) |
173 | tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock); |
174 | #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) |
175 | bool update_docs_called TF_GUARDED_BY(lock); |
176 | tensorflow::mutex lock; |
177 | }; |
178 | |
179 | #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) |
180 | struct TF_Server { |
181 | TF_Server(std::unique_ptr<tensorflow::ServerInterface> server); |
182 | |
183 | const tensorflow::string target; |
184 | std::unique_ptr<tensorflow::ServerInterface> server; |
185 | }; |
186 | #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) |
187 | |
188 | namespace tensorflow { |
189 | |
190 | // Set the shapes and types of the output's handle. |
191 | // |
192 | // The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must |
193 | // all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the |
194 | // rank is known), then it must be equal to the length of `shapes[i]`; if |
195 | // `ranks[i] == 1`, then `shapes[i]` may be nullptr. |
196 | // |
197 | // TODO(akshayka): Implement a corresponding getter method. |
198 | void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, |
199 | int num_shapes_and_types, |
200 | const int64_t** shapes, |
201 | const int* ranks, |
202 | const TF_DataType* types, |
203 | TF_Status* status); |
204 | |
205 | void RecordMutation(TF_Graph* graph, const TF_Operation& op, |
206 | const char* mutation_type) |
207 | TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu); |
208 | |
209 | bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) |
210 | TF_LOCKS_EXCLUDED(session->graph->mu, session->mu); |
211 | |
212 | std::string getTF_OutputDebugString(TF_Output node); |
213 | |
214 | } // end namespace tensorflow |
215 | |
216 | #endif // TENSORFLOW_C_C_API_INTERNAL_H_ |
217 | |