1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <memory> |
17 | #include <vector> |
18 | |
19 | #include "absl/strings/str_cat.h" |
20 | #include "tensorflow/c/c_api.h" |
21 | #include "tensorflow/c/eager/abstract_context.h" |
22 | #include "tensorflow/c/eager/c_api_internal.h" |
23 | #include "tensorflow/c/eager/c_api_unified_experimental.h" |
24 | #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" |
25 | #include "tensorflow/c/eager/graph_function.h" |
26 | #include "tensorflow/c/tf_datatype.h" |
27 | #include "tensorflow/c/tf_status.h" |
28 | #include "tensorflow/c/tf_status_helper.h" |
29 | #include "tensorflow/core/framework/shape_inference.h" |
30 | #include "tensorflow/core/framework/tensor_shape.h" |
31 | #include "tensorflow/core/framework/types.pb.h" |
32 | #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" |
33 | #include "tensorflow/core/platform/errors.h" |
34 | #include "tensorflow/core/platform/strcat.h" |
35 | #include "tensorflow/core/platform/types.h" |
36 | |
37 | using tensorflow::dyn_cast; |
38 | using tensorflow::string; |
39 | using tensorflow::gtl::ArraySlice; |
40 | |
41 | namespace tensorflow { |
42 | namespace tracing { |
43 | namespace graph { |
44 | |
45 | class GraphContext; |
46 | class GraphOperation; |
47 | class GraphTensor; |
48 | |
49 | auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim; |
50 | auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank; |
51 | |
52 | // GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index |
53 | // into the list of outputs for the operation. |
54 | class GraphTensor : public TracingTensorHandle { |
55 | public: |
56 | explicit GraphTensor(TF_Output output, TF_Graph* graph) |
57 | : TracingTensorHandle(kGraph), output_(output), graph_(graph) {} |
58 | |
59 | tensorflow::DataType DataType() const override { |
60 | return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_)); |
61 | } |
62 | |
63 | tensorflow::Status Shape( |
64 | tensorflow::PartialTensorShape* shape) const override { |
65 | DCHECK(shape != nullptr); |
66 | TF_Status status; |
67 | int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status); |
68 | DCHECK_GE(num_dims, -1); |
69 | TF_RETURN_IF_ERROR(StatusFromTF_Status(&status)); |
70 | if (num_dims == kUnknownRank) { |
71 | return OkStatus(); |
72 | } |
73 | |
74 | std::vector<int64_t> dims(num_dims, kUnknownDim); |
75 | TF_GraphGetTensorShape(graph_, output_, |
76 | reinterpret_cast<int64_t*>(dims.data()), num_dims, |
77 | &status); |
78 | TF_RETURN_IF_ERROR(StatusFromTF_Status(&status)); |
79 | TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape)); |
80 | |
81 | return OkStatus(); |
82 | } |
83 | |
84 | TF_Output output_; |
85 | |
86 | // For LLVM style RTTI. |
87 | static bool classof(const AbstractTensorHandle* ptr) { |
88 | return ptr->getKind() == kGraph; |
89 | } |
90 | |
91 | private: |
92 | TF_Graph* graph_; // For shape inference. |
93 | }; |
94 | |
95 | // GraphOperation wraps and populates a TF_OperationDescription. |
96 | class GraphOperation : public TracingOperation { |
97 | public: |
98 | explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {} |
99 | void Release() override { delete this; } |
100 | Status Reset(const char* op, const char* raw_device_name) override { |
101 | if (op_) { |
102 | return errors::FailedPrecondition("Reset called on already built op." ); |
103 | } |
104 | if (raw_device_name) { |
105 | device_name_ = raw_device_name; |
106 | } |
107 | op_type_ = op; |
108 | return OkStatus(); |
109 | } |
110 | Status SetOpName(const char* const op_name) override { |
111 | if (op_) { |
112 | return errors::FailedPrecondition( |
113 | "SetOpName called on already built op." ); |
114 | } |
115 | if (op_type_.empty()) { |
116 | return errors::FailedPrecondition( |
117 | "GraphOperation::Reset must be called before calling SetOpName." ); |
118 | } |
119 | // TODO(b/145674566): We use Graph::NewName to get a unique name here but |
120 | // this may not be consistent with python's naming policy. |
121 | mutex_lock l(g_->mu); |
122 | op_.reset(new TF_OperationDescription(g_, op_type_.c_str(), |
123 | g_->graph.NewName(op_name).c_str())); |
124 | return OkStatus(); |
125 | } |
126 | const string& Name() const override { return op_type_; } |
127 | const string& DeviceName() const override { return device_name_; } |
128 | |
129 | Status SetDeviceName(const char* name) override { |
130 | // TODO(srbs): Implement this. |
131 | device_name_ = name; |
132 | return OkStatus(); |
133 | } |
134 | |
135 | Status AddInput(AbstractTensorHandle* input) override { |
136 | GraphTensor* t = dyn_cast<GraphTensor>(input); |
137 | if (!t) { |
138 | return tensorflow::errors::InvalidArgument( |
139 | "Unable to cast input to GraphTensor" ); |
140 | } |
141 | TF_AddInput(op_.get(), t->output_); |
142 | return OkStatus(); |
143 | } |
144 | Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override { |
145 | std::vector<TF_Output> tf_outputs(inputs.size()); |
146 | for (int i = 0; i < inputs.size(); i++) { |
147 | GraphTensor* t = dyn_cast<GraphTensor>(inputs[i]); |
148 | if (!t) { |
149 | return tensorflow::errors::InvalidArgument( |
150 | "Unable to cast input to GraphTensor" ); |
151 | } |
152 | tf_outputs[i] = t->output_; |
153 | } |
154 | TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size()); |
155 | return OkStatus(); |
156 | } |
157 | Status Execute(absl::Span<AbstractTensorHandle*> retvals, |
158 | int* num_retvals) override { |
159 | auto* tf_opdesc = op_.release(); |
160 | if (tf_opdesc == nullptr) { |
161 | return errors::InvalidArgument("AbstractOp is incomplete." ); |
162 | } |
163 | TF_Status* s = TF_NewStatus(); |
164 | auto* operation = TF_FinishOperation(tf_opdesc, s); |
165 | TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); |
166 | TF_DeleteStatus(s); |
167 | *num_retvals = TF_OperationNumOutputs(operation); |
168 | for (int i = 0; i < *num_retvals; ++i) { |
169 | retvals[i] = new GraphTensor({operation, i}, g_); |
170 | } |
171 | return OkStatus(); |
172 | } |
173 | |
174 | Status SetAttrString(const char* attr_name, const char* data, |
175 | size_t length) override { |
176 | tensorflow::StringPiece s(data, length); |
177 | op_->node_builder.Attr(attr_name, s); |
178 | return OkStatus(); |
179 | } |
180 | Status SetAttrInt(const char* attr_name, int64_t value) override { |
181 | op_->node_builder.Attr(attr_name, static_cast<int64_t>(value)); |
182 | return OkStatus(); |
183 | } |
184 | Status SetAttrFloat(const char* attr_name, float value) override { |
185 | op_->node_builder.Attr(attr_name, value); |
186 | return OkStatus(); |
187 | } |
188 | Status SetAttrBool(const char* attr_name, bool value) override { |
189 | op_->node_builder.Attr(attr_name, value); |
190 | return OkStatus(); |
191 | } |
192 | Status SetAttrType(const char* const attr_name, DataType value) override { |
193 | if (!op_) { |
194 | return Status( |
195 | error::Code::FAILED_PRECONDITION, |
196 | "op_type and op_name must be specified before specifying attrs." ); |
197 | } |
198 | op_->node_builder.Attr(attr_name, value); |
199 | return OkStatus(); |
200 | } |
201 | Status SetAttrShape(const char* attr_name, const int64_t* dims, |
202 | const int num_dims) override { |
203 | PartialTensorShape shape; |
204 | if (num_dims >= 0) { |
205 | shape = PartialTensorShape(ArraySlice<int64_t>( |
206 | reinterpret_cast<const int64_t*>(dims), num_dims)); |
207 | } |
208 | op_->node_builder.Attr(attr_name, shape); |
209 | return OkStatus(); |
210 | } |
211 | Status SetAttrFunction(const char* attr_name, |
212 | const AbstractOperation* value) override { |
213 | return tensorflow::errors::Unimplemented( |
214 | "SetAttrFunction has not been implemented yet." ); |
215 | } |
216 | Status SetAttrFunctionName(const char* attr_name, const char* value, |
217 | size_t length) override { |
218 | tensorflow::NameAttrList func_name; |
219 | func_name.set_name(string(value, value + length)); |
220 | op_->node_builder.Attr(attr_name, func_name); |
221 | return OkStatus(); |
222 | } |
223 | Status SetAttrTensor(const char* attr_name, |
224 | AbstractTensorInterface* tensor) override { |
225 | return tensorflow::errors::Unimplemented( |
226 | "SetAttrTensor has not been implemented yet." ); |
227 | } |
228 | Status SetAttrStringList(const char* attr_name, const void* const* values, |
229 | const size_t* lengths, int num_values) override { |
230 | if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { |
231 | op_->colocation_constraints.clear(); |
232 | for (int i = 0; i < num_values; ++i) { |
233 | op_->colocation_constraints.emplace(static_cast<const char*>(values[i]), |
234 | lengths[i]); |
235 | } |
236 | } else { |
237 | std::vector<tensorflow::StringPiece> v; |
238 | v.reserve(num_values); |
239 | for (int i = 0; i < num_values; ++i) { |
240 | v.emplace_back(static_cast<const char*>(values[i]), lengths[i]); |
241 | } |
242 | op_->node_builder.Attr(attr_name, v); |
243 | } |
244 | return OkStatus(); |
245 | } |
246 | Status SetAttrFloatList(const char* attr_name, const float* values, |
247 | int num_values) override { |
248 | op_->node_builder.Attr(attr_name, |
249 | ArraySlice<const float>(values, num_values)); |
250 | return OkStatus(); |
251 | } |
252 | Status SetAttrIntList(const char* attr_name, const int64_t* values, |
253 | int num_values) override { |
254 | op_->node_builder.Attr( |
255 | attr_name, ArraySlice<const int64_t>( |
256 | reinterpret_cast<const int64_t*>(values), num_values)); |
257 | return OkStatus(); |
258 | } |
259 | Status SetAttrTypeList(const char* attr_name, const DataType* values, |
260 | int num_values) override { |
261 | op_->node_builder.Attr(attr_name, |
262 | ArraySlice<const DataType>(values, num_values)); |
263 | return OkStatus(); |
264 | } |
265 | Status SetAttrBoolList(const char* attr_name, const unsigned char* values, |
266 | int num_values) override { |
267 | std::unique_ptr<bool[]> b(new bool[num_values]); |
268 | for (int i = 0; i < num_values; ++i) { |
269 | b[i] = values[i]; |
270 | } |
271 | op_->node_builder.Attr(attr_name, |
272 | ArraySlice<const bool>(b.get(), num_values)); |
273 | |
274 | return OkStatus(); |
275 | } |
276 | Status SetAttrShapeList(const char* attr_name, const int64_t** dims, |
277 | const int* num_dims, int num_values) override { |
278 | std::vector<PartialTensorShape> shapes; |
279 | shapes.reserve(num_values); |
280 | for (int i = 0; i < num_values; ++i) { |
281 | if (num_dims[i] < 0) { |
282 | shapes.emplace_back(); |
283 | } else { |
284 | shapes.emplace_back(ArraySlice<int64_t>( |
285 | reinterpret_cast<const int64_t*>(dims[i]), num_dims[i])); |
286 | } |
287 | } |
288 | op_->node_builder.Attr(attr_name, shapes); |
289 | return OkStatus(); |
290 | } |
291 | Status SetAttrFunctionList( |
292 | const char* attr_name, |
293 | absl::Span<const AbstractOperation*> values) override { |
294 | return tensorflow::errors::Unimplemented( |
295 | "SetAttrFunctionList has not been implemented yet." ); |
296 | } |
297 | // For LLVM style RTTI. |
298 | static bool classof(const AbstractOperation* ptr) { |
299 | return ptr->getKind() == kGraph; |
300 | } |
301 | ~GraphOperation() override {} |
302 | |
303 | private: |
304 | friend class GraphContext; // For access to op_. |
305 | TF_Graph* g_; |
306 | std::unique_ptr<TF_OperationDescription> op_; |
307 | // Hold `op_type` and `op_name` till both are available since we need both |
308 | // to build a graph operation. |
309 | string op_type_; |
310 | const char* op_name_ = nullptr; |
311 | // TODO(srbs): Use this. |
312 | string device_name_; |
313 | }; |
314 | |
315 | // GraphContext wraps a TF_Graph modeling a single function and manages the |
316 | // "execution" of operation, i.e. adding them to the function. |
317 | class GraphContext : public TracingContext { |
318 | public: |
319 | explicit GraphContext(const char* name) |
320 | : TracingContext(kGraph), |
321 | graph_(new TF_Graph(), TF_DeleteGraph), |
322 | name_(name) {} |
323 | |
324 | void Release() override { delete this; } |
325 | |
326 | TracingOperation* CreateOperation() override { |
327 | return new GraphOperation(graph_.get()); |
328 | } |
329 | |
330 | Status AddParameter(DataType dtype, const PartialTensorShape& shape, |
331 | TracingTensorHandle** output) override { |
332 | TracingOperationPtr operation(CreateOperation()); |
333 | TF_RETURN_IF_ERROR(operation->Reset("Placeholder" , nullptr)); |
334 | TF_RETURN_IF_ERROR( |
335 | operation->SetOpName(absl::StrCat("_input_" , inputs_.size()).c_str())); |
336 | TF_RETURN_IF_ERROR(operation->SetAttrType("dtype" , dtype)); |
337 | if (!shape.unknown_rank()) { |
338 | TF_RETURN_IF_ERROR(operation->SetAttrShape( |
339 | "shape" , reinterpret_cast<int64_t*>(shape.dim_sizes().data()), |
340 | shape.dims())); |
341 | } |
342 | int num_outputs = 1; |
343 | std::vector<AbstractTensorHandle*> outputs(num_outputs); |
344 | TF_RETURN_IF_ERROR(operation->Execute( |
345 | absl::Span<AbstractTensorHandle*>(outputs), &num_outputs)); |
346 | |
347 | if (num_outputs != 1) { |
348 | return errors::Internal("Expected 1 output but found " , num_outputs); |
349 | } |
350 | auto* t = dyn_cast<GraphTensor>(outputs[0]); |
351 | if (!t) { |
352 | return tensorflow::errors::InvalidArgument( |
353 | "Unable to cast input to GraphTensor" ); |
354 | } |
355 | inputs_.push_back(t->output_); |
356 | *output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]); |
357 | return OkStatus(); |
358 | } |
359 | |
360 | Status Finalize(OutputList* outputs, AbstractFunction** f) override { |
361 | std::vector<TF_Output> graph_outputs; |
362 | graph_outputs.reserve(outputs->outputs.size()); |
363 | for (auto* abstract_output : outputs->outputs) { |
364 | GraphTensor* output = dyn_cast<GraphTensor>(abstract_output); |
365 | if (!output) { |
366 | return errors::Unimplemented( |
367 | "Returning a non-graph tensor from a function has not " |
368 | "been implemented yet." ); |
369 | } |
370 | graph_outputs.push_back(output->output_); |
371 | } |
372 | |
373 | auto s = TF_NewStatus(); |
374 | auto func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr, |
375 | inputs_.size(), inputs_.data(), |
376 | graph_outputs.size(), graph_outputs.data(), |
377 | nullptr, nullptr, name_.data(), s); |
378 | *f = new GraphFunction(std::move(func->fdef)); |
379 | TF_DeleteFunction(func); |
380 | TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); |
381 | TF_DeleteStatus(s); |
382 | return OkStatus(); |
383 | } |
384 | |
385 | Status RegisterFunction(AbstractFunction* func) override { |
386 | return errors::Unimplemented( |
387 | "Registering graph functions has not been implemented yet." ); |
388 | } |
389 | |
390 | Status RemoveFunction(const string& func) override { |
391 | return errors::Unimplemented( |
392 | "GraphContext::RemoveFunction has not been implemented yet." ); |
393 | } |
394 | // For LLVM style RTTI. |
395 | static bool classof(const AbstractContext* ptr) { |
396 | return ptr->getKind() == kGraph; |
397 | } |
398 | |
399 | private: |
400 | std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_; |
401 | std::vector<TF_Output> inputs_; |
402 | string name_; |
403 | }; |
404 | |
405 | static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) { |
406 | return new GraphContext(name); |
407 | } |
408 | |
409 | // Register the tracing implemented in this file as the default tracing engine. |
410 | static bool register_tracing = [] { |
411 | RegisterTracingEngineFactory("graphdef" , GraphTracingFactory); |
412 | SetDefaultTracingEngine("graphdef" ).IgnoreError(); |
413 | return true; |
414 | }(); |
415 | |
416 | } // namespace graph |
417 | } // namespace tracing |
418 | } // namespace tensorflow |
419 | |