1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/xent_op.h"
21
22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/util/bcast.h"
28#include "tensorflow/core/util/determinism.h"
29#include "tensorflow/core/util/env_var.h"
30
31namespace tensorflow {
32
33typedef Eigen::ThreadPoolDevice CPUDevice;
34typedef Eigen::GpuDevice GPUDevice;
35
36template <typename Device, typename T>
37class SoftmaxXentWithLogitsOp : public OpKernel {
38 public:
39 explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* context)
40 : OpKernel(context) {}
41
42 void Compute(OpKernelContext* context) override {
43 const Tensor& logits_in = context->input(0);
44 const Tensor& labels_in = context->input(1);
45
46 TensorShape shape_in = logits_in.shape();
47
48 BCast bcast(BCast::FromShape(logits_in.shape()),
49 BCast::FromShape(labels_in.shape()),
50 /*fewer_dims_optimization=*/false);
51 if (!logits_in.IsSameSize(labels_in)) {
52 OP_REQUIRES(context, bcast.IsValid(),
53 errors::InvalidArgument(
54 "logits and labels must be broadcastable: logits_size=",
55 logits_in.shape().DebugString(),
56 " labels_size=", labels_in.shape().DebugString()));
57 shape_in = BCast::ToShape(bcast.output_shape());
58 }
59 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in),
60 errors::InvalidArgument("logits and labels must be either "
61 "2-dimensional, or broadcasted to be "
62 "2-dimensional"));
63
64 if (std::is_same<Device, GPUDevice>::value) {
65 OP_REQUIRES(context, !OpDeterminismRequired(),
66 errors::Unimplemented(
67 "The GPU implementation of SoftmaxCrossEntropyWithLogits"
68 " that would have been executed is not deterministic."
69 " Note that the Python API uses an alternative,"
70 " deterministic, GPU-accelerated path when determinism is"
71 " enabled."));
72 }
73
74 // loss is 1-D (one per example), and size is batch_size.
75
76 Tensor scratch;
77 OP_REQUIRES_OK(
78 context, context->allocate_temp(DataTypeToEnum<T>::value,
79 TensorShape({shape_in.dim_size(0), 1}),
80 &scratch));
81
82 Tensor* loss_out = nullptr;
83 OP_REQUIRES_OK(context,
84 context->allocate_output(
85 0, TensorShape({shape_in.dim_size(0)}), &loss_out));
86 Tensor* back_out = nullptr;
87 // Try to reuse the logits_in buffer for the backprop output.
88 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
89 {0}, 1, shape_in, &back_out));
90 if (shape_in.dim_size(0) > 0) {
91 functor::XentFunctor<Device, T> functor;
92 functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
93 BCast::ToIndexArray<2>(bcast.x_bcast()),
94 BCast::ToIndexArray<2>(bcast.y_bcast()),
95 logits_in.template shaped<T, 2>(bcast.x_reshape()),
96 labels_in.template shaped<T, 2>(bcast.y_reshape()),
97 scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>());
98 }
99 }
100};
101
102// Partial specialization for a CPUDevice, that uses the Eigen implementation
103// from XentEigenImpl.
104namespace functor {
105template <typename Device, typename T>
106struct XentFunctorBase {
107 void operator()(const Device& d,
108 const Eigen::DSizes<Eigen::DenseIndex, 2>& shape,
109 const Eigen::array<Eigen::DenseIndex, 2>& logits_bcast,
110 const Eigen::array<Eigen::DenseIndex, 2>& labels_bcast,
111 typename TTypes<T>::ConstMatrix logits,
112 typename TTypes<T>::ConstMatrix labels,
113 typename TTypes<T>::Matrix scratch,
114 typename TTypes<T>::Vec loss,
115 typename TTypes<T>::Matrix backprop) {
116 XentEigenImpl<Device, T>::Compute(d, shape, logits_bcast, labels_bcast,
117 logits, labels, scratch, loss, backprop);
118 }
119};
120
121template <typename T>
122struct XentFunctor<CPUDevice, T> : XentFunctorBase<CPUDevice, T> {};
123
124} // namespace functor
125
126#define REGISTER_CPU(T) \
127 REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") \
128 .Device(DEVICE_CPU) \
129 .TypeConstraint<T>("T"), \
130 SoftmaxXentWithLogitsOp<CPUDevice, T>);
131TF_CALL_half(REGISTER_CPU);
132TF_CALL_float(REGISTER_CPU);
133TF_CALL_double(REGISTER_CPU);
134TF_CALL_bfloat16(REGISTER_CPU);
135
136#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
137 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
138REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
139 .Device(DEVICE_GPU)
140 .TypeConstraint<Eigen::half>("T"),
141 SoftmaxXentWithLogitsOp<GPUDevice, Eigen::half>);
142REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
143 .Device(DEVICE_GPU)
144 .TypeConstraint<float>("T"),
145 SoftmaxXentWithLogitsOp<GPUDevice, float>);
146REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
147 .Device(DEVICE_GPU)
148 .TypeConstraint<double>("T"),
149 SoftmaxXentWithLogitsOp<GPUDevice, double>);
150#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
151
152} // namespace tensorflow
153