1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implements convolution operations with image transformations (resize and
17// mirror padding) baked into the processing, to optimize latency and memory
18// usage.
19
20#define EIGEN_USE_THREADS
21
22#include <string>
23#include <vector>
24
25#include "tensorflow/core/framework/bounds_check.h"
26#include "tensorflow/core/framework/kernel_shape_util.h"
27#include "tensorflow/core/framework/numeric_op.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/register_types.h"
30#include "tensorflow/core/framework/resource_mgr.h"
31#include "tensorflow/core/framework/tensor.h"
32#include "tensorflow/core/framework/tensor_shape.h"
33#include "tensorflow/core/framework/tensor_slice.h"
34#include "tensorflow/core/kernels/conv_2d.h"
35#include "tensorflow/core/kernels/conv_ops.h"
36#include "tensorflow/core/kernels/gemm_functors.h"
37#include "tensorflow/core/kernels/ops_util.h"
38#include "tensorflow/core/lib/core/threadpool.h"
39#include "tensorflow/core/util/image_resizer_state.h"
40#include "tensorflow/core/util/mirror_pad_mode.h"
41#include "tensorflow/core/util/padding.h"
42#include "tensorflow/core/util/tensor_format.h"
43
44namespace tensorflow {
45namespace {
46
47// We don't want to allocate a buffer to hold all the patches if the size is
48// going to be extremely large, so break it into chunks if it's bigger than
49// a limit. Each chunk will be processed serially, so we can refill the
50// buffer for the next chunk and reuse it, keeping maximum memory size down.
51// In this case, we've picked 16 megabytes as a reasonable limit for Android and
52// other platforms using Eigen, and 1MB for iOS devices, from experimentation.
53#if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
54const size_t kMaxChunkSize = (1 * 1024 * 1024);
55#else
56const size_t kMaxChunkSize = (16 * 1024 * 1024);
57#endif
58const size_t kResizeCacheSize = (8 * 1024 * 1024);
59
60// Lookup method used when resizing.
61enum SamplingMode {
62 BILINEAR = 0,
63 NEAREST = 1,
64};
65
66// Simple utility function used by FusedConv to multithread basic workloads. To
67// use it, pass begin and end values for the full workload and a std::function
68// that receives a subset of that through the begin and end values for each
69// worker's task. The division of the full workload into worker tasks is handled
70// by the multithreading logic. Here's an example of how to use it:
71// std::vector<float> my_vector(100);
72// ...
73// FusedConvParallelFor(context, 0, 100,
74// [&my_vector](int64 task_begin, int64 task_end) {
75// for (int64 current = task_begin; current != task_end; ++current) {
76// my_vector[current] *= 10.0f;
77// }
78// });
79void FusedConvParallelFor(
80 OpKernelContext* context, int64_t begin, int64_t end,
81 const std::function<void(int64_t, int64_t)>& task_function) {
82// On iOS, the thread management imposes a very big performance penalty, so
83// just call the function directly with no multithreading.
84#if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
85 task_function(begin, end);
86#else
87 auto& worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
88 thread::ThreadPool* thread_pool = worker_threads.workers;
89 const int64_t total_elements = end - begin;
90 // This is a bit of an arbitrary number, but was found to work well for
91 // typical models we've been profiling on various devices.
92 const int64_t element_cost = 10000000;
93 thread_pool->ParallelFor(
94 total_elements, element_cost,
95 [begin, task_function](int64_t begin_offset, int64_t end_offset) {
96 const int64_t task_begin = begin + begin_offset;
97 const int64_t task_end = begin + end_offset;
98 task_function(task_begin, task_end);
99 });
100#endif
101}
102
103// Holds the state needed for the resizing subtasks.
104template <class T1>
105struct ResizeTaskParameters {
106 ResizeTaskParameters() : st(false, false) {}
107
108 int cache_height;
109 T1* resize_cache;
110 int cache_line_width;
111 int input_width;
112 int input_depth;
113 int top_padding;
114 int pad_offset;
115 int64_t resized_height;
116 ImageResizerState st;
117 const T1* input_batch_start;
118 int64_t cache_start_x;
119 int64_t cache_end_x;
120 int left_padding;
121 int64_t resized_width;
122 int64_t padded_width;
123 int64_t padded_height;
124};
125
126template <class T1>
127struct PerCacheLineParameters {
128 PerCacheLineParameters() {}
129 PerCacheLineParameters(const PerCacheLineParameters<T1>& other)
130 : cache_line_start(other.cache_line_start),
131 input_top_row_start(other.input_top_row_start),
132 input_bottom_row_start(other.input_bottom_row_start),
133 y_lerp(other.y_lerp) {}
134
135 T1* cache_line_start;
136 const T1* input_top_row_start;
137 const T1* input_bottom_row_start;
138 T1 y_lerp;
139};
140
141// Helper class to simplify bilinear filtering
142template <class T1>
143struct SampleRect {
144 EIGEN_ALWAYS_INLINE SampleRect(const T1* in_top_left, const T1* in_top_right,
145 const T1* in_bottom_left,
146 const T1* in_bottom_right)
147 : top_left(in_top_left),
148 top_right(in_top_right),
149 bottom_left(in_bottom_left),
150 bottom_right(in_bottom_right) {}
151
152 EIGEN_ALWAYS_INLINE T1 BilinearSample(int channel, T1 x_lerp,
153 T1 y_lerp) const {
154 const T1 top =
155 top_left[channel] + (top_right[channel] - top_left[channel]) * x_lerp;
156 const T1 bottom = bottom_left[channel] +
157 (bottom_right[channel] - bottom_left[channel]) * x_lerp;
158 return top + (bottom - top) * y_lerp;
159 }
160
161 const T1* top_left;
162 const T1* top_right;
163 const T1* bottom_left;
164 const T1* bottom_right;
165};
166
167// Calculates parameters which remain constant through a resize cache row.
168template <class T1>
169EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters(
170 int64_t cache_height, int64_t cache_y, T1* resize_cache,
171 int64_t cache_line_width, int64_t input_width, int64_t input_depth,
172 int64_t top_padding, int64_t pad_offset, int64_t resized_height,
173 const ImageResizerState& st, const T1* input_batch_start) {
174 PerCacheLineParameters<T1> result;
175 // The cache is organized so that the real y values of the resized image map
176 // onto the actual cache values through a modulo scheme. This means that as we
177 // progress downwards through the image, we keep reusing a small cache and so
178 // keep memory usage down.
179 int64_t cache_index_y;
180 if (cache_y < 0) {
181 cache_index_y = cache_height + (cache_y % cache_height);
182 } else {
183 cache_index_y = cache_y % cache_height;
184 }
185 result.cache_line_start =
186 resize_cache + (cache_index_y * cache_line_width * input_depth);
187 // This part is implementing the mirror padding that happens before resizing.
188 float in_y = (cache_y - top_padding);
189 if (in_y < 0) {
190 in_y = -(in_y + 1.0f - pad_offset);
191 } else if (in_y >= resized_height) {
192 in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
193 }
194 // Here's where to do the actual resize.
195 in_y *= st.height_scale;
196 const int64_t top_y_index = static_cast<int64_t>(std::floor(in_y));
197 const int64_t bottom_y_index =
198 std::min(static_cast<int64_t>(std::ceil(in_y)), (st.in_height - 1));
199 // Lerp is used for bilinear filtering when that's needed.
200 result.y_lerp = static_cast<T1>(in_y - top_y_index);
201 // Which rows of the original input image to pull the values from.
202 result.input_top_row_start =
203 input_batch_start + (top_y_index * input_width * input_depth);
204 result.input_bottom_row_start =
205 input_batch_start + (bottom_y_index * input_width * input_depth);
206 return result;
207}
208
209template <class T1>
210struct PerCachePixelParameters {
211 PerCachePixelParameters() {}
212 PerCachePixelParameters(const PerCachePixelParameters<T1>& other)
213 : cache_line_pixel(other.cache_line_pixel),
214 left_x_index(other.left_x_index),
215 right_x_index(other.right_x_index),
216 x_lerp(other.x_lerp) {}
217
218 T1* cache_line_pixel;
219 int64_t left_x_index;
220 int64_t right_x_index;
221 T1 x_lerp;
222};
223
224// Pulls out common parameters used for every resized pixel.
225template <class T1>
226EIGEN_ALWAYS_INLINE PerCachePixelParameters<T1>
227CalculatePerCachePixelParameters(int64_t cache_x, int64_t cache_start_x,
228 T1* cache_line_start, int64_t input_depth,
229 int64_t left_padding, int64_t pad_offset,
230 int64_t resized_width,
231 const ImageResizerState& st) {
232 PerCachePixelParameters<T1> result;
233 // Figure out where we're going to store the results of our transform.
234 const int cache_index_x = cache_x - cache_start_x;
235 result.cache_line_pixel = cache_line_start + (cache_index_x * input_depth);
236 // Implement mirror padding by flipping in_x if it's off the edge.
237 float in_x = (cache_x - left_padding);
238 if (in_x < 0) {
239 in_x = -(in_x + 1.0f - pad_offset);
240 } else if (in_x >= resized_width) {
241 in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
242 }
243 // Resize the x parameters.
244 in_x *= st.width_scale;
245 // Get the x coordinates for the left and right pixels to pull from.
246 result.left_x_index = static_cast<int64_t>(std::floor(in_x));
247 result.right_x_index =
248 std::min(static_cast<int64_t>(std::ceil(in_x)), (st.in_width - 1));
249 // This x_lerp is used to blend pixels in bilinear filtering.
250 result.x_lerp = static_cast<T1>(in_x - result.left_x_index);
251 return result;
252}
253
254// Combines bilinear resizing and mirror padding into the im2col transformation
255// stage of convolution.
256template <class T1, class T2, class T3, class TGemmFunctor,
257 SamplingMode SampleMode>
258class FusedResizeAndPadConvFunctor {
259 public:
260 void operator()(OpKernelContext* context, const Tensor& input,
261 int input_batches, int resized_height, int resized_width,
262 int padded_height, int padded_width, int input_depth,
263 const T2* filter_data, int filter_height, int filter_width,
264 int filter_count, int stride_rows, int stride_cols,
265 Padding padding, T3* output_data, int output_height,
266 int output_width, const ImageResizerState& st,
267 int top_padding, int bottom_padding, int left_padding,
268 int right_padding, int pad_offset) {
269 if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
270 (input_depth <= 0)) {
271 LOG(WARNING) << "Conv2D was called with bad input dimensions: "
272 << input_batches << ", " << padded_height << ", "
273 << padded_width << ", " << input_depth;
274 return;
275 }
276 if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
277 LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
278 << filter_width << ", " << filter_height << ", "
279 << filter_count;
280 return;
281 }
282 if ((output_width <= 0) || (output_height <= 0)) {
283 LOG(WARNING) << "Conv2D was called with bad output width or height: "
284 << output_width << ", " << output_height;
285 return;
286 }
287 OP_REQUIRES(
288 context, ((SampleMode == NEAREST) || (SampleMode == BILINEAR)),
289 errors::InvalidArgument("Bad sample mode passed in", SampleMode));
290
291 // These calculations define how the patches will be positioned within the
292 // input image. The actual definitions are quite complex, and rely on the
293 // previously-calculated output size.
294 int filter_left_offset;
295 int filter_top_offset;
296 if (padding == VALID) {
297 filter_left_offset =
298 ((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
299 2;
300 filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
301 padded_height + 1) /
302 2;
303 } else {
304 filter_left_offset =
305 ((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
306 filter_top_offset =
307 ((output_height - 1) * stride_rows + filter_height - padded_height) /
308 2;
309 }
310
311 ResizeTaskParameters<T1> task_params;
312 task_params.input_depth = input_depth;
313 task_params.top_padding = top_padding;
314 task_params.pad_offset = pad_offset;
315 task_params.resized_height = resized_height;
316 task_params.st = st;
317 task_params.left_padding = left_padding;
318 task_params.resized_width = resized_width;
319 task_params.padded_width = padded_width;
320 task_params.padded_height = padded_height;
321
322 // The im2col buffer has # of patches rows, and # of filters cols.
323 // It's laid out like this, in row major order in memory:
324 // < filter value count >
325 // ^ +---------------------+
326 // patch | |
327 // count | |
328 // v +---------------------+
329 // Each patch row contains a filter_width x filter_height patch of the
330 // input, with the depth channel as the most contiguous in memory, followed
331 // by the width, then the height. This is the standard memory order in the
332 // image world if it helps to visualize it.
333 const int filter_value_count = filter_width * filter_height * input_depth;
334
335 OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize,
336 errors::InvalidArgument("Im2Col patch too large for buffer"));
337 const size_t patches_per_chunk =
338 kMaxChunkSize / (filter_value_count * sizeof(T1));
339 // Because memory allocation is very expensive on mobile platforms, try to
340 // allocate a persistent buffer that will be kept around between calls. We
341 // use TensorFlow's resource management to ensure that the memory will be
342 // released when the session is over.
343 Im2ColBufferResource<T1, kMaxChunkSize>* im2col_buffer_resource;
344 std::function<Status(Im2ColBufferResource<T1, kMaxChunkSize>**)> creator =
345 [](Im2ColBufferResource<T1, kMaxChunkSize>** resource) {
346 *resource = new Im2ColBufferResource<T1, kMaxChunkSize>();
347 return OkStatus();
348 };
349 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
350 "Conv2d", "im2col_buffer",
351 &im2col_buffer_resource, creator));
352
353 // Create a resize cache memory buffer that will hold the rows of
354 // transformed and mirror padded input pixels, ready to be copied
355 // into filter patches by im2col.
356 // It's laid out like this, in row major order in memory:
357 // < cache line width >
358 // ^ +--------------------+
359 // cache | |
360 // height | |
361 // v +--------------------+
362 // Each cache row contains a cache_line_width number of resized pixels,
363 // each with input_depth channels. The cache height is typically less than
364 // the full height the resized image would be, so it's filled up
365 // incrementally as we progress downwards through the input creating im2col
366 // patches.
367 task_params.cache_start_x = -filter_left_offset;
368 task_params.cache_end_x =
369 (((output_width - 1) * stride_cols) - filter_left_offset) +
370 filter_width;
371 task_params.cache_line_width =
372 task_params.cache_end_x - task_params.cache_start_x;
373 task_params.cache_height =
374 kResizeCacheSize / (task_params.cache_line_width * input_depth);
375 const int needed_resize_cache_count =
376 filter_height * task_params.cache_line_width * input_depth;
377 OP_REQUIRES(context,
378 (needed_resize_cache_count * sizeof(T1)) <= kResizeCacheSize,
379 errors::InvalidArgument("Input too large for resize cache"));
380 Im2ColBufferResource<T1, kResizeCacheSize>* resize_cache_resource;
381 std::function<Status(Im2ColBufferResource<T1, kResizeCacheSize>**)>
382 resize_creator =
383 [](Im2ColBufferResource<T1, kResizeCacheSize>** resource) {
384 *resource = new Im2ColBufferResource<T1, kResizeCacheSize>();
385 return OkStatus();
386 };
387 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
388 "Conv2d", "resize_cache",
389 &resize_cache_resource, resize_creator));
390
391 // This means that multiple ops can't be run simultaneously on different
392 // threads, because we have a single shared resource. The platforms this is
393 // aimed at have intra-op parallelism as their focus though, so it shouldn't
394 // be an issue.
395 mutex_lock lock_buffer(im2col_buffer_resource->mu);
396 core::ScopedUnref unref_buffer(im2col_buffer_resource);
397 T1* im2col_buffer = im2col_buffer_resource->data;
398
399 // This buffer is used as a fairly heavy-weight cache for the resized and
400 // mirrored inputs to the im2col operation. The problem is that we want to
401 // keep the memory usage down by not rendering the fully resized and padded
402 // input tensor to the convolution into an entire buffer. The first approach
403 // to avoid this was to fold the bilinear filtering and padding spatial
404 // transformations into the im2col lookup itself. This successfully reduced
405 // memory usage, but because im2col can access an individual pixel for many
406 // different patches, the extra overhead of doing the same bilinear lookups
407 // repeatedly became too expensive.
408 // The resize cache is designed to avoid this problem by keeping a
409 // horizontal slice of the resized and padded input to the im2col
410 // precalculated, so that repeated accesses to the same pixel from different
411 // filter patches can just be copied from this cache. It's organized as a
412 // horizontal slice stretching across the whole virtual image, and as high
413 // as the filter window, so that as the patch processing moves across all
414 // the pixels are present, and before a new row of patches is started any
415 // previously calculated rows that are needed are maintained, with new rows
416 // calculated as required.
417 mutex_lock resize_lock_buffer(resize_cache_resource->mu);
418 core::ScopedUnref unref_resized_cache(resize_cache_resource);
419 task_params.resize_cache = resize_cache_resource->data;
420
421 const T1* input_data = input.flat<T1>().data();
422 const int64_t input_height = input.shape().dim_sizes()[1];
423 task_params.input_width = input.shape().dim_sizes()[2];
424
425 int end_cached_lines = std::numeric_limits<int>::min();
426
427 for (int batch = 0; batch < input_batches; ++batch) {
428 task_params.input_batch_start =
429 input_data +
430 (batch * input_height * task_params.input_width * input_depth);
431 const int in_y_end =
432 ((output_height * stride_rows) - filter_top_offset) + filter_height;
433 for (int out_y = 0; out_y < output_height; ++out_y) {
434 const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
435 const int cache_start_y = std::max(in_y_origin, end_cached_lines);
436 const int cache_end_y = std::min(
437 in_y_end, std::max((in_y_origin + task_params.cache_height),
438 end_cached_lines));
439 if (end_cached_lines < (in_y_origin + filter_height)) {
440 // This call breaks up the work required for calculating the mirror
441 // padding and resizing across multiple threads.
442 FusedConvParallelFor(
443 context, cache_start_y, cache_end_y,
444 [task_params](int64_t task_cache_start_y,
445 int64_t task_cache_end_y) {
446 // This is a long and confusing function, but it's been laid out
447 // this way to help with performance on some intensive models.
448 // What it's doing is populating a cache of the original input
449 // image, after it's been bilinear resized and had its edges
450 // mirrored. This allows the following im2col code to access the
451 // transformed pixels from this cache, without having to
452 // repeatedly apply the expensive bilinear calculations as the
453 // same pixels are accessed by different patches.
454 // This is most effective when the stride is small and the
455 // filter size is large, since that's when pixels are reused
456 // most frequently as patches overlap.
457 for (int cache_y = task_cache_start_y;
458 cache_y < task_cache_end_y; ++cache_y) {
459 // We organize the cache as a series of rows, each containing
460 // all the transformed pixels for a given line in the image.
461 // This cache is big enough to hold at least a filter's height
462 // worth of rows, but typically more, limited by the size of
463 // the cache buffer.
464 // We don't allocate an entire image's worth of rows though,
465 // because we're trying to keep memory usage down, so as we
466 // progress downwards through the im2col we periodically
467 // refresh the cache so that the next lines that are needed
468 // for that operation are always present.
469 // Work out the parameters that remain constant across the
470 // row we're calculating.
471 PerCacheLineParameters<T1> line_params(
472 CalculatePerCacheLineParameters<T1>(
473 task_params.cache_height, cache_y,
474 task_params.resize_cache,
475 task_params.cache_line_width, task_params.input_width,
476 task_params.input_depth, task_params.top_padding,
477 task_params.pad_offset, task_params.resized_height,
478 task_params.st, task_params.input_batch_start));
479 // Iterate through the resize cache row we're filling in.
480 for (int cache_x = task_params.cache_start_x;
481 cache_x < task_params.cache_end_x; ++cache_x) {
482 // Figure out what we need for the cache pixel we're
483 // populating.
484 PerCachePixelParameters<T1> pixel_params(
485 CalculatePerCachePixelParameters<T1>(
486 cache_x, task_params.cache_start_x,
487 line_params.cache_line_start,
488 task_params.input_depth, task_params.left_padding,
489 task_params.pad_offset, task_params.resized_width,
490 task_params.st));
491 // If the access is off the left, right, top, or bottom of
492 // the resized image, the conv padding means we should set
493 // it to zero.
494 if ((cache_x < 0) ||
495 (cache_x >= task_params.padded_width) ||
496 (cache_y < 0) ||
497 (cache_y >= task_params.padded_height)) {
498 std::fill_n(pixel_params.cache_line_pixel,
499 task_params.input_depth, T1(0));
500 } else {
501 // There are two different sampling strategies for
502 // resizing. When using nearest, we can just do a
503 // straight copy of the pixel closest to our sample point,
504 // but bilinear requires a more complex calculation.
505 if (SampleMode == NEAREST) {
506 const T1* input_top_left_pixel =
507 line_params.input_top_row_start +
508 (pixel_params.left_x_index *
509 task_params.input_depth);
510
511 std::copy_n(input_top_left_pixel,
512 task_params.input_depth,
513 pixel_params.cache_line_pixel);
514 } else {
515 const SampleRect<T1> rect(
516 line_params.input_top_row_start +
517 (pixel_params.left_x_index *
518 task_params.input_depth),
519 line_params.input_top_row_start +
520 (pixel_params.right_x_index *
521 task_params.input_depth),
522 line_params.input_bottom_row_start +
523 (pixel_params.left_x_index *
524 task_params.input_depth),
525 line_params.input_bottom_row_start +
526 (pixel_params.right_x_index *
527 task_params.input_depth));
528 for (int in_channel = 0;
529 in_channel < task_params.input_depth;
530 ++in_channel) {
531 pixel_params.cache_line_pixel[in_channel] =
532 rect.BilinearSample(in_channel,
533 pixel_params.x_lerp,
534 line_params.y_lerp);
535 }
536 }
537 }
538 }
539 }
540 });
541 end_cached_lines = cache_end_y;
542 }
543 for (int out_x = 0; out_x < output_width; ++out_x) {
544 const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
545 const int patch_index = (batch * output_width * output_height) +
546 (out_y * output_width) + out_x;
547 const int patch_index_within_chunk = patch_index % patches_per_chunk;
548 T1* im2col_patch_start =
549 im2col_buffer + (patch_index_within_chunk * filter_value_count);
550 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
551 T1* im2col_row_start =
552 im2col_patch_start +
553 (filter_y * filter_width * task_params.input_depth);
554 const int conv_in_y = in_y_origin + filter_y;
555 int cache_index_y;
556 if (conv_in_y < 0) {
557 cache_index_y = task_params.cache_height +
558 (conv_in_y % task_params.cache_height);
559 } else {
560 cache_index_y = conv_in_y % task_params.cache_height;
561 }
562 T1* cache_line_start =
563 task_params.resize_cache +
564 (cache_index_y * task_params.cache_line_width *
565 task_params.input_depth);
566 T1* cache_filter_row_start =
567 cache_line_start + ((in_x_origin - task_params.cache_start_x) *
568 task_params.input_depth);
569 std::copy_n(cache_filter_row_start,
570 (filter_width * task_params.input_depth),
571 im2col_row_start);
572 }
573 const bool is_last_in_chunk =
574 (patch_index_within_chunk == (patches_per_chunk - 1));
575 const bool is_last_overall =
576 ((batch == (input_batches - 1)) &&
577 (out_y == (output_height - 1)) && (out_x == (output_width - 1)));
578 if (is_last_in_chunk || is_last_overall) {
579 // Now we've assembled a set of image patches into a matrix, apply
580 // a GEMM matrix multiply of the patches as rows, times the filter
581 // weights in columns, to get partial results in the output
582 // matrix.
583 const int how_many_patches = patch_index_within_chunk + 1;
584 const int m = how_many_patches;
585 const int n = filter_count;
586 const int k = filter_value_count;
587 const int lda = filter_value_count;
588 const int ldb = filter_count;
589 const int ldc = filter_count;
590 const size_t start_patch_index =
591 patch_index - (how_many_patches - 1);
592 T3* chunk_output_data =
593 output_data + (start_patch_index * filter_count);
594 TGemmFunctor gemm_functor;
595 gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb,
596 chunk_output_data, ldc);
597 }
598 }
599 }
600 }
601 }
602};
603
604} // namespace
605
606// Implements a version of convolution with bilinear resizing and mirror padding
607// included.
608template <class T, class TConvFunctor, bool DoResize>
609class FusedResizeConv2DUsingGemmOp : public OpKernel {
610 public:
611 explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
612 : OpKernel(context) {
613 if (DoResize) {
614 OP_REQUIRES_OK(context,
615 context->GetAttr("resize_align_corners", &align_corners_));
616 }
617 MirrorPadMode mode;
618 OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
619
620 switch (mode) {
621 case MirrorPadMode::SYMMETRIC: {
622 offset_ = 0;
623 break;
624 }
625 case MirrorPadMode::REFLECT: {
626 offset_ = 1;
627 break;
628 }
629 default:
630 OP_REQUIRES(context, false,
631 errors::InvalidArgument(
632 "mode must be either REFLECT or SYMMETRIC."));
633 }
634 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
635 OP_REQUIRES(context, strides_.size() == 4,
636 errors::InvalidArgument("Sliding window strides field must "
637 "specify 4 dimensions"));
638 const int64_t stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
639 const int64_t stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
640 OP_REQUIRES(
641 context, stride_n == 1 && stride_c == 1,
642 errors::InvalidArgument("Current implementation does not yet support "
643 "strides in the batch and depth dimensions."));
644 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
645 }
646
647 void Compute(OpKernelContext* context) override {
648 // Input tensor is of the following dimensions:
649 // [ batch, in_rows, in_cols, in_depth ]
650 const Tensor& input = context->input(0);
651 OP_REQUIRES(context, (input.shape().num_elements() > 0),
652 errors::InvalidArgument("Input tensor can't be empty"));
653
654 ImageResizerState st(false, false);
655 if (DoResize) {
656 st = ImageResizerState(align_corners_, false);
657 st.ValidateAndCalculateOutputSize(context);
658 if (!context->status().ok()) return;
659 } else {
660 // Set up the resize parameters to do no scaling at all.
661 st.batch_size = input.dim_size(0);
662 st.out_height = input.dim_size(1);
663 st.out_width = input.dim_size(2);
664 st.in_height = input.dim_size(1);
665 st.in_width = input.dim_size(2);
666 st.channels = input.dim_size(3);
667 st.height_scale = 1.0f;
668 st.width_scale = 1.0f;
669 }
670 TensorShape resized_shape(
671 {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
672 int paddings_index;
673 int filter_index;
674 if (DoResize) {
675 paddings_index = 2;
676 filter_index = 3;
677 } else {
678 paddings_index = 1;
679 filter_index = 2;
680 }
681 const Tensor& paddings = context->input(paddings_index);
682
683 const int dims = resized_shape.dims();
684 OP_REQUIRES(
685 context,
686 TensorShapeUtils::IsMatrix(paddings.shape()) &&
687 paddings.dim_size(1) == 2,
688 errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
689 paddings.shape().DebugString()));
690 OP_REQUIRES(
691 context, dims == paddings.dim_size(0),
692 errors::InvalidArgument(
693 "The first dimension of paddings must be the rank of inputs: ",
694 dims, " ", paddings.shape().DebugString(), " ",
695 resized_shape.DebugString()));
696 OP_REQUIRES(
697 context, dims == paddings.dim_size(0),
698 errors::InvalidArgument(
699 "The first dimension of paddings must be the rank of inputs: ",
700 dims, " ", paddings.shape().DebugString(), " ",
701 resized_shape.DebugString()));
702
703 OP_REQUIRES(
704 context, dims == 4,
705 errors::InvalidArgument(
706 "Fused mirror padding only supports four-dimensional inputs, but ",
707 dims, " requested"));
708
709 // Compute the shape of the output tensor, and allocate it.
710 TensorShape padded_shape;
711 TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
712 for (int d = 0; d < dims; ++d) {
713 const int32_t before =
714 paddings_matrix(d, 0); // Pad before existing elements.
715 const int32_t after =
716 paddings_matrix(d, 1); // Pad after existing elements.
717 OP_REQUIRES(context, before >= 0 && after >= 0,
718 errors::InvalidArgument(
719 "paddings must be non-negative: ", before, " ", after));
720 if (offset_ == 0) { // SYMMETRIC mode.
721 OP_REQUIRES(
722 context,
723 before <= resized_shape.dim_size(d) &&
724 after <= resized_shape.dim_size(d),
725 errors::InvalidArgument("paddings must be no greater "
726 "than the dimension size: ",
727 before, ", ", after, " greater than ",
728 resized_shape.dim_size(d)));
729 } else if (offset_ == 1) { // REFLECT mode.
730 OP_REQUIRES(
731 context,
732 before < resized_shape.dim_size(d) &&
733 after < resized_shape.dim_size(d),
734 errors::InvalidArgument("paddings must be less than"
735 " the dimension size: ",
736 before, ", ", after, " not less than ",
737 resized_shape.dim_size(d)));
738 }
739 padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
740 }
741
742 OP_REQUIRES(
743 context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
744 errors::InvalidArgument(
745 "Fused mirror padding only support spatial padding, not batches: ",
746 paddings.DebugString()));
747 OP_REQUIRES(
748 context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
749 errors::InvalidArgument(
750 "Fused mirror padding only support spatial padding, not channels: ",
751 paddings.DebugString()));
752 const int32_t top_padding = paddings_matrix(1, 0);
753 const int32_t bottom_padding = paddings_matrix(1, 1);
754 const int32_t left_padding = paddings_matrix(2, 0);
755 const int32_t right_padding = paddings_matrix(2, 1);
756
757 // Input filter is of the following dimensions:
758 // [ filter_rows, filter_cols, in_depth, out_depth]
759 const Tensor& filter = context->input(filter_index);
760
761 // For 2D convolution, there should be 4 dimensions.
762 OP_REQUIRES(context, padded_shape.dims() == 4,
763 errors::InvalidArgument("input must be 4-dimensional",
764 padded_shape.DebugString()));
765 OP_REQUIRES(context, filter.dims() == 4,
766 errors::InvalidArgument("filter must be 4-dimensional: ",
767 filter.shape().DebugString()));
768
769 // We only check the first three dims, since the depth is accessed as an
770 // int64 below.
771 for (int i = 0; i < 3; i++) {
772 OP_REQUIRES(
773 context,
774 FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
775 errors::InvalidArgument("filter too large"));
776 }
777
778 // The last dimension for input is in_depth. It must be the same as the
779 // filter's in_depth.
780 const int64_t in_depth = padded_shape.dim_size(3);
781 OP_REQUIRES(context, in_depth == filter.dim_size(2),
782 errors::InvalidArgument(
783 "input and filter must have the same depth: ", in_depth,
784 " vs ", filter.dim_size(2)));
785
786 // The last dimension for filter is out_depth.
787 const int out_depth = static_cast<int>(filter.dim_size(3));
788
789 // The second dimension for input is rows/height.
790 // The first dimension for filter is rows/height.
791 const int64_t padded_rows_raw = padded_shape.dim_size(1);
792 OP_REQUIRES(
793 context,
794 FastBoundsCheck(padded_rows_raw, std::numeric_limits<int>::max()),
795 errors::InvalidArgument("Input rows too large"));
796 const int padded_rows = static_cast<int>(padded_rows_raw);
797 const int filter_rows = static_cast<int>(filter.dim_size(0));
798 const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
799
800 // The third dimension for input is columns/width.
801 // The second dimension for filter is columns/width.
802 const int64_t padded_cols_raw = padded_shape.dim_size(2);
803 OP_REQUIRES(
804 context,
805 FastBoundsCheck(padded_cols_raw, std::numeric_limits<int>::max()),
806 errors::InvalidArgument("Input cols too large"));
807 const int padded_cols = static_cast<int>(padded_cols_raw);
808 const int filter_cols = static_cast<int>(filter.dim_size(1));
809 const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
810
811 // The first dimension for input is batch.
812 const int64_t batch_raw = padded_shape.dim_size(0);
813 OP_REQUIRES(context,
814 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
815 errors::InvalidArgument("batch is too large"));
816 const int batch = static_cast<int>(batch_raw);
817
818 // For now we take the stride from the second and third dimensions only (we
819 // do not support striding on the batch or depth dimension).
820 const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
821 const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
822
823 int64_t out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
824 OP_REQUIRES_OK(context,
825 GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
826 padding_, &out_rows, &pad_rows));
827 OP_REQUIRES_OK(context,
828 GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
829 padding_, &out_cols, &pad_cols));
830 TensorShape out_shape =
831 ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
832 OP_REQUIRES(context, (out_shape.num_elements() > 0),
833 errors::InvalidArgument("Output tensor can't be empty"));
834
835 // Output tensor is of the following dimensions:
836 // [ in_batch, out_rows, out_cols, out_depth ]
837 Tensor* output = nullptr;
838 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
839
840 VLOG(2) << "FusedConv2D: " << name() << ", in_depth = " << in_depth
841 << ", padded_cols = " << padded_cols
842 << ", resized_cols = " << resized_cols
843 << ", filter_cols = " << filter_cols
844 << ", padded_rows = " << padded_rows
845 << ", resized_rows = " << resized_rows
846 << ", filter_rows = " << filter_rows
847 << ", stride_rows = " << stride_rows
848 << ", stride_cols = " << stride_cols
849 << ", out_depth = " << out_depth << ", DoResize=" << DoResize;
850
851 // If there is nothing to compute, return.
852 if (out_shape.num_elements() == 0) {
853 return;
854 }
855 TConvFunctor conv_functor;
856 conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
857 padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
858 filter_cols, out_depth, stride_rows, stride_cols, padding_,
859 output->flat<T>().data(), out_rows, out_cols, st, top_padding,
860 bottom_padding, left_padding, right_padding, offset_);
861 }
862
863 private:
864 std::vector<int32> strides_;
865 Padding padding_;
866 bool align_corners_;
867 int offset_;
868
869 TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
870};
871
872#define REGISTER_FUSED(T) \
873 REGISTER_KERNEL_BUILDER( \
874 Name("FusedResizeAndPadConv2D") \
875 .Device(DEVICE_CPU) \
876 .TypeConstraint<T>("T"), \
877 FusedResizeConv2DUsingGemmOp< \
878 T, \
879 FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
880 BILINEAR>, \
881 true>);
882
883TF_CALL_half(REGISTER_FUSED);
884TF_CALL_float(REGISTER_FUSED);
885TF_CALL_double(REGISTER_FUSED);
886
887#define REGISTER_PAD_ONLY_FUSED(T) \
888 REGISTER_KERNEL_BUILDER( \
889 Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
890 FusedResizeConv2DUsingGemmOp< \
891 T, \
892 FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
893 NEAREST>, \
894 false>);
895
896TF_CALL_half(REGISTER_PAD_ONLY_FUSED);
897TF_CALL_float(REGISTER_PAD_ONLY_FUSED);
898TF_CALL_double(REGISTER_PAD_ONLY_FUSED);
899
900} // namespace tensorflow
901