1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/eager/c_api_unified_experimental.h"
17
18#include <vector>
19
20#include "absl/container/flat_hash_map.h"
21#include "absl/strings/str_cat.h"
22#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
23#include "tensorflow/c/tf_datatype.h"
24#include "tensorflow/c/tf_status.h"
25#include "tensorflow/c/tf_status_helper.h"
26#include "tensorflow/core/framework/types.pb.h"
27#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
28#include "tensorflow/core/platform/errors.h"
29#include "tensorflow/core/platform/types.h"
30
31using tensorflow::string;
32
33namespace tensorflow {
34namespace tracing {
35typedef absl::flat_hash_map<std::string, tracing::FactoryFunction> FactoriesMap;
36
37static FactoriesMap& GetFactories() {
38 static FactoriesMap* factories = new FactoriesMap;
39 return *factories;
40}
41
42static tracing::FactoryFunction default_factory;
43
44void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
45 assert((!GetFactories().count(name)) ||
46 (GetFactories()[name] == factory) &&
47 "Duplicate tracing factory registration");
48 GetFactories()[name] = factory;
49}
50
51Status SetDefaultTracingEngine(const char* name) {
52 auto entry = GetFactories().find(name);
53 if (entry != GetFactories().end()) {
54 default_factory = GetFactories().find(name)->second;
55 return OkStatus();
56 }
57 string msg = absl::StrCat(
58 "No tracing engine factory has been registered with the key '", name,
59 "' (available: ");
60 // Ensure deterministic (sorted) order in the error message
61 std::set<string> factories_sorted;
62 for (const auto& factory : GetFactories())
63 factories_sorted.insert(factory.first);
64 const char* comma = "";
65 for (const string& factory : factories_sorted) {
66 msg += comma + factory;
67 comma = ", ";
68 }
69 msg += ")";
70
71 return errors::InvalidArgument(msg.c_str());
72}
73
74static TracingContext* CreateTracingExecutionContext(const char* fn_name,
75 TF_Status* s) {
76 if (default_factory) {
77 return default_factory(fn_name, s);
78 }
79 Set_TF_Status_from_Status(
80 s, errors::FailedPrecondition("default_factory is nullptr"));
81 return nullptr;
82}
83
84} // end namespace tracing
85} // end namespace tensorflow
86
87// =============================================================================
88// Public C API entry points
89//
90// These are only the generic entry points for the C API. This file does not
91// have any visibility into the graph/eager implementation and is only providing
92// C bindings to the abstract classes defined in the
93// c_api_unified_experimental_internal.h header.
94//
95// =============================================================================
96
97using tensorflow::AbstractFunction;
98using tensorflow::AbstractTensorHandle;
99using tensorflow::DataType;
100using tensorflow::dyn_cast;
101using tensorflow::OutputList;
102using tensorflow::Status;
103using tensorflow::unwrap;
104using tensorflow::wrap;
105using tensorflow::tracing::CreateTracingExecutionContext;
106using tensorflow::tracing::SetDefaultTracingEngine;
107using tensorflow::tracing::TracingContext;
108using tensorflow::tracing::TracingOperation;
109using tensorflow::tracing::TracingTensorHandle;
110
111void TF_SetTracingImplementation(const char* name, TF_Status* s) {
112 Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
113}
114
115// Creates a new TensorFlow function, it is an execution context attached to a
116// given tracing context.
117TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
118 return wrap(CreateTracingExecutionContext(fn_name, s));
119}
120
121TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
122 TF_OutputList* outputs, TF_Status* s) {
123 AbstractFunction* func;
124 TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(ctx));
125 if (!tracing_ctx) {
126 Set_TF_Status_from_Status(
127 s, tensorflow::errors::InvalidArgument(
128 "Only TracingContext can be converted into a function."));
129 return nullptr;
130 }
131 Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func));
132 TF_DeleteExecutionContext(ctx);
133 return wrap(func);
134}
135
136TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
137 TF_DataType dtype, TF_Shape shape,
138 TF_Status* s) {
139 DCHECK_GE(shape.num_dims, -1);
140 TracingTensorHandle* t;
141 TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
142 if (!tracing_ctx) {
143 Set_TF_Status_from_Status(
144 s, tensorflow::errors::InvalidArgument(
145 "TF_AddFunctionParameter must be called on a TracingContext."));
146 return nullptr;
147 }
148 tensorflow::PartialTensorShape partial_shape;
149 if (shape.num_dims != -1) {
150 DCHECK(shape.dim_sizes != nullptr);
151 Status status = tensorflow::PartialTensorShape::MakePartialShape(
152 reinterpret_cast<int64_t*>(shape.dim_sizes), shape.num_dims,
153 &partial_shape);
154 if (!status.ok()) {
155 Set_TF_Status_from_Status(s, status);
156 return nullptr;
157 }
158 }
159 Set_TF_Status_from_Status(
160 s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape,
161 &t));
162 return wrap(t);
163}
164
165void TF_DeleteExecutionContext(TF_ExecutionContext* c) { unwrap(c)->Release(); }
166
167TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
168 return wrap((unwrap(c)->CreateOperation()));
169}
170
171void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
172
173void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); }
174
175TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
176void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
177void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
178 TF_Status* s) {
179 unwrap(o)->expected_num_outputs = num_outputs;
180 unwrap(o)->outputs.clear();
181 unwrap(o)->outputs.resize(num_outputs);
182}
183int TF_OutputListNumOutputs(TF_OutputList* o) {
184 return unwrap(o)->outputs.size();
185}
186TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
187 return wrap(unwrap(o)->outputs[i]);
188}
189void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
190 TF_Status* s) {
191 unwrap(o)->outputs.push_back(unwrap(tensor));
192}
193
194void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
195 TF_Status* s) {
196 Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type,
197 /*raw_device_name=*/nullptr));
198}
199
200void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
201 TF_Status* s) {
202 TracingOperation* tracing_op = dyn_cast<TracingOperation>(unwrap(op));
203 if (!tracing_op) {
204 Set_TF_Status_from_Status(
205 s, tensorflow::errors::InvalidArgument(
206 "TF_AbstractOpSetOpName must be called on a TracingOperation."));
207 return;
208 }
209 Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name));
210}
211
212void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
213 TF_DataType value, TF_Status* s) {
214 Status status =
215 unwrap(op)->SetAttrType(attr_name, static_cast<DataType>(value));
216 TF_SetStatus(s, static_cast<TF_Code>(status.code()),
217 status.error_message().c_str());
218}
219
220void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
221 TF_AbstractTensor* const* inputs, TF_OutputList* o,
222 TF_Status* s) {
223 for (int i = 0; i < num_inputs; i++) {
224 Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i])));
225 if (TF_GetCode(s) != TF_OK) {
226 return;
227 }
228 }
229 int num_outputs = unwrap(o)->expected_num_outputs;
230 Set_TF_Status_from_Status(
231 s, unwrap(op)->Execute(
232 absl::MakeSpan(reinterpret_cast<AbstractTensorHandle**>(
233 unwrap(o)->outputs.data()),
234 unwrap(o)->outputs.size()),
235 &num_outputs));
236}
237
238void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
239 unwrap(func)->Unref();
240}
241
242void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
243 TF_AbstractFunction* func,
244 TF_Status* s) {
245 Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func)));
246}
247