1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <vector> |
17 | |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/framework/register_types.h" |
20 | #include "tensorflow/core/framework/tensor_shape.h" |
21 | #include "tensorflow/core/framework/tensor_types.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | #include "tensorflow/core/kernels/reshape_op.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | class QuantizedReshapeOp : public ReshapeOp { |
28 | public: |
29 | explicit QuantizedReshapeOp(OpKernelConstruction* c) : ReshapeOp(c) {} |
30 | |
31 | void Compute(OpKernelContext* ctx) override { |
32 | // This call processes inputs 1 and 2 to write output 0. |
33 | ReshapeOp::Compute(ctx); |
34 | if (!ctx->status().ok()) { |
35 | return; |
36 | } |
37 | |
38 | const auto& input_min_float_tensor = ctx->input(2); |
39 | const auto& input_min_float_shape = input_min_float_tensor.shape(); |
40 | OP_REQUIRES(ctx, |
41 | TensorShapeUtils::IsScalar(input_min_float_shape) || |
42 | (TensorShapeUtils::IsVector(input_min_float_shape) && |
43 | (input_min_float_shape.dim_size(0) == 1)), |
44 | errors::InvalidArgument( |
45 | "input_min must be a scalar or a vector of 1 element" )); |
46 | const float input_min_float = input_min_float_tensor.flat<float>()(0); |
47 | const auto& input_max_float_tensor = ctx->input(3); |
48 | const auto& input_max_float_shape = input_max_float_tensor.shape(); |
49 | OP_REQUIRES(ctx, |
50 | TensorShapeUtils::IsScalar(input_max_float_shape) || |
51 | (TensorShapeUtils::IsVector(input_max_float_shape) && |
52 | (input_max_float_shape.dim_size(0) == 1)), |
53 | errors::InvalidArgument( |
54 | "input_max must be a scalar or a vector of 1 element" )); |
55 | const float input_max_float = input_max_float_tensor.flat<float>()(0); |
56 | |
57 | Tensor* output_min = nullptr; |
58 | OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &output_min)); |
59 | output_min->flat<float>()(0) = input_min_float; |
60 | |
61 | Tensor* output_max = nullptr; |
62 | OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({}), &output_max)); |
63 | output_max->flat<float>()(0) = input_max_float; |
64 | } |
65 | }; |
66 | |
67 | #define REGISTER_CPU_KERNEL(type) \ |
68 | REGISTER_KERNEL_BUILDER(Name("QuantizedReshape") \ |
69 | .Device(DEVICE_CPU) \ |
70 | .HostMemory("shape") \ |
71 | .TypeConstraint<type>("T"), \ |
72 | QuantizedReshapeOp) |
73 | |
74 | REGISTER_CPU_KERNEL(::tensorflow::quint8); |
75 | REGISTER_CPU_KERNEL(::tensorflow::qint32); |
76 | |
77 | #undef REGISTER_CPU_KERNEL |
78 | |
79 | } // namespace tensorflow |
80 | |