1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_C_PYTHON_API_H_
17#define TENSORFLOW_C_PYTHON_API_H_
18
19#include <string>
20
21#include "tensorflow/c/c_api.h"
22#include "tensorflow/core/framework/full_type.pb.h"
23
24// These functions can be removed without notice. They exist to facilitate some
25// refactoring of graph construction code in the Python API.
26
27namespace tensorflow {
28
29void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
30
31// Changes an attr value in the node_def Protocol Buffer and sets a status upon
32// completion.
33void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
34 TF_Buffer* attr_value_proto, TF_Status* status);
35
36// Clears the attr in the node_def Protocol Buffer and sets a status upon
37// completion.
38void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
39 TF_Status* status);
40
41// Sets the experimental_type` field in the node_def Protocol Buffer.
42void SetFullType(TF_Graph* graph, TF_Operation* op,
43 const FullTypeDef& full_type);
44
45void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
46
47// Updates 'dst' to consume 'new_src'.
48void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
49 TF_Status* status);
50
51void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op);
52
53// Sets whether ops missing a shape inference function should trigger an
54// error. The default is true.
55void SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
56
57// Extends `session` with any new operations added to its associated graph.
58// Usually this happens automatically in TF_SessionRun. After this is called,
59// TF_SessionRun will no longer extend the session on every call.
60//
61// We expose this here to allow fine-grained synchronization in multi-threaded
62// workloads, which is required since the Python implementation depends on the
63// above mutation methods. This allows us to prevent modifications to nodes in
64// the graph after the session has been made aware of them.
65void ExtendSession(TF_Session* session, TF_Status* status);
66
67// Returns the serialized CppShapeInferenceResult::HandleData proto for
68// `output` if its a resource or variant tensor, or otherwise returns the empty
69// string.
70std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
71
72// Sets `output` based on `proto`, which should be a serialized
73// CppShapeInferenceResult::HandleData proto. `output` should be a resource
74// or variant tensor.
75// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string
76// because I couldn't get SWIG to work otherwise.
77void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
78 size_t proto_len, TF_Status* status);
79
80// This method is used to add a new input edge to 'dst', which must be a While
81// op. The While op's "T" attribute must have already been updated to include
82// the new edge. This is used to construct tf.while_loop gradients.
83void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
84 TF_Status* status);
85
86} // namespace tensorflow
87
88#endif // TENSORFLOW_C_PYTHON_API_H_
89