1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/python_api.h"
17
18#include "tensorflow/c/c_api_internal.h"
19#include "tensorflow/core/framework/full_type.pb.h"
20#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
21
22namespace tensorflow {
23
24void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
25 mutex_lock l(graph->mu);
26 graph->graph.AddControlEdge(&input->node, &op->node);
27 RecordMutation(graph, *op, "adding control input");
28}
29
30void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
31 TF_Buffer* attr_value_proto, TF_Status* status) {
32 AttrValue attr_val;
33 if (!attr_val.ParseFromArray(attr_value_proto->data,
34 attr_value_proto->length)) {
35 status->status =
36 tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
37 return;
38 }
39
40 mutex_lock l(graph->mu);
41 op->node.AddAttr(attr_name, attr_val);
42 RecordMutation(graph, *op, "setting attribute");
43}
44
45void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
46 TF_Status* status) {
47 mutex_lock l(graph->mu);
48 op->node.ClearAttr(attr_name);
49 RecordMutation(graph, *op, "clearing attribute");
50}
51
52void SetFullType(TF_Graph* graph, TF_Operation* op,
53 const FullTypeDef& full_type) {
54 mutex_lock l(graph->mu);
55 *op->node.mutable_def()->mutable_experimental_type() = full_type;
56 RecordMutation(graph, *op, "setting fulltype");
57}
58
59void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
60 mutex_lock l(graph->mu);
61 op->node.set_requested_device(device);
62 RecordMutation(graph, *op, "setting device");
63}
64
65void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
66 TF_Status* status) {
67 TF_UpdateEdge(graph, new_src, dst, status);
68}
69
70void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
71 mutex_lock l(graph->mu);
72 std::vector<const Edge*> control_edges;
73 for (const Edge* edge : op->node.in_edges()) {
74 if (!edge->IsControlEdge()) continue;
75 control_edges.push_back(edge);
76 }
77 for (const Edge* edge : control_edges) {
78 graph->graph.RemoveControlEdge(edge);
79 }
80}
81
82void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
83 mutex_lock l(graph->mu);
84 graph->refiner.set_require_shape_inference_fns(require);
85}
86
87void ExtendSession(TF_Session* session, TF_Status* status) {
88 ExtendSessionGraphHelper(session, status);
89 session->extend_before_run = false;
90}
91
92std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
93 Node* node = &output.oper->node;
94 CppShapeInferenceResult::HandleData handle_data;
95 handle_data.set_is_set(true);
96 {
97 mutex_lock l(graph->mu);
98 tensorflow::shape_inference::InferenceContext* ic =
99 graph->refiner.GetContext(node);
100 CHECK(ic != nullptr);
101 CHECK_LT(output.index, ic->num_outputs());
102 const auto* shapes_and_types =
103 ic->output_handle_shapes_and_types(output.index);
104 if (shapes_and_types == nullptr) return "";
105
106 for (const auto& p : *shapes_and_types) {
107 auto* out_shape_and_type = handle_data.add_shape_and_type();
108 ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
109 out_shape_and_type->set_dtype(p.dtype);
110 *out_shape_and_type->mutable_type() = p.type;
111 }
112 }
113 string result;
114 handle_data.SerializeToString(&result);
115 return result;
116}
117
118void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
119 size_t proto_len, TF_Status* status) {
120 tensorflow::CppShapeInferenceResult::HandleData handle_data;
121 if (!handle_data.ParseFromArray(proto, proto_len)) {
122 status->status = tensorflow::errors::InvalidArgument(
123 "Couldn't deserialize HandleData proto");
124 return;
125 }
126 DCHECK(handle_data.is_set());
127
128 tensorflow::mutex_lock l(graph->mu);
129 tensorflow::shape_inference::InferenceContext* ic =
130 graph->refiner.GetContext(&output.oper->node);
131
132 std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
133 for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
134 tensorflow::shape_inference::ShapeHandle shape;
135 status->status =
136 ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
137 if (TF_GetCode(status) != TF_OK) return;
138 shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
139 shape_and_type_proto.type());
140 }
141 ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
142}
143
144void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
145 TF_Status* status) {
146 mutex_lock l(graph->mu);
147 status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
148 new_src.index, &dst->node);
149 if (TF_GetCode(status) == TF_OK) {
150 // This modification only updates the destination node for
151 // the purposes of running this graph in a session. Thus, we don't
152 // record the source node as being modified.
153 RecordMutation(graph, *dst, "adding input tensor");
154 }
155}
156
157} // namespace tensorflow
158