1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/sparse_xent_op.h"
21
22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/framework/tensor_types.h"
27#include "tensorflow/core/util/determinism.h"
28#include "tensorflow/core/util/env_var.h"
29
30namespace tensorflow {
31
32typedef Eigen::ThreadPoolDevice CPUDevice;
33typedef Eigen::GpuDevice GPUDevice;
34
35template <typename Index>
36Status CheckInvalidLabelIndex(const Tensor& labels, int64_t max_index) {
37 if (labels.NumElements() == 0) return OkStatus();
38 const auto label_values = labels.vec<Index>();
39 int64_t bad_index;
40 auto min_max_dim_value = std::minmax_element(
41 label_values.data(), label_values.data() + label_values.size());
42 if (*min_max_dim_value.first < 0 || *min_max_dim_value.second >= max_index) {
43 bad_index = (*min_max_dim_value.first < 0) ? *min_max_dim_value.first
44 : *min_max_dim_value.second;
45 return errors::InvalidArgument(
46 "Received a label value of ", bad_index,
47 " which is outside the valid range of [0, ", max_index,
48 "). Label values: ", labels.SummarizeValue(labels.NumElements()));
49 }
50 return OkStatus();
51}
52
53template <typename Device, typename T, typename Index>
54class SparseSoftmaxXentWithLogitsOp : public OpKernel {
55 public:
56 explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* context)
57 : OpKernel(context) {}
58
59 void Compute(OpKernelContext* context) override {
60 const Tensor& logits = context->input(0);
61 const Tensor& labels = context->input(1);
62 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()),
63 errors::InvalidArgument("logits must be 2-D, but got shape ",
64 logits.shape().DebugString()));
65 OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()),
66 errors::InvalidArgument("labels must be 1-D, but got shape ",
67 labels.shape().DebugString()));
68 OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0),
69 errors::InvalidArgument(
70 "logits and labels must have the same first dimension, "
71 "got logits shape ",
72 logits.shape().DebugString(), " and labels shape ",
73 labels.shape().DebugString()));
74 OP_REQUIRES(context, logits.dim_size(1) > 0,
75 errors::InvalidArgument(
76 "Must have at least one class, but got logits shape ",
77 logits.shape().DebugString()));
78
79 if (std::is_same<Device, GPUDevice>::value) {
80 OP_REQUIRES(
81 context, !OpDeterminismRequired(),
82 errors::Unimplemented(
83 "The GPU implementation of SparseSoftmaxCrossEntropyWithLogits"
84 " that would have been executed is not deterministic. Note that"
85 " the Python API uses an alternative, deterministic,"
86 " GPU-accelerated path when determinsim is enabled."));
87 }
88
89 Tensor scratch;
90 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
91 labels.shape(), &scratch));
92
93 Tensor* loss_out = nullptr;
94 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
95 {1}, 0, labels.shape(), &loss_out));
96 Tensor* back_out = nullptr;
97 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
98 {0}, 1, logits.shape(), &back_out));
99
100 if (logits.dim_size(0) > 0) {
101 if (std::is_same<Device, CPUDevice>::value) {
102 OP_REQUIRES_OK(
103 context, CheckInvalidLabelIndex<Index>(labels, logits.dim_size(1)));
104 }
105 functor::SparseXentFunctor<Device, T, Index> functor;
106 functor(context, logits.matrix<T>(), labels.vec<Index>(),
107 scratch.vec<T>(), loss_out->vec<T>(), back_out->matrix<T>());
108 }
109 }
110};
111
112// Partial specialization for a CPUDevice, that uses the Eigen implementation
113// from XentEigenImpl.
114namespace functor {
115template <typename T, typename Index>
116struct SparseXentFunctor<CPUDevice, T, Index> {
117 void operator()(OpKernelContext* ctx, typename TTypes<T>::ConstMatrix logits,
118 typename TTypes<Index>::ConstVec labels,
119 typename TTypes<T>::Vec scratch, typename TTypes<T>::Vec loss,
120 typename TTypes<T>::Matrix backprop) {
121 SparseXentEigenImpl<CPUDevice, T, Index>::Compute(ctx, logits, labels,
122 scratch, loss, backprop);
123 }
124};
125} // namespace functor
126
127#define REGISTER(Dev, T, Index) \
128 REGISTER_KERNEL_BUILDER( \
129 Name("SparseSoftmaxCrossEntropyWithLogits") \
130 .Device(DEVICE_##Dev) \
131 .TypeConstraint<T>("T") \
132 .TypeConstraint<Index>("Tlabels"), \
133 SparseSoftmaxXentWithLogitsOp<Dev##Device, T, Index>);
134REGISTER(CPU, float, int32)
135REGISTER(CPU, float, int64_t)
136REGISTER(CPU, double, int32)
137REGISTER(CPU, double, int64_t)
138REGISTER(CPU, Eigen::half, int32)
139REGISTER(CPU, Eigen::half, int64_t)
140REGISTER(CPU, bfloat16, int32)
141REGISTER(CPU, bfloat16, int64_t)
142
143#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
144REGISTER(GPU, float, int32)
145REGISTER(GPU, float, int64_t)
146REGISTER(GPU, Eigen::half, int32)
147REGISTER(GPU, Eigen::half, int64_t)
148#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
149
150#undef REGISTER
151
152} // namespace tensorflow
153