1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/ops.h" |
17 | |
18 | #include "tensorflow/c/tf_status_helper.h" |
19 | #include "tensorflow/core/framework/common_shape_fns.h" |
20 | #include "tensorflow/core/framework/op.h" |
21 | #include "tensorflow/core/framework/op_def_builder.h" |
22 | #include "tensorflow/core/framework/shape_inference.h" |
23 | |
24 | using ::tensorflow::DataType; |
25 | using ::tensorflow::OpDef; |
26 | using ::tensorflow::OpDefBuilder; |
27 | using ::tensorflow::OpDeprecation; |
28 | using ::tensorflow::OpShapeInferenceFn; |
29 | using ::tensorflow::Set_TF_Status_from_Status; |
30 | using ::tensorflow::Status; |
31 | using ::tensorflow::shape_inference::DimensionHandle; |
32 | using ::tensorflow::shape_inference::InferenceContext; |
33 | using ::tensorflow::shape_inference::ShapeHandle; |
34 | |
35 | TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) { |
36 | auto* result = new OpDefBuilder(op_name); |
37 | return reinterpret_cast<TF_OpDefinitionBuilder*>(result); |
38 | } |
39 | |
40 | void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) { |
41 | delete reinterpret_cast<OpDefBuilder*>(builder); |
42 | } |
43 | |
44 | void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder, |
45 | const char* input_spec) { |
46 | reinterpret_cast<OpDefBuilder*>(builder)->Input(input_spec); |
47 | } |
48 | |
49 | void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder, |
50 | const char* output_spec) { |
51 | reinterpret_cast<OpDefBuilder*>(builder)->Output(output_spec); |
52 | } |
53 | |
54 | #define DEFINE_BUILDER_BOOL_SETTER(func_name) \ |
55 | void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \ |
56 | bool arg_name) { \ |
57 | reinterpret_cast<OpDefBuilder*>(builder)->func_name(); \ |
58 | } |
59 | |
60 | DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative) |
61 | DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate) |
62 | DEFINE_BUILDER_BOOL_SETTER(SetIsStateful) |
63 | DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput) |
64 | |
65 | void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder, |
66 | const char* attr_spec) { |
67 | reinterpret_cast<OpDefBuilder*>(builder)->Attr(attr_spec); |
68 | } |
69 | |
70 | void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder, |
71 | int version, const char* explanation) { |
72 | reinterpret_cast<OpDefBuilder*>(builder)->Deprecated(version, explanation); |
73 | } |
74 | |
75 | void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder, |
76 | TF_Status* status) { |
77 | auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder); |
78 | TF_SetStatus(status, TF_OK, "" ); |
79 | ::tensorflow::OpRegistry::Global()->Register( |
80 | [cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status { |
81 | Status result = cc_builder->Finalize(op_reg_data); |
82 | delete cc_builder; |
83 | return result; |
84 | }); |
85 | } |
86 | |
87 | void TF_OpDefinitionBuilderSetShapeInferenceFunction( |
88 | TF_OpDefinitionBuilder* builder, |
89 | void (*shape_inference_func)(TF_ShapeInferenceContext* ctx, |
90 | TF_Status* status)) { |
91 | auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder); |
92 | cc_builder->SetShapeFn( |
93 | [shape_inference_func](InferenceContext* ctx) -> tensorflow::Status { |
94 | TF_Status* c_status = TF_NewStatus(); |
95 | auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx); |
96 | shape_inference_func(c_ctx, c_status); |
97 | tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status); |
98 | TF_DeleteStatus(c_status); |
99 | return result; |
100 | }); |
101 | } |
102 | |
103 | TF_ShapeHandle* TF_NewShapeHandle() { |
104 | return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle); |
105 | } |
106 | |
107 | TF_ShapeHandle* TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext* ctx) { |
108 | auto* handle = new ShapeHandle; |
109 | *handle = reinterpret_cast<InferenceContext*>(ctx)->Scalar(); |
110 | return reinterpret_cast<TF_ShapeHandle*>(handle); |
111 | } |
112 | |
113 | TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize( |
114 | TF_ShapeInferenceContext* ctx, size_t size) { |
115 | auto* handle = new ShapeHandle; |
116 | *handle = reinterpret_cast<InferenceContext*>(ctx)->Vector(size); |
117 | return reinterpret_cast<TF_ShapeHandle*>(handle); |
118 | } |
119 | |
120 | void TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext* ctx, |
121 | TF_ShapeHandle* first, |
122 | TF_ShapeHandle* second, |
123 | TF_ShapeHandle* result, |
124 | TF_Status* status) { |
125 | auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); |
126 | Status s = cc_ctx->Concatenate(*reinterpret_cast<ShapeHandle*>(first), |
127 | *reinterpret_cast<ShapeHandle*>(second), |
128 | reinterpret_cast<ShapeHandle*>(result)); |
129 | Set_TF_Status_from_Status(status, s); |
130 | } |
131 | |
132 | TF_DimensionHandle* TF_NewDimensionHandle() { |
133 | return reinterpret_cast<TF_DimensionHandle*>(new DimensionHandle); |
134 | } |
135 | |
136 | int64_t TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext* ctx) { |
137 | auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); |
138 | return cc_ctx->num_inputs(); |
139 | } |
140 | |
141 | void TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext* ctx, int i, |
142 | TF_ShapeHandle* handle, |
143 | TF_Status* status) { |
144 | TF_SetStatus(status, TF_OK, "" ); |
145 | auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); |
146 | if (i < 0 || i >= cc_ctx->num_inputs()) { |
147 | TF_SetStatus(status, TF_INVALID_ARGUMENT, "input index out of range" ); |
148 | } |
149 | if (TF_GetCode(status) == TF_OK) { |
150 | auto* cc_result = reinterpret_cast<ShapeHandle*>(handle); |
151 | *cc_result = cc_ctx->input(i); |
152 | } |
153 | } |
154 | |
155 | int (TF_ShapeInferenceContext* ctx, |
156 | TF_ShapeHandle* handle) { |
157 | auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); |
158 | return cc_ctx->RankKnown(*reinterpret_cast<ShapeHandle*>(handle)); |
159 | } |
160 | |
161 | void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i, |
162 | TF_ShapeHandle* handle, |
163 | TF_Status* status) { |
164 | TF_SetStatus(status, TF_OK, "" ); |
165 | auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); |
166 | if (i < 0 || i >= cc_ctx->num_outputs()) { |
167 | TF_SetStatus(status, TF_INVALID_ARGUMENT, "output index out of range" ); |
168 | } |
169 | if (TF_GetCode(status) == TF_OK) { |
170 | cc_ctx->set_output(i, *(reinterpret_cast<ShapeHandle*>(handle))); |
171 | } |
172 | } |
173 | |
174 | void TF_DeleteShapeHandle(TF_ShapeHandle* handle) { |
175 | if (handle == nullptr) { |
176 | return; |
177 | } |
178 | |
179 | delete reinterpret_cast<ShapeHandle*>(handle); |
180 | } |
181 | |
182 | void TF_DeleteDimensionHandle(TF_DimensionHandle* handle) { |
183 | if (handle == nullptr) { |
184 | return; |
185 | } |
186 | |
187 | delete reinterpret_cast<DimensionHandle*>(handle); |
188 | } |
189 | |
190 | #define DEFINE_TF_GETATTR(func, c_type, cc_type) \ |
191 | void TF_ShapeInferenceContext_GetAttr##func( \ |
192 | TF_ShapeInferenceContext* ctx, const char* attr_name, c_type* val, \ |
193 | TF_Status* status) { \ |
194 | TF_SetStatus(status, TF_OK, ""); \ |
195 | cc_type v; \ |
196 | auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \ |
197 | Status s = cc_ctx->GetAttr(attr_name, &v); \ |
198 | Set_TF_Status_from_Status(status, s); \ |
199 | if (s.ok()) { \ |
200 | *val = static_cast<c_type>(v); \ |
201 | } \ |
202 | } |
203 | |
204 | DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) |
205 | |
206 | #define DEFINE_RANK_FUNC(func_name) \ |
207 | void TF_ShapeInferenceContext##func_name( \ |
208 | TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, \ |
209 | TF_ShapeHandle* result, TF_Status* status) { \ |
210 | auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \ |
211 | auto* cc_handle = reinterpret_cast<ShapeHandle*>(handle); \ |
212 | auto* cc_result = reinterpret_cast<ShapeHandle*>(result); \ |
213 | Status s = cc_ctx->func_name(*cc_handle, rank, cc_result); \ |
214 | Set_TF_Status_from_Status(status, s); \ |
215 | } |
216 | |
217 | DEFINE_RANK_FUNC(WithRank) |
218 | DEFINE_RANK_FUNC(WithRankAtLeast) |
219 | DEFINE_RANK_FUNC(WithRankAtMost) |
220 | |
221 | int64_t (TF_ShapeInferenceContext* ctx, |
222 | TF_ShapeHandle* handle) { |
223 | return reinterpret_cast<InferenceContext*>(ctx)->Rank( |
224 | *reinterpret_cast<ShapeHandle*>(handle)); |
225 | } |
226 | |
227 | void TF_ShapeInferenceContextDim(TF_ShapeInferenceContext* ctx, |
228 | TF_ShapeHandle* shape_handle, int64_t i, |
229 | TF_DimensionHandle* result) { |
230 | int64_t rank = TF_ShapeInferenceContextRank(ctx, shape_handle); |
231 | auto* cc_result = reinterpret_cast<DimensionHandle*>(result); |
232 | |
233 | if (i < -rank || i >= rank) { |
234 | *cc_result = DimensionHandle(); |
235 | return; |
236 | } |
237 | |
238 | auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); |
239 | auto* cc_shape_handle = reinterpret_cast<ShapeHandle*>(shape_handle); |
240 | *cc_result = cc_ctx->Dim(*cc_shape_handle, i); |
241 | } |
242 | |
243 | int TF_DimensionHandleValueKnown(TF_DimensionHandle* dim_handle) { |
244 | return InferenceContext::ValueKnown( |
245 | *reinterpret_cast<DimensionHandle*>(dim_handle)); |
246 | } |
247 | |
248 | void TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext* ctx, |
249 | TF_Status* status) { |
250 | Status s = ::tensorflow::shape_inference::UnknownShape( |
251 | reinterpret_cast<InferenceContext*>(ctx)); |
252 | Set_TF_Status_from_Status(status, s); |
253 | } |
254 | |
255 | void TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext* ctx, |
256 | TF_ShapeHandle* shape_handle, |
257 | int64_t start, int64_t end, |
258 | TF_ShapeHandle* result, |
259 | TF_Status* status) { |
260 | TF_SetStatus(status, TF_OK, "" ); |
261 | auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); |
262 | auto* cc_result = reinterpret_cast<ShapeHandle*>(result); |
263 | Status s = cc_ctx->Subshape(*reinterpret_cast<ShapeHandle*>(shape_handle), |
264 | start, end, cc_result); |
265 | Set_TF_Status_from_Status(status, s); |
266 | } |
267 | |
268 | int64_t TF_DimensionHandleValue(TF_DimensionHandle* dim_handle) { |
269 | return InferenceContext::Value( |
270 | *reinterpret_cast<DimensionHandle*>(dim_handle)); |
271 | } |
272 | |