1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ |
19 | #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ |
20 | |
21 | #include <cstdint> |
22 | |
23 | #include "tensorflow/core/framework/op_requires.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | #define EIGEN_USE_THREADS |
26 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
27 | #define EIGEN_USE_GPU |
28 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
29 | |
30 | #include "third_party/eigen3/Eigen/Core" |
31 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
32 | #include "tensorflow/core/framework/bounds_check.h" |
33 | #include "tensorflow/core/framework/numeric_op.h" |
34 | #include "tensorflow/core/framework/op_kernel.h" |
35 | #include "tensorflow/core/framework/register_types.h" |
36 | #include "tensorflow/core/framework/tensor.h" |
37 | #include "tensorflow/core/framework/tensor_types.h" |
38 | #include "tensorflow/core/framework/tensor_util.h" |
39 | #include "tensorflow/core/framework/types.h" |
40 | #include "tensorflow/core/kernels/segment_reduction_ops.h" |
41 | #include "tensorflow/core/lib/core/status.h" |
42 | #include "tensorflow/core/platform/logging.h" |
43 | #include "tensorflow/core/util/determinism.h" |
44 | #include "tensorflow/core/util/util.h" |
45 | |
46 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
47 | #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" |
48 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
49 | |
50 | #if GOOGLE_CUDA |
51 | #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_activation.h" |
52 | #include "tensorflow/core/util/gpu_solvers.h" |
53 | |
54 | using stream_executor::cuda::ScopedActivateExecutorContext; |
55 | #elif TENSORFLOW_USE_ROCM |
56 | #include "tensorflow/core/platform/rocm.h" |
57 | #include "tensorflow/core/util/gpu_solvers.h" |
58 | using stream_executor::rocm::ScopedActivateExecutorContext; |
59 | #endif // GOOGLE_CUDA |
60 | |
61 | namespace tensorflow { |
62 | |
63 | typedef Eigen::ThreadPoolDevice CPUDevice; |
64 | typedef Eigen::GpuDevice GPUDevice; |
65 | |
66 | namespace internal { |
67 | Status ValidateSegmentReduction(OpKernelContext* c, const Tensor& input, |
68 | const Tensor& segment_ids); |
69 | Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, |
70 | OpKernelContext* context, |
71 | const Tensor& data, |
72 | const Tensor& segment_ids, |
73 | const Tensor& num_segments); |
74 | Status ValidateSparseSegmentReduction(OpKernelContext* context, |
75 | const Tensor& input, |
76 | const Tensor& indices, |
77 | const Tensor& segment_ids, |
78 | bool has_num_segments); |
79 | } // namespace internal |
80 | |
81 | // This operator handles reducing segments along the first dimension. |
82 | // See core/ops/math_ops.cc for more details. |
83 | template <typename Device, class T, class Index, typename Reducer, |
84 | int default_value> |
85 | class SegmentReductionOp : public OpKernel { |
86 | public: |
87 | explicit SegmentReductionOp(OpKernelConstruction* context) |
88 | : OpKernel(context) {} |
89 | |
90 | void Compute(OpKernelContext* context) override { |
91 | const Tensor& input = context->input(0); |
92 | const Tensor& segment_ids = context->input(1); |
93 | |
94 | OP_REQUIRES_OK(context, internal::ValidateSegmentReduction(context, input, |
95 | segment_ids)); |
96 | |
97 | const int64_t num_indices = segment_ids.NumElements(); |
98 | auto input_flat = input.flat_outer_dims<T>(); |
99 | const int64_t num_col = input_flat.dimension(1); |
100 | |
101 | const auto segment_vec = segment_ids.vec<Index>(); |
102 | // Note that the current implementation assumes that segment_vec values are |
103 | // sorted. |
104 | const Index output_rows = |
105 | num_indices > 0 |
106 | ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 |
107 | : 0; |
108 | OP_REQUIRES(context, output_rows >= 0, |
109 | errors::InvalidArgument("segment ids must be >= 0" )); |
110 | |
111 | OP_REQUIRES(context, input.dims() >= 1, |
112 | errors::InvalidArgument("Shape must be at least rank 1" )); |
113 | |
114 | TensorShape output_shape = input.shape(); |
115 | // Since we're changing the first dimension of the shape, we need to make |
116 | // sure the new shape won't overflow. |
117 | OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, output_rows)); |
118 | |
119 | // Note that we do not initialize the output buffer with a default value, so |
120 | // we need to explicitly set missing indices to the default value. |
121 | Tensor* output = nullptr; |
122 | OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); |
123 | if (num_indices == 0) return; |
124 | OP_REQUIRES(context, output_rows > 0, |
125 | errors::InvalidArgument("segment ids must be >= 0" )); |
126 | auto output_flat = output->flat_outer_dims<T>(); |
127 | |
128 | Eigen::IndexList<Eigen::type2index<0> > dims_to_reduce; |
129 | Index start = 0, end = 1; |
130 | |
131 | Index uninitialized_index = 0; // Index from which the output is not set. |
132 | Index out_index = internal::SubtleMustCopy(segment_vec(start)); |
133 | |
134 | // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it |
135 | // across threads. |
136 | Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col); |
137 | while (end <= num_indices) { |
138 | // We initialize next_index to 0 to avoid "warning: 'next_index' may be |
139 | // used uninitialized in this function" in the Mac build (since the |
140 | // compiler isn't smart enough to realize the code is safe). |
141 | Index next_index = 0; |
142 | if (end < num_indices) { |
143 | next_index = internal::SubtleMustCopy(segment_vec(end)); |
144 | if (out_index == next_index) { |
145 | ++end; |
146 | continue; |
147 | } |
148 | // We have a new segment here. Verify that the segment ids are growing. |
149 | OP_REQUIRES(context, out_index < next_index, |
150 | errors::InvalidArgument("segment ids are not increasing" )); |
151 | } |
152 | |
153 | // Process segment [start, end) |
154 | const T* in_slice_ptr = &input_flat(start, 0); |
155 | typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, |
156 | Eigen::Unaligned> |
157 | OutT; |
158 | |
159 | OP_REQUIRES( |
160 | context, FastBoundsCheck(out_index, output_rows), |
161 | errors::InvalidArgument( |
162 | "Segment id " , out_index, " out of range [0, " , output_rows, |
163 | "), possibly because 'segment_ids' input is not sorted." )); |
164 | |
165 | // If there is a gap between two indices, we need to set that gap to the |
166 | // default value. |
167 | if (out_index > uninitialized_index) { |
168 | Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( |
169 | out_index - uninitialized_index, num_col); |
170 | Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> |
171 | gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); |
172 | gap_slice.setConstant(T(default_value)); |
173 | } |
174 | |
175 | T* out_slice_ptr = &output_flat(out_index, 0); |
176 | OutT out_slice(out_slice_ptr, out_slice_shape); |
177 | // We don't use out_slice.device(context->eigen_device<Device>) |
178 | // because these pieces of work are likely to be very small and |
179 | // the context switching overhead dwarfs any benefit we get from |
180 | // using another thread to do this work. |
181 | if (start == end - 1) { |
182 | typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>, |
183 | Eigen::Unaligned> |
184 | InT; |
185 | InT in_slice(in_slice_ptr, out_slice_shape); |
186 | out_slice = in_slice; |
187 | } else { |
188 | Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start, |
189 | num_col); |
190 | typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, |
191 | Eigen::Unaligned> |
192 | InT; |
193 | InT in_slice(in_slice_ptr, in_slice_shape); |
194 | |
195 | out_slice = in_slice.reduce(dims_to_reduce, Reducer()); |
196 | } |
197 | if (end >= num_indices) break; |
198 | start = end; |
199 | ++end; |
200 | uninitialized_index = out_index + 1; |
201 | out_index = next_index; |
202 | } |
203 | } |
204 | }; |
205 | |
206 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
207 | |
208 | // SegmentReductionGPUOp is a segment reduction operator implemented for GPU |
209 | // only. |
210 | // TODO: This implementation of SegmentReductionGPUOp is sometimes slower than |
211 | // its unsorted counterpart (mostly when problem size is small). |
212 | // This is due to the following two main reasons and a cost-effective way |
213 | // to resolve these problems is desirable. |
214 | // 1. Sorted segment reduction requires a memory transfer from device to host |
215 | // in order to know the size of the output dimension whereas unsorted |
216 | // segment reduction receives the size of the output dimension as an input |
217 | // parameter. |
218 | // 2. Sorted segment reduction is essentially a tiled version of unsorted |
219 | // segment reduction and therefore such optimization comes at an inherent |
220 | // cost. However such cost may not be justified when the problem size is |
221 | // small. When to use the tiled version or the untiled version depends on |
222 | // many factors including data alignments, ratio of calculation to memory |
223 | // traffic and obviously, the problem sizes. |
224 | template <class T, class Index, class SegmentReductionFunctor, bool IsMean> |
225 | class SegmentReductionGPUOp : public AsyncOpKernel { |
226 | public: |
227 | explicit SegmentReductionGPUOp(OpKernelConstruction* context) |
228 | : AsyncOpKernel(context) {} |
229 | |
230 | void ComputeAsync(OpKernelContext* context, DoneCallback done) override { |
231 | const Tensor& input = context->input(0); |
232 | const Tensor& segment_ids = context->input(1); |
233 | |
234 | OP_REQUIRES_ASYNC( |
235 | context, TensorShapeUtils::IsVector(segment_ids.shape()), |
236 | errors::InvalidArgument("segment_ids should be a vector." ), done); |
237 | |
238 | OP_REQUIRES_ASYNC(context, input.dims() >= 1, |
239 | errors::InvalidArgument("Shape must be at least rank 1" ), |
240 | done); |
241 | |
242 | const int64_t num_indices = segment_ids.NumElements(); |
243 | OP_REQUIRES_ASYNC( |
244 | context, num_indices == input.dim_size(0), |
245 | errors::InvalidArgument( |
246 | "segment_ids should be the same size as dimension 0 of" |
247 | " input." ), |
248 | done); |
249 | |
250 | if (num_indices == 0) { |
251 | TensorShape output_shape = input.shape(); |
252 | output_shape.set_dim(0, 0); |
253 | |
254 | Tensor* output = nullptr; |
255 | OP_REQUIRES_OK_ASYNC( |
256 | context, context->allocate_output(0, output_shape, &output), done); |
257 | done(); |
258 | return; |
259 | } |
260 | |
261 | se::DeviceMemoryBase output_rows_device( |
262 | const_cast<Tensor&>(segment_ids).template flat<Index>().data() + |
263 | (num_indices - 1)); |
264 | ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true); |
265 | |
266 | auto stream = context->op_device_context()->stream(); |
267 | OP_REQUIRES_ASYNC( |
268 | context, |
269 | stream |
270 | ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device, |
271 | sizeof(Index)) |
272 | .ok(), |
273 | errors::Internal(type_string() + |
274 | ": failed to copy output_rows from device" ), |
275 | done); |
276 | |
277 | SegmentReductionFunctor functor_; |
278 | auto create_and_check_output = [context, output_rows_host, &input, |
279 | &segment_ids, &functor_, done]() { |
280 | // Ensure that within the callback, the proper GPU settings are |
281 | // configured. |
282 | auto stream = context->op_device_context()->stream(); |
283 | ScopedActivateExecutorContext scoped_activation{stream->parent()}; |
284 | |
285 | Index output_rows = *output_rows_host.data(); |
286 | output_rows++; |
287 | OP_REQUIRES_ASYNC(context, output_rows > 0, |
288 | errors::InvalidArgument("segment ids must be >= 0" ), |
289 | done); |
290 | |
291 | TensorShape output_shape = input.shape(); |
292 | // Since we're changing the first dimension of the shape, we need to make |
293 | // sure the new shape won't overflow. |
294 | OP_REQUIRES_OK_ASYNC(context, |
295 | output_shape.SetDimWithStatus(0, output_rows), done); |
296 | |
297 | Tensor* output = nullptr; |
298 | OP_REQUIRES_OK_ASYNC( |
299 | context, context->allocate_output(0, output_shape, &output), done); |
300 | |
301 | bool use_deterministic_kernels = |
302 | #if defined(PLATFORM_WINDOWS) |
303 | // See comment in segment_reduction_ops_gpu_0.cu.cc regarding Windows |
304 | // CI build error. |
305 | false; |
306 | #else |
307 | UseDeterministicSegmentReductions() || |
308 | (!SegmentReductionFunctor::atomic_reduction_is_associative && |
309 | OpDeterminismRequired()); |
310 | #endif |
311 | |
312 | // The determinism check is here, rather than inside the functor (as it is |
313 | // for the unsorted segment reduction ops) because the done callback |
314 | // (required for OP_REQUIRES_ASYNC) is not available inside the functor. |
315 | bool determinism_requirement_met = |
316 | use_deterministic_kernels || |
317 | SegmentReductionFunctor::atomic_reduction_is_associative || |
318 | !OpDeterminismRequired() || |
319 | DisableSegmentReductionOpDeterminismExceptions(); |
320 | OP_REQUIRES_ASYNC( |
321 | context, determinism_requirement_met, |
322 | errors::Unimplemented( |
323 | "Deterministic GPU implementation of sorted segment reduction op" |
324 | " not available." ), |
325 | done); |
326 | |
327 | auto output_flat = output->flat_outer_dims<T>(); |
328 | auto data_ptr = input.template flat<T>().data(); |
329 | auto segment_flat = segment_ids.flat<Index>(); |
330 | functor_(context, context->eigen_device<GPUDevice>(), output_rows, |
331 | segment_ids.shape(), IsMean, segment_flat, input.NumElements(), |
332 | data_ptr, output_flat); |
333 | |
334 | done(); |
335 | }; |
336 | |
337 | context->device() |
338 | ->tensorflow_accelerator_device_info() |
339 | ->event_mgr->ThenExecute(stream, create_and_check_output); |
340 | } |
341 | }; |
342 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
343 | |
344 | // ____________________________________________________________________________ |
345 | // Unsorted segment reduction ops. |
346 | |
347 | namespace functor { |
348 | |
349 | // The ReductionFunctor implementation for CPU. |
350 | template <typename T, typename Index, typename InitialValueF, |
351 | typename ReductionF> |
352 | struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> { |
353 | void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape, |
354 | typename TTypes<Index>::ConstFlat segment_ids, |
355 | typename TTypes<T, 2>::ConstTensor data, |
356 | typename TTypes<T, 2>::Tensor output) { |
357 | auto cpu_device = ctx->eigen_cpu_device(); |
358 | output.device(cpu_device) = output.constant(InitialValueF()()); |
359 | if (data.size() == 0) { |
360 | return; |
361 | } |
362 | |
363 | // This functor will reduce `N` rows input to `num_segments` rows output. |
364 | const int64_t N = segment_ids.dimension(0); |
365 | const int64_t num_segments = output.dimension(0); |
366 | const int64_t inner_dim = data.dimension(1); |
367 | ReductionF reduction; |
368 | |
369 | // `num_real_segment` counts the rows actually reduced from input, |
370 | // the rows with negative segment index will be excluded. |
371 | // It will be used for cost model. |
372 | int64_t num_real_segment = N; |
373 | // `num_reductions` counts the rows actually reduced in output, |
374 | // the rows only filled with InitialValueF() will be excluded. |
375 | int64_t num_reductions = 0; |
376 | // `row_counter` records how many input rows will be reduced in each |
377 | // output row, the row only fills with InitialValueF() will keep 0. |
378 | // Length of non-zero elements is `num_reductions`. |
379 | std::vector<Index> row_counter(num_segments, 0); |
380 | |
381 | for (int64_t i = 0; i < N; ++i) { |
382 | Index j = internal::SubtleMustCopy(segment_ids(i)); |
383 | if (j < 0) { |
384 | --num_real_segment; |
385 | continue; |
386 | } |
387 | OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments), |
388 | errors::InvalidArgument( |
389 | "segment_ids" , SliceDebugString(segment_ids_shape, i), |
390 | " = " , j, " is out of range [0, " , num_segments, ")" )); |
391 | if (row_counter[j] == 0) num_reductions++; |
392 | row_counter[j]++; |
393 | } |
394 | |
395 | // Nothing to reduce. All output values equal to `InitialValueF()`. |
396 | if (num_reductions == 0) return; |
397 | |
398 | // Parallelize by `num_segments`. It's simple, efficient and safe |
399 | // (no data dependency): |
400 | // |
401 | // input segment_ids num_segments operation |
402 | // | a0 | | 0 | worker 1: |0| f(a0, a1) |
403 | // | b0 | | 1 | worker 2: |1| f(b0, b1) |
404 | // N | c0 | | 2 | --> worker 3: |2| f(c0) |
405 | // | b1 | | 1 | |
406 | // | a1 | | 0 | |
407 | // |
408 | // TODO(intel-tf): Balance workload in `row_counter` to make parallelism |
409 | // more efficient. |
410 | auto reductionWorker = [&](int64_t begin, int64_t end) -> void { |
411 | for (int64_t i = 0; i < N; i++) { |
412 | Index j = internal::SubtleMustCopy(segment_ids(i)); |
413 | // If `j` is in work scope of this worker, do the reduction. |
414 | if (j >= begin && j < end) { |
415 | reduction(data.template chip<0>(i), output.template chip<0>(j)); |
416 | } |
417 | } |
418 | }; |
419 | |
420 | // Reduction functors includes Sum, Max, Min, etc. Simply consider it |
421 | // will cost 5 cycles per operation. |
422 | const int64_t kAverTaskSize = num_real_segment / num_segments; |
423 | const int64_t compute_cycles = 5 * inner_dim * kAverTaskSize; |
424 | const int64_t input_bytes = sizeof(T) * inner_dim * kAverTaskSize; |
425 | const int64_t output_bytes = sizeof(T) * inner_dim * kAverTaskSize; |
426 | const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles); |
427 | cpu_device.parallelFor(num_segments, cost, reductionWorker); |
428 | } |
429 | }; |
430 | |
431 | template <typename T> |
432 | using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes<T, 2>::Matrix>; |
433 | |
434 | template <typename T> |
435 | using constMatrixChip = |
436 | Eigen::TensorChippingOp<0l, const typename TTypes<T, 2>::ConstMatrix>; |
437 | |
438 | // reduction functors |
439 | template <typename T> |
440 | struct SumOp { |
441 | void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { |
442 | output += data; |
443 | } |
444 | }; |
445 | |
446 | template <typename T> |
447 | struct MaxOp { |
448 | void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { |
449 | output = data.cwiseMax(output); |
450 | } |
451 | }; |
452 | |
453 | template <typename T> |
454 | struct MinOp { |
455 | void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { |
456 | output = data.cwiseMin(output); |
457 | } |
458 | }; |
459 | |
460 | template <typename T> |
461 | struct ProdOp { |
462 | void operator()(const constMatrixChip<T> data, MatrixChip<T> output) { |
463 | output *= data; |
464 | } |
465 | }; |
466 | } // namespace functor |
467 | |
468 | // The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor |
469 | // is the device specific implementation of the reduction. These device |
470 | // specific implementations are templated themselves with the corresponding |
471 | // initial value functors and reduction functors. |
472 | template <typename T, typename Index, typename DeviceReductionFunctor> |
473 | class UnsortedSegmentReductionOp : public OpKernel { |
474 | public: |
475 | explicit UnsortedSegmentReductionOp(OpKernelConstruction* context) |
476 | : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {} |
477 | |
478 | void Compute(OpKernelContext* context) override { |
479 | const Tensor& data = context->input(0); |
480 | const Tensor& segment_ids = context->input(1); |
481 | const Tensor& num_segments = context->input(2); |
482 | OP_REQUIRES_OK(context, |
483 | internal::ValidateUnsortedSegmentReduction( |
484 | this, context, data, segment_ids, num_segments)); |
485 | const auto segment_flat = segment_ids.flat<Index>(); |
486 | const int64_t output_rows = internal::SubtleMustCopy(static_cast<int64_t>( |
487 | num_segments.dtype() == DT_INT32 ? num_segments.scalar<int32>()() |
488 | : num_segments.scalar<int64_t>()())); |
489 | OP_REQUIRES(context, output_rows >= 0, |
490 | errors::InvalidArgument("Input num_segments == " , output_rows, |
491 | " must not be negative." )); |
492 | TensorShape output_shape; |
493 | OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(output_rows)); |
494 | for (int i = segment_ids.dims(); i < data.dims(); i++) { |
495 | OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(data.dim_size(i))); |
496 | } |
497 | Tensor* output = nullptr; |
498 | OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); |
499 | auto output_flat = output->flat_outer_dims<T>(); |
500 | auto data_flat = data.flat_inner_outer_dims<T, 2>(segment_ids.dims() - 1); |
501 | reduction_functor_(context, segment_ids.shape(), segment_flat, data_flat, |
502 | output_flat); |
503 | } |
504 | |
505 | protected: |
506 | DeviceReductionFunctor reduction_functor_; |
507 | }; |
508 | |
509 | // ____________________________________________________________________________ |
510 | // Sparse segment reduction ops. |
511 | |
512 | // Same as SegmentReductionOp but takes as input a "sparse" tensor, represented |
513 | // by two dense tensors, one containing the data, and the other containing |
514 | // indices into the data. |
515 | // |
516 | // The template parameters are: |
517 | // * Device: An Eigen device object, on which the kernel will execute. |
518 | // * T: The value type. |
519 | // * Index: The element type of the indices tensor (int32 or int64). |
520 | // * SegmentId: The element type of the segment_ids tensor (int32 or int64). |
521 | template <typename Device, class T, typename Index, typename SegmentId> |
522 | class SparseSegmentReductionOpBase : public OpKernel { |
523 | public: |
524 | explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, |
525 | bool is_mean, bool is_sqrtn, |
526 | bool has_num_segments, T default_value) |
527 | : OpKernel(context), |
528 | dtidx_(DataTypeToEnum<Index>::v()), |
529 | is_mean_(is_mean), |
530 | is_sqrtn_(is_sqrtn), |
531 | has_num_segments_(has_num_segments), |
532 | default_value_(default_value) {} |
533 | |
534 | void Compute(OpKernelContext* context) override { |
535 | const Tensor& input = context->input(0); |
536 | const Tensor& indices = context->input(1); |
537 | const Tensor& segment_ids = context->input(2); |
538 | |
539 | OP_REQUIRES_OK( |
540 | context, internal::ValidateSparseSegmentReduction( |
541 | context, input, indices, segment_ids, has_num_segments_)); |
542 | |
543 | Index output_rows = -1; |
544 | if (has_num_segments_) { |
545 | const Tensor& num_segments = context->input(3); |
546 | // Note that there is a Tnumsegments parameter on the op, but it is not |
547 | // plumbed through to here and so always takes its default value of int32. |
548 | output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()()); |
549 | } |
550 | const int64_t num_indices = indices.NumElements(); |
551 | |
552 | auto input_flat = input.flat_outer_dims<T>(); |
553 | const int64_t num_col = input_flat.dimension(1); |
554 | const auto indices_vec = indices.vec<Index>(); |
555 | const auto segment_vec = segment_ids.vec<SegmentId>(); |
556 | // Note that the current implementation assumes that segment_vec values are |
557 | // sorted. |
558 | const SegmentId last_segment_id = |
559 | num_indices > 0 ? segment_vec(num_indices - 1) : 0; |
560 | int64_t limit = dtidx_ == DataType::DT_INT32 ? kint32max : kint64max; |
561 | |
562 | OP_REQUIRES( |
563 | context, last_segment_id < limit, |
564 | errors::InvalidArgument("Last segment id must be < kintmax, got " , |
565 | last_segment_id, " limit " , limit)); |
566 | |
567 | const SegmentId last_segment_id_plus_one = |
568 | num_indices > 0 |
569 | ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 |
570 | : 0; |
571 | |
572 | if (has_num_segments_) { |
573 | OP_REQUIRES( |
574 | context, output_rows >= last_segment_id_plus_one, |
575 | errors::InvalidArgument("segment ids must be < num_segments" )); |
576 | } else { |
577 | output_rows = last_segment_id_plus_one; |
578 | } |
579 | OP_REQUIRES(context, output_rows >= 0, |
580 | errors::InvalidArgument("segment ids must be >= 0" )); |
581 | |
582 | TensorShape output_shape = input.shape(); |
583 | OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, output_rows)); |
584 | |
585 | // Note that we do not initialize the output buffer with a default value, so |
586 | // we need to explicitly set missing indices to the default value. |
587 | Tensor* output = nullptr; |
588 | OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); |
589 | if (num_indices == 0) { |
590 | if (output_rows > 0) { |
591 | output->flat_outer_dims<T>().setConstant(default_value_); |
592 | } |
593 | return; |
594 | } |
595 | OP_REQUIRES(context, output_rows > 0, |
596 | errors::InvalidArgument("segment ids must be >= 0" )); |
597 | auto output_flat = output->flat_outer_dims<T>(); |
598 | |
599 | Tensor temp; |
600 | if (input.dtype() == DT_BFLOAT16 || input.dtype() == DT_HALF) { |
601 | temp = tensorflow::Tensor(DT_FLOAT, output_shape); |
602 | } |
603 | auto temp_flat = temp.flat_outer_dims<float>(); |
604 | |
605 | int64_t start = 0, end = 1; |
606 | // Index from which the output is not initialized. |
607 | SegmentId uninitialized_index = 0; |
608 | SegmentId out_index = internal::SubtleMustCopy(segment_vec(start)); |
609 | |
610 | while (true) { |
611 | // We initialize next_index to 0 to avoid "warning: 'next_index' may be |
612 | // used uninitialized in this function" in the Mac build (since the |
613 | // compiler isn't smart enough to realize the code is safe). |
614 | SegmentId next_index = 0; |
615 | if (end < num_indices) { |
616 | next_index = internal::SubtleMustCopy(segment_vec(end)); |
617 | if (out_index == next_index) { |
618 | ++end; |
619 | continue; |
620 | } |
621 | // We have a new segment here. Verify that the segment ids are growing. |
622 | OP_REQUIRES(context, out_index < next_index, |
623 | errors::InvalidArgument("segment ids are not increasing" )); |
624 | } |
625 | |
626 | OP_REQUIRES( |
627 | context, FastBoundsCheck(out_index, output_rows), |
628 | errors::InvalidArgument( |
629 | "Segment id " , out_index, " out of range [0, " , output_rows, |
630 | "), possibly because 'segment_ids' input is not sorted." )); |
631 | |
632 | // If there is a gap between two indices, we need to set that gap to the |
633 | // default value. |
634 | if (out_index > uninitialized_index) { |
635 | Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( |
636 | out_index - uninitialized_index, num_col); |
637 | Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> |
638 | gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); |
639 | gap_slice.setConstant(default_value_); |
640 | } |
641 | |
642 | auto out = output_flat.template chip<0>(out_index); |
643 | auto temp = temp_flat.template chip<0>(out_index); |
644 | const int bad_offset = Reduce<T, Index>(input_flat, indices_vec, start, |
645 | end - start, out, temp); |
646 | OP_REQUIRES(context, bad_offset < 0, |
647 | errors::InvalidArgument( |
648 | "Bad: indices[" , start + bad_offset, |
649 | "] == " , indices_vec(start + bad_offset), |
650 | " out of range [0, " , input_flat.dimension(0), ")" )); |
651 | |
652 | start = end; |
653 | ++end; |
654 | uninitialized_index = out_index + 1; |
655 | out_index = next_index; |
656 | if (end > num_indices) break; |
657 | } |
658 | |
659 | // Fill the gap at the end with the default value. |
660 | if (uninitialized_index < output_rows) { |
661 | Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape( |
662 | output_rows - uninitialized_index, num_col); |
663 | Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned> |
664 | gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); |
665 | gap_slice.setConstant(default_value_); |
666 | } |
667 | } |
668 | |
669 | private: |
670 | const DataType dtidx_; |
671 | template <typename Tin> |
672 | using EnableIfBfloat16OrHalf = |
673 | typename std::enable_if<std::is_same<Tin, bfloat16>::value || |
674 | std::is_same<Tin, Eigen::half>::value, |
675 | int>::type; |
676 | template <typename Tin> |
677 | using EnableIfNotBfloat16OrHalf = |
678 | typename std::enable_if<!std::is_same<Tin, bfloat16>::value && |
679 | !std::is_same<Tin, Eigen::half>::value, |
680 | int>::type; |
681 | |
682 | template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0> |
683 | EIGEN_ALWAYS_INLINE auto fetch_val( |
684 | const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) { |
685 | return input_flat.template chip<0>(index); |
686 | } |
687 | |
688 | template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0> |
689 | EIGEN_ALWAYS_INLINE auto fetch_val( |
690 | const typename TTypes<Tin>::ConstMatrix& input_flat, Tindex index) { |
691 | return input_flat.template chip<0>(index).template cast<float>(); |
692 | } |
693 | |
694 | template <typename Tout> |
695 | EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64_t num) { |
696 | Tout m(1); |
697 | if (is_mean_ && (num < 10)) { |
698 | m = Tout(num); |
699 | } |
700 | if (is_sqrtn_ && (num < 10)) { |
701 | m = Tout(sqrt(num)); |
702 | } |
703 | return Tout(1) / m; |
704 | } |
705 | |
706 | template <typename Tin, typename Tindex, EnableIfNotBfloat16OrHalf<Tin> = 0> |
707 | int64_t Reduce( |
708 | const typename TTypes<Tin>::ConstMatrix& input_flat, |
709 | const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start, |
710 | int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out, |
711 | Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) { |
712 | return ReduceImpl<Tin, Tindex, Tin>(input_flat, indices_vec, start, num, |
713 | out, get_scaling_factor<Tin>(num)); |
714 | } |
715 | |
716 | template <typename Tin, typename Tindex, EnableIfBfloat16OrHalf<Tin> = 0> |
717 | int64_t Reduce( |
718 | const typename TTypes<Tin>::ConstMatrix& input_flat, |
719 | const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start, |
720 | int64_t num, Eigen::TensorChippingOp<0, typename TTypes<Tin>::Matrix> out, |
721 | Eigen::TensorChippingOp<0, typename TTypes<float>::Matrix> temp) { |
722 | int64_t res = |
723 | ReduceImpl<Tin, Tindex, float>(input_flat, indices_vec, start, num, |
724 | temp, get_scaling_factor<float>(num)); |
725 | out = temp.template cast<Tin>(); |
726 | return res; |
727 | } |
728 | |
729 | template <typename Tin, typename Tindex, typename Tout> |
730 | int64_t ReduceImpl( |
731 | const typename TTypes<Tin>::ConstMatrix& input_flat, |
732 | const typename TTypes<Tindex>::ConstVec& indices_vec, int64_t start, |
733 | int64_t num, |
734 | Eigen::TensorChippingOp<0, typename TTypes<Tout>::Matrix> out, |
735 | const Tout scaling_factor) { |
736 | #define INDEX(n, i) \ |
737 | const auto index##n = indices_vec(start + (i)); \ |
738 | if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i); |
739 | |
740 | #define L(n) fetch_val<Tin, Tindex>(input_flat, index##n) |
741 | |
742 | if (num == 1) { |
743 | INDEX(0, 0); |
744 | out = L(0); |
745 | } else { |
746 | int64_t r = num & 7; |
747 | switch (r) { |
748 | case 2: { |
749 | INDEX(0, 0); |
750 | INDEX(1, 1); |
751 | out = (L(0) + L(1)) * scaling_factor; |
752 | break; |
753 | } |
754 | case 3: { |
755 | INDEX(0, 0); |
756 | INDEX(1, 1); |
757 | INDEX(2, 2); |
758 | out = (L(0) + L(1) + L(2)) * scaling_factor; |
759 | break; |
760 | } |
761 | case 4: { |
762 | INDEX(0, 0); |
763 | INDEX(1, 1); |
764 | INDEX(2, 2); |
765 | INDEX(3, 3); |
766 | out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor; |
767 | break; |
768 | } |
769 | case 5: { |
770 | INDEX(0, 0); |
771 | INDEX(1, 1); |
772 | INDEX(2, 2); |
773 | INDEX(3, 3); |
774 | INDEX(4, 4); |
775 | out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor; |
776 | break; |
777 | } |
778 | case 6: { |
779 | INDEX(0, 0); |
780 | INDEX(1, 1); |
781 | INDEX(2, 2); |
782 | INDEX(3, 3); |
783 | INDEX(4, 4); |
784 | INDEX(5, 5); |
785 | out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor; |
786 | break; |
787 | } |
788 | case 7: { |
789 | INDEX(0, 0); |
790 | INDEX(1, 1); |
791 | INDEX(2, 2); |
792 | INDEX(3, 3); |
793 | INDEX(4, 4); |
794 | INDEX(5, 5); |
795 | INDEX(6, 6); |
796 | out = |
797 | (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor; |
798 | break; |
799 | } |
800 | case 0: { |
801 | INDEX(0, 0); |
802 | INDEX(1, 1); |
803 | INDEX(2, 2); |
804 | INDEX(3, 3); |
805 | INDEX(4, 4); |
806 | INDEX(5, 5); |
807 | INDEX(6, 6); |
808 | INDEX(7, 7); |
809 | out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) * |
810 | scaling_factor; |
811 | r = 8; |
812 | break; |
813 | } |
814 | case 1: { |
815 | INDEX(0, 0); |
816 | INDEX(1, 1); |
817 | INDEX(2, 2); |
818 | INDEX(3, 3); |
819 | INDEX(4, 4); |
820 | INDEX(5, 5); |
821 | INDEX(6, 6); |
822 | INDEX(7, 7); |
823 | INDEX(8, 8); |
824 | out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) * |
825 | scaling_factor; |
826 | r = 9; |
827 | break; |
828 | } |
829 | } |
830 | for (; r < num; r += 8) { |
831 | INDEX(0, r); |
832 | INDEX(1, r + 1); |
833 | INDEX(2, r + 2); |
834 | INDEX(3, r + 3); |
835 | INDEX(4, r + 4); |
836 | INDEX(5, r + 5); |
837 | INDEX(6, r + 6); |
838 | INDEX(7, r + 7); |
839 | out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7); |
840 | } |
841 | if (is_mean_ && num >= 10) { |
842 | out = out / static_cast<Tout>(num); |
843 | } |
844 | if (is_sqrtn_ && num >= 10) { |
845 | out = out / static_cast<Tout>(sqrt(num)); |
846 | } |
847 | } |
848 | |
849 | return -1; |
850 | #undef L |
851 | #undef INDEX |
852 | } |
853 | |
854 | const bool is_mean_; |
855 | const bool is_sqrtn_; |
856 | const bool has_num_segments_; |
857 | const T default_value_; |
858 | }; |
859 | |
860 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
861 | |
862 | // Specialization for GPU. Must be Async because may need to wait for a host to |
863 | // device memcpy before allocating output. |
864 | template <class T, typename Index, typename SegmentId> |
865 | class SparseSegmentReductionOpBase<GPUDevice, T, Index, SegmentId> |
866 | : public AsyncOpKernel { |
867 | public: |
868 | explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, |
869 | bool is_mean, bool is_sqrtn, |
870 | bool has_num_segments, T default_value) |
871 | : AsyncOpKernel(context), |
872 | is_mean_(is_mean), |
873 | is_sqrtn_(is_sqrtn), |
874 | has_num_segments_(has_num_segments), |
875 | default_value_(default_value) {} |
876 | |
877 | void ComputeAsync(OpKernelContext* context, DoneCallback done) override { |
878 | const Tensor& input = context->input(0); |
879 | const Tensor& indices = context->input(1); |
880 | const Tensor& segment_ids = context->input(2); |
881 | |
882 | OP_REQUIRES_OK_ASYNC( |
883 | context, |
884 | internal::ValidateSparseSegmentReduction( |
885 | context, input, indices, segment_ids, has_num_segments_), |
886 | done); |
887 | |
888 | ScratchSpace<SegmentId> last_segment_id_host(context, 1, /*on_host=*/true); |
889 | |
890 | auto create_and_check_output = [this, context, input, indices, segment_ids, |
891 | last_segment_id_host, done]() { |
892 | // Ensure that within the callback, the proper GPU settings are |
893 | // configured. |
894 | auto stream = context->op_device_context()->stream(); |
895 | ScopedActivateExecutorContext scoped_activation{stream->parent()}; |
896 | |
897 | SegmentId last_segment_id = *last_segment_id_host.data(); |
898 | SegmentId output_rows = last_segment_id + 1; |
899 | OP_REQUIRES_ASYNC(context, output_rows > 0, |
900 | errors::InvalidArgument("segment ids must be >= 0" ), |
901 | done); |
902 | |
903 | TensorShape output_shape = input.shape(); |
904 | output_shape.set_dim(0, output_rows); |
905 | |
906 | Tensor* output = nullptr; |
907 | OP_REQUIRES_OK_ASYNC( |
908 | context, context->allocate_output(0, output_shape, &output), done); |
909 | |
910 | auto input_flat = input.flat_outer_dims<T>(); |
911 | const auto indices_vec = indices.vec<Index>(); |
912 | const auto segment_ids_vec = segment_ids.vec<SegmentId>(); |
913 | auto output_flat = output->flat_outer_dims<T>(); |
914 | |
915 | functor::SparseSegmentReductionFunctor<T, Index, SegmentId> functor; |
916 | OP_REQUIRES_OK_ASYNC( |
917 | context, |
918 | functor(context, is_mean_, is_sqrtn_, default_value_, input_flat, |
919 | indices_vec, segment_ids_vec, output_flat), |
920 | done); |
921 | done(); |
922 | }; |
923 | |
924 | if (has_num_segments_) { |
925 | // No need to do any device to host memcpy, just compute synchronously. |
926 | const Tensor& num_segments_t = context->input(3); |
927 | SegmentId num_segments = |
928 | internal::SubtleMustCopy(num_segments_t.dtype() == DT_INT32 |
929 | ? num_segments_t.scalar<int32>()() |
930 | : num_segments_t.scalar<int64_t>()()); |
931 | *last_segment_id_host.mutable_data() = num_segments - 1; |
932 | create_and_check_output(); |
933 | } else { |
934 | const int64_t num_indices = indices.NumElements(); |
935 | // Need to copy last element of segment_ids from device to host, and then |
936 | // asynchronously allocate the output and finish the computation. |
937 | se::DeviceMemoryBase last_segment_id_device( |
938 | const_cast<Tensor&>(segment_ids).template flat<SegmentId>().data() + |
939 | (num_indices - 1)); |
940 | auto stream = context->op_device_context()->stream(); |
941 | OP_REQUIRES_ASYNC( |
942 | context, |
943 | stream |
944 | ->ThenMemcpy(last_segment_id_host.mutable_data(), |
945 | last_segment_id_device, sizeof(SegmentId)) |
946 | .ok(), |
947 | errors::Internal(type_string() + |
948 | ": failed to copy last_segment_id from device" ), |
949 | done); |
950 | context->device() |
951 | ->tensorflow_accelerator_device_info() |
952 | ->event_mgr->ThenExecute(stream, create_and_check_output); |
953 | } |
954 | } |
955 | |
956 | private: |
957 | const bool is_mean_; |
958 | const bool is_sqrtn_; |
959 | const bool has_num_segments_; |
960 | const T default_value_; |
961 | }; |
962 | |
963 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
964 | |
965 | template <typename Device, class T, typename Index, typename SegmentId> |
966 | class SparseSegmentReductionMeanOp |
967 | : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { |
968 | public: |
969 | explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context) |
970 | : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( |
971 | context, true /*is_mean*/, false /*is_sqrtn*/, |
972 | false /* has_num_segments */, T(0) /* default_value */) {} |
973 | }; |
974 | |
975 | template <typename Device, class T, typename Index, typename SegmentId> |
976 | class SparseSegmentReductionMeanWithNumSegmentsOp |
977 | : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { |
978 | public: |
979 | explicit SparseSegmentReductionMeanWithNumSegmentsOp( |
980 | OpKernelConstruction* context) |
981 | : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( |
982 | context, true /*is_mean*/, false /*is_sqrtn*/, |
983 | true /* has_num_segments */, T(0) /* default_value */) {} |
984 | }; |
985 | |
986 | template <typename Device, class T, typename Index, typename SegmentId> |
987 | class SparseSegmentReductionSqrtNOp |
988 | : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { |
989 | public: |
990 | explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context) |
991 | : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( |
992 | context, false /*is_mean*/, true /*is_sqrtn*/, |
993 | false /* has_num_segments */, T(0) /* default_value */) {} |
994 | }; |
995 | |
996 | template <typename Device, class T, typename Index, typename SegmentId> |
997 | class SparseSegmentReductionSqrtNWithNumSegmentsOp |
998 | : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { |
999 | public: |
1000 | explicit SparseSegmentReductionSqrtNWithNumSegmentsOp( |
1001 | OpKernelConstruction* context) |
1002 | : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( |
1003 | context, false /*is_mean*/, true /*is_sqrtn*/, |
1004 | true /* has_num_segments */, T(0) /* default_value */) {} |
1005 | }; |
1006 | |
1007 | template <typename Device, class T, typename Index, typename SegmentId> |
1008 | class SparseSegmentReductionSumOp |
1009 | : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { |
1010 | public: |
1011 | explicit SparseSegmentReductionSumOp(OpKernelConstruction* context) |
1012 | : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( |
1013 | context, false /*is_mean*/, false /*is_sqrtn*/, |
1014 | false /* has_num_segments */, T(0) /* default_value */) {} |
1015 | }; |
1016 | |
1017 | template <typename Device, class T, typename Index, typename SegmentId> |
1018 | class SparseSegmentReductionSumWithNumSegmentsOp |
1019 | : public SparseSegmentReductionOpBase<Device, T, Index, SegmentId> { |
1020 | public: |
1021 | explicit SparseSegmentReductionSumWithNumSegmentsOp( |
1022 | OpKernelConstruction* context) |
1023 | : SparseSegmentReductionOpBase<Device, T, Index, SegmentId>( |
1024 | context, false /*is_mean*/, false /*is_sqrtn*/, |
1025 | true /* has_num_segments */, T(0) /* default_value */) {} |
1026 | }; |
1027 | |
1028 | namespace functor { |
1029 | |
1030 | template <typename T, typename Index, typename SegmentId> |
1031 | struct SparseSegmentGradFunctor<CPUDevice, T, Index, SegmentId> { |
1032 | void operator()(OpKernelContext* context, |
1033 | SparseSegmentReductionOperation operation, |
1034 | typename TTypes<T>::ConstMatrix input_flat, |
1035 | typename TTypes<Index>::ConstVec indices_vec, |
1036 | typename TTypes<SegmentId>::ConstVec segment_vec, |
1037 | typename TTypes<T>::Matrix output_flat) { |
1038 | const int64_t N = indices_vec.size(); |
1039 | const SegmentId M = output_flat.dimension(0); |
1040 | |
1041 | // Note that similar to SparseSegmentMean, we assume that segment_vec is |
1042 | // already sorted and has non-negative values. |
1043 | const SegmentId num_segments = input_flat.dimension(0); |
1044 | const SegmentId last_segment_id_plus_one = |
1045 | internal::SubtleMustCopy(segment_vec(N - 1)) + 1; |
1046 | OP_REQUIRES(context, last_segment_id_plus_one <= num_segments, |
1047 | errors::InvalidArgument("Invalid number of segments" )); |
1048 | |
1049 | // Compute scaling factors for input. |
1050 | std::vector<double> scaling( |
1051 | (operation == SparseSegmentReductionOperation::kSum ? 0 : num_segments), |
1052 | 0.0); |
1053 | if (operation != SparseSegmentReductionOperation::kSum) { |
1054 | for (int64_t i = 0; i < N; ++i) { |
1055 | const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); |
1056 | OP_REQUIRES( |
1057 | context, FastBoundsCheck(idx, num_segments), |
1058 | errors::InvalidArgument("Segment id " , idx, " out of range [0, " , |
1059 | num_segments, ")." )); |
1060 | scaling[idx] += 1; |
1061 | } |
1062 | for (size_t i = 0; i < scaling.size(); ++i) { |
1063 | switch (operation) { |
1064 | case SparseSegmentReductionOperation::kSum: { |
1065 | OP_REQUIRES( |
1066 | context, false, |
1067 | errors::Internal( |
1068 | "Should not happen: sum inside SparseSegmentReductionOp " |
1069 | "scaling generation." )); |
1070 | } |
1071 | case SparseSegmentReductionOperation::kMean: { |
1072 | scaling[i] = 1.0 / std::max(scaling[i], 1.0); |
1073 | break; |
1074 | } |
1075 | case SparseSegmentReductionOperation::kSqrtN: { |
1076 | scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0)); |
1077 | break; |
1078 | } |
1079 | // No default to get compiler warnings for missing cases. |
1080 | } |
1081 | } |
1082 | } |
1083 | |
1084 | output_flat.setZero(); |
1085 | std::vector<bool> is_modified(M, false); |
1086 | |
1087 | for (int64_t i = 0; i < N; ++i) { |
1088 | const Index output_idx = internal::SubtleMustCopy(indices_vec(i)); |
1089 | OP_REQUIRES(context, FastBoundsCheck(output_idx, M), |
1090 | errors::InvalidArgument("Index " , output_idx, |
1091 | " out of range [0, " , M, ")." )); |
1092 | |
1093 | const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); |
1094 | OP_REQUIRES( |
1095 | context, FastBoundsCheck(idx, num_segments), |
1096 | errors::InvalidArgument("Segment id " , idx, " out of range [0, " , |
1097 | num_segments, ")." )); |
1098 | |
1099 | const T scale = (operation == SparseSegmentReductionOperation::kSum |
1100 | ? static_cast<T>(1) |
1101 | : static_cast<T>(scaling[idx])); |
1102 | if (is_modified[output_idx]) { |
1103 | if (scale == 1.0) { |
1104 | output_flat.template chip<0>(output_idx) += |
1105 | input_flat.template chip<0>(idx); |
1106 | } else { |
1107 | output_flat.template chip<0>(output_idx) += |
1108 | input_flat.template chip<0>(idx) * scale; |
1109 | } |
1110 | } else { |
1111 | if (scale == 1.0) { |
1112 | output_flat.template chip<0>(output_idx) = |
1113 | input_flat.template chip<0>(idx); |
1114 | } else { |
1115 | output_flat.template chip<0>(output_idx) = |
1116 | input_flat.template chip<0>(idx) * scale; |
1117 | } |
1118 | } |
1119 | is_modified[output_idx] = true; |
1120 | } |
1121 | } |
1122 | }; |
1123 | |
1124 | } // namespace functor |
1125 | |
1126 | // Implements the common logic for the gradients of SparseSegmentReduction |
1127 | // kernels. |
1128 | // |
1129 | // The template parameters are: |
1130 | // * Device: An Eigen device object, on which the kernel will execute. |
1131 | // * T: The value type. |
1132 | // * Index: The element type of the indices tensor (int32 or int64). |
1133 | // * SegmentId: The element type of the segment_ids tensor (int32 or int64). |
1134 | template <typename Device, class T, typename Index, typename SegmentId> |
1135 | class SparseSegmentGradOpBase : public OpKernel { |
1136 | public: |
1137 | explicit SparseSegmentGradOpBase(OpKernelConstruction* context, |
1138 | SparseSegmentReductionOperation operation) |
1139 | : OpKernel(context), operation_(operation) {} |
1140 | |
1141 | void Compute(OpKernelContext* context) override { |
1142 | const Tensor& input = context->input(0); |
1143 | const Tensor& indices = context->input(1); |
1144 | const Tensor& segment_ids = context->input(2); |
1145 | const Tensor& output_dim0 = context->input(3); |
1146 | |
1147 | OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), |
1148 | errors::InvalidArgument("indices should be a vector." )); |
1149 | OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), |
1150 | errors::InvalidArgument("segment_ids should be a vector." )); |
1151 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_dim0.shape()), |
1152 | errors::InvalidArgument("output_dim0 should be a scalar." )); |
1153 | |
1154 | const int64_t N = indices.NumElements(); |
1155 | OP_REQUIRES(context, N == segment_ids.NumElements(), |
1156 | errors::InvalidArgument( |
1157 | "segment_ids and indices should have same size." )); |
1158 | const SegmentId M = internal::SubtleMustCopy(output_dim0.scalar<int32>()()); |
1159 | |
1160 | auto input_flat = input.flat_outer_dims<T>(); |
1161 | const auto indices_vec = indices.vec<Index>(); |
1162 | const auto segment_vec = segment_ids.vec<SegmentId>(); |
1163 | |
1164 | TensorShape output_shape = input.shape(); |
1165 | OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, M)); |
1166 | Tensor* output = nullptr; |
1167 | OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); |
1168 | if (M == 0 || N == 0) return; |
1169 | |
1170 | auto output_flat = output->flat_outer_dims<T>(); |
1171 | functor::SparseSegmentGradFunctor<Device, T, Index, SegmentId>()( |
1172 | context, operation_, input_flat, indices_vec, segment_vec, output_flat); |
1173 | } |
1174 | |
1175 | private: |
1176 | const SparseSegmentReductionOperation operation_; |
1177 | }; |
1178 | |
1179 | template <typename Device, class T, typename Index, typename SegmentId> |
1180 | class SparseSegmentSumGradOp |
1181 | : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> { |
1182 | public: |
1183 | explicit SparseSegmentSumGradOp(OpKernelConstruction* context) |
1184 | : SparseSegmentGradOpBase<Device, T, Index, SegmentId>( |
1185 | context, SparseSegmentReductionOperation::kSum) {} |
1186 | }; |
1187 | |
1188 | template <typename Device, class T, typename Index, typename SegmentId> |
1189 | class SparseSegmentMeanGradOp |
1190 | : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> { |
1191 | public: |
1192 | explicit SparseSegmentMeanGradOp(OpKernelConstruction* context) |
1193 | : SparseSegmentGradOpBase<Device, T, Index, SegmentId>( |
1194 | context, SparseSegmentReductionOperation::kMean) {} |
1195 | }; |
1196 | |
1197 | template <typename Device, class T, typename Index, typename SegmentId> |
1198 | class SparseSegmentSqrtNGradOp |
1199 | : public SparseSegmentGradOpBase<Device, T, Index, SegmentId> { |
1200 | public: |
1201 | explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context) |
1202 | : SparseSegmentGradOpBase<Device, T, Index, SegmentId>( |
1203 | context, SparseSegmentReductionOperation::kSqrtN) {} |
1204 | }; |
1205 | |
1206 | } // namespace tensorflow |
1207 | |
1208 | #endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ |
1209 | |