1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_ |
18 | |
19 | #include "tensorflow/core/framework/op_kernel.h" |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/framework/types.h" |
22 | #include "tensorflow/core/framework/types.pb.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | // One input and one output, both the same type. |
29 | template <class T> |
30 | class UnaryOp : public OpKernel { |
31 | public: |
32 | explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) { |
33 | const DataType dt = DataTypeToEnum<T>::v(); |
34 | OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt})); |
35 | } |
36 | }; |
37 | |
38 | // Two inputs and one output, all the same type. |
39 | template <class T> |
40 | class BinaryOp : public OpKernel { |
41 | public: |
42 | explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) { |
43 | const DataType dt = DataTypeToEnum<T>::v(); |
44 | OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt})); |
45 | } |
46 | }; |
47 | |
48 | // For operations where the input and output are the same shape. |
49 | // |
50 | // For usage, see ../framework/elementwise_ops.cc. |
51 | template <class T, class CHILD> |
52 | class UnaryElementWiseOp : public UnaryOp<T> { |
53 | public: |
54 | using UnaryOp<T>::UnaryOp; |
55 | |
56 | void Compute(OpKernelContext* context) override { |
57 | // Output shape is the same as input shape. |
58 | const Tensor& input = context->input(0); |
59 | Tensor* output = nullptr; |
60 | OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( |
61 | {0}, 0, input.shape(), &output)); |
62 | static_cast<CHILD*>(this)->Operate(context, input, output); |
63 | } |
64 | }; |
65 | |
66 | // For binary elementwise operations. |
67 | template <class T, class CHILD> |
68 | class BinaryElementWiseOp : public BinaryOp<T> { |
69 | public: |
70 | using BinaryOp<T>::BinaryOp; |
71 | |
72 | void Compute(OpKernelContext* context) override { |
73 | const Tensor& a = context->input(0); |
74 | const Tensor& b = context->input(1); |
75 | |
76 | if (!context->ValidateInputsAreSameShape(this)) { |
77 | return; |
78 | } |
79 | |
80 | Tensor* output = nullptr; |
81 | OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( |
82 | {0, 1}, 0, a.shape(), &output)); |
83 | |
84 | // Dispatch to the descendant's Operate() function. |
85 | switch (a.dims()) { |
86 | #define NDIM_CASE(NDIMS) \ |
87 | case NDIMS: { \ |
88 | static_cast<CHILD*>(this)->template Operate<NDIMS>(context, a, b, output); \ |
89 | break; \ |
90 | } |
91 | |
92 | NDIM_CASE(0); |
93 | NDIM_CASE(1); |
94 | NDIM_CASE(2); |
95 | NDIM_CASE(3); |
96 | NDIM_CASE(4); |
97 | NDIM_CASE(5); |
98 | NDIM_CASE(6); |
99 | NDIM_CASE(7); |
100 | NDIM_CASE(8); |
101 | #undef NDIM_CASE |
102 | |
103 | default: |
104 | context->SetStatus(errors::InvalidArgument( |
105 | "We only handle up to Tensor::dims() up to 8, not " , a.dims())); |
106 | break; |
107 | } |
108 | } |
109 | }; |
110 | |
111 | } // namespace tensorflow |
112 | |
113 | #endif // TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_ |
114 | |