1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/builtin_op_data.h"
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/compatibility.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace shape {
28
29constexpr int kInputTensor = 0;
30constexpr int kOutputTensor = 0;
31
32template <typename OutType>
33void ExtractShape(const TfLiteTensor* input, OutType* output_data) {
34 for (int i = 0; i < NumDimensions(input); ++i) {
35 output_data[i] = SizeOfDimension(input, i);
36 }
37}
38
39TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
40 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
41 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
42
43 const TfLiteTensor* input;
44 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
45 TfLiteTensor* output;
46 TF_LITE_ENSURE_OK(context,
47 GetOutputSafe(context, node, kOutputTensor, &output));
48
49 auto* params = reinterpret_cast<TfLiteShapeParams*>(node->builtin_data);
50 switch (params->out_type) {
51 case kTfLiteInt32:
52 output->type = kTfLiteInt32;
53 break;
54 case kTfLiteInt64:
55 output->type = kTfLiteInt64;
56 break;
57 default:
58 TF_LITE_KERNEL_LOG(context, "Unknown shape output data type: %d",
59 params->out_type);
60 return kTfLiteError;
61 }
62
63 // By design, the input shape is always known at the time of Prepare, even
64 // if the preceding op that generates |input| is dynamic. Thus, we can
65 // always compute the shape immediately, without waiting for Eval.
66 SetTensorToPersistentRo(output);
67
68 // Shape always produces a 1-dimensional output tensor, where each output
69 // element is the length of the corresponding input tensor's dimension.
70 TfLiteIntArray* output_size = TfLiteIntArrayCreate(1);
71 output_size->data[0] = NumDimensions(input);
72 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size));
73
74 TFLITE_DCHECK_EQ(NumDimensions(output), 1);
75 TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input));
76
77 // Immediately propagate the known shape to the output tensor. This allows
78 // downstream ops that rely on the value to use it during prepare.
79 switch (output->type) {
80 case kTfLiteInt32:
81 ExtractShape(input, GetTensorData<int32_t>(output));
82 break;
83 case kTfLiteInt64:
84 ExtractShape(input, GetTensorData<int64_t>(output));
85 break;
86 default:
87 return kTfLiteError;
88 }
89
90 return kTfLiteOk;
91}
92
93TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
94 return kTfLiteOk;
95}
96
97} // namespace shape
98
99TfLiteRegistration* Register_SHAPE() {
100 static TfLiteRegistration r = {nullptr, nullptr, shape::Prepare, shape::Eval};
101 return &r;
102}
103
104} // namespace builtin
105} // namespace ops
106} // namespace tflite
107