1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_ |
16 | #define TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_ |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/register_types.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/types.h" |
25 | #include "tensorflow/core/framework/variant_op_registry.h" |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | typedef Eigen::ThreadPoolDevice CPUDevice; |
31 | typedef Eigen::GpuDevice GPUDevice; |
32 | |
33 | template <typename Device> |
34 | Status ZerosLikeTensor(OpKernelContext* ctx, const Tensor& x, Tensor* out) { |
35 | AllocatorAttributes attr; |
36 | if (x.dtype() == DT_VARIANT) { |
37 | attr.set_on_host(true); |
38 | } |
39 | TF_RETURN_IF_ERROR(ctx->allocate_temp(x.dtype(), x.shape(), out, attr)); |
40 | |
41 | switch (out->dtype()) { |
42 | #define DTYPE_CASE(dtype) \ |
43 | case DataTypeToEnum<dtype>::value: \ |
44 | /* TODO(skyewm): use SetZeroFunctor like in ZerosLikeOp? */ \ |
45 | out->flat<dtype>().device(ctx->eigen_device<Device>()) = \ |
46 | out->flat<dtype>().constant(dtype(0)); \ |
47 | break; |
48 | |
49 | TF_CALL_POD_TYPES(DTYPE_CASE) |
50 | #undef DTYPE_CASE |
51 | |
52 | case DT_INVALID: { |
53 | *out = Tensor(DT_INVALID); |
54 | break; |
55 | } |
56 | case DataTypeToEnum<Variant>::value: { |
57 | Variant* out_variant = out->scalar<Variant>().data(); |
58 | TF_RETURN_IF_ERROR( |
59 | UnaryOpVariant<Device>(ctx, ZEROS_LIKE_VARIANT_UNARY_OP, |
60 | x.scalar<Variant>()(), out_variant)); |
61 | break; |
62 | } |
63 | default: |
64 | return errors::InvalidArgument( |
65 | "Trying to compute zeros_like for unsupported dtype " , |
66 | DataTypeString(out->dtype())); |
67 | } |
68 | return OkStatus(); |
69 | } |
70 | |
71 | template <typename Device> |
72 | Status BinaryAddTensors(OpKernelContext* ctx, const Tensor& a, const Tensor& b, |
73 | Tensor* out) { |
74 | if (a.dtype() == DT_INVALID) { |
75 | *out = b; |
76 | return OkStatus(); |
77 | } |
78 | if (b.dtype() == DT_INVALID) { |
79 | *out = a; |
80 | return OkStatus(); |
81 | } |
82 | if (a.dtype() != b.dtype()) { |
83 | return errors::InvalidArgument( |
84 | "Trying to add two tensors with incompatible element types. " , |
85 | "One is " , DataTypeString(a.dtype()), " and the other is " , |
86 | DataTypeString(b.dtype())); |
87 | } |
88 | if (a.shape() != b.shape()) { |
89 | // TODO(apassos) support broadcasting additions here? |
90 | return errors::InvalidArgument( |
91 | "Trying to add two tensors with incompatible element shapes. " , |
92 | "One is " , a.shape().DebugString(), " and the other is " , |
93 | b.shape().DebugString()); |
94 | } |
95 | |
96 | AllocatorAttributes attr; |
97 | if (a.dtype() == DT_VARIANT) { |
98 | attr.set_on_host(true); |
99 | } |
100 | TF_RETURN_IF_ERROR(ctx->allocate_temp(a.dtype(), a.shape(), out, attr)); |
101 | |
102 | switch (out->dtype()) { |
103 | #define DTYPE_CASE(dtype) \ |
104 | case DataTypeToEnum<dtype>::value: \ |
105 | out->flat<dtype>().device(ctx->eigen_device<Device>()) = \ |
106 | a.flat<dtype>() + b.flat<dtype>(); \ |
107 | break; |
108 | |
109 | TF_CALL_NUMBER_TYPES(DTYPE_CASE) |
110 | #undef DTYPE_CASE |
111 | |
112 | case DataTypeToEnum<Variant>::value: { |
113 | Variant* out_variant = out->scalar<Variant>().data(); |
114 | TF_RETURN_IF_ERROR(BinaryOpVariants<Device>( |
115 | ctx, ADD_VARIANT_BINARY_OP, a.scalar<Variant>()(), |
116 | b.scalar<Variant>()(), out_variant)); |
117 | break; |
118 | } |
119 | default: |
120 | return errors::InvalidArgument("Trying to add unsupported dtype " , |
121 | out->dtype()); |
122 | } |
123 | return OkStatus(); |
124 | } |
125 | |
126 | } // namespace tensorflow |
127 | |
128 | #endif // TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_ |
129 | |