1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23#include "tensorflow/lite/string_util.h"
24
25namespace tflite {
26namespace ops {
27namespace builtin {
28namespace fill {
29
30namespace {
31
32constexpr int kDimsTensor = 0;
33constexpr int kValueTensor = 1;
34constexpr int kOutputTensor = 0;
35
36template <typename T>
37TfLiteStatus ResizeOutputImpl(TfLiteContext* context, const TfLiteTensor* dims,
38 TfLiteTensor* output) {
39 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dims->dims->data[0]);
40 for (int i = 0; i < output_shape->size; ++i) {
41 T data = GetTensorData<T>(dims)[i];
42 if (data < 0) {
43 TfLiteIntArrayFree(output_shape);
44 TF_LITE_KERNEL_LOG(context, "Fill dimensions must be >= 0", dims->type);
45 return kTfLiteError;
46 }
47 output_shape->data[i] = data;
48 }
49 return context->ResizeTensor(context, output, output_shape);
50}
51
52TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* dims,
53 TfLiteTensor* output) {
54 switch (dims->type) {
55 case kTfLiteInt32:
56 return ResizeOutputImpl<int32_t>(context, dims, output);
57 case kTfLiteInt64:
58 return ResizeOutputImpl<int64_t>(context, dims, output);
59 default:
60 TF_LITE_KERNEL_LOG(
61 context,
62 "Fill only currently supports int32, int64 for input 0, "
63 "got %d.",
64 dims->type);
65 return kTfLiteError;
66 }
67}
68
69} // namespace
70
71TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
72 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
73 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
74
75 const TfLiteTensor* dims;
76 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
77 const TfLiteTensor* value;
78 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
79
80 // Make sure the 1st input tensor is 1-D.
81 TF_LITE_ENSURE_EQ(context, NumDimensions(dims), 1);
82
83 // Make sure the 1st input tensor is int32 or int64.
84 const auto dtype = dims->type;
85 TF_LITE_ENSURE(context, dtype == kTfLiteInt32 || dtype == kTfLiteInt64);
86
87 // Make sure the 2nd input tensor is a scalar.
88 TF_LITE_ENSURE_EQ(context, NumDimensions(value), 0);
89
90 TfLiteTensor* output;
91 TF_LITE_ENSURE_OK(context,
92 GetOutputSafe(context, node, kOutputTensor, &output));
93 output->type = value->type;
94
95 TF_LITE_ENSURE_EQ(context, output->params.scale, value->params.scale);
96 TF_LITE_ENSURE_EQ(context, output->params.zero_point,
97 value->params.zero_point);
98
99 if (value->type == kTfLiteInt16) {
100 TF_LITE_ENSURE_EQ(context, value->params.zero_point, 0);
101 }
102
103 if (IsConstantTensor(dims)) {
104 TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
105 } else {
106 SetTensorToDynamic(output);
107 }
108 return kTfLiteOk;
109}
110
111TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) {
112 DynamicBuffer buffer;
113 const auto string_ref = GetString(value, 0);
114 int n = 1;
115 for (int i = 0; i < output->dims->size; ++i) {
116 n *= output->dims->data[i];
117 }
118 for (int i = 0; i < n; ++i) {
119 buffer.AddString(string_ref.str, string_ref.len);
120 }
121 buffer.WriteToTensor(output, /*new_shape=*/nullptr);
122 return kTfLiteOk;
123}
124
125TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
126 const TfLiteTensor* value;
127 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
128
129 TfLiteTensor* output;
130 TF_LITE_ENSURE_OK(context,
131 GetOutputSafe(context, node, kOutputTensor, &output));
132
133 if (IsDynamicTensor(output)) {
134 const TfLiteTensor* dims;
135 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDimsTensor, &dims));
136 TF_LITE_ENSURE_OK(context, ResizeOutput(context, dims, output));
137 }
138#define TF_LITE_FILL(data_type) \
139 reference_ops::Fill(GetTensorShape(value), GetTensorData<data_type>(value), \
140 GetTensorShape(output), \
141 GetTensorData<data_type>(output))
142 switch (output->type) {
143 case kTfLiteInt8:
144 TF_LITE_FILL(int8_t);
145 break;
146 case kTfLiteInt16:
147 TF_LITE_FILL(int16_t);
148 break;
149 case kTfLiteInt32:
150 TF_LITE_FILL(int32_t);
151 break;
152 case kTfLiteInt64:
153 TF_LITE_FILL(int64_t);
154 break;
155 case kTfLiteFloat32:
156 TF_LITE_FILL(float);
157 break;
158 case kTfLiteBool:
159 TF_LITE_FILL(bool);
160 break;
161 case kTfLiteString:
162 FillString(value, output);
163 break;
164 default:
165 TF_LITE_KERNEL_LOG(
166 context,
167 "Fill only currently supports int8, int16, int32, int64, float32, "
168 "bool, string for input 1, got %d.",
169 value->type);
170 return kTfLiteError;
171 }
172#undef TF_LITE_FILL
173 return kTfLiteOk;
174}
175
176} // namespace fill
177
178TfLiteRegistration* Register_FILL() {
179 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
180 fill::Prepare, fill::Eval};
181 return &r;
182}
183
184} // namespace builtin
185} // namespace ops
186} // namespace tflite
187