1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_C_C_API_INTERNAL_H_
17#define TENSORFLOW_C_C_API_INTERNAL_H_
18
19#include "tensorflow/c/c_api.h"
20
21#include <list>
22#include <set>
23#include <string>
24#include <unordered_map>
25#include <vector>
26
27// clang-format off
28// Required for IS_MOBILE_PLATFORM
29#include "tensorflow/core/platform/platform.h"
30// clang-format on
31
32#include "tensorflow/c/tf_status_internal.h"
33#include "tensorflow/c/tf_tensor_internal.h"
34#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
35#include "tensorflow/core/framework/op_gen_lib.h"
36#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
37#include "tensorflow/core/common_runtime/shape_refiner.h"
38#include "tensorflow/core/framework/tensor.h"
39#include "tensorflow/core/framework/tensor_shape.h"
40#include "tensorflow/core/graph/graph.h"
41#include "tensorflow/core/common_runtime/graph_constructor.h"
42#include "tensorflow/core/graph/node_builder.h"
43#include "tensorflow/core/platform/mutex.h"
44#include "tensorflow/core/platform/status.h"
45#include "tensorflow/core/platform/types.h"
46#include "tensorflow/core/public/session.h"
47
48namespace tensorflow {
49class Device;
50class DeviceMgr;
51class ServerInterface;
52} // namespace tensorflow
53
54// Internal structures used by the C API. These are likely to change and should
55// not be depended on.
56
57struct TF_SessionOptions {
58 tensorflow::SessionOptions options;
59};
60
61struct TF_DeprecatedSession {
62 tensorflow::Session* session;
63};
64
65struct TF_Library {
66 void* lib_handle;
67 TF_Buffer op_list;
68};
69
70struct TF_Graph {
71 TF_Graph();
72
73 mutable tensorflow::mutex mu;
74 tensorflow::Graph graph TF_GUARDED_BY(mu);
75
76 // Runs shape inference.
77 tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);
78
79 // Maps from name of an operation to the Node* in 'graph'.
80 std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
81 TF_GUARDED_BY(mu);
82
83 // The keys of this map are all the active sessions using this graph. Each
84 // value records whether the graph has been mutated since the corresponding
85 // session has been run (this is detected in RecordMutation function). If the
86 // string is empty, no mutation has occurred. Otherwise the string is a
87 // description of the mutation suitable for returning to the user.
88 //
89 // Sessions are added to this map in TF_NewSession, and removed in
90 // TF_DeleteSession.
91 // TF_Graph may only / must be deleted when
92 // sessions.size() == 0 && delete_requested == true
93 //
94 // TODO(b/74949947): mutations currently trigger a warning instead of a bad
95 // status, this should be reverted when possible.
96 tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
97 TF_GUARDED_BY(mu);
98 bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph
99
100 // Used to link graphs contained in TF_WhileParams to the parent graph that
101 // will eventually contain the full while loop.
102 TF_Graph* parent;
103 TF_Output* parent_inputs;
104};
105
106struct TF_OperationDescription {
107 TF_OperationDescription(TF_Graph* g, const char* op_type,
108 const char* node_name)
109 : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
110
111 tensorflow::NodeBuilder node_builder;
112 TF_Graph* graph;
113 std::set<tensorflow::string> colocation_constraints;
114};
115
116struct TF_Operation {
117 tensorflow::Node node;
118};
119
120struct TF_Session {
121 TF_Session(tensorflow::Session* s, TF_Graph* g);
122
123 tensorflow::Session* session;
124 TF_Graph* const graph;
125
126 tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu);
127 int last_num_graph_nodes;
128
129 // If true, TF_SessionRun and similar methods will call
130 // ExtendSessionGraphHelper before running the graph (this is the default
131 // public behavior). Can be set to false if the caller needs to call
132 // ExtendSessionGraphHelper manually.
133 std::atomic<bool> extend_before_run;
134};
135
136struct TF_ImportGraphDefOptions {
137 tensorflow::ImportGraphDefOptions opts;
138
139 // Backing memory for TensorId fields in opts.
140 // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this.
141 std::list<tensorflow::string> tensor_id_data;
142};
143
144struct TF_ImportGraphDefResults {
145 std::vector<TF_Output> return_tensors;
146 std::vector<TF_Operation*> return_nodes;
147 std::vector<const char*> missing_unused_key_names;
148 std::vector<int> missing_unused_key_indexes;
149
150 // Backing memory for missing_unused_key_names values.
151 std::list<tensorflow::string> missing_unused_key_names_data;
152};
153
154struct TF_DeviceList {
155 std::vector<tensorflow::DeviceAttributes> response;
156};
157
158struct TF_Function {
159 tensorflow::FunctionDef fdef;
160 tensorflow::StackTracesMap stack_traces;
161};
162
163struct TF_ApiDefMap {
164 explicit TF_ApiDefMap(const tensorflow::OpList& op_list)
165 :
166#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
167 api_def_map(op_list),
168#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
169 update_docs_called(false) {
170 }
171
172#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
173 tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock);
174#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
175 bool update_docs_called TF_GUARDED_BY(lock);
176 tensorflow::mutex lock;
177};
178
179#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
180struct TF_Server {
181 TF_Server(std::unique_ptr<tensorflow::ServerInterface> server);
182
183 const tensorflow::string target;
184 std::unique_ptr<tensorflow::ServerInterface> server;
185};
186#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
187
188namespace tensorflow {
189
190// Set the shapes and types of the output's handle.
191//
192// The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must
193// all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the
194// rank is known), then it must be equal to the length of `shapes[i]`; if
195// `ranks[i] == 1`, then `shapes[i]` may be nullptr.
196//
197// TODO(akshayka): Implement a corresponding getter method.
198void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
199 int num_shapes_and_types,
200 const int64_t** shapes,
201 const int* ranks,
202 const TF_DataType* types,
203 TF_Status* status);
204
205void RecordMutation(TF_Graph* graph, const TF_Operation& op,
206 const char* mutation_type)
207 TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
208
209bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
210 TF_LOCKS_EXCLUDED(session->graph->mu, session->mu);
211
212std::string getTF_OutputDebugString(TF_Output node);
213
214} // end namespace tensorflow
215
216#endif // TENSORFLOW_C_C_API_INTERNAL_H_
217