1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// LRN = Local Response Normalization
17// See docs in ../ops/nn_ops.cc.
18
19#define EIGEN_USE_THREADS
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/bounds_check.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/kernels/ops_util.h"
27#include "tensorflow/core/lib/core/errors.h"
28
29#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
30#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
31#endif
32
33#if !defined(IS_MOBILE_PLATFORM)
34#include "tensorflow/core/util/work_sharder.h"
35#endif
36
37#if GOOGLE_CUDA
38#include "third_party/gpus/cuda/include/cuda.h"
39#endif
40
41#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
42#include "tensorflow/core/kernels/conv_2d.h"
43#include "tensorflow/core/kernels/gpu_utils.h"
44#if TENSORFLOW_USE_ROCM
45#include "tensorflow/core/kernels/conv_ops_gpu.h"
46#endif
47#include "tensorflow/core/platform/stream_executor.h"
48#include "tensorflow/core/util/stream_executor_util.h"
49#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50
51namespace tensorflow {
52
53namespace {
54
55// When the depth is large and beta_ is 0.5 or 1.0, Single-threaded
56// LRN is faster than the main band matrix approach used
57// below. Benchmarks suggest switching to SingleThreadedLRN when depth > 384.
58const int kSingleThreadedLRNDepthCutoff = 384;
59
60// Create a depth-by-depth band matrix with 1s along a swath of size (2 *
61// depth_radius + 1) around the diagonal.
62template <typename T>
63void GetBandMatrix(int depth, int depth_radius,
64 Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
65 result->setZero();
66 for (int row = 0; row < depth; ++row) {
67 const int begin = std::max<int>(0, row - depth_radius);
68 const int end = std::min<int>(depth, row + depth_radius + 1);
69 Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
70 Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
71 result->slice(start, sizes).setConstant(T(1));
72 }
73}
74
75} // namespace
76
77typedef Eigen::ThreadPoolDevice CPUDevice;
78typedef Eigen::GpuDevice GPUDevice;
79
80template <typename Device, typename T>
81struct LaunchLRN;
82
83template <typename T>
84struct LaunchLRN<CPUDevice, T> {
85 LaunchLRN(int depth_radius, T bias, T alpha, T beta)
86 : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
87
88 void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
89 Tensor* output) {
90 const int batch = static_cast<int>(in.dim_size(0));
91 const int rows = static_cast<int>(in.dim_size(1));
92 const int cols = static_cast<int>(in.dim_size(2));
93 const int depth = static_cast<int>(in.dim_size(3));
94
95#if defined(IS_MOBILE_PLATFORM)
96 SingleThreadedLRN(in, batch, rows, cols, depth, output);
97#else
98 const int nodes = cols * rows;
99 if (depth > kSingleThreadedLRNDepthCutoff &&
100 (beta_ == T(0.5) || beta_ == T(1))) {
101 SingleThreadedLRN(in, batch, rows, cols, depth, output);
102 return;
103 }
104
105 auto in_shaped = in.shaped<T, 2>({nodes * batch, depth});
106
107 // Multiplying the input with the band matrix has the effect of reducing the
108 // correct patch along the depth.
109 Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
110 GetBandMatrix<T>(depth, depth_radius_, &multiplier);
111
112 auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
113 Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
114 auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
115 if (beta_ == T(1)) {
116 out_shaped.device(context->eigen_cpu_device()) =
117 in_shaped * tmp.inverse();
118 } else if (beta_ == T(0.5)) {
119 out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt();
120 } else {
121 out_shaped.device(context->eigen_cpu_device()) =
122 in_shaped * (tmp.log() * -beta_).exp();
123 }
124#endif
125 }
126
127 private:
128 typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
129
130 void SingleThreadedLRN(const Tensor& in, const int batch, const int rows,
131 const int cols, const int depth, Tensor* out) {
132 Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> data_in(
133 in.flat<T>().data(), depth, batch * rows * cols);
134
135 Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> data_out(
136 out->flat<T>().data(), depth, batch * rows * cols);
137
138 const int double_depth_radius = depth_radius_ * 2;
139 Eigen::Matrix<T, Eigen::Dynamic, 1> padded_square(data_in.rows() +
140 double_depth_radius);
141 padded_square.setZero();
142 for (int r = 0; r < data_in.cols(); ++r) {
143 // Do local response normalization for data_in(:, r). First, compute the
144 // square and store them in buffer for repeated use.
145 padded_square.block(depth_radius_, 0, data_out.rows(), 1) =
146 data_in.col(r).cwiseProduct(data_in.col(r)) * alpha_;
147 // Then, compute the scale and write it to data_out.
148 T accumulated_scale(0);
149 for (int i = 0; i < double_depth_radius; ++i) {
150 accumulated_scale += padded_square(i);
151 }
152 for (int i = 0; i < data_in.rows(); ++i) {
153 accumulated_scale += padded_square(i + double_depth_radius);
154 data_out(i, r) = bias_ + accumulated_scale;
155 accumulated_scale -= padded_square(i);
156 }
157 }
158
159 if (beta_ == T(1)) {
160 data_out.array() = data_in.array() * data_out.array().inverse();
161 } else if (beta_ == T(0.5)) {
162 data_out.array() = data_in.array() * data_out.array().rsqrt();
163 } else {
164 data_out.array() =
165 data_in.array() * (data_out.array().log() * -beta_).exp();
166 }
167 }
168
169 int depth_radius_;
170 T bias_;
171 T alpha_;
172 T beta_;
173};
174
175#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
176
177template <typename T>
178struct LaunchLRN<GPUDevice, T> {
179 LaunchLRN(int depth_radius, T bias, T alpha, T beta)
180 : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
181
182 void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
183 Tensor* output) {
184#if GOOGLE_CUDA
185 OP_REQUIRES(
186 context, beta_ >= 0.01,
187 errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
188
189 OP_REQUIRES(
190 context, depth_radius_ > 0 && depth_radius_ <= 7,
191 errors::InvalidArgument("cuDNN requires depth_radius in [1, 7], got: ",
192 depth_radius_));
193 OP_REQUIRES(
194 context, bias_ >= 1e-5,
195 errors::InvalidArgument("cuDNN requires bias >= 1e-5, got: ", bias_));
196
197 // Cast to platform-specific int to avoid conversion warnings.
198 const int batch = static_cast<int>(in.dim_size(0));
199 const int rows = static_cast<int>(in.dim_size(1));
200 const int cols = static_cast<int>(in.dim_size(2));
201 const int depth = static_cast<int>(in.dim_size(3));
202
203 se::dnn::BatchDescriptor dimensions_desc;
204 dimensions_desc.set_count(batch)
205 .set_height(rows)
206 .set_width(cols)
207 .set_feature_map_count(depth)
208 .set_layout(se::dnn::DataLayout::kBatchYXDepth);
209
210 se::dnn::NormalizeDescriptor normalize_desc;
211 normalize_desc.set_bias(bias_)
212 .set_range(depth_radius_)
213 .set_alpha(alpha_)
214 .set_beta(beta_);
215
216 auto input_data = StreamExecutorUtil::AsDeviceMemory<T>(in);
217 auto output_data = StreamExecutorUtil::AsDeviceMemory<T>(*output);
218
219 auto* stream = context->op_device_context()->stream();
220 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
221
222 bool status =
223 stream
224 ->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc,
225 input_data, &output_data)
226 .ok();
227 OP_REQUIRES(context, status,
228 errors::Internal("NormalizeWithDimensions launch failed"));
229#elif TENSORFLOW_USE_ROCM
230 // For NHWC input/output tensors, convert to NCHW because it's the only
231 // supported format in MIOpen for now.
232
233 // Cast to platform-specific int to avoid conversion warnings.
234 const int batch = static_cast<int>(in.dim_size(0));
235 const int rows = static_cast<int>(in.dim_size(1));
236 const int cols = static_cast<int>(in.dim_size(2));
237 const int depth = static_cast<int>(in.dim_size(3));
238
239 Tensor transformed_input;
240 OP_REQUIRES_OK(context,
241 context->allocate_temp(
242 DataTypeToEnum<T>::value,
243 ShapeFromFormat(FORMAT_NCHW, in.shape(), FORMAT_NHWC),
244 &transformed_input));
245 functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
246 in.tensor<T, 4>(),
247 transformed_input.tensor<T, 4>());
248
249 Tensor transformed_output;
250 OP_REQUIRES_OK(
251 context, context->allocate_temp(
252 DataTypeToEnum<T>::value,
253 ShapeFromFormat(FORMAT_NCHW, output->shape(), FORMAT_NHWC),
254 &transformed_output));
255
256 perftools::gputools::dnn::BatchDescriptor dimensions_desc;
257 dimensions_desc.set_count(batch)
258 .set_height(rows)
259 .set_width(cols)
260 .set_feature_map_count(depth)
261 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
262
263 perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
264 normalize_desc.set_bias(bias_)
265 .set_range(depth_radius_)
266 .set_alpha(alpha_)
267 .set_beta(beta_);
268
269 auto input_data =
270 AsDeviceMemory(transformed_input.template flat<T>().data(),
271 transformed_input.template flat<T>().size());
272 auto output_data =
273 AsDeviceMemory(transformed_output.template flat<T>().data(),
274 transformed_output.template flat<T>().size());
275
276 auto* stream = context->op_device_context()->stream();
277 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
278
279 bool status =
280 stream
281 ->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc,
282 input_data, &output_data)
283 .ok();
284 OP_REQUIRES(context, status,
285 errors::Internal("NormalizeWithDimensions launch failed"));
286
287 // Need to convert it back to NHWC once MIOpen kernels finishes.
288 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
289 functor::NCHWToNHWC<GPUDevice, T, 4>()(
290 context->eigen_device<GPUDevice>(),
291 toConstTensor(transformed_output).template tensor<T, 4>(),
292 output->tensor<T, 4>());
293#endif
294 }
295
296 int depth_radius_;
297 T bias_;
298 T alpha_;
299 T beta_;
300};
301
302#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
303
304template <typename Device, typename T>
305class LRNOp : public OpKernel {
306 public:
307 explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) {
308 int64_t depth_radius64;
309 OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
310 OP_REQUIRES(
311 context,
312 FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
313 errors::InvalidArgument("depth_radius = ", depth_radius64,
314 " larger than int max"));
315 depth_radius_ = static_cast<int>(depth_radius64);
316 float tmp;
317 OP_REQUIRES_OK(context, context->GetAttr("bias", &tmp));
318 bias_ = T(tmp);
319 OP_REQUIRES_OK(context, context->GetAttr("alpha", &tmp));
320 alpha_ = T(tmp);
321 OP_REQUIRES_OK(context, context->GetAttr("beta", &tmp));
322 beta_ = T(tmp);
323 }
324
325 void Compute(OpKernelContext* context) override {
326 const Tensor& in = context->input(0);
327 OP_REQUIRES(context, in.dims() == 4,
328 errors::InvalidArgument("in must be 4-dimensional"));
329 OP_REQUIRES(
330 context,
331 FastBoundsCheck(in.NumElements(), std::numeric_limits<int>::max()),
332 errors::InvalidArgument("argument to LRN too large"));
333 // Cast to platform-specific int to avoid conversion warnings.
334 const int batch = static_cast<int>(in.dim_size(0));
335 const int rows = static_cast<int>(in.dim_size(1));
336 const int cols = static_cast<int>(in.dim_size(2));
337 const int depth = static_cast<int>(in.dim_size(3));
338
339 OP_REQUIRES(context,
340 (depth + depth_radius_) <= std::numeric_limits<int>::max(),
341 errors::InvalidArgument("depth ", depth, " + depth_radius ",
342 depth_radius_, " exceeds int max."));
343
344 Tensor* output = nullptr;
345 OP_REQUIRES_OK(context,
346 context->allocate_output(
347 0, TensorShape({batch, rows, cols, depth}), &output));
348
349 LaunchLRN<Device, T> launcher(depth_radius_, bias_, alpha_, beta_);
350 launcher.launch(context, this, in, output);
351 }
352
353 private:
354 int depth_radius_;
355 T bias_;
356 T alpha_;
357 T beta_;
358};
359
360#define REGISTER_CPU(T) \
361 REGISTER_KERNEL_BUILDER( \
362 Name("LRN").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
363 LRNOp<CPUDevice, T>);
364TF_CALL_float(REGISTER_CPU);
365TF_CALL_half(REGISTER_CPU);
366
367#undef REGISTER_CPU
368
369#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
370
371#define REGISTER_GPU(T) \
372 REGISTER_KERNEL_BUILDER( \
373 Name("LRN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
374 LRNOp<GPUDevice, T>);
375TF_CALL_float(REGISTER_GPU);
376
377#undef REGISTER_GPU
378
379#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
380
381#if !defined(IS_MOBILE_PLATFORM)
382
383template <typename Device, typename T>
384struct LaunchLRNGrad;
385
386template <typename T>
387struct LaunchLRNGrad<CPUDevice, T> {
388 LaunchLRNGrad(int depth_radius, T bias, T alpha, T beta)
389 : depth_radius_(depth_radius),
390 bias_(bias),
391 alpha_(alpha),
392 beta_(beta),
393 alpha_beta_2_(T(-2) * alpha * beta) {}
394
395 void launch(OpKernelContext* context, OpKernel* kernel,
396 const Tensor& in_grads, const Tensor& in_image,
397 const Tensor& out_image, Tensor* output) {
398 const int64_t batch = in_grads.dim_size(0);
399 const int64_t rows = in_grads.dim_size(1);
400 const int64_t cols = in_grads.dim_size(2);
401 const int64_t depth = in_grads.dim_size(3);
402 const auto nodes = cols * rows;
403 auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth});
404 auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth});
405 auto activations = out_image.shaped<T, 2>({nodes * batch, depth});
406
407 auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
408 out_shaped.setZero();
409
410 auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
411 depth](int64_t begin, int64_t end) {
412 for (int64_t i = begin; i < end; ++i) {
413 for (int64_t j = 0; j < depth; ++j) {
414 // Let y be the LRN activations and x be the inputs along the depth
415 // dimension. (LRN operates independently along rows, cols, and
416 // batch).
417 // We have
418 // yi = xi / (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius}
419 // x_j^2))^beta
420 //
421 // Let N = (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius}
422 // x_j^2))
423 // dy_i/dx_i = (N^beta - xi. beta*N^(beta-1)*2*alpha*xi)/N^(2*beta)
424 // dy_i/dx_j = ( - xi. beta*N^(beta-1)*2*alpha*xj)/N^(2*beta)
425 //
426 // NOTE(keveman) : We can compute N by doing (yi/xi) ^ (1/beta).
427 // However, this is numerically unstable for small values of xi. We
428 // compute N explicitly here to avoid that.
429
430 T gs = grads_shaped(i, j);
431 if (gs == T(0)) continue;
432
433 int64_t depth_begin = std::max<int64_t>(0, j - depth_radius_);
434 int64_t depth_end = std::min<int64_t>(depth, j + depth_radius_ + 1);
435
436 T norm(0);
437 for (int64_t k = depth_begin; k < depth_end; ++k) {
438 norm += in_shaped(i, k) * in_shaped(i, k);
439 }
440 norm = alpha_ * norm + bias_;
441 DCHECK_GT(norm, T(1e-6));
442 T pre_computed_pow = Eigen::numext::pow(norm, -beta_);
443 T activations_ab2 = alpha_beta_2_ * activations(i, j);
444 for (int64_t k = depth_begin; k < depth_end; ++k) {
445 T dyi = in_shaped(i, k) * activations_ab2 / norm;
446 if (k == j) {
447 dyi += pre_computed_pow;
448 }
449 dyi *= gs;
450 const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) += dyi;
451 }
452 }
453 }
454 };
455 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
456 Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
457 depth * depth, shard);
458 }
459
460 int depth_radius_;
461 T bias_;
462 T alpha_;
463 T beta_;
464 T alpha_beta_2_;
465};
466
467#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
468
469template <typename T>
470struct LaunchLRNGrad<GPUDevice, T> {
471 LaunchLRNGrad(int depth_radius, T bias, T alpha, T beta)
472 : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
473
474 void launch(OpKernelContext* context, OpKernel* kernel,
475 const Tensor& in_grads, const Tensor& in_image,
476 const Tensor& out_image, Tensor* output) {
477#if GOOGLE_CUDA
478 OP_REQUIRES(
479 context, beta_ >= 0.01,
480 errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
481
482 OP_REQUIRES(
483 context, depth_radius_ > 0 && depth_radius_ <= 7,
484 errors::InvalidArgument("cuDNN requires depth_radius in [1, 7], got: ",
485 depth_radius_));
486 OP_REQUIRES(
487 context, bias_ >= 1e-5,
488 errors::InvalidArgument("cuDNN requires bias >= 1e-5, got: ", bias_));
489
490 const int64_t batch = in_grads.dim_size(0);
491 const int64_t rows = in_grads.dim_size(1);
492 const int64_t cols = in_grads.dim_size(2);
493 const int64_t depth = in_grads.dim_size(3);
494
495 se::dnn::BatchDescriptor dimensions_desc;
496 dimensions_desc.set_count(batch)
497 .set_height(rows)
498 .set_width(cols)
499 .set_feature_map_count(depth)
500 .set_layout(se::dnn::DataLayout::kBatchYXDepth);
501
502 se::dnn::NormalizeDescriptor normalize_desc;
503 normalize_desc.set_bias(bias_)
504 .set_range(depth_radius_)
505 .set_alpha(alpha_)
506 .set_beta(beta_);
507
508 auto input_grads_data = StreamExecutorUtil::AsDeviceMemory<T>(in_grads);
509 auto input_image_data = StreamExecutorUtil::AsDeviceMemory<T>(in_image);
510 auto output_image_data = StreamExecutorUtil::AsDeviceMemory<T>(out_image);
511 auto output_grads_data = StreamExecutorUtil::AsDeviceMemory<T>(*output);
512
513 auto* stream = context->op_device_context()->stream();
514 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
515
516 bool status =
517 stream
518 ->ThenNormalizeBackwardWithDimensions(
519 normalize_desc, dimensions_desc, input_image_data,
520 output_image_data, input_grads_data, &output_grads_data)
521 .ok();
522 OP_REQUIRES(
523 context, status,
524 errors::Internal("NormalizeBackwardWithDimensions launch failed"));
525#elif TENSORFLOW_USE_ROCM
526 // For NHWC input/output tensors, convert to NCHW because it's the only
527 // supported format in MIOpen for now.
528 const int64 batch = in_grads.dim_size(0);
529 const int64 rows = in_grads.dim_size(1);
530 const int64 cols = in_grads.dim_size(2);
531 const int64 depth = in_grads.dim_size(3);
532
533 Tensor transformed_in_grads;
534 OP_REQUIRES_OK(context, context->allocate_temp(
535 DataTypeToEnum<T>::value,
536 ShapeFromFormat(FORMAT_NCHW, in_grads.shape(),
537 FORMAT_NHWC),
538 &transformed_in_grads));
539 functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
540 in_grads.tensor<T, 4>(),
541 transformed_in_grads.tensor<T, 4>());
542
543 Tensor transformed_in_image;
544 OP_REQUIRES_OK(context, context->allocate_temp(
545 DataTypeToEnum<T>::value,
546 ShapeFromFormat(FORMAT_NCHW, in_image.shape(),
547 FORMAT_NHWC),
548 &transformed_in_image));
549 functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
550 in_image.tensor<T, 4>(),
551 transformed_in_image.tensor<T, 4>());
552
553 Tensor transformed_out_image;
554 OP_REQUIRES_OK(context, context->allocate_temp(
555 DataTypeToEnum<T>::value,
556 ShapeFromFormat(FORMAT_NCHW, out_image.shape(),
557 FORMAT_NHWC),
558 &transformed_out_image));
559 functor::NHWCToNCHW<GPUDevice, T, 4>()(
560 context->eigen_device<GPUDevice>(), out_image.tensor<T, 4>(),
561 transformed_out_image.tensor<T, 4>());
562
563 Tensor transformed_output;
564 OP_REQUIRES_OK(
565 context, context->allocate_temp(
566 DataTypeToEnum<T>::value,
567 ShapeFromFormat(FORMAT_NCHW, output->shape(), FORMAT_NHWC),
568 &transformed_output));
569
570 perftools::gputools::dnn::BatchDescriptor dimensions_desc;
571 dimensions_desc.set_count(batch)
572 .set_height(rows)
573 .set_width(cols)
574 .set_feature_map_count(depth)
575 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
576
577 perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
578 normalize_desc.set_bias(bias_)
579 .set_range(depth_radius_)
580 .set_alpha(alpha_)
581 .set_beta(beta_);
582
583 auto input_grads_data =
584 AsDeviceMemory(transformed_in_grads.template flat<T>().data(),
585 transformed_in_grads.template flat<T>().size());
586 auto input_image_data =
587 AsDeviceMemory(transformed_in_image.template flat<T>().data(),
588 transformed_in_image.template flat<T>().size());
589 auto output_image_data =
590 AsDeviceMemory(transformed_out_image.template flat<T>().data(),
591 transformed_out_image.template flat<T>().size());
592 auto output_grads_data =
593 AsDeviceMemory(transformed_output.template flat<T>().data(),
594 transformed_output.template flat<T>().size());
595
596 auto* stream = context->op_device_context()->stream();
597 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
598
599 static int64 NormalizeBackwardScratchSize = GetDnnWorkspaceLimit(
600 // default value is in bytes despite the name of the environment
601 // variable
602 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
603 );
604
605 DnnScratchAllocator scratch_allocator(NormalizeBackwardScratchSize,
606 context);
607 bool status = stream
608 ->ThenNormalizeBackwardWithDimensions(
609 normalize_desc, dimensions_desc, input_image_data,
610 output_image_data, input_grads_data,
611 &output_grads_data, &scratch_allocator)
612 .ok();
613 OP_REQUIRES(
614 context, status,
615 errors::Internal("NormalizeBackwardWithDimensions launch failed"));
616
617 // Need to convert it back to NHWC once MIOpen kernels finishes.
618 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
619 functor::NCHWToNHWC<GPUDevice, T, 4>()(
620 context->eigen_device<GPUDevice>(),
621 toConstTensor(transformed_output).template tensor<T, 4>(),
622 output->tensor<T, 4>());
623#endif
624 }
625
626 int depth_radius_;
627 T bias_;
628 T alpha_;
629 T beta_;
630};
631
632#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
633
634template <typename Device, typename T>
635class LRNGradOp : public OpKernel {
636 public:
637 explicit LRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
638 int64_t depth_radius64;
639 OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
640 OP_REQUIRES(
641 context,
642 FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
643 errors::InvalidArgument("depth_radius = ", depth_radius64,
644 " larger than int max"));
645 depth_radius_ = static_cast<int>(depth_radius64);
646 float tmp;
647 OP_REQUIRES_OK(context, context->GetAttr("bias", &tmp));
648 bias_ = T(tmp);
649 OP_REQUIRES_OK(context, context->GetAttr("alpha", &tmp));
650 alpha_ = T(tmp);
651 OP_REQUIRES_OK(context, context->GetAttr("beta", &tmp));
652 beta_ = T(tmp);
653 }
654
655 void Compute(OpKernelContext* context) override {
656 const Tensor& in_grads = context->input(0);
657 const Tensor& in_image = context->input(1);
658 const Tensor& out_image = context->input(2);
659
660 OP_REQUIRES(context, in_grads.dims() == 4 && in_image.dims() == 4,
661 errors::InvalidArgument("inputs must be 4-dimensional"));
662 const int64_t batch = in_grads.dim_size(0);
663 const int64_t rows = in_grads.dim_size(1);
664 const int64_t cols = in_grads.dim_size(2);
665 const int64_t depth = in_grads.dim_size(3);
666 OP_REQUIRES(
667 context,
668 in_image.dim_size(0) == batch && in_image.dim_size(1) == rows &&
669 in_image.dim_size(2) == cols && in_image.dim_size(3) == depth &&
670 out_image.dim_size(0) == batch && out_image.dim_size(1) == rows &&
671 out_image.dim_size(2) == cols && out_image.dim_size(3) == depth &&
672 out_image.dims() == 4,
673 errors::InvalidArgument(
674 "input_grads, input_image, and out_image should have the same "
675 "shape"));
676
677 Tensor* output = nullptr;
678 OP_REQUIRES_OK(context,
679 context->allocate_output(
680 0, TensorShape({batch, rows, cols, depth}), &output));
681
682 LaunchLRNGrad<Device, T> launcher(depth_radius_, bias_, alpha_, beta_);
683 launcher.launch(context, this, in_grads, in_image, out_image, output);
684 }
685
686 private:
687 int depth_radius_;
688 T bias_;
689 T alpha_;
690 T beta_;
691};
692
693#define REGISTER_CPU(T) \
694 REGISTER_KERNEL_BUILDER( \
695 Name("LRNGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
696 LRNGradOp<CPUDevice, T>);
697TF_CALL_float(REGISTER_CPU);
698TF_CALL_half(REGISTER_CPU);
699
700#undef REGISTER_CPU
701
702#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
703
704#define REGISTER_GPU(T) \
705 REGISTER_KERNEL_BUILDER( \
706 Name("LRNGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
707 LRNGradOp<GPUDevice, T>);
708TF_CALL_float(REGISTER_GPU);
709
710#undef REGISTER_GPU
711
712#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
713
714#endif // !defined(IS_MOBILE_PLATFORM)
715
716} // namespace tensorflow
717