1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/eager/c_api_unified_experimental.h" |
17 | |
18 | #include <vector> |
19 | |
20 | #include "absl/container/flat_hash_map.h" |
21 | #include "absl/strings/str_cat.h" |
22 | #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" |
23 | #include "tensorflow/c/tf_datatype.h" |
24 | #include "tensorflow/c/tf_status.h" |
25 | #include "tensorflow/c/tf_status_helper.h" |
26 | #include "tensorflow/core/framework/types.pb.h" |
27 | #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" |
28 | #include "tensorflow/core/platform/errors.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | |
31 | using tensorflow::string; |
32 | |
33 | namespace tensorflow { |
34 | namespace tracing { |
35 | typedef absl::flat_hash_map<std::string, tracing::FactoryFunction> FactoriesMap; |
36 | |
37 | static FactoriesMap& GetFactories() { |
38 | static FactoriesMap* factories = new FactoriesMap; |
39 | return *factories; |
40 | } |
41 | |
42 | static tracing::FactoryFunction default_factory; |
43 | |
44 | void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { |
45 | assert((!GetFactories().count(name)) || |
46 | (GetFactories()[name] == factory) && |
47 | "Duplicate tracing factory registration" ); |
48 | GetFactories()[name] = factory; |
49 | } |
50 | |
51 | Status SetDefaultTracingEngine(const char* name) { |
52 | auto entry = GetFactories().find(name); |
53 | if (entry != GetFactories().end()) { |
54 | default_factory = GetFactories().find(name)->second; |
55 | return OkStatus(); |
56 | } |
57 | string msg = absl::StrCat( |
58 | "No tracing engine factory has been registered with the key '" , name, |
59 | "' (available: " ); |
60 | // Ensure deterministic (sorted) order in the error message |
61 | std::set<string> factories_sorted; |
62 | for (const auto& factory : GetFactories()) |
63 | factories_sorted.insert(factory.first); |
64 | const char* comma = "" ; |
65 | for (const string& factory : factories_sorted) { |
66 | msg += comma + factory; |
67 | comma = ", " ; |
68 | } |
69 | msg += ")" ; |
70 | |
71 | return errors::InvalidArgument(msg.c_str()); |
72 | } |
73 | |
74 | static TracingContext* CreateTracingExecutionContext(const char* fn_name, |
75 | TF_Status* s) { |
76 | if (default_factory) { |
77 | return default_factory(fn_name, s); |
78 | } |
79 | Set_TF_Status_from_Status( |
80 | s, errors::FailedPrecondition("default_factory is nullptr" )); |
81 | return nullptr; |
82 | } |
83 | |
84 | } // end namespace tracing |
85 | } // end namespace tensorflow |
86 | |
87 | // ============================================================================= |
88 | // Public C API entry points |
89 | // |
90 | // These are only the generic entry points for the C API. This file does not |
91 | // have any visibility into the graph/eager implementation and is only providing |
92 | // C bindings to the abstract classes defined in the |
93 | // c_api_unified_experimental_internal.h header. |
94 | // |
95 | // ============================================================================= |
96 | |
97 | using tensorflow::AbstractFunction; |
98 | using tensorflow::AbstractTensorHandle; |
99 | using tensorflow::DataType; |
100 | using tensorflow::dyn_cast; |
101 | using tensorflow::OutputList; |
102 | using tensorflow::Status; |
103 | using tensorflow::unwrap; |
104 | using tensorflow::wrap; |
105 | using tensorflow::tracing::CreateTracingExecutionContext; |
106 | using tensorflow::tracing::SetDefaultTracingEngine; |
107 | using tensorflow::tracing::TracingContext; |
108 | using tensorflow::tracing::TracingOperation; |
109 | using tensorflow::tracing::TracingTensorHandle; |
110 | |
111 | void TF_SetTracingImplementation(const char* name, TF_Status* s) { |
112 | Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name)); |
113 | } |
114 | |
115 | // Creates a new TensorFlow function, it is an execution context attached to a |
116 | // given tracing context. |
117 | TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) { |
118 | return wrap(CreateTracingExecutionContext(fn_name, s)); |
119 | } |
120 | |
121 | TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, |
122 | TF_OutputList* outputs, TF_Status* s) { |
123 | AbstractFunction* func; |
124 | TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(ctx)); |
125 | if (!tracing_ctx) { |
126 | Set_TF_Status_from_Status( |
127 | s, tensorflow::errors::InvalidArgument( |
128 | "Only TracingContext can be converted into a function." )); |
129 | return nullptr; |
130 | } |
131 | Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func)); |
132 | TF_DeleteExecutionContext(ctx); |
133 | return wrap(func); |
134 | } |
135 | |
136 | TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, |
137 | TF_DataType dtype, TF_Shape shape, |
138 | TF_Status* s) { |
139 | DCHECK_GE(shape.num_dims, -1); |
140 | TracingTensorHandle* t; |
141 | TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func)); |
142 | if (!tracing_ctx) { |
143 | Set_TF_Status_from_Status( |
144 | s, tensorflow::errors::InvalidArgument( |
145 | "TF_AddFunctionParameter must be called on a TracingContext." )); |
146 | return nullptr; |
147 | } |
148 | tensorflow::PartialTensorShape partial_shape; |
149 | if (shape.num_dims != -1) { |
150 | DCHECK(shape.dim_sizes != nullptr); |
151 | Status status = tensorflow::PartialTensorShape::MakePartialShape( |
152 | reinterpret_cast<int64_t*>(shape.dim_sizes), shape.num_dims, |
153 | &partial_shape); |
154 | if (!status.ok()) { |
155 | Set_TF_Status_from_Status(s, status); |
156 | return nullptr; |
157 | } |
158 | } |
159 | Set_TF_Status_from_Status( |
160 | s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape, |
161 | &t)); |
162 | return wrap(t); |
163 | } |
164 | |
165 | void TF_DeleteExecutionContext(TF_ExecutionContext* c) { unwrap(c)->Release(); } |
166 | |
167 | TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) { |
168 | return wrap((unwrap(c)->CreateOperation())); |
169 | } |
170 | |
171 | void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); } |
172 | |
173 | void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); } |
174 | |
175 | TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); } |
176 | void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); } |
177 | void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, |
178 | TF_Status* s) { |
179 | unwrap(o)->expected_num_outputs = num_outputs; |
180 | unwrap(o)->outputs.clear(); |
181 | unwrap(o)->outputs.resize(num_outputs); |
182 | } |
183 | int TF_OutputListNumOutputs(TF_OutputList* o) { |
184 | return unwrap(o)->outputs.size(); |
185 | } |
186 | TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) { |
187 | return wrap(unwrap(o)->outputs[i]); |
188 | } |
189 | void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, |
190 | TF_Status* s) { |
191 | unwrap(o)->outputs.push_back(unwrap(tensor)); |
192 | } |
193 | |
194 | void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, |
195 | TF_Status* s) { |
196 | Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type, |
197 | /*raw_device_name=*/nullptr)); |
198 | } |
199 | |
200 | void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name, |
201 | TF_Status* s) { |
202 | TracingOperation* tracing_op = dyn_cast<TracingOperation>(unwrap(op)); |
203 | if (!tracing_op) { |
204 | Set_TF_Status_from_Status( |
205 | s, tensorflow::errors::InvalidArgument( |
206 | "TF_AbstractOpSetOpName must be called on a TracingOperation." )); |
207 | return; |
208 | } |
209 | Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name)); |
210 | } |
211 | |
212 | void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, |
213 | TF_DataType value, TF_Status* s) { |
214 | Status status = |
215 | unwrap(op)->SetAttrType(attr_name, static_cast<DataType>(value)); |
216 | TF_SetStatus(s, static_cast<TF_Code>(status.code()), |
217 | status.error_message().c_str()); |
218 | } |
219 | |
220 | void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, |
221 | TF_AbstractTensor* const* inputs, TF_OutputList* o, |
222 | TF_Status* s) { |
223 | for (int i = 0; i < num_inputs; i++) { |
224 | Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i]))); |
225 | if (TF_GetCode(s) != TF_OK) { |
226 | return; |
227 | } |
228 | } |
229 | int num_outputs = unwrap(o)->expected_num_outputs; |
230 | Set_TF_Status_from_Status( |
231 | s, unwrap(op)->Execute( |
232 | absl::MakeSpan(reinterpret_cast<AbstractTensorHandle**>( |
233 | unwrap(o)->outputs.data()), |
234 | unwrap(o)->outputs.size()), |
235 | &num_outputs)); |
236 | } |
237 | |
238 | void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { |
239 | unwrap(func)->Unref(); |
240 | } |
241 | |
242 | void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx, |
243 | TF_AbstractFunction* func, |
244 | TF_Status* s) { |
245 | Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func))); |
246 | } |
247 | |