1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/nn_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/xent_op.h" |
21 | |
22 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/util/bcast.h" |
28 | #include "tensorflow/core/util/determinism.h" |
29 | #include "tensorflow/core/util/env_var.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | typedef Eigen::ThreadPoolDevice CPUDevice; |
34 | typedef Eigen::GpuDevice GPUDevice; |
35 | |
36 | template <typename Device, typename T> |
37 | class SoftmaxXentWithLogitsOp : public OpKernel { |
38 | public: |
39 | explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* context) |
40 | : OpKernel(context) {} |
41 | |
42 | void Compute(OpKernelContext* context) override { |
43 | const Tensor& logits_in = context->input(0); |
44 | const Tensor& labels_in = context->input(1); |
45 | |
46 | TensorShape shape_in = logits_in.shape(); |
47 | |
48 | BCast bcast(BCast::FromShape(logits_in.shape()), |
49 | BCast::FromShape(labels_in.shape()), |
50 | /*fewer_dims_optimization=*/false); |
51 | if (!logits_in.IsSameSize(labels_in)) { |
52 | OP_REQUIRES(context, bcast.IsValid(), |
53 | errors::InvalidArgument( |
54 | "logits and labels must be broadcastable: logits_size=" , |
55 | logits_in.shape().DebugString(), |
56 | " labels_size=" , labels_in.shape().DebugString())); |
57 | shape_in = BCast::ToShape(bcast.output_shape()); |
58 | } |
59 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in), |
60 | errors::InvalidArgument("logits and labels must be either " |
61 | "2-dimensional, or broadcasted to be " |
62 | "2-dimensional" )); |
63 | |
64 | if (std::is_same<Device, GPUDevice>::value) { |
65 | OP_REQUIRES(context, !OpDeterminismRequired(), |
66 | errors::Unimplemented( |
67 | "The GPU implementation of SoftmaxCrossEntropyWithLogits" |
68 | " that would have been executed is not deterministic." |
69 | " Note that the Python API uses an alternative," |
70 | " deterministic, GPU-accelerated path when determinism is" |
71 | " enabled." )); |
72 | } |
73 | |
74 | // loss is 1-D (one per example), and size is batch_size. |
75 | |
76 | Tensor scratch; |
77 | OP_REQUIRES_OK( |
78 | context, context->allocate_temp(DataTypeToEnum<T>::value, |
79 | TensorShape({shape_in.dim_size(0), 1}), |
80 | &scratch)); |
81 | |
82 | Tensor* loss_out = nullptr; |
83 | OP_REQUIRES_OK(context, |
84 | context->allocate_output( |
85 | 0, TensorShape({shape_in.dim_size(0)}), &loss_out)); |
86 | Tensor* back_out = nullptr; |
87 | // Try to reuse the logits_in buffer for the backprop output. |
88 | OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( |
89 | {0}, 1, shape_in, &back_out)); |
90 | if (shape_in.dim_size(0) > 0) { |
91 | functor::XentFunctor<Device, T> functor; |
92 | functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(), |
93 | BCast::ToIndexArray<2>(bcast.x_bcast()), |
94 | BCast::ToIndexArray<2>(bcast.y_bcast()), |
95 | logits_in.template shaped<T, 2>(bcast.x_reshape()), |
96 | labels_in.template shaped<T, 2>(bcast.y_reshape()), |
97 | scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>()); |
98 | } |
99 | } |
100 | }; |
101 | |
102 | // Partial specialization for a CPUDevice, that uses the Eigen implementation |
103 | // from XentEigenImpl. |
104 | namespace functor { |
105 | template <typename Device, typename T> |
106 | struct XentFunctorBase { |
107 | void operator()(const Device& d, |
108 | const Eigen::DSizes<Eigen::DenseIndex, 2>& shape, |
109 | const Eigen::array<Eigen::DenseIndex, 2>& logits_bcast, |
110 | const Eigen::array<Eigen::DenseIndex, 2>& labels_bcast, |
111 | typename TTypes<T>::ConstMatrix logits, |
112 | typename TTypes<T>::ConstMatrix labels, |
113 | typename TTypes<T>::Matrix scratch, |
114 | typename TTypes<T>::Vec loss, |
115 | typename TTypes<T>::Matrix backprop) { |
116 | XentEigenImpl<Device, T>::Compute(d, shape, logits_bcast, labels_bcast, |
117 | logits, labels, scratch, loss, backprop); |
118 | } |
119 | }; |
120 | |
121 | template <typename T> |
122 | struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T> {}; |
123 | |
124 | } // namespace functor |
125 | |
126 | #define REGISTER_CPU(T) \ |
127 | REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") \ |
128 | .Device(DEVICE_CPU) \ |
129 | .TypeConstraint<T>("T"), \ |
130 | SoftmaxXentWithLogitsOp<CPUDevice, T>); |
131 | TF_CALL_half(REGISTER_CPU); |
132 | TF_CALL_float(REGISTER_CPU); |
133 | TF_CALL_double(REGISTER_CPU); |
134 | TF_CALL_bfloat16(REGISTER_CPU); |
135 | |
136 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
137 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
138 | REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits" ) |
139 | .Device(DEVICE_GPU) |
140 | .TypeConstraint<Eigen::half>("T" ), |
141 | SoftmaxXentWithLogitsOp<GPUDevice, Eigen::half>); |
142 | REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits" ) |
143 | .Device(DEVICE_GPU) |
144 | .TypeConstraint<float>("T" ), |
145 | SoftmaxXentWithLogitsOp<GPUDevice, float>); |
146 | REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits" ) |
147 | .Device(DEVICE_GPU) |
148 | .TypeConstraint<double>("T" ), |
149 | SoftmaxXentWithLogitsOp<GPUDevice, double>); |
150 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
151 | |
152 | } // namespace tensorflow |
153 | |