1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | |
17 | #include "tensorflow/lite/c/builtin_op_data.h" |
18 | #include "tensorflow/lite/c/common.h" |
19 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
20 | #include "tensorflow/lite/kernels/internal/tensor.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
22 | #include "tensorflow/lite/kernels/kernel_util.h" |
23 | |
24 | namespace tflite { |
25 | namespace ops { |
26 | namespace builtin { |
27 | namespace shape { |
28 | |
29 | constexpr int kInputTensor = 0; |
30 | constexpr int kOutputTensor = 0; |
31 | |
32 | template <typename OutType> |
33 | void (const TfLiteTensor* input, OutType* output_data) { |
34 | for (int i = 0; i < NumDimensions(input); ++i) { |
35 | output_data[i] = SizeOfDimension(input, i); |
36 | } |
37 | } |
38 | |
39 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
40 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
41 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
42 | |
43 | const TfLiteTensor* input; |
44 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
45 | TfLiteTensor* output; |
46 | TF_LITE_ENSURE_OK(context, |
47 | GetOutputSafe(context, node, kOutputTensor, &output)); |
48 | |
49 | auto* params = reinterpret_cast<TfLiteShapeParams*>(node->builtin_data); |
50 | switch (params->out_type) { |
51 | case kTfLiteInt32: |
52 | output->type = kTfLiteInt32; |
53 | break; |
54 | case kTfLiteInt64: |
55 | output->type = kTfLiteInt64; |
56 | break; |
57 | default: |
58 | TF_LITE_KERNEL_LOG(context, "Unknown shape output data type: %d" , |
59 | params->out_type); |
60 | return kTfLiteError; |
61 | } |
62 | |
63 | // By design, the input shape is always known at the time of Prepare, even |
64 | // if the preceding op that generates |input| is dynamic. Thus, we can |
65 | // always compute the shape immediately, without waiting for Eval. |
66 | SetTensorToPersistentRo(output); |
67 | |
68 | // Shape always produces a 1-dimensional output tensor, where each output |
69 | // element is the length of the corresponding input tensor's dimension. |
70 | TfLiteIntArray* output_size = TfLiteIntArrayCreate(1); |
71 | output_size->data[0] = NumDimensions(input); |
72 | TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_size)); |
73 | |
74 | TFLITE_DCHECK_EQ(NumDimensions(output), 1); |
75 | TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input)); |
76 | |
77 | // Immediately propagate the known shape to the output tensor. This allows |
78 | // downstream ops that rely on the value to use it during prepare. |
79 | switch (output->type) { |
80 | case kTfLiteInt32: |
81 | ExtractShape(input, GetTensorData<int32_t>(output)); |
82 | break; |
83 | case kTfLiteInt64: |
84 | ExtractShape(input, GetTensorData<int64_t>(output)); |
85 | break; |
86 | default: |
87 | return kTfLiteError; |
88 | } |
89 | |
90 | return kTfLiteOk; |
91 | } |
92 | |
93 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
94 | return kTfLiteOk; |
95 | } |
96 | |
97 | } // namespace shape |
98 | |
99 | TfLiteRegistration* Register_SHAPE() { |
100 | static TfLiteRegistration r = {nullptr, nullptr, shape::Prepare, shape::Eval}; |
101 | return &r; |
102 | } |
103 | |
104 | } // namespace builtin |
105 | } // namespace ops |
106 | } // namespace tflite |
107 | |