1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17#include <vector>
18
19#include "absl/strings/str_cat.h"
20#include "tensorflow/c/c_api.h"
21#include "tensorflow/c/eager/abstract_context.h"
22#include "tensorflow/c/eager/c_api_internal.h"
23#include "tensorflow/c/eager/c_api_unified_experimental.h"
24#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
25#include "tensorflow/c/eager/graph_function.h"
26#include "tensorflow/c/tf_datatype.h"
27#include "tensorflow/c/tf_status.h"
28#include "tensorflow/c/tf_status_helper.h"
29#include "tensorflow/core/framework/shape_inference.h"
30#include "tensorflow/core/framework/tensor_shape.h"
31#include "tensorflow/core/framework/types.pb.h"
32#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
33#include "tensorflow/core/platform/errors.h"
34#include "tensorflow/core/platform/strcat.h"
35#include "tensorflow/core/platform/types.h"
36
37using tensorflow::dyn_cast;
38using tensorflow::string;
39using tensorflow::gtl::ArraySlice;
40
41namespace tensorflow {
42namespace tracing {
43namespace graph {
44
45class GraphContext;
46class GraphOperation;
47class GraphTensor;
48
49auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
50auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
51
52// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
53// into the list of outputs for the operation.
54class GraphTensor : public TracingTensorHandle {
55 public:
56 explicit GraphTensor(TF_Output output, TF_Graph* graph)
57 : TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
58
59 tensorflow::DataType DataType() const override {
60 return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
61 }
62
63 tensorflow::Status Shape(
64 tensorflow::PartialTensorShape* shape) const override {
65 DCHECK(shape != nullptr);
66 TF_Status status;
67 int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
68 DCHECK_GE(num_dims, -1);
69 TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
70 if (num_dims == kUnknownRank) {
71 return OkStatus();
72 }
73
74 std::vector<int64_t> dims(num_dims, kUnknownDim);
75 TF_GraphGetTensorShape(graph_, output_,
76 reinterpret_cast<int64_t*>(dims.data()), num_dims,
77 &status);
78 TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
79 TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
80
81 return OkStatus();
82 }
83
84 TF_Output output_;
85
86 // For LLVM style RTTI.
87 static bool classof(const AbstractTensorHandle* ptr) {
88 return ptr->getKind() == kGraph;
89 }
90
91 private:
92 TF_Graph* graph_; // For shape inference.
93};
94
95// GraphOperation wraps and populates a TF_OperationDescription.
96class GraphOperation : public TracingOperation {
97 public:
98 explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {}
99 void Release() override { delete this; }
100 Status Reset(const char* op, const char* raw_device_name) override {
101 if (op_) {
102 return errors::FailedPrecondition("Reset called on already built op.");
103 }
104 if (raw_device_name) {
105 device_name_ = raw_device_name;
106 }
107 op_type_ = op;
108 return OkStatus();
109 }
110 Status SetOpName(const char* const op_name) override {
111 if (op_) {
112 return errors::FailedPrecondition(
113 "SetOpName called on already built op.");
114 }
115 if (op_type_.empty()) {
116 return errors::FailedPrecondition(
117 "GraphOperation::Reset must be called before calling SetOpName.");
118 }
119 // TODO(b/145674566): We use Graph::NewName to get a unique name here but
120 // this may not be consistent with python's naming policy.
121 mutex_lock l(g_->mu);
122 op_.reset(new TF_OperationDescription(g_, op_type_.c_str(),
123 g_->graph.NewName(op_name).c_str()));
124 return OkStatus();
125 }
126 const string& Name() const override { return op_type_; }
127 const string& DeviceName() const override { return device_name_; }
128
129 Status SetDeviceName(const char* name) override {
130 // TODO(srbs): Implement this.
131 device_name_ = name;
132 return OkStatus();
133 }
134
135 Status AddInput(AbstractTensorHandle* input) override {
136 GraphTensor* t = dyn_cast<GraphTensor>(input);
137 if (!t) {
138 return tensorflow::errors::InvalidArgument(
139 "Unable to cast input to GraphTensor");
140 }
141 TF_AddInput(op_.get(), t->output_);
142 return OkStatus();
143 }
144 Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override {
145 std::vector<TF_Output> tf_outputs(inputs.size());
146 for (int i = 0; i < inputs.size(); i++) {
147 GraphTensor* t = dyn_cast<GraphTensor>(inputs[i]);
148 if (!t) {
149 return tensorflow::errors::InvalidArgument(
150 "Unable to cast input to GraphTensor");
151 }
152 tf_outputs[i] = t->output_;
153 }
154 TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size());
155 return OkStatus();
156 }
157 Status Execute(absl::Span<AbstractTensorHandle*> retvals,
158 int* num_retvals) override {
159 auto* tf_opdesc = op_.release();
160 if (tf_opdesc == nullptr) {
161 return errors::InvalidArgument("AbstractOp is incomplete.");
162 }
163 TF_Status* s = TF_NewStatus();
164 auto* operation = TF_FinishOperation(tf_opdesc, s);
165 TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
166 TF_DeleteStatus(s);
167 *num_retvals = TF_OperationNumOutputs(operation);
168 for (int i = 0; i < *num_retvals; ++i) {
169 retvals[i] = new GraphTensor({operation, i}, g_);
170 }
171 return OkStatus();
172 }
173
174 Status SetAttrString(const char* attr_name, const char* data,
175 size_t length) override {
176 tensorflow::StringPiece s(data, length);
177 op_->node_builder.Attr(attr_name, s);
178 return OkStatus();
179 }
180 Status SetAttrInt(const char* attr_name, int64_t value) override {
181 op_->node_builder.Attr(attr_name, static_cast<int64_t>(value));
182 return OkStatus();
183 }
184 Status SetAttrFloat(const char* attr_name, float value) override {
185 op_->node_builder.Attr(attr_name, value);
186 return OkStatus();
187 }
188 Status SetAttrBool(const char* attr_name, bool value) override {
189 op_->node_builder.Attr(attr_name, value);
190 return OkStatus();
191 }
192 Status SetAttrType(const char* const attr_name, DataType value) override {
193 if (!op_) {
194 return Status(
195 error::Code::FAILED_PRECONDITION,
196 "op_type and op_name must be specified before specifying attrs.");
197 }
198 op_->node_builder.Attr(attr_name, value);
199 return OkStatus();
200 }
201 Status SetAttrShape(const char* attr_name, const int64_t* dims,
202 const int num_dims) override {
203 PartialTensorShape shape;
204 if (num_dims >= 0) {
205 shape = PartialTensorShape(ArraySlice<int64_t>(
206 reinterpret_cast<const int64_t*>(dims), num_dims));
207 }
208 op_->node_builder.Attr(attr_name, shape);
209 return OkStatus();
210 }
211 Status SetAttrFunction(const char* attr_name,
212 const AbstractOperation* value) override {
213 return tensorflow::errors::Unimplemented(
214 "SetAttrFunction has not been implemented yet.");
215 }
216 Status SetAttrFunctionName(const char* attr_name, const char* value,
217 size_t length) override {
218 tensorflow::NameAttrList func_name;
219 func_name.set_name(string(value, value + length));
220 op_->node_builder.Attr(attr_name, func_name);
221 return OkStatus();
222 }
223 Status SetAttrTensor(const char* attr_name,
224 AbstractTensorInterface* tensor) override {
225 return tensorflow::errors::Unimplemented(
226 "SetAttrTensor has not been implemented yet.");
227 }
228 Status SetAttrStringList(const char* attr_name, const void* const* values,
229 const size_t* lengths, int num_values) override {
230 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
231 op_->colocation_constraints.clear();
232 for (int i = 0; i < num_values; ++i) {
233 op_->colocation_constraints.emplace(static_cast<const char*>(values[i]),
234 lengths[i]);
235 }
236 } else {
237 std::vector<tensorflow::StringPiece> v;
238 v.reserve(num_values);
239 for (int i = 0; i < num_values; ++i) {
240 v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
241 }
242 op_->node_builder.Attr(attr_name, v);
243 }
244 return OkStatus();
245 }
246 Status SetAttrFloatList(const char* attr_name, const float* values,
247 int num_values) override {
248 op_->node_builder.Attr(attr_name,
249 ArraySlice<const float>(values, num_values));
250 return OkStatus();
251 }
252 Status SetAttrIntList(const char* attr_name, const int64_t* values,
253 int num_values) override {
254 op_->node_builder.Attr(
255 attr_name, ArraySlice<const int64_t>(
256 reinterpret_cast<const int64_t*>(values), num_values));
257 return OkStatus();
258 }
259 Status SetAttrTypeList(const char* attr_name, const DataType* values,
260 int num_values) override {
261 op_->node_builder.Attr(attr_name,
262 ArraySlice<const DataType>(values, num_values));
263 return OkStatus();
264 }
265 Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
266 int num_values) override {
267 std::unique_ptr<bool[]> b(new bool[num_values]);
268 for (int i = 0; i < num_values; ++i) {
269 b[i] = values[i];
270 }
271 op_->node_builder.Attr(attr_name,
272 ArraySlice<const bool>(b.get(), num_values));
273
274 return OkStatus();
275 }
276 Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
277 const int* num_dims, int num_values) override {
278 std::vector<PartialTensorShape> shapes;
279 shapes.reserve(num_values);
280 for (int i = 0; i < num_values; ++i) {
281 if (num_dims[i] < 0) {
282 shapes.emplace_back();
283 } else {
284 shapes.emplace_back(ArraySlice<int64_t>(
285 reinterpret_cast<const int64_t*>(dims[i]), num_dims[i]));
286 }
287 }
288 op_->node_builder.Attr(attr_name, shapes);
289 return OkStatus();
290 }
291 Status SetAttrFunctionList(
292 const char* attr_name,
293 absl::Span<const AbstractOperation*> values) override {
294 return tensorflow::errors::Unimplemented(
295 "SetAttrFunctionList has not been implemented yet.");
296 }
297 // For LLVM style RTTI.
298 static bool classof(const AbstractOperation* ptr) {
299 return ptr->getKind() == kGraph;
300 }
301 ~GraphOperation() override {}
302
303 private:
304 friend class GraphContext; // For access to op_.
305 TF_Graph* g_;
306 std::unique_ptr<TF_OperationDescription> op_;
307 // Hold `op_type` and `op_name` till both are available since we need both
308 // to build a graph operation.
309 string op_type_;
310 const char* op_name_ = nullptr;
311 // TODO(srbs): Use this.
312 string device_name_;
313};
314
315// GraphContext wraps a TF_Graph modeling a single function and manages the
316// "execution" of operation, i.e. adding them to the function.
317class GraphContext : public TracingContext {
318 public:
319 explicit GraphContext(const char* name)
320 : TracingContext(kGraph),
321 graph_(new TF_Graph(), TF_DeleteGraph),
322 name_(name) {}
323
324 void Release() override { delete this; }
325
326 TracingOperation* CreateOperation() override {
327 return new GraphOperation(graph_.get());
328 }
329
330 Status AddParameter(DataType dtype, const PartialTensorShape& shape,
331 TracingTensorHandle** output) override {
332 TracingOperationPtr operation(CreateOperation());
333 TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
334 TF_RETURN_IF_ERROR(
335 operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
336 TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
337 if (!shape.unknown_rank()) {
338 TF_RETURN_IF_ERROR(operation->SetAttrShape(
339 "shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
340 shape.dims()));
341 }
342 int num_outputs = 1;
343 std::vector<AbstractTensorHandle*> outputs(num_outputs);
344 TF_RETURN_IF_ERROR(operation->Execute(
345 absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
346
347 if (num_outputs != 1) {
348 return errors::Internal("Expected 1 output but found ", num_outputs);
349 }
350 auto* t = dyn_cast<GraphTensor>(outputs[0]);
351 if (!t) {
352 return tensorflow::errors::InvalidArgument(
353 "Unable to cast input to GraphTensor");
354 }
355 inputs_.push_back(t->output_);
356 *output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]);
357 return OkStatus();
358 }
359
360 Status Finalize(OutputList* outputs, AbstractFunction** f) override {
361 std::vector<TF_Output> graph_outputs;
362 graph_outputs.reserve(outputs->outputs.size());
363 for (auto* abstract_output : outputs->outputs) {
364 GraphTensor* output = dyn_cast<GraphTensor>(abstract_output);
365 if (!output) {
366 return errors::Unimplemented(
367 "Returning a non-graph tensor from a function has not "
368 "been implemented yet.");
369 }
370 graph_outputs.push_back(output->output_);
371 }
372
373 auto s = TF_NewStatus();
374 auto func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
375 inputs_.size(), inputs_.data(),
376 graph_outputs.size(), graph_outputs.data(),
377 nullptr, nullptr, name_.data(), s);
378 *f = new GraphFunction(std::move(func->fdef));
379 TF_DeleteFunction(func);
380 TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
381 TF_DeleteStatus(s);
382 return OkStatus();
383 }
384
385 Status RegisterFunction(AbstractFunction* func) override {
386 return errors::Unimplemented(
387 "Registering graph functions has not been implemented yet.");
388 }
389
390 Status RemoveFunction(const string& func) override {
391 return errors::Unimplemented(
392 "GraphContext::RemoveFunction has not been implemented yet.");
393 }
394 // For LLVM style RTTI.
395 static bool classof(const AbstractContext* ptr) {
396 return ptr->getKind() == kGraph;
397 }
398
399 private:
400 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
401 std::vector<TF_Output> inputs_;
402 string name_;
403};
404
405static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
406 return new GraphContext(name);
407}
408
409// Register the tracing implemented in this file as the default tracing engine.
410static bool register_tracing = [] {
411 RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
412 SetDefaultTracingEngine("graphdef").IgnoreError();
413 return true;
414}();
415
416} // namespace graph
417} // namespace tracing
418} // namespace tensorflow
419