1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_
19#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_
20
21#include <cstdint>
22
23#include "tensorflow/core/framework/op_requires.h"
24#include "tensorflow/core/platform/types.h"
25#define EIGEN_USE_THREADS
26#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
27#define EIGEN_USE_GPU
28#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
29
30#include "third_party/eigen3/Eigen/Core"
31#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32#include "tensorflow/core/framework/bounds_check.h"
33#include "tensorflow/core/framework/numeric_op.h"
34#include "tensorflow/core/framework/op_kernel.h"
35#include "tensorflow/core/framework/register_types.h"
36#include "tensorflow/core/framework/tensor.h"
37#include "tensorflow/core/framework/tensor_types.h"
38#include "tensorflow/core/framework/tensor_util.h"
39#include "tensorflow/core/framework/types.h"
40#include "tensorflow/core/kernels/segment_reduction_ops.h"
41#include "tensorflow/core/lib/core/status.h"
42#include "tensorflow/core/platform/logging.h"
43#include "tensorflow/core/util/determinism.h"
44#include "tensorflow/core/util/util.h"
45
46#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
47#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
48#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
49
50#if GOOGLE_CUDA
51#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_activation.h"
52#include "tensorflow/core/util/gpu_solvers.h"
53
54using stream_executor::cuda::ScopedActivateExecutorContext;
55#elif TENSORFLOW_USE_ROCM
56#include "tensorflow/core/platform/rocm.h"
57#include "tensorflow/core/util/gpu_solvers.h"
58using stream_executor::rocm::ScopedActivateExecutorContext;
59#endif // GOOGLE_CUDA
60
61namespace tensorflow {
62
63typedef Eigen::ThreadPoolDevice CPUDevice;
64typedef Eigen::GpuDevice GPUDevice;
65
66namespace internal {
67Status ValidateSegmentReduction(OpKernelContext* c, const Tensor& input,
68 const Tensor& segment_ids);
69Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel,
70 OpKernelContext* context,
71 const Tensor& data,
72 const Tensor& segment_ids,
73 const Tensor& num_segments);
74Status ValidateSparseSegmentReduction(OpKernelContext* context,
75 const Tensor& input,
76 const Tensor& indices,
77 const Tensor& segment_ids,
78 bool has_num_segments);
79} // namespace internal
80
81// This operator handles reducing segments along the first dimension.
82// See core/ops/math_ops.cc for more details.
83template <typename Device, class T, class Index, typename Reducer,
84 int default_value>
85class SegmentReductionOp : public OpKernel {
86 public:
87 explicit SegmentReductionOp(OpKernelConstruction* context)
88 : OpKernel(context) {}
89
90 void Compute(OpKernelContext* context) override {
91 const Tensor& input = context->input(0);
92 const Tensor& segment_ids = context->input(1);
93
94 OP_REQUIRES_OK(context, internal::ValidateSegmentReduction(context, input,
95 segment_ids));
96
97 const int64_t num_indices = segment_ids.NumElements();
98 auto input_flat = input.flat_outer_dims<T>();
99 const int64_t num_col = input_flat.dimension(1);
100
101 const auto segment_vec = segment_ids.vec<Index>();
102 // Note that the current implementation assumes that segment_vec values are
103 // sorted.
104 const Index output_rows =
105 num_indices > 0
106 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
107 : 0;
108 OP_REQUIRES(context, output_rows >= 0,
109 errors::InvalidArgument("segment ids must be >= 0"));
110
111 OP_REQUIRES(context, input.dims() >= 1,
112 errors::InvalidArgument("Shape must be at least rank 1"));
113
114 TensorShape output_shape = input.shape();
115 // Since we're changing the first dimension of the shape, we need to make
116 // sure the new shape won't overflow.
117 OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, output_rows));
118
119 // Note that we do not initialize the output buffer with a default value, so
120 // we need to explicitly set missing indices to the default value.
121 Tensor* output = nullptr;
122 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
123 if (num_indices == 0) return;
124 OP_REQUIRES(context, output_rows > 0,
125 errors::InvalidArgument("segment ids must be >= 0"));
126 auto output_flat = output->flat_outer_dims<T>();
127
128 Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce;
129 Index start = 0, end = 1;
130
131 Index uninitialized_index = 0; // Index from which the output is not set.
132 Index out_index = internal::SubtleMustCopy(segment_vec(start));
133
134 // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it
135 // across threads.
136 Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col);
137 while (end <= num_indices) {
138 // We initialize next_index to 0 to avoid "warning: 'next_index' may be
139 // used uninitialized in this function" in the Mac build (since the
140 // compiler isn't smart enough to realize the code is safe).
141 Index next_index = 0;
142 if (end < num_indices) {
143 next_index = internal::SubtleMustCopy(segment_vec(end));
144 if (out_index == next_index) {
145 ++end;
146 continue;
147 }
148 // We have a new segment here. Verify that the segment ids are growing.
149 OP_REQUIRES(context, out_index < next_index,
150 errors::InvalidArgument("segment ids are not increasing"));
151 }
152
153 // Process segment [start, end)
154 const T* in_slice_ptr = &input_flat(start, 0);
155 typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
156 Eigen::Unaligned>
157 OutT;
158
159 OP_REQUIRES(
160 context, FastBoundsCheck(out_index, output_rows),
161 errors::InvalidArgument(
162 "Segment id ", out_index, " out of range [0, ", output_rows,
163 "), possibly because 'segment_ids' input is not sorted."));
164
165 // If there is a gap between two indices, we need to set that gap to the
166 // default value.
167 if (out_index > uninitialized_index) {
168 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
169 out_index - uninitialized_index, num_col);
170 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
171 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
172 gap_slice.setConstant(T(default_value));
173 }
174
175 T* out_slice_ptr = &output_flat(out_index, 0);
176 OutT out_slice(out_slice_ptr, out_slice_shape);
177 // We don't use out_slice.device(context->eigen_device<Device>)
178 // because these pieces of work are likely to be very small and
179 // the context switching overhead dwarfs any benefit we get from
180 // using another thread to do this work.
181 if (start == end - 1) {
182 typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
183 Eigen::Unaligned>
184 InT;
185 InT in_slice(in_slice_ptr, out_slice_shape);
186 out_slice = in_slice;
187 } else {
188 Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start,
189 num_col);
190 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
191 Eigen::Unaligned>
192 InT;
193 InT in_slice(in_slice_ptr, in_slice_shape);
194
195 out_slice = in_slice.reduce(dims_to_reduce, Reducer());
196 }
197 if (end >= num_indices) break;
198 start = end;
199 ++end;
200 uninitialized_index = out_index + 1;
201 out_index = next_index;
202 }
203 }
204};
205
206#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
207
208// SegmentReductionGPUOp is a segment reduction operator implemented for GPU
209// only.
210// TODO: This implementation of SegmentReductionGPUOp is sometimes slower than
211// its unsorted counterpart (mostly when problem size is small).
212// This is due to the following two main reasons and a cost-effective way
213// to resolve these problems is desirable.
214// 1. Sorted segment reduction requires a memory transfer from device to host
215// in order to know the size of the output dimension whereas unsorted
216// segment reduction receives the size of the output dimension as an input
217// parameter.
218// 2. Sorted segment reduction is essentially a tiled version of unsorted
219// segment reduction and therefore such optimization comes at an inherent
220// cost. However such cost may not be justified when the problem size is
221// small. When to use the tiled version or the untiled version depends on
222// many factors including data alignments, ratio of calculation to memory
223// traffic and obviously, the problem sizes.
224template <class T, class Index, class SegmentReductionFunctor, bool IsMean>
225class SegmentReductionGPUOp : public AsyncOpKernel {
226 public:
227 explicit SegmentReductionGPUOp(OpKernelConstruction* context)
228 : AsyncOpKernel(context) {}
229
230 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
231 const Tensor& input = context->input(0);
232 const Tensor& segment_ids = context->input(1);
233
234 OP_REQUIRES_ASYNC(
235 context, TensorShapeUtils::IsVector(segment_ids.shape()),
236 errors::InvalidArgument("segment_ids should be a vector."), done);
237
238 OP_REQUIRES_ASYNC(context, input.dims() >= 1,
239 errors::InvalidArgument("Shape must be at least rank 1"),
240 done);
241
242 const int64_t num_indices = segment_ids.NumElements();
243 OP_REQUIRES_ASYNC(
244 context, num_indices == input.dim_size(0),
245 errors::InvalidArgument(
246 "segment_ids should be the same size as dimension 0 of"
247 " input."),
248 done);
249
250 if (num_indices == 0) {
251 TensorShape output_shape = input.shape();
252 output_shape.set_dim(0, 0);
253
254 Tensor* output = nullptr;
255 OP_REQUIRES_OK_ASYNC(
256 context, context->allocate_output(0, output_shape, &output), done);
257 done();
258 return;
259 }
260
261 se::DeviceMemoryBase output_rows_device(
262 const_cast<Tensor&>(segment_ids).template flat<Index>().data() +
263 (num_indices - 1));
264 ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true);
265
266 auto stream = context->op_device_context()->stream();
267 OP_REQUIRES_ASYNC(
268 context,
269 stream
270 ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device,
271 sizeof(Index))
272 .ok(),
273 errors::Internal(type_string() +
274 ": failed to copy output_rows from device"),
275 done);
276
277 SegmentReductionFunctor functor_;
278 auto create_and_check_output = [context, output_rows_host, &input,
279 &segment_ids, &functor_, done]() {
280 // Ensure that within the callback, the proper GPU settings are
281 // configured.
282 auto stream = context->op_device_context()->stream();
283 ScopedActivateExecutorContext scoped_activation{stream->parent()};
284
285 Index output_rows = *output_rows_host.data();
286 output_rows++;
287 OP_REQUIRES_ASYNC(context, output_rows > 0,
288 errors::InvalidArgument("segment ids must be >= 0"),
289 done);
290
291 TensorShape output_shape = input.shape();
292 // Since we're changing the first dimension of the shape, we need to make
293 // sure the new shape won't overflow.
294 OP_REQUIRES_OK_ASYNC(context,
295 output_shape.SetDimWithStatus(0, output_rows), done);
296
297 Tensor* output = nullptr;
298 OP_REQUIRES_OK_ASYNC(
299 context, context->allocate_output(0, output_shape, &output), done);
300
301 bool use_deterministic_kernels =
302#if defined(PLATFORM_WINDOWS)
303 // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows
304 // CI build error.
305 false;
306#else
307 UseDeterministicSegmentReductions() ||
308 (!SegmentReductionFunctor::atomic_reduction_is_associative &&
309 OpDeterminismRequired());
310#endif
311
312 // The determinism check is here, rather than inside the functor (as it is
313 // for the unsorted segment reduction ops) because the done callback
314 // (required for OP_REQUIRES_ASYNC) is not available inside the functor.
315 bool determinism_requirement_met =
316 use_deterministic_kernels ||
317 SegmentReductionFunctor::atomic_reduction_is_associative ||
318 !OpDeterminismRequired() ||
319 DisableSegmentReductionOpDeterminismExceptions();
320 OP_REQUIRES_ASYNC(
321 context, determinism_requirement_met,
322 errors::Unimplemented(
323 "Deterministic GPU implementation of sorted segment reduction op"
324 " not available."),
325 done);
326
327 auto output_flat = output->flat_outer_dims<T>();
328 auto data_ptr = input.template flat<T>().data();
329 auto segment_flat = segment_ids.flat<Index>();
330 functor_(context, context->eigen_device<GPUDevice>(), output_rows,
331 segment_ids.shape(), IsMean, segment_flat, input.NumElements(),
332 data_ptr, output_flat);
333
334 done();
335 };
336
337 context->device()
338 ->tensorflow_accelerator_device_info()
339 ->event_mgr->ThenExecute(stream, create_and_check_output);
340 }
341};
342#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
343
344// ____________________________________________________________________________
345// Unsorted segment reduction ops.
346
347namespace functor {
348
349// The ReductionFunctor implementation for CPU.
350template <typename T, typename Index, typename InitialValueF,
351 typename ReductionF>
352struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
353 void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
354 typename TTypes<Index>::ConstFlat segment_ids,
355 typename TTypes<T, 2>::ConstTensor data,
356 typename TTypes<T, 2>::Tensor output) {
357 auto cpu_device = ctx->eigen_cpu_device();
358 output.device(cpu_device) = output.constant(InitialValueF()());
359 if (data.size() == 0) {
360 return;
361 }
362
363 // This functor will reduce `N` rows input to `num_segments` rows output.
364 const int64_t N = segment_ids.dimension(0);
365 const int64_t num_segments = output.dimension(0);
366 const int64_t inner_dim = data.dimension(1);
367 ReductionF reduction;
368
369 // `num_real_segment` counts the rows actually reduced from input,
370 // the rows with negative segment index will be excluded.
371 // It will be used for cost model.
372 int64_t num_real_segment = N;
373 // `num_reductions` counts the rows actually reduced in output,
374 // the rows only filled with InitialValueF() will be excluded.
375 int64_t num_reductions = 0;
376 // `row_counter` records how many input rows will be reduced in each
377 // output row, the row only fills with InitialValueF() will keep 0.
378 // Length of non-zero elements is `num_reductions`.
379 std::vector<Index> row_counter(num_segments, 0);
380
381 for (int64_t i = 0; i < N; ++i) {
382 Index j = internal::SubtleMustCopy(segment_ids(i));
383 if (j < 0) {
384 --num_real_segment;
385 continue;
386 }
387 OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments),
388 errors::InvalidArgument(
389 "segment_ids", SliceDebugString(segment_ids_shape, i),
390 " = ", j, " is out of range [0, ", num_segments, ")"));
391 if (row_counter[j] == 0) num_reductions++;
392 row_counter[j]++;
393 }
394
395 // Nothing to reduce. All output values equal to `InitialValueF()`.
396 if (num_reductions == 0) return;
397
398 // Parallelize by `num_segments`. It's simple, efficient and safe
399 // (no data dependency):
400 //
401 // input segment_ids num_segments operation
402 // | a0 | | 0 | worker 1: |0| f(a0, a1)
403 // | b0 | | 1 | worker 2: |1| f(b0, b1)
404 // N | c0 | | 2 | --> worker 3: |2| f(c0)
405 // | b1 | | 1 |
406 // | a1 | | 0 |
407 //
408 // TODO(intel-tf): Balance workload in `row_counter` to make parallelism
409 // more efficient.
410 auto reductionWorker = [&](int64_t begin, int64_t end) -> void {
411 for (int64_t i = 0; i < N; i++) {
412 Index j = internal::SubtleMustCopy(segment_ids(i));
413 // If `j` is in work scope of this worker, do the reduction.
414 if (j >= begin && j < end) {
415 reduction(data.template chip<0>(i), output.template chip<0>(j));
416 }
417 }
418 };
419
420 // Reduction functors includes Sum, Max, Min, etc. Simply consider it
421 // will cost 5 cycles per operation.
422 const int64_t kAverTaskSize = num_real_segment / num_segments;
423 const int64_t compute_cycles = 5 * inner_dim * kAverTaskSize;
424 const int64_t input_bytes = sizeof(T) * inner_dim * kAverTaskSize;
425 const int64_t output_bytes = sizeof(T) * inner_dim * kAverTaskSize;
426 const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
427 cpu_device.parallelFor(num_segments, cost, reductionWorker);
428 }
429};
430
431template <typename T>
432using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>;
433
434template <typename T>
435using constMatrixChip =
436 Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>;
437
438// reduction functors
439template <typename T>
440struct SumOp {
441 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
442 output += data;
443 }
444};
445
446template <typename T>
447struct MaxOp {
448 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
449 output = data.cwiseMax(output);
450 }
451};
452
453template <typename T>
454struct MinOp {
455 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
456 output = data.cwiseMin(output);
457 }
458};
459
460template <typename T>
461struct ProdOp {
462 void operator()(const constMatrixChip<T> data, MatrixChip<T> output) {
463 output *= data;
464 }
465};
466} // namespace functor
467
468// The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor
469// is the device specific implementation of the reduction. These device
470// specific implementations are templated themselves with the corresponding
471// initial value functors and reduction functors.
472template <typename T, typename Index, typename DeviceReductionFunctor>
473class UnsortedSegmentReductionOp : public OpKernel {
474 public:
475 explicit UnsortedSegmentReductionOp(OpKernelConstruction* context)
476 : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {}
477
478 void Compute(OpKernelContext* context) override {
479 const Tensor& data = context->input(0);
480 const Tensor& segment_ids = context->input(1);
481 const Tensor& num_segments = context->input(2);
482 OP_REQUIRES_OK(context,
483 internal::ValidateUnsortedSegmentReduction(
484 this, context, data, segment_ids, num_segments));
485 const auto segment_flat = segment_ids.flat<Index>();
486 const int64_t output_rows = internal::SubtleMustCopy(static_cast<int64_t>(
487 num_segments.dtype() == DT_INT32 ? num_segments.scalar<int32>()()
488 : num_segments.scalar<int64_t>()()));
489 OP_REQUIRES(context, output_rows >= 0,
490 errors::InvalidArgument("Input num_segments == ", output_rows,
491 " must not be negative."));
492 TensorShape output_shape;
493 OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(output_rows));
494 for (int i = segment_ids.dims(); i < data.dims(); i++) {
495 OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(data.dim_size(i)));
496 }
497 Tensor* output = nullptr;
498 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
499 auto output_flat = output->flat_outer_dims<T>();
500 auto data_flat = data.flat_inner_outer_dims<T, 2>(segment_ids.dims() - 1);
501 reduction_functor_(context, segment_ids.shape(), segment_flat, data_flat,
502 output_flat);
503 }
504
505 protected:
506 DeviceReductionFunctor reduction_functor_;
507};
508
509// ____________________________________________________________________________
510// Sparse segment reduction ops.
511
512// Same as SegmentReductionOp but takes as input a "sparse" tensor, represented
513// by two dense tensors, one containing the data, and the other containing
514// indices into the data.
515//
516// The template parameters are:
517// * Device: An Eigen device object, on which the kernel will execute.
518// * T: The value type.
519// * Index: The element type of the indices tensor (int32 or int64).
520// * SegmentId: The element type of the segment_ids tensor (int32 or int64).
521template <typename Device, class T, typename Index, typename SegmentId>
522class SparseSegmentReductionOpBase : public OpKernel {
523 public:
524 explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
525 bool is_mean, bool is_sqrtn,
526 bool has_num_segments, T default_value)
527 : OpKernel(context),
528 dtidx_(DataTypeToEnum<Index>::v()),
529 is_mean_(is_mean),
530 is_sqrtn_(is_sqrtn),
531 has_num_segments_(has_num_segments),
532 default_value_(default_value) {}
533
534 void Compute(OpKernelContext* context) override {
535 const Tensor& input = context->input(0);
536 const Tensor& indices = context->input(1);
537 const Tensor& segment_ids = context->input(2);
538
539 OP_REQUIRES_OK(
540 context, internal::ValidateSparseSegmentReduction(
541 context, input, indices, segment_ids, has_num_segments_));
542
543 Index output_rows = -1;
544 if (has_num_segments_) {
545 const Tensor& num_segments = context->input(3);
546 // Note that there is a Tnumsegments parameter on the op, but it is not
547 // plumbed through to here and so always takes its default value of int32.
548 output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()());
549 }
550 const int64_t num_indices = indices.NumElements();
551
552 auto input_flat = input.flat_outer_dims<T>();
553 const int64_t num_col = input_flat.dimension(1);
554 const auto indices_vec = indices.vec<Index>();
555 const auto segment_vec = segment_ids.vec<SegmentId>();
556 // Note that the current implementation assumes that segment_vec values are
557 // sorted.
558 const SegmentId last_segment_id =
559 num_indices > 0 ? segment_vec(num_indices - 1) : 0;
560 int64_t limit = dtidx_ == DataType::DT_INT32 ? kint32max : kint64max;
561
562 OP_REQUIRES(
563 context, last_segment_id < limit,
564 errors::InvalidArgument("Last segment id must be < kintmax, got ",
565 last_segment_id, " limit ", limit));
566
567 const SegmentId last_segment_id_plus_one =
568 num_indices > 0
569 ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
570 : 0;
571
572 if (has_num_segments_) {
573 OP_REQUIRES(
574 context, output_rows >= last_segment_id_plus_one,
575 errors::InvalidArgument("segment ids must be < num_segments"));
576 } else {
577 output_rows = last_segment_id_plus_one;
578 }
579 OP_REQUIRES(context, output_rows >= 0,
580 errors::InvalidArgument("segment ids must be >= 0"));
581
582 TensorShape output_shape = input.shape();
583 OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, output_rows));
584
585 // Note that we do not initialize the output buffer with a default value, so
586 // we need to explicitly set missing indices to the default value.
587 Tensor* output = nullptr;
588 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
589 if (num_indices == 0) {
590 if (output_rows > 0) {
591 output->flat_outer_dims<T>().setConstant(default_value_);
592 }
593 return;
594 }
595 OP_REQUIRES(context, output_rows > 0,
596 errors::InvalidArgument("segment ids must be >= 0"));
597 auto output_flat = output->flat_outer_dims<T>();
598
599 Tensor temp;
600 if (input.dtype() == DT_BFLOAT16 || input.dtype() == DT_HALF) {
601 temp = tensorflow::Tensor(DT_FLOAT, output_shape);
602 }
603 auto temp_flat = temp.flat_outer_dims<float>();
604
605 int64_t start = 0, end = 1;
606 // Index from which the output is not initialized.
607 SegmentId uninitialized_index = 0;
608 SegmentId out_index = internal::SubtleMustCopy(segment_vec(start));
609
610 while (true) {
611 // We initialize next_index to 0 to avoid "warning: 'next_index' may be
612 // used uninitialized in this function" in the Mac build (since the
613 // compiler isn't smart enough to realize the code is safe).
614 SegmentId next_index = 0;
615 if (end < num_indices) {
616 next_index = internal::SubtleMustCopy(segment_vec(end));
617 if (out_index == next_index) {
618 ++end;
619 continue;
620 }
621 // We have a new segment here. Verify that the segment ids are growing.
622 OP_REQUIRES(context, out_index < next_index,
623 errors::InvalidArgument("segment ids are not increasing"));
624 }
625
626 OP_REQUIRES(
627 context, FastBoundsCheck(out_index, output_rows),
628 errors::InvalidArgument(
629 "Segment id ", out_index, " out of range [0, ", output_rows,
630 "), possibly because 'segment_ids' input is not sorted."));
631
632 // If there is a gap between two indices, we need to set that gap to the
633 // default value.
634 if (out_index > uninitialized_index) {
635 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
636 out_index - uninitialized_index, num_col);
637 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
638 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
639 gap_slice.setConstant(default_value_);
640 }
641
642 auto out = output_flat.template chip<0>(out_index);
643 auto temp = temp_flat.template chip<0>(out_index);
644 const int bad_offset = Reduce<T, Index>(input_flat, indices_vec, start,
645 end - start, out, temp);
646 OP_REQUIRES(context, bad_offset < 0,
647 errors::InvalidArgument(
648 "Bad: indices[", start + bad_offset,
649 "] == ", indices_vec(start + bad_offset),
650 " out of range [0, ", input_flat.dimension(0), ")"));
651
652 start = end;
653 ++end;
654 uninitialized_index = out_index + 1;
655 out_index = next_index;
656 if (end > num_indices) break;
657 }
658
659 // Fill the gap at the end with the default value.
660 if (uninitialized_index < output_rows) {
661 Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
662 output_rows - uninitialized_index, num_col);
663 Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
664 gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
665 gap_slice.setConstant(default_value_);
666 }
667 }
668
669 private:
670 const DataType dtidx_;
671 template <typename Tin>
672 using EnableIfBfloat16OrHalf =
673 typename std::enable_if<std::is_same<Tin, bfloat16>::value ||
674 std::is_same<Tin, Eigen::half>::value,
675 int>::type;
676 template <typename Tin>
677 using EnableIfNotBfloat16OrHalf =
678 typename std::enable_if<!std::is_same<Tin, bfloat16>::value &&
679 !std::is_same<Tin, Eigen::half>::value,
680 int>::type;
681
682 template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0>
683 EIGEN_ALWAYS_INLINE auto fetch_val(
684 const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
685 return input_flat.template chip<0>(index);
686 }
687
688 template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0>
689 EIGEN_ALWAYS_INLINE auto fetch_val(
690 const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) {
691 return input_flat.template chip<0>(index).template cast<float>();
692 }
693
694 template <typename Tout>
695 EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64_t num) {
696 Tout m(1);
697 if (is_mean_ && (num < 10)) {
698 m = Tout(num);
699 }
700 if (is_sqrtn_ && (num < 10)) {
701 m = Tout(sqrt(num));
702 }
703 return Tout(1) / m;
704 }
705
706 template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0>
707 int64_t Reduce(
708 const typename TTypes<Tin>::ConstMatrix& input_flat,
709 const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start,
710 int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
711 Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
712 return ReduceImpl<Tin, Tindex, Tin>(input_flat, indices_vec, start, num,
713 out, get_scaling_factor<Tin>(num));
714 }
715
716 template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0>
717 int64_t Reduce(
718 const typename TTypes<Tin>::ConstMatrix& input_flat,
719 const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start,
720 int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out,
721 Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) {
722 int64_t res =
723 ReduceImpl<Tin, Tindex, float>(input_flat, indices_vec, start, num,
724 temp, get_scaling_factor<float>(num));
725 out = temp.template cast<Tin>();
726 return res;
727 }
728
729 template <typename Tin, typename Tindex, typename Tout>
730 int64_t ReduceImpl(
731 const typename TTypes<Tin>::ConstMatrix& input_flat,
732 const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start,
733 int64_t num,
734 Eigen::TensorChippingOp<0, typename TTypes<Tout>::Matrix> out,
735 const Tout scaling_factor) {
736#define INDEX(n, i) \
737 const auto index##n = indices_vec(start + (i)); \
738 if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i);
739
740#define L(n) fetch_val<Tin, Tindex>(input_flat, index##n)
741
742 if (num == 1) {
743 INDEX(0, 0);
744 out = L(0);
745 } else {
746 int64_t r = num & 7;
747 switch (r) {
748 case 2: {
749 INDEX(0, 0);
750 INDEX(1, 1);
751 out = (L(0) + L(1)) * scaling_factor;
752 break;
753 }
754 case 3: {
755 INDEX(0, 0);
756 INDEX(1, 1);
757 INDEX(2, 2);
758 out = (L(0) + L(1) + L(2)) * scaling_factor;
759 break;
760 }
761 case 4: {
762 INDEX(0, 0);
763 INDEX(1, 1);
764 INDEX(2, 2);
765 INDEX(3, 3);
766 out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor;
767 break;
768 }
769 case 5: {
770 INDEX(0, 0);
771 INDEX(1, 1);
772 INDEX(2, 2);
773 INDEX(3, 3);
774 INDEX(4, 4);
775 out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor;
776 break;
777 }
778 case 6: {
779 INDEX(0, 0);
780 INDEX(1, 1);
781 INDEX(2, 2);
782 INDEX(3, 3);
783 INDEX(4, 4);
784 INDEX(5, 5);
785 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor;
786 break;
787 }
788 case 7: {
789 INDEX(0, 0);
790 INDEX(1, 1);
791 INDEX(2, 2);
792 INDEX(3, 3);
793 INDEX(4, 4);
794 INDEX(5, 5);
795 INDEX(6, 6);
796 out =
797 (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor;
798 break;
799 }
800 case 0: {
801 INDEX(0, 0);
802 INDEX(1, 1);
803 INDEX(2, 2);
804 INDEX(3, 3);
805 INDEX(4, 4);
806 INDEX(5, 5);
807 INDEX(6, 6);
808 INDEX(7, 7);
809 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) *
810 scaling_factor;
811 r = 8;
812 break;
813 }
814 case 1: {
815 INDEX(0, 0);
816 INDEX(1, 1);
817 INDEX(2, 2);
818 INDEX(3, 3);
819 INDEX(4, 4);
820 INDEX(5, 5);
821 INDEX(6, 6);
822 INDEX(7, 7);
823 INDEX(8, 8);
824 out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) *
825 scaling_factor;
826 r = 9;
827 break;
828 }
829 }
830 for (; r < num; r += 8) {
831 INDEX(0, r);
832 INDEX(1, r + 1);
833 INDEX(2, r + 2);
834 INDEX(3, r + 3);
835 INDEX(4, r + 4);
836 INDEX(5, r + 5);
837 INDEX(6, r + 6);
838 INDEX(7, r + 7);
839 out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7);
840 }
841 if (is_mean_ && num >= 10) {
842 out = out / static_cast<Tout>(num);
843 }
844 if (is_sqrtn_ && num >= 10) {
845 out = out / static_cast<Tout>(sqrt(num));
846 }
847 }
848
849 return -1;
850#undef L
851#undef INDEX
852 }
853
854 const bool is_mean_;
855 const bool is_sqrtn_;
856 const bool has_num_segments_;
857 const T default_value_;
858};
859
860#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
861
862// Specialization for GPU. Must be Async because may need to wait for a host to
863// device memcpy before allocating output.
864template <class T, typename Index, typename SegmentId>
865class SparseSegmentReductionOpBase<GPUDevice, T, Index, SegmentId>
866 : public AsyncOpKernel {
867 public:
868 explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
869 bool is_mean, bool is_sqrtn,
870 bool has_num_segments, T default_value)
871 : AsyncOpKernel(context),
872 is_mean_(is_mean),
873 is_sqrtn_(is_sqrtn),
874 has_num_segments_(has_num_segments),
875 default_value_(default_value) {}
876
877 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
878 const Tensor& input = context->input(0);
879 const Tensor& indices = context->input(1);
880 const Tensor& segment_ids = context->input(2);
881
882 OP_REQUIRES_OK_ASYNC(
883 context,
884 internal::ValidateSparseSegmentReduction(
885 context, input, indices, segment_ids, has_num_segments_),
886 done);
887
888 ScratchSpace<SegmentId> last_segment_id_host(context, 1, /*on_host=*/true);
889
890 auto create_and_check_output = [this, context, input, indices, segment_ids,
891 last_segment_id_host, done]() {
892 // Ensure that within the callback, the proper GPU settings are
893 // configured.
894 auto stream = context->op_device_context()->stream();
895 ScopedActivateExecutorContext scoped_activation{stream->parent()};
896
897 SegmentId last_segment_id = *last_segment_id_host.data();
898 SegmentId output_rows = last_segment_id + 1;
899 OP_REQUIRES_ASYNC(context, output_rows > 0,
900 errors::InvalidArgument("segment ids must be >= 0"),
901 done);
902
903 TensorShape output_shape = input.shape();
904 output_shape.set_dim(0, output_rows);
905
906 Tensor* output = nullptr;
907 OP_REQUIRES_OK_ASYNC(
908 context, context->allocate_output(0, output_shape, &output), done);
909
910 auto input_flat = input.flat_outer_dims<T>();
911 const auto indices_vec = indices.vec<Index>();
912 const auto segment_ids_vec = segment_ids.vec<SegmentId>();
913 auto output_flat = output->flat_outer_dims<T>();
914
915 functor::SparseSegmentReductionFunctor<T, Index, SegmentId> functor;
916 OP_REQUIRES_OK_ASYNC(
917 context,
918 functor(context, is_mean_, is_sqrtn_, default_value_, input_flat,
919 indices_vec, segment_ids_vec, output_flat),
920 done);
921 done();
922 };
923
924 if (has_num_segments_) {
925 // No need to do any device to host memcpy, just compute synchronously.
926 const Tensor& num_segments_t = context->input(3);
927 SegmentId num_segments =
928 internal::SubtleMustCopy(num_segments_t.dtype() == DT_INT32
929 ? num_segments_t.scalar<int32>()()
930 : num_segments_t.scalar<int64_t>()());
931 *last_segment_id_host.mutable_data() = num_segments - 1;
932 create_and_check_output();
933 } else {
934 const int64_t num_indices = indices.NumElements();
935 // Need to copy last element of segment_ids from device to host, and then
936 // asynchronously allocate the output and finish the computation.
937 se::DeviceMemoryBase last_segment_id_device(
938 const_cast<Tensor&>(segment_ids).template flat<SegmentId>().data() +
939 (num_indices - 1));
940 auto stream = context->op_device_context()->stream();
941 OP_REQUIRES_ASYNC(
942 context,
943 stream
944 ->ThenMemcpy(last_segment_id_host.mutable_data(),
945 last_segment_id_device, sizeof(SegmentId))
946 .ok(),
947 errors::Internal(type_string() +
948 ": failed to copy last_segment_id from device"),
949 done);
950 context->device()
951 ->tensorflow_accelerator_device_info()
952 ->event_mgr->ThenExecute(stream, create_and_check_output);
953 }
954 }
955
956 private:
957 const bool is_mean_;
958 const bool is_sqrtn_;
959 const bool has_num_segments_;
960 const T default_value_;
961};
962
963#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
964
965template <typename Device, class T, typename Index, typename SegmentId>
966class SparseSegmentReductionMeanOp
967 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
968 public:
969 explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
970 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
971 context, true /*is_mean*/, false /*is_sqrtn*/,
972 false /* has_num_segments */, T(0) /* default_value */) {}
973};
974
975template <typename Device, class T, typename Index, typename SegmentId>
976class SparseSegmentReductionMeanWithNumSegmentsOp
977 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
978 public:
979 explicit SparseSegmentReductionMeanWithNumSegmentsOp(
980 OpKernelConstruction* context)
981 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
982 context, true /*is_mean*/, false /*is_sqrtn*/,
983 true /* has_num_segments */, T(0) /* default_value */) {}
984};
985
986template <typename Device, class T, typename Index, typename SegmentId>
987class SparseSegmentReductionSqrtNOp
988 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
989 public:
990 explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
991 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
992 context, false /*is_mean*/, true /*is_sqrtn*/,
993 false /* has_num_segments */, T(0) /* default_value */) {}
994};
995
996template <typename Device, class T, typename Index, typename SegmentId>
997class SparseSegmentReductionSqrtNWithNumSegmentsOp
998 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
999 public:
1000 explicit SparseSegmentReductionSqrtNWithNumSegmentsOp(
1001 OpKernelConstruction* context)
1002 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
1003 context, false /*is_mean*/, true /*is_sqrtn*/,
1004 true /* has_num_segments */, T(0) /* default_value */) {}
1005};
1006
1007template <typename Device, class T, typename Index, typename SegmentId>
1008class SparseSegmentReductionSumOp
1009 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
1010 public:
1011 explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
1012 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
1013 context, false /*is_mean*/, false /*is_sqrtn*/,
1014 false /* has_num_segments */, T(0) /* default_value */) {}
1015};
1016
1017template <typename Device, class T, typename Index, typename SegmentId>
1018class SparseSegmentReductionSumWithNumSegmentsOp
1019 : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> {
1020 public:
1021 explicit SparseSegmentReductionSumWithNumSegmentsOp(
1022 OpKernelConstruction* context)
1023 : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>(
1024 context, false /*is_mean*/, false /*is_sqrtn*/,
1025 true /* has_num_segments */, T(0) /* default_value */) {}
1026};
1027
1028namespace functor {
1029
1030template <typename T, typename Index, typename SegmentId>
1031struct SparseSegmentGradFunctor<CPUDevice, T, Index, SegmentId> {
1032 void operator()(OpKernelContext* context,
1033 SparseSegmentReductionOperation operation,
1034 typename TTypes<T>::ConstMatrix input_flat,
1035 typename TTypes<Index>::ConstVec indices_vec,
1036 typename TTypes<SegmentId>::ConstVec segment_vec,
1037 typename TTypes<T>::Matrix output_flat) {
1038 const int64_t N = indices_vec.size();
1039 const SegmentId M = output_flat.dimension(0);
1040
1041 // Note that similar to SparseSegmentMean, we assume that segment_vec is
1042 // already sorted and has non-negative values.
1043 const SegmentId num_segments = input_flat.dimension(0);
1044 const SegmentId last_segment_id_plus_one =
1045 internal::SubtleMustCopy(segment_vec(N - 1)) + 1;
1046 OP_REQUIRES(context, last_segment_id_plus_one <= num_segments,
1047 errors::InvalidArgument("Invalid number of segments"));
1048
1049 // Compute scaling factors for input.
1050 std::vector<double> scaling(
1051 (operation == SparseSegmentReductionOperation::kSum ? 0 : num_segments),
1052 0.0);
1053 if (operation != SparseSegmentReductionOperation::kSum) {
1054 for (int64_t i = 0; i < N; ++i) {
1055 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
1056 OP_REQUIRES(
1057 context, FastBoundsCheck(idx, num_segments),
1058 errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
1059 num_segments, ")."));
1060 scaling[idx] += 1;
1061 }
1062 for (size_t i = 0; i < scaling.size(); ++i) {
1063 switch (operation) {
1064 case SparseSegmentReductionOperation::kSum: {
1065 OP_REQUIRES(
1066 context, false,
1067 errors::Internal(
1068 "Should not happen: sum inside SparseSegmentReductionOp "
1069 "scaling generation."));
1070 }
1071 case SparseSegmentReductionOperation::kMean: {
1072 scaling[i] = 1.0 / std::max(scaling[i], 1.0);
1073 break;
1074 }
1075 case SparseSegmentReductionOperation::kSqrtN: {
1076 scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0));
1077 break;
1078 }
1079 // No default to get compiler warnings for missing cases.
1080 }
1081 }
1082 }
1083
1084 output_flat.setZero();
1085 std::vector<bool> is_modified(M, false);
1086
1087 for (int64_t i = 0; i < N; ++i) {
1088 const Index output_idx = internal::SubtleMustCopy(indices_vec(i));
1089 OP_REQUIRES(context, FastBoundsCheck(output_idx, M),
1090 errors::InvalidArgument("Index ", output_idx,
1091 " out of range [0, ", M, ")."));
1092
1093 const SegmentId idx = internal::SubtleMustCopy(segment_vec(i));
1094 OP_REQUIRES(
1095 context, FastBoundsCheck(idx, num_segments),
1096 errors::InvalidArgument("Segment id ", idx, " out of range [0, ",
1097 num_segments, ")."));
1098
1099 const T scale = (operation == SparseSegmentReductionOperation::kSum
1100 ? static_cast<T>(1)
1101 : static_cast<T>(scaling[idx]));
1102 if (is_modified[output_idx]) {
1103 if (scale == 1.0) {
1104 output_flat.template chip<0>(output_idx) +=
1105 input_flat.template chip<0>(idx);
1106 } else {
1107 output_flat.template chip<0>(output_idx) +=
1108 input_flat.template chip<0>(idx) * scale;
1109 }
1110 } else {
1111 if (scale == 1.0) {
1112 output_flat.template chip<0>(output_idx) =
1113 input_flat.template chip<0>(idx);
1114 } else {
1115 output_flat.template chip<0>(output_idx) =
1116 input_flat.template chip<0>(idx) * scale;
1117 }
1118 }
1119 is_modified[output_idx] = true;
1120 }
1121 }
1122};
1123
1124} // namespace functor
1125
1126// Implements the common logic for the gradients of SparseSegmentReduction
1127// kernels.
1128//
1129// The template parameters are:
1130// * Device: An Eigen device object, on which the kernel will execute.
1131// * T: The value type.
1132// * Index: The element type of the indices tensor (int32 or int64).
1133// * SegmentId: The element type of the segment_ids tensor (int32 or int64).
1134template <typename Device, class T, typename Index, typename SegmentId>
1135class SparseSegmentGradOpBase : public OpKernel {
1136 public:
1137 explicit SparseSegmentGradOpBase(OpKernelConstruction* context,
1138 SparseSegmentReductionOperation operation)
1139 : OpKernel(context), operation_(operation) {}
1140
1141 void Compute(OpKernelContext* context) override {
1142 const Tensor& input = context->input(0);
1143 const Tensor& indices = context->input(1);
1144 const Tensor& segment_ids = context->input(2);
1145 const Tensor& output_dim0 = context->input(3);
1146
1147 OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
1148 errors::InvalidArgument("indices should be a vector."));
1149 OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
1150 errors::InvalidArgument("segment_ids should be a vector."));
1151 OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_dim0.shape()),
1152 errors::InvalidArgument("output_dim0 should be a scalar."));
1153
1154 const int64_t N = indices.NumElements();
1155 OP_REQUIRES(context, N == segment_ids.NumElements(),
1156 errors::InvalidArgument(
1157 "segment_ids and indices should have same size."));
1158 const SegmentId M = internal::SubtleMustCopy(output_dim0.scalar<int32>()());
1159
1160 auto input_flat = input.flat_outer_dims<T>();
1161 const auto indices_vec = indices.vec<Index>();
1162 const auto segment_vec = segment_ids.vec<SegmentId>();
1163
1164 TensorShape output_shape = input.shape();
1165 OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, M));
1166 Tensor* output = nullptr;
1167 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
1168 if (M == 0 || N == 0) return;
1169
1170 auto output_flat = output->flat_outer_dims<T>();
1171 functor::SparseSegmentGradFunctor<Device, T, Index, SegmentId>()(
1172 context, operation_, input_flat, indices_vec, segment_vec, output_flat);
1173 }
1174
1175 private:
1176 const SparseSegmentReductionOperation operation_;
1177};
1178
1179template <typename Device, class T, typename Index, typename SegmentId>
1180class SparseSegmentSumGradOp
1181 : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> {
1182 public:
1183 explicit SparseSegmentSumGradOp(OpKernelConstruction* context)
1184 : SparseSegmentGradOpBase<Device, T, Index, SegmentId>(
1185 context, SparseSegmentReductionOperation::kSum) {}
1186};
1187
1188template <typename Device, class T, typename Index, typename SegmentId>
1189class SparseSegmentMeanGradOp
1190 : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> {
1191 public:
1192 explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
1193 : SparseSegmentGradOpBase<Device, T, Index, SegmentId>(
1194 context, SparseSegmentReductionOperation::kMean) {}
1195};
1196
1197template <typename Device, class T, typename Index, typename SegmentId>
1198class SparseSegmentSqrtNGradOp
1199 : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> {
1200 public:
1201 explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context)
1202 : SparseSegmentGradOpBase<Device, T, Index, SegmentId>(
1203 context, SparseSegmentReductionOperation::kSqrtN) {}
1204};
1205
1206} // namespace tensorflow
1207
1208#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_
1209