1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/ops.h"
17
18#include "tensorflow/c/tf_status_helper.h"
19#include "tensorflow/core/framework/common_shape_fns.h"
20#include "tensorflow/core/framework/op.h"
21#include "tensorflow/core/framework/op_def_builder.h"
22#include "tensorflow/core/framework/shape_inference.h"
23
24using ::tensorflow::DataType;
25using ::tensorflow::OpDef;
26using ::tensorflow::OpDefBuilder;
27using ::tensorflow::OpDeprecation;
28using ::tensorflow::OpShapeInferenceFn;
29using ::tensorflow::Set_TF_Status_from_Status;
30using ::tensorflow::Status;
31using ::tensorflow::shape_inference::DimensionHandle;
32using ::tensorflow::shape_inference::InferenceContext;
33using ::tensorflow::shape_inference::ShapeHandle;
34
35TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) {
36 auto* result = new OpDefBuilder(op_name);
37 return reinterpret_cast<TF_OpDefinitionBuilder*>(result);
38}
39
40void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) {
41 delete reinterpret_cast<OpDefBuilder*>(builder);
42}
43
44void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder,
45 const char* input_spec) {
46 reinterpret_cast<OpDefBuilder*>(builder)->Input(input_spec);
47}
48
49void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder,
50 const char* output_spec) {
51 reinterpret_cast<OpDefBuilder*>(builder)->Output(output_spec);
52}
53
54#define DEFINE_BUILDER_BOOL_SETTER(func_name) \
55 void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
56 bool arg_name) { \
57 reinterpret_cast<OpDefBuilder*>(builder)->func_name(); \
58 }
59
60DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative)
61DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate)
62DEFINE_BUILDER_BOOL_SETTER(SetIsStateful)
63DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput)
64
65void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder,
66 const char* attr_spec) {
67 reinterpret_cast<OpDefBuilder*>(builder)->Attr(attr_spec);
68}
69
70void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder,
71 int version, const char* explanation) {
72 reinterpret_cast<OpDefBuilder*>(builder)->Deprecated(version, explanation);
73}
74
75void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
76 TF_Status* status) {
77 auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
78 TF_SetStatus(status, TF_OK, "");
79 ::tensorflow::OpRegistry::Global()->Register(
80 [cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
81 Status result = cc_builder->Finalize(op_reg_data);
82 delete cc_builder;
83 return result;
84 });
85}
86
87void TF_OpDefinitionBuilderSetShapeInferenceFunction(
88 TF_OpDefinitionBuilder* builder,
89 void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
90 TF_Status* status)) {
91 auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
92 cc_builder->SetShapeFn(
93 [shape_inference_func](InferenceContext* ctx) -> tensorflow::Status {
94 TF_Status* c_status = TF_NewStatus();
95 auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
96 shape_inference_func(c_ctx, c_status);
97 tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
98 TF_DeleteStatus(c_status);
99 return result;
100 });
101}
102
103TF_ShapeHandle* TF_NewShapeHandle() {
104 return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
105}
106
107TF_ShapeHandle* TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext* ctx) {
108 auto* handle = new ShapeHandle;
109 *handle = reinterpret_cast<InferenceContext*>(ctx)->Scalar();
110 return reinterpret_cast<TF_ShapeHandle*>(handle);
111}
112
113TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
114 TF_ShapeInferenceContext* ctx, size_t size) {
115 auto* handle = new ShapeHandle;
116 *handle = reinterpret_cast<InferenceContext*>(ctx)->Vector(size);
117 return reinterpret_cast<TF_ShapeHandle*>(handle);
118}
119
120void TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext* ctx,
121 TF_ShapeHandle* first,
122 TF_ShapeHandle* second,
123 TF_ShapeHandle* result,
124 TF_Status* status) {
125 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
126 Status s = cc_ctx->Concatenate(*reinterpret_cast<ShapeHandle*>(first),
127 *reinterpret_cast<ShapeHandle*>(second),
128 reinterpret_cast<ShapeHandle*>(result));
129 Set_TF_Status_from_Status(status, s);
130}
131
132TF_DimensionHandle* TF_NewDimensionHandle() {
133 return reinterpret_cast<TF_DimensionHandle*>(new DimensionHandle);
134}
135
136int64_t TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext* ctx) {
137 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
138 return cc_ctx->num_inputs();
139}
140
141void TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext* ctx, int i,
142 TF_ShapeHandle* handle,
143 TF_Status* status) {
144 TF_SetStatus(status, TF_OK, "");
145 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
146 if (i < 0 || i >= cc_ctx->num_inputs()) {
147 TF_SetStatus(status, TF_INVALID_ARGUMENT, "input index out of range");
148 }
149 if (TF_GetCode(status) == TF_OK) {
150 auto* cc_result = reinterpret_cast<ShapeHandle*>(handle);
151 *cc_result = cc_ctx->input(i);
152 }
153}
154
155int TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext* ctx,
156 TF_ShapeHandle* handle) {
157 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
158 return cc_ctx->RankKnown(*reinterpret_cast<ShapeHandle*>(handle));
159}
160
161void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i,
162 TF_ShapeHandle* handle,
163 TF_Status* status) {
164 TF_SetStatus(status, TF_OK, "");
165 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
166 if (i < 0 || i >= cc_ctx->num_outputs()) {
167 TF_SetStatus(status, TF_INVALID_ARGUMENT, "output index out of range");
168 }
169 if (TF_GetCode(status) == TF_OK) {
170 cc_ctx->set_output(i, *(reinterpret_cast<ShapeHandle*>(handle)));
171 }
172}
173
174void TF_DeleteShapeHandle(TF_ShapeHandle* handle) {
175 if (handle == nullptr) {
176 return;
177 }
178
179 delete reinterpret_cast<ShapeHandle*>(handle);
180}
181
182void TF_DeleteDimensionHandle(TF_DimensionHandle* handle) {
183 if (handle == nullptr) {
184 return;
185 }
186
187 delete reinterpret_cast<DimensionHandle*>(handle);
188}
189
190#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
191 void TF_ShapeInferenceContext_GetAttr##func( \
192 TF_ShapeInferenceContext* ctx, const char* attr_name, c_type* val, \
193 TF_Status* status) { \
194 TF_SetStatus(status, TF_OK, ""); \
195 cc_type v; \
196 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
197 Status s = cc_ctx->GetAttr(attr_name, &v); \
198 Set_TF_Status_from_Status(status, s); \
199 if (s.ok()) { \
200 *val = static_cast<c_type>(v); \
201 } \
202 }
203
204DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
205
206#define DEFINE_RANK_FUNC(func_name) \
207 void TF_ShapeInferenceContext##func_name( \
208 TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, \
209 TF_ShapeHandle* result, TF_Status* status) { \
210 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
211 auto* cc_handle = reinterpret_cast<ShapeHandle*>(handle); \
212 auto* cc_result = reinterpret_cast<ShapeHandle*>(result); \
213 Status s = cc_ctx->func_name(*cc_handle, rank, cc_result); \
214 Set_TF_Status_from_Status(status, s); \
215 }
216
217DEFINE_RANK_FUNC(WithRank)
218DEFINE_RANK_FUNC(WithRankAtLeast)
219DEFINE_RANK_FUNC(WithRankAtMost)
220
221int64_t TF_ShapeInferenceContextRank(TF_ShapeInferenceContext* ctx,
222 TF_ShapeHandle* handle) {
223 return reinterpret_cast<InferenceContext*>(ctx)->Rank(
224 *reinterpret_cast<ShapeHandle*>(handle));
225}
226
227void TF_ShapeInferenceContextDim(TF_ShapeInferenceContext* ctx,
228 TF_ShapeHandle* shape_handle, int64_t i,
229 TF_DimensionHandle* result) {
230 int64_t rank = TF_ShapeInferenceContextRank(ctx, shape_handle);
231 auto* cc_result = reinterpret_cast<DimensionHandle*>(result);
232
233 if (i < -rank || i >= rank) {
234 *cc_result = DimensionHandle();
235 return;
236 }
237
238 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
239 auto* cc_shape_handle = reinterpret_cast<ShapeHandle*>(shape_handle);
240 *cc_result = cc_ctx->Dim(*cc_shape_handle, i);
241}
242
243int TF_DimensionHandleValueKnown(TF_DimensionHandle* dim_handle) {
244 return InferenceContext::ValueKnown(
245 *reinterpret_cast<DimensionHandle*>(dim_handle));
246}
247
248void TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext* ctx,
249 TF_Status* status) {
250 Status s = ::tensorflow::shape_inference::UnknownShape(
251 reinterpret_cast<InferenceContext*>(ctx));
252 Set_TF_Status_from_Status(status, s);
253}
254
255void TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext* ctx,
256 TF_ShapeHandle* shape_handle,
257 int64_t start, int64_t end,
258 TF_ShapeHandle* result,
259 TF_Status* status) {
260 TF_SetStatus(status, TF_OK, "");
261 auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
262 auto* cc_result = reinterpret_cast<ShapeHandle*>(result);
263 Status s = cc_ctx->Subshape(*reinterpret_cast<ShapeHandle*>(shape_handle),
264 start, end, cc_result);
265 Set_TF_Status_from_Status(status, s);
266}
267
268int64_t TF_DimensionHandleValue(TF_DimensionHandle* dim_handle) {
269 return InferenceContext::Value(
270 *reinterpret_cast<DimensionHandle*>(dim_handle));
271}
272