1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/internal/reference/densify.h" |
16 | |
17 | #include <stddef.h> |
18 | |
19 | #include <cstdint> |
20 | |
21 | #include "tensorflow/lite/c/common.h" |
22 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
23 | #include "tensorflow/lite/kernels/kernel_util.h" |
24 | |
25 | namespace tflite { |
26 | namespace ops { |
27 | namespace builtin { |
28 | namespace densify { |
29 | |
30 | struct OpContext { |
31 | OpContext(TfLiteContext* context, TfLiteNode* node) { |
32 | input = GetInput(context, node, 0); |
33 | output = GetOutput(context, node, 0); |
34 | } |
35 | const TfLiteTensor* input; |
36 | TfLiteTensor* output; |
37 | }; |
38 | |
39 | struct OpData { |
40 | bool dense_weights_initialized; |
41 | }; |
42 | |
43 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
44 | auto* op_data = new OpData(); |
45 | op_data->dense_weights_initialized = false; |
46 | return op_data; |
47 | } |
48 | |
49 | void Free(TfLiteContext* context, void* buffer) { |
50 | delete reinterpret_cast<OpData*>(buffer); |
51 | } |
52 | |
53 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
54 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
55 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
56 | |
57 | OpContext op_context(context, node); |
58 | |
59 | TF_LITE_ENSURE(context, op_context.input->type != kTfLiteString); |
60 | TF_LITE_ENSURE(context, IsConstantTensor(op_context.input)); |
61 | TF_LITE_ENSURE(context, op_context.input->sparsity != nullptr); |
62 | |
63 | op_context.output->type = op_context.input->type; |
64 | op_context.output->name = "Densify_output" ; |
65 | op_context.output->allocation_type = kTfLiteArenaRwPersistent; |
66 | |
67 | return context->ResizeTensor(context, op_context.output, |
68 | TfLiteIntArrayCopy(op_context.input->dims)); |
69 | } |
70 | |
71 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
72 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
73 | OpContext op_context(context, node); |
74 | if (op_data->dense_weights_initialized) { |
75 | return kTfLiteOk; |
76 | } |
77 | |
78 | switch (op_context.input->type) { |
79 | case kTfLiteFloat32: |
80 | reference_ops::Densify(op_context.input->sparsity, |
81 | GetTensorShape(op_context.input), |
82 | GetTensorData<float>(op_context.input), |
83 | GetTensorShape(op_context.output), |
84 | GetTensorData<float>(op_context.output), context); |
85 | break; |
86 | case kTfLiteFloat16: |
87 | reference_ops::Densify( |
88 | op_context.input->sparsity, GetTensorShape(op_context.input), |
89 | GetTensorData<Eigen::half>(op_context.input), |
90 | GetTensorShape(op_context.output), |
91 | GetTensorData<Eigen::half>(op_context.output), context); |
92 | break; |
93 | case kTfLiteInt8: |
94 | reference_ops::Densify(op_context.input->sparsity, |
95 | GetTensorShape(op_context.input), |
96 | GetTensorData<int8_t>(op_context.input), |
97 | GetTensorShape(op_context.output), |
98 | GetTensorData<int8_t>(op_context.output), context); |
99 | break; |
100 | |
101 | default: |
102 | TF_LITE_KERNEL_LOG(context, "Type %d not supported." , |
103 | op_context.input->type); |
104 | return kTfLiteError; |
105 | } |
106 | |
107 | op_data->dense_weights_initialized = true; |
108 | return kTfLiteOk; |
109 | } |
110 | |
111 | } // namespace densify |
112 | |
113 | TfLiteRegistration* Register_DENSIFY() { |
114 | static TfLiteRegistration r = {densify::Init, densify::Free, densify::Prepare, |
115 | densify::Eval}; |
116 | return &r; |
117 | } |
118 | |
119 | } // namespace builtin |
120 | } // namespace ops |
121 | } // namespace tflite |
122 | |