1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <algorithm>
16#include <complex>
17
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23#include "tensorflow/lite/kernels/op_macros.h"
24
25namespace tflite {
26namespace ops {
27namespace builtin {
28namespace cast {
29constexpr int kInputTensor = 0;
30constexpr int kOutputTensor = 0;
31
32TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
33 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
34 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35 const TfLiteTensor* input;
36 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
37 TfLiteTensor* output;
38 TF_LITE_ENSURE_OK(context,
39 GetOutputSafe(context, node, kOutputTensor, &output));
40
41 // TODO(ahentz): these two checks would make the new implementation
42 // incompatible with some existing models, where params is not specified. It
43 // is OK not to have them because toco would have set input and output types
44 // to match the parameters.
45 // auto* params = reinterpret_cast<TfLiteCastParams*>(node->builtin_data);
46 // TF_LITE_ENSURE_EQ(context, input->type, params->in_data_type);
47 // TF_LITE_ENSURE_EQ(context, output->type, params->out_data_type);
48
49 return context->ResizeTensor(context, output,
50 TfLiteIntArrayCopy(input->dims));
51}
52
53template <typename FromT, typename ToT>
54void copyCast(const FromT* in, ToT* out, int num_elements) {
55 std::transform(in, in + num_elements, out,
56 [](FromT a) { return static_cast<ToT>(a); });
57}
58
59template <typename ToT>
60void copyCast(const std::complex<float>* in, ToT* out, int num_elements) {
61 std::transform(in, in + num_elements, out, [](std::complex<float> a) {
62 return static_cast<ToT>(std::real(a));
63 });
64}
65
66template <>
67void copyCast(const std::complex<float>* in, std::complex<float>* out,
68 int num_elements) {
69 std::transform(in, in + num_elements, out,
70 [](std::complex<float> a) { return a; });
71}
72
73template <typename FromT>
74TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
75 TfLiteTensor* out, int num_elements) {
76 switch (out->type) {
77 case kTfLiteInt64:
78 copyCast(in, out->data.i64, num_elements);
79 break;
80 case kTfLiteInt32:
81 copyCast(in, out->data.i32, num_elements);
82 break;
83 case kTfLiteUInt32:
84 copyCast(in, out->data.u32, num_elements);
85 break;
86 case kTfLiteInt16:
87 copyCast(in, out->data.i16, num_elements);
88 break;
89 case kTfLiteUInt16:
90 copyCast(in, out->data.ui16, num_elements);
91 break;
92 case kTfLiteUInt8:
93 copyCast(in, out->data.uint8, num_elements);
94 break;
95 case kTfLiteInt8:
96 copyCast(in, out->data.int8, num_elements);
97 break;
98 case kTfLiteFloat32:
99 copyCast(in, GetTensorData<float>(out), num_elements);
100 break;
101 case kTfLiteBool:
102 copyCast(in, out->data.b, num_elements);
103 break;
104 case kTfLiteComplex64:
105 copyCast(in, reinterpret_cast<std::complex<float>*>(out->data.c64),
106 num_elements);
107 break;
108 default:
109 // Unsupported type.
110 TF_LITE_UNSUPPORTED_TYPE(context, out->type, "Cast");
111 }
112 return kTfLiteOk;
113}
114
115TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
116 const TfLiteTensor* input;
117 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
118 TfLiteTensor* output;
119 TF_LITE_ENSURE_OK(context,
120 GetOutputSafe(context, node, kOutputTensor, &output));
121 const int num_elements = NumElements(input);
122 TF_LITE_ENSURE_EQ(context, num_elements, NumElements(output));
123 switch (input->type) {
124 case kTfLiteInt64:
125 return copyToTensor(context, input->data.i64, output, num_elements);
126 case kTfLiteInt32:
127 return copyToTensor(context, input->data.i32, output, num_elements);
128 case kTfLiteUInt32:
129 return copyToTensor(context, input->data.u32, output, num_elements);
130 case kTfLiteUInt16:
131 return copyToTensor(context, input->data.ui16, output, num_elements);
132 case kTfLiteInt16:
133 return copyToTensor(context, input->data.i16, output, num_elements);
134 case kTfLiteUInt8:
135 return copyToTensor(context, input->data.uint8, output, num_elements);
136 case kTfLiteInt8:
137 return copyToTensor(context, input->data.int8, output, num_elements);
138 case kTfLiteFloat32:
139 return copyToTensor(context, GetTensorData<float>(input), output,
140 num_elements);
141 case kTfLiteBool:
142 return copyToTensor(context, input->data.b, output, num_elements);
143 case kTfLiteComplex64:
144 return copyToTensor(
145 context, reinterpret_cast<std::complex<float>*>(input->data.c64),
146 output, num_elements);
147 default:
148 // Unsupported type.
149 TF_LITE_UNSUPPORTED_TYPE(context, input->type, "Cast");
150 }
151}
152} // namespace cast
153
154TfLiteRegistration* Register_CAST() {
155 static TfLiteRegistration r = {nullptr, nullptr, cast::Prepare, cast::Eval};
156 return &r;
157}
158
159} // namespace builtin
160} // namespace ops
161} // namespace tflite
162