1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/python_api.h" |
17 | |
18 | #include "tensorflow/c/c_api_internal.h" |
19 | #include "tensorflow/core/framework/full_type.pb.h" |
20 | #include "tensorflow/python/framework/cpp_shape_inference.pb.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { |
25 | mutex_lock l(graph->mu); |
26 | graph->graph.AddControlEdge(&input->node, &op->node); |
27 | RecordMutation(graph, *op, "adding control input" ); |
28 | } |
29 | |
30 | void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, |
31 | TF_Buffer* attr_value_proto, TF_Status* status) { |
32 | AttrValue attr_val; |
33 | if (!attr_val.ParseFromArray(attr_value_proto->data, |
34 | attr_value_proto->length)) { |
35 | status->status = |
36 | tensorflow::errors::InvalidArgument("Invalid AttrValue proto" ); |
37 | return; |
38 | } |
39 | |
40 | mutex_lock l(graph->mu); |
41 | op->node.AddAttr(attr_name, attr_val); |
42 | RecordMutation(graph, *op, "setting attribute" ); |
43 | } |
44 | |
45 | void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, |
46 | TF_Status* status) { |
47 | mutex_lock l(graph->mu); |
48 | op->node.ClearAttr(attr_name); |
49 | RecordMutation(graph, *op, "clearing attribute" ); |
50 | } |
51 | |
52 | void SetFullType(TF_Graph* graph, TF_Operation* op, |
53 | const FullTypeDef& full_type) { |
54 | mutex_lock l(graph->mu); |
55 | *op->node.mutable_def()->mutable_experimental_type() = full_type; |
56 | RecordMutation(graph, *op, "setting fulltype" ); |
57 | } |
58 | |
59 | void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { |
60 | mutex_lock l(graph->mu); |
61 | op->node.set_requested_device(device); |
62 | RecordMutation(graph, *op, "setting device" ); |
63 | } |
64 | |
65 | void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, |
66 | TF_Status* status) { |
67 | TF_UpdateEdge(graph, new_src, dst, status); |
68 | } |
69 | |
70 | void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { |
71 | mutex_lock l(graph->mu); |
72 | std::vector<const Edge*> control_edges; |
73 | for (const Edge* edge : op->node.in_edges()) { |
74 | if (!edge->IsControlEdge()) continue; |
75 | control_edges.push_back(edge); |
76 | } |
77 | for (const Edge* edge : control_edges) { |
78 | graph->graph.RemoveControlEdge(edge); |
79 | } |
80 | } |
81 | |
82 | void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) { |
83 | mutex_lock l(graph->mu); |
84 | graph->refiner.set_require_shape_inference_fns(require); |
85 | } |
86 | |
87 | void ExtendSession(TF_Session* session, TF_Status* status) { |
88 | ExtendSessionGraphHelper(session, status); |
89 | session->extend_before_run = false; |
90 | } |
91 | |
92 | std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { |
93 | Node* node = &output.oper->node; |
94 | CppShapeInferenceResult::HandleData handle_data; |
95 | handle_data.set_is_set(true); |
96 | { |
97 | mutex_lock l(graph->mu); |
98 | tensorflow::shape_inference::InferenceContext* ic = |
99 | graph->refiner.GetContext(node); |
100 | CHECK(ic != nullptr); |
101 | CHECK_LT(output.index, ic->num_outputs()); |
102 | const auto* shapes_and_types = |
103 | ic->output_handle_shapes_and_types(output.index); |
104 | if (shapes_and_types == nullptr) return "" ; |
105 | |
106 | for (const auto& p : *shapes_and_types) { |
107 | auto* out_shape_and_type = handle_data.add_shape_and_type(); |
108 | ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); |
109 | out_shape_and_type->set_dtype(p.dtype); |
110 | *out_shape_and_type->mutable_type() = p.type; |
111 | } |
112 | } |
113 | string result; |
114 | handle_data.SerializeToString(&result); |
115 | return result; |
116 | } |
117 | |
118 | void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, |
119 | size_t proto_len, TF_Status* status) { |
120 | tensorflow::CppShapeInferenceResult::HandleData handle_data; |
121 | if (!handle_data.ParseFromArray(proto, proto_len)) { |
122 | status->status = tensorflow::errors::InvalidArgument( |
123 | "Couldn't deserialize HandleData proto" ); |
124 | return; |
125 | } |
126 | DCHECK(handle_data.is_set()); |
127 | |
128 | tensorflow::mutex_lock l(graph->mu); |
129 | tensorflow::shape_inference::InferenceContext* ic = |
130 | graph->refiner.GetContext(&output.oper->node); |
131 | |
132 | std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types; |
133 | for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { |
134 | tensorflow::shape_inference::ShapeHandle shape; |
135 | status->status = |
136 | ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); |
137 | if (TF_GetCode(status) != TF_OK) return; |
138 | shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(), |
139 | shape_and_type_proto.type()); |
140 | } |
141 | ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); |
142 | } |
143 | |
144 | void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, |
145 | TF_Status* status) { |
146 | mutex_lock l(graph->mu); |
147 | status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, |
148 | new_src.index, &dst->node); |
149 | if (TF_GetCode(status) == TF_OK) { |
150 | // This modification only updates the destination node for |
151 | // the purposes of running this graph in a session. Thus, we don't |
152 | // record the source node as being modified. |
153 | RecordMutation(graph, *dst, "adding input tensor" ); |
154 | } |
155 | } |
156 | |
157 | } // namespace tensorflow |
158 | |