1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/bias_op.h"
21
22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23#include "tensorflow/core/framework/bounds_check.h"
24#include "tensorflow/core/framework/numeric_op.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/kernels/redux_functor.h"
29#include "tensorflow/core/profiler/lib/scoped_annotation.h"
30#include "tensorflow/core/util/determinism.h"
31#include "tensorflow/core/util/tensor_format.h"
32
33#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
34#include "tensorflow/core/kernels/bias_op_gpu.h"
35#include "tensorflow/core/platform/stream_executor.h"
36#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37#if GOOGLE_CUDA
38#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_stream.h"
39#endif // GOOGLE_CUDA
40
41namespace tensorflow {
42
43typedef Eigen::ThreadPoolDevice CPUDevice;
44typedef Eigen::GpuDevice GPUDevice;
45
46namespace {
47
48void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
49 int32* batch, int32* height, int32* width, int32* depth,
50 int32* channel) {
51 *batch = 1;
52 *height = 1;
53 *width = 1;
54 *depth = 1;
55 *channel = 1;
56 if (data_format == FORMAT_NHWC) {
57 int32_t channel_dim = value_tensor.dims() - 1;
58 *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
59 for (int32_t i = 0; i < channel_dim; i++) {
60 *batch *= static_cast<int32>(value_tensor.dim_size(i));
61 }
62 } else if (data_format == FORMAT_NCHW) {
63 *batch = static_cast<int32>(value_tensor.dim_size(0));
64 *channel = static_cast<int32>(value_tensor.dim_size(1));
65 *height = static_cast<int32>(value_tensor.dim_size(2));
66 if (value_tensor.dims() > 3) {
67 *width = static_cast<int32>(value_tensor.dim_size(3));
68 }
69 if (value_tensor.dims() > 4) {
70 *depth = static_cast<int32>(value_tensor.dim_size(4));
71 }
72 }
73}
74
75template <class T>
76struct AccumulatorType {
77 typedef T type;
78};
79
80// float is faster on the CPU than half, and also more precise,
81// so use float for the temporary accumulators.
82template <>
83struct AccumulatorType<Eigen::half> {
84 typedef float type;
85};
86
87} // namespace
88
89template <typename Device, typename T>
90class BiasOp : public BinaryOp<T> {
91 public:
92 explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
93 string data_format;
94 if (context->GetAttr("data_format", &data_format).ok()) {
95 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
96 errors::InvalidArgument("Invalid data format"));
97 } else {
98 data_format_ = FORMAT_NHWC;
99 }
100 }
101
102 void Compute(OpKernelContext* context) override {
103 const Tensor& input = context->input(0);
104 const Tensor& bias = context->input(1);
105
106 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
107 errors::InvalidArgument("Input tensor must be at least 2D: ",
108 input.shape()));
109 OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
110 errors::InvalidArgument("Biases must be 1D: ", bias.shape()));
111
112 // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
113 int channel_dim;
114 if (data_format_ == FORMAT_NCHW) {
115 channel_dim = 1; // NCHW always have channel dim in 1 (with 3, 4, 5
116 // dimensions data).
117 } else {
118 channel_dim = input.shape().dims() - 1; // End of code by intel_tf.
119 }
120
121 OP_REQUIRES(context,
122 bias.shape().dim_size(0) == input.shape().dim_size(channel_dim),
123 errors::InvalidArgument(
124 "Must provide as many biases as the last dimension "
125 "of the input tensor: ",
126 bias.shape(), " vs. ", input.shape()));
127
128 Tensor* output = nullptr;
129 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
130 {0}, 0, input.shape(), &output));
131 if (input.NumElements() == 0) return;
132
133 functor::Bias<Device, T> functor;
134 const Device& d = context->eigen_device<Device>();
135 if (data_format_ == FORMAT_NCHW && input.shape().dims() > 2) {
136 functor(d, input.flat_inner_outer_dims<T, 2>(1),
137 bias.flat_outer_dims<T, 2>(),
138 output->flat_inner_outer_dims<T, 2>(1));
139 } else {
140 functor(d, input.flat<T>(), bias.vec<T>(), output->flat<T>());
141 }
142 }
143
144 private:
145 TensorFormat data_format_;
146};
147
148#define REGISTER_KERNEL(type) \
149 REGISTER_KERNEL_BUILDER( \
150 Name("BiasAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
151 BiasOp<CPUDevice, type>); \
152 REGISTER_KERNEL_BUILDER( \
153 Name("BiasAddV1").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
154 BiasOp<CPUDevice, type>);
155
156TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
157#undef REGISTER_KERNEL
158
159template <typename Device, typename T>
160class BiasGradOp : public OpKernel {
161 public:
162 explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
163 string data_format;
164 if (context->GetAttr("data_format", &data_format).ok()) {
165 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
166 errors::InvalidArgument("Invalid data format"));
167 } else {
168 data_format_ = FORMAT_NHWC;
169 }
170 }
171
172 void Compute(OpKernelContext* context) override {
173 const Tensor& output_backprop = context->input(0);
174
175 OP_REQUIRES(context,
176 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
177 errors::InvalidArgument("Input tensor must be at least 2D: ",
178 output_backprop.shape()));
179
180 OP_REQUIRES(
181 context,
182 FastBoundsCheck(output_backprop.NumElements(),
183 std::numeric_limits<int32>::max()),
184 errors::InvalidArgument("BiasGrad requires tensor size <= int32 max"));
185
186 int channel_dim;
187 if (data_format_ == FORMAT_NCHW) {
188 channel_dim = 1;
189 } else {
190 channel_dim = output_backprop.shape().dims() - 1;
191 }
192 Tensor* output = nullptr;
193 TensorShape output_shape{output_backprop.shape().dim_size(channel_dim)};
194 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
195
196 if (output_backprop.NumElements() == 0) {
197 // Eigen often crashes by design on empty tensors, but setZero is safe
198 output->template flat<T>().setZero();
199 } else {
200 // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
201 using AccumT = typename AccumulatorType<T>::type;
202 if (data_format_ == FORMAT_NCHW) {
203 const functor::ReduceMiddleDimensions<
204 T, AccumT, T, Eigen::internal::scalar_sum_op<AccumT>,
205 Eigen::internal::SumReducer<T>>
206 redux;
207
208 auto flat_outer = output_backprop.flat_outer_dims<T, 3>();
209 redux(context->eigen_device<Device>(), flat_outer.dimensions(),
210 output_backprop, output, 1);
211 } else {
212 const functor::ReduceOuterDimensions<
213 T, AccumT, T, Eigen::internal::scalar_sum_op<AccumT>>
214 redux;
215
216 auto flat_inner = output_backprop.flat_inner_dims<T, 2>();
217 redux(context->eigen_device<Device>(), flat_inner.dimensions(),
218 output_backprop, output);
219 }
220 }
221 }
222
223 private:
224 TensorFormat data_format_;
225};
226
227// Registration of the GPU implementations.
228#define REGISTER_KERNEL(type) \
229 REGISTER_KERNEL_BUILDER( \
230 Name("BiasAddGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
231 BiasGradOp<CPUDevice, type>);
232
233TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
234#undef REGISTER_KERNEL
235
236#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
237template <typename T>
238class BiasOp<GPUDevice, T> : public BinaryOp<T> {
239 public:
240 typedef GPUDevice Device;
241 explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
242 string data_format;
243 if (context->GetAttr("data_format", &data_format).ok()) {
244 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
245 errors::InvalidArgument("Invalid data format"));
246 } else {
247 data_format_ = FORMAT_NHWC;
248 }
249 }
250
251 void Compute(OpKernelContext* context) override {
252 const Tensor& input = context->input(0);
253 const Tensor& bias = context->input(1);
254
255 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
256 errors::InvalidArgument("Input tensor must be at least 2D: ",
257 input.shape().DebugString()));
258 OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
259 errors::InvalidArgument("Biases must be 1D: ",
260 bias.shape().DebugString()));
261 int32_t batch, height, width, depth, channel;
262 GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth,
263 &channel);
264 OP_REQUIRES(context, bias.shape().dim_size(0) == channel,
265 errors::InvalidArgument(
266 "Must provide as many biases as the channel dimension "
267 "of the input tensor: ",
268 bias.shape().DebugString(), " vs. ", channel, " in ",
269 input.shape().DebugString()));
270 Tensor* output = nullptr;
271 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
272 {0}, 0, input.shape(), &output));
273 if (input.NumElements() > 0) {
274 BiasGPU<T>::compute(context->template eigen_device<Device>(),
275 input.flat<T>().data(), bias.flat<T>().data(),
276 output->flat<T>().data(), batch, width, height, depth,
277 channel, data_format_);
278 }
279 }
280
281 private:
282 TensorFormat data_format_;
283};
284
285// Registration of the GPU implementations.
286#define REGISTER_GPU_KERNEL(type) \
287 REGISTER_KERNEL_BUILDER( \
288 Name("BiasAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
289 BiasOp<GPUDevice, type>); \
290 REGISTER_KERNEL_BUILDER( \
291 Name("BiasAddV1").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
292 BiasOp<GPUDevice, type>);
293
294TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
295REGISTER_GPU_KERNEL(int32);
296#undef REGISTER_GPU_KERNEL
297
298struct BiasGradAutotuneGroup {
299 static string name() { return "BiasGrad"; }
300};
301
302class BiasAddGradGPUConfig {
303 public:
304 BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {}
305 string ToString() const {
306 if (mode_ == BiasAddGradGPUMode::kNative) {
307 return "native CUDA kernel.";
308 }
309 if (mode_ == BiasAddGradGPUMode::kReduction) {
310 return "cub reduction kernel.";
311 }
312 return "unknown kernel.";
313 }
314 BiasAddGradGPUMode get_mode() const { return mode_; }
315 void set_mode(BiasAddGradGPUMode val) { mode_ = val; }
316
317 bool operator==(const BiasAddGradGPUConfig& other) const {
318 return this->mode_ == other.get_mode();
319 }
320
321 bool operator!=(const BiasAddGradGPUConfig& other) const {
322 return !(*this == other);
323 }
324
325 private:
326 BiasAddGradGPUMode mode_;
327};
328
329// Encapsulate all the shape information that is used in bias add grad
330// operations.
331class BiasAddParams {
332 public:
333 // We use a list to maintain both the shape value and the order (data format).
334 using SpatialArray = gtl::InlinedVector<int64_t, 4>;
335 BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format,
336 DataType dtype, int device_id)
337 : in_shape_(in_shape),
338 data_format_(data_format),
339 dtype_(dtype),
340 device_id_(device_id) {
341 for (int64_t val : in_shape_) {
342 hash_code_ = Hash64Combine(hash_code_, val);
343 }
344 hash_code_ = Hash64Combine(hash_code_, data_format);
345 hash_code_ = Hash64Combine(hash_code_, dtype);
346 hash_code_ = Hash64Combine(hash_code_, device_id);
347 }
348 bool operator==(const BiasAddParams& other) const {
349 return this->get_data_as_tuple() == other.get_data_as_tuple();
350 }
351
352 bool operator!=(const BiasAddParams& other) const {
353 return !(*this == other);
354 }
355 uint64 hash() const { return hash_code_; }
356
357 string ToString() const {
358 // clang-format off
359 return strings::StrCat(
360 "(", absl::StrJoin(in_shape_, ", "), "), ",
361 data_format_, ", ", dtype_, ", ", device_id_);
362 // clang-format on
363 }
364
365 protected:
366 using ParamsDataType = std::tuple<SpatialArray, TensorFormat, DataType, int>;
367
368 ParamsDataType get_data_as_tuple() const {
369 return std::make_tuple(in_shape_, data_format_, dtype_, device_id_);
370 }
371
372 uint64 hash_code_ = 0;
373
374 private:
375 SpatialArray in_shape_;
376 TensorFormat data_format_;
377 DataType dtype_;
378 int device_id_;
379};
380
381typedef AutotuneSingleton<BiasGradAutotuneGroup, BiasAddParams,
382 BiasAddGradGPUConfig>
383 AutotuneBiasGrad;
384
385template <typename T>
386class BiasGradOp<GPUDevice, T> : public OpKernel {
387 public:
388 typedef GPUDevice Device;
389 explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
390 string data_format;
391 if (context->GetAttr("data_format", &data_format).ok()) {
392 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
393 errors::InvalidArgument("Invalid data format"));
394 } else {
395 data_format_ = FORMAT_NCHW;
396 }
397 }
398
399 void ComputeWithCustomKernel(OpKernelContext* context,
400 const Tensor& output_backprop, int32_t batch,
401 int32_t width, int32_t height, int32_t depth,
402 int32_t channel, Tensor* output) {
403 BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
404 output_backprop.template flat<T>().data(),
405 output->flat<T>().data(), batch, width, height,
406 depth, channel, data_format_);
407 }
408
409 void ComputeWithReduceSum(OpKernelContext* context,
410 const Tensor& output_backprop, int32_t batch,
411 int32_t width, int32_t height, int32_t depth,
412 int32_t channel, Tensor* output) {
413 if (data_format_ == FORMAT_NCHW) {
414 int32_t row_count = batch * channel;
415 int32_t col_count = height * width * depth;
416 Tensor temp_grad_outputs;
417 // For 'NCHW' format, we perform reduction twice: first HW, then N.
418 TensorShape temp_grad_output_shape{row_count, col_count};
419 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
420 temp_grad_output_shape,
421 &temp_grad_outputs));
422 BiasGradGPU<T>::DoRowReduction(
423 context, temp_grad_outputs.flat<T>().data(),
424 output_backprop.template flat<T>().data(), row_count, col_count);
425
426 row_count = batch;
427 col_count = channel;
428 BiasGradGPU<T>::DoColReduction(context, output->flat<T>().data(),
429 temp_grad_outputs.flat<T>().data(),
430 row_count, col_count);
431 } else {
432 // For 'NHWC', we simply apply reduction once on NHW.
433 int32_t row_count = batch * height * width * depth;
434 int32_t col_count = channel;
435 BiasGradGPU<T>::DoColReduction(
436 context, const_cast<T*>(output->flat<T>().data()),
437 reinterpret_cast<const T*>(output_backprop.template flat<T>().data()),
438 row_count, col_count);
439 }
440 }
441
442 void Compute(OpKernelContext* context) override {
443 const Tensor& output_backprop = context->input(0);
444
445 OP_REQUIRES(context,
446 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
447 errors::InvalidArgument("Input tensor must be at least 2D: ",
448 output_backprop.shape().DebugString()));
449 int32_t batch, height, width, depth, channel;
450 GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width,
451 &depth, &channel);
452 Tensor* output = nullptr;
453 TensorShape output_shape{channel};
454 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
455 if (channel == 0) return;
456 auto* stream = context->op_device_context()->stream();
457 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
458 se::DeviceMemoryBase output_ptr(output->flat<T>().data(),
459 output->NumElements() * sizeof(T));
460 stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
461 if (output_backprop.NumElements() <= 0) return;
462 if (OpDeterminismRequired()) {
463 // ComputeWithReduceSum is the only deterministic algorithm.
464 ComputeWithReduceSum(context, output_backprop, batch, width, height,
465 depth, channel, output);
466 return;
467 }
468
469 int device_id = stream->parent()->device_ordinal();
470 DataType dtype = output_backprop.dtype();
471 BiasAddParams bias_parameters = {
472 {batch, height * width * depth, channel},
473 data_format_,
474 dtype,
475 device_id,
476 };
477
478 // Autotune two algorithm: customized
479 BiasAddGradGPUConfig algo_config;
480 if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) {
481 profiler::ScopedAnnotation trace("bias_grad_autotuning");
482
483 BiasGradGPUProfileResult best_result;
484 // Initialize the timer.
485 perftools::gputools::Timer timer(stream->parent());
486 stream->InitTimer(&timer);
487 stream->ThenStartTimer(&timer);
488 ComputeWithCustomKernel(context, output_backprop, batch, width, height,
489 depth, channel, output);
490 stream->ThenStopTimer(&timer);
491 uint64 elapsed_microseconds = timer.Microseconds();
492 VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
493 << " Native algo latency: " << elapsed_microseconds;
494 if (elapsed_microseconds < best_result.elapsed_time()) {
495 best_result.set_algorithm(BiasAddGradGPUMode::kNative);
496 best_result.set_elapsed_time(elapsed_microseconds);
497 }
498
499 // Try reduction and profile.
500 stream->ThenStartTimer(&timer);
501 ComputeWithReduceSum(context, output_backprop, batch, width, height,
502 depth, channel, output);
503 stream->ThenStopTimer(&timer);
504
505 elapsed_microseconds = timer.Microseconds();
506 VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
507 << " Reduction algo latency: " << elapsed_microseconds;
508 if (elapsed_microseconds < best_result.elapsed_time()) {
509 best_result.set_algorithm(BiasAddGradGPUMode::kReduction);
510 best_result.set_elapsed_time(elapsed_microseconds);
511 }
512
513 algo_config.set_mode(best_result.algorithm());
514 AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config);
515
516 // Results are already available during autotune, so no need to continue.
517 return;
518 }
519
520 // Choose the best algorithm based on autotune results.
521 if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
522 ComputeWithReduceSum(context, output_backprop, batch, width, height,
523 depth, channel, output);
524 } else {
525 // Default to the customized kernel.
526 ComputeWithCustomKernel(context, output_backprop, batch, width, height,
527 depth, channel, output);
528 }
529 }
530
531 private:
532 TensorFormat data_format_;
533};
534
535// Registration of the GPU implementations.
536#define REGISTER_GPU_KERNEL(type) \
537 REGISTER_KERNEL_BUILDER( \
538 Name("BiasAddGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
539 BiasGradOp<GPUDevice, type>);
540
541TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
542#undef REGISTER_GPU_KERNEL
543
544#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
545
546} // namespace tensorflow
547