1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ |
17 | #define TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ |
18 | |
19 | // Must be included first |
20 | #include "tensorflow/python/lib/core/numpy.h" |
21 | |
22 | #include "tensorflow/c/c_api.h" |
23 | #include "tensorflow/core/framework/graph.pb.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | // Container types for the various arguments and temporary values used |
31 | // in the wrapper. |
32 | |
33 | // A NameVector is a vector of tensor or operation names, as borrowed |
34 | // C strings. |
35 | typedef tensorflow::gtl::InlinedVector<const char*, 8> NameVector; |
36 | |
37 | // A PyObjectVector is a vector of borrowed pointers to PyObjects. |
38 | typedef tensorflow::gtl::InlinedVector<PyObject*, 8> PyObjectVector; |
39 | |
40 | // A TF_TensorVector is a vector of borrowed pointers to TF_Tensors. |
41 | typedef gtl::InlinedVector<TF_Tensor*, 8> TF_TensorVector; |
42 | |
43 | TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts, |
44 | TF_Status* status); |
45 | |
46 | // Run the graph associated with the session starting with the |
47 | // supplied inputs[]. Regardless of success or failure, inputs[] are |
48 | // stolen by the implementation (i.e. the implementation will |
49 | // eventually call Py_DECREF on each array input). |
50 | // |
51 | // The PyObject* feed_dict must be a dictionary mapping strings to |
52 | // NumPy arrays. This function does not modify its reference count. |
53 | // |
54 | // On success, the tensors corresponding to output_names[0,noutputs-1] |
55 | // are placed in out_values[], and these outputs[] become the property |
56 | // of the caller (the caller must eventually call Py_DECREF on them). |
57 | // |
58 | // On failure, out_status contains a tensorflow::Status with an error |
59 | // message. |
60 | void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options, |
61 | PyObject* feed_dict, const NameVector& output_names, |
62 | const NameVector& target_nodes, TF_Status* out_status, |
63 | PyObjectVector* out_values, TF_Buffer* run_outputs); |
64 | |
65 | // Python wrappers for the `Session::MakeCallable()` API. |
66 | void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session, |
67 | const TF_Buffer* callable_options, |
68 | int64_t* out_handle, TF_Status* status); |
69 | void TF_SessionMakeCallable(TF_Session* session, |
70 | const TF_Buffer* callable_options, |
71 | int64_t* out_handle, TF_Status* status); |
72 | |
73 | // Python wrappers for the `Session::RunCallable()` API. |
74 | void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session, |
75 | int64_t handle, PyObject* feed_values, |
76 | PyObjectVector* out_values, |
77 | TF_Buffer* run_metadata, |
78 | TF_Status* status); |
79 | void TF_SessionRunCallable(TF_Session* session, int64_t handle, |
80 | PyObject* feed_values, PyObjectVector* out_values, |
81 | TF_Buffer* run_metadata, TF_Status* status); |
82 | |
83 | // Python wrappers for the `Session::ReleaseCallable()` API. |
84 | void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session, |
85 | int64_t handle, TF_Status* status); |
86 | void TF_SessionReleaseCallable(TF_Session* session, int64_t handle, |
87 | TF_Status* status); |
88 | |
89 | // Set up the graph with the intended feeds and fetches for partial run. |
90 | // *out_handle is owned by the caller. |
91 | // |
92 | // On success, returns a handle that is used for subsequent PRun calls. |
93 | // |
94 | // On failure, out_status contains a tensorflow::Status with an error |
95 | // message. |
96 | void TF_PRunSetup_wrapper(TF_DeprecatedSession* session, |
97 | const NameVector& input_names, |
98 | const NameVector& output_names, |
99 | const NameVector& target_nodes, TF_Status* out_status, |
100 | const char** out_handle); |
101 | |
102 | // Continue to run the graph with additional feeds and fetches. The |
103 | // execution state is uniquely identified by the handle. |
104 | // |
105 | // The PyObject* feed_dict must be a dictionary mapping strings to |
106 | // NumPy arrays. This function does not modify its reference count. |
107 | // |
108 | // On success, the tensors corresponding to output_names[0,noutputs-1] |
109 | // are placed in out_values[], and these outputs[] become the property |
110 | // of the caller (the caller must eventually call Py_DECREF on them). |
111 | // |
112 | // On failure, out_status contains a tensorflow::Status with an error |
113 | // message. |
114 | void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle, |
115 | PyObject* feed_dict, const NameVector& output_names, |
116 | TF_Status* out_status, PyObjectVector* out_values); |
117 | |
118 | // Wrapper for TF_Reset that converts the string vectors to character arrays. |
119 | void TF_Reset_wrapper(const TF_SessionOptions* opt, |
120 | const NameVector& containers, TF_Status* status); |
121 | |
122 | // Convenience wrapper around EqualGraphDef to make it easier to wrap. |
123 | // Returns an explanation if a difference is found, or the empty string |
124 | // for no difference. |
125 | string EqualGraphDefWrapper(const string& actual, const string& expected); |
126 | |
127 | // Convenience wrapper around AreAttrValuesEqual to make it easier to wrap. |
128 | // The actual and expected strings must correspond to a serialized binary |
129 | // representation of two AttrValue proto instances. |
130 | // Returns an explanation if a difference is found, or the empty string |
131 | // for no difference. |
132 | string EqualAttrValueWrapper(const string& actual, const string& expected); |
133 | |
134 | // Gets shape from C API Graph object. |
135 | // |
136 | // If shape is known, returns shape vector where -1 means "unknown |
137 | // dimension". Sets unknown_shape to false. |
138 | // |
139 | // If shape is unknown, sets unknown_shape to true. |
140 | tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper( |
141 | TF_Graph* graph, TF_Output output, TF_Status* status, bool* unknown_shape); |
142 | |
143 | // Runs the graph associated with the session starting with the supplied inputs. |
144 | // On success, `py_outputs` is populated with a numpy ndarray for each output |
145 | // (the caller must decref these ndarrays, although this will likely be handled |
146 | // by the Python gc). `session`, `out_status`, and `py_outputs` must be |
147 | // non-null. `py_outputs` should be empty. |
148 | void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options, |
149 | const std::vector<TF_Output>& inputs, |
150 | const std::vector<PyObject*>& input_ndarrays, |
151 | const std::vector<TF_Output>& outputs, |
152 | const std::vector<TF_Operation*>& targets, |
153 | TF_Buffer* run_metadata, TF_Status* status, |
154 | std::vector<PyObject*>* py_outputs); |
155 | |
156 | // Set up the graph with the intended feeds (inputs) and fetches (output) for |
157 | // a sequence of partial run calls. |
158 | // |
159 | // On success, returns a handle that can be used for subsequent PRun calls. The |
160 | // handle is owned by the caller and should be deleted with TF_DeletePRunHandle |
161 | // when it is no longer needed. |
162 | // |
163 | // On failure, out_status contains a tensorflow::Status with an error |
164 | // message. |
165 | void TF_SessionPRunSetup_wrapper(TF_Session* session, |
166 | const std::vector<TF_Output>& inputs, |
167 | const std::vector<TF_Output>& outputs, |
168 | const std::vector<TF_Operation*>& targets, |
169 | const char** out_handle, TF_Status* status); |
170 | |
171 | // Continue to run the graph with additional feeds and fetches. The |
172 | // execution state is uniquely identified by the handle. |
173 | // |
174 | // On success, `py_outputs` is populated with a numpy ndarray for each output |
175 | // (the caller must decref these ndarrays, although this will likely be handled |
176 | // by the Python gc). `session`, `handle`, `out_status`, and `py_outputs` must |
177 | // be non-null. `py_outputs` should be empty. |
178 | // |
179 | // On failure, out_status contains a tensorflow::Status with an error |
180 | // message. |
181 | void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, |
182 | const std::vector<TF_Output>& inputs, |
183 | const std::vector<PyObject*>& input_ndarrays, |
184 | const std::vector<TF_Output>& outputs, |
185 | TF_Status* status, |
186 | std::vector<PyObject*>* py_outputs); |
187 | |
188 | // Retrieves the inputs of this operation. |
189 | std::vector<TF_Output> GetOperationInputs(TF_Operation* oper); |
190 | |
191 | // Retrieves the control inputs of this operation. |
192 | std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper( |
193 | TF_Operation* oper); |
194 | |
195 | // Retrieves the control outputs of this operation. |
196 | std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper( |
197 | TF_Operation* oper); |
198 | |
199 | // Retrieves the op names of the consumers of `oper_out`. The returned strings |
200 | // have the lifetime of the underlying TF_Graph. |
201 | std::vector<const char*> TF_OperationOutputConsumers_wrapper( |
202 | TF_Output oper_out); |
203 | |
204 | // `opers` equaling NULL are converted to `nopers = -1`. |
205 | // `output_names` must be empty or have the same length as `outputs`. |
206 | TF_Function* TF_GraphToFunction_wrapper( |
207 | const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name, |
208 | const std::vector<TF_Operation*>* opers, |
209 | const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs, |
210 | const NameVector& output_names, |
211 | const std::vector<TF_Operation*>* control_outputs, |
212 | const NameVector& control_output_names, const TF_FunctionOptions* opts, |
213 | const char* description, TF_Status* status); |
214 | |
215 | // Set the shapes and types for the output's handle. |
216 | // |
217 | // The sizes of 'shapes', 'ranks', and 'types' must be equal; `shapes[i]` |
218 | // contains the shape of the handle's i-th value, `ranks[i]` contains the i-th |
219 | // shape's rank, and `types[i]` contains the i-th value's dtype. If the i-th |
220 | // shape is unknown, then `ranks[i]` must be equal to -1. |
221 | // |
222 | // The space between the double angle brackets below looks extraneous, but |
223 | // our version of SWIG cannot parse ">>". |
224 | void TF_GraphSetOutputHandleShapesAndTypes_wrapper( |
225 | TF_Graph* graph, TF_Output output, |
226 | const std::vector<std::vector<int64_t> >& shapes, |
227 | const std::vector<int>& ranks, const std::vector<TF_DataType>& types, |
228 | TF_Status* status); |
229 | |
230 | // Creates Placeholders with specified types in the Graph. |
231 | // |
232 | // This is an internal API used to speed up creation of unused placeholders |
233 | // in while_v2 cond graph and is subject to change/removal. |
234 | std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes, |
235 | const char* prefix, |
236 | TF_Status* status); |
237 | |
238 | // Set the shape of output. If unknown is true, `num_dims` must be set to |
239 | // -1 and `dims` is set to nullptr. |
240 | void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, |
241 | const std::vector<int64_t>& dims, |
242 | bool unknown_shape, TF_Status* status); |
243 | |
244 | // Returns the string representations of the missing unused input mappings. |
245 | std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( |
246 | TF_ImportGraphDefResults* results); |
247 | |
248 | // If evaluation was possible, returns the numpy ndarray of the evaluated |
249 | // result. Otherwise returns None. |
250 | PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output, |
251 | TF_Status* status); |
252 | |
253 | } // namespace tensorflow |
254 | |
255 | #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ |
256 | |