1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This file contains a set of different implementations of the two-dimensional
17// convolution operation. The standard TensorFlow Conv2d kernel uses EigenTensor
18// to implement the computation, but this module has a variety of different ways
19// of producing the same result. These methods are designed to be easier to
20// understand and connect to other libraries, so that we can take advantage of
21// platforms that have specialized implementations of GEMM for example.
22//
23// The basic interface is a Conv functor object that's templated by the types
24// of the data it will be operating on, and is passed in the arguments needed to
25// calculate the convolution. The simplest implementation of this functor is
26// ReferenceConvFunctor, which is a readable but slow reference version.
27//
28// A faster version uses the approach of packing image patches into a matrix
29// before calling a matrix multiply, the Im2ColConvFunctor. In turn, this can
30// use a variety of different methods to calculate the matrix multiplication,
31// or GEMM. The simplest but slowest is the ReferenceGemmFunctor, but the
32// FastGemmFunctor will use whatever optimized libraries are available. By
33// default it uses Eigen, but on Apple platforms it will take advantage of the
34// system's Accelerate BLAS library to get better performance than the standard
35// TensorFlow convolution kernel.
36//
37// The version actually used is defined at the bottom of this file using the
38// REGISTER_KERNEL_BUILDER() macro. To try out different implementations (for
39// example to switch to a reference one for easier debugging) you can swap out
40// the default functors in that call.
41//
42// The registration itself is guarded with the USE_GEMM_FOR_CONV macro. The iOS
43// makefile build defines this, but if you want to enable this implementation
44// and disable the standard EigenTensor one in other build setups, you'll need
45// to define it there too.
46
47#define EIGEN_USE_THREADS
48
49#include <string.h>
50
51#include <map>
52#include <vector>
53
54#include "tensorflow/core/framework/bounds_check.h"
55#include "tensorflow/core/framework/kernel_shape_util.h"
56#include "tensorflow/core/framework/numeric_op.h"
57#include "tensorflow/core/framework/op_kernel.h"
58#include "tensorflow/core/framework/register_types.h"
59#include "tensorflow/core/framework/resource_mgr.h"
60#include "tensorflow/core/framework/tensor.h"
61#include "tensorflow/core/framework/tensor_shape.h"
62#include "tensorflow/core/framework/tensor_slice.h"
63#include "tensorflow/core/kernels/conv_ops.h"
64#include "tensorflow/core/kernels/gemm_functors.h"
65#include "tensorflow/core/util/image_resizer_state.h"
66#include "tensorflow/core/util/mirror_pad_mode.h"
67#include "tensorflow/core/util/padding.h"
68#include "tensorflow/core/util/tensor_format.h"
69
70namespace tensorflow {
71
72namespace {
73// This function implements the convolution operation in as simple a form as
74// possible. It won't give great performance, but it is very useful for
75// stepping through and instrumenting for debugging, creating minimal benchmarks
76// to prototype with, and sharing with teams that want to run this outside of
77// our environment.
78// With that in mind, I've avoided using anything except pretty standard C++
79// types. This is especially noticeable in the data access through raw array
80// indexing. It's deliberate in this case though, since it makes the underlying
81// memory order very explicit, which is important for both inspecting memory
82// contents during debugging and for specifying what we expect to others.
83// The memory layout of the data is, from biggest stride to smallest:
84// input_data = [input_batches, input_height, input_width, input_depth]
85// filter_data = [filter_height, filter_width, input_depth, filter_count]
86// output_data = [input_batches, output_height, output_width, filter_count]
87template <class T1, class T2, class T3>
88class ReferenceConvFunctor {
89 public:
90 void operator()(OpKernelContext* context, const T1* input_data,
91 int input_batches, int input_height, int input_width,
92 int input_depth, const T2* filter_data, int filter_height,
93 int filter_width, int filter_count, int stride_rows,
94 int stride_cols, Padding padding, T3* output_data,
95 int output_height, int output_width) {
96 // The two different padding modes we support can be a bit confusing. SAME
97 // means we're trying to produce an output image that's the same size as the
98 // input. It's complicated by stride, which shrinks the output image by a
99 // a factor, but it means we end up sampling from outside the borders of the
100 // input. These out-of-bounds values are read as zeroes. VALID means only
101 // produce output values where the filters can read all their values from
102 // within the input image. It effectively removes the margins of the output
103 // image compared to the one produced by SAME. Stride complicates this
104 // definition though, because it can result in the right and bottom filter
105 // patches sampling from outside the borders if it's greater than 1.
106 // Most of the logic for sorting this all out is done before this function,
107 // when we calculate the output size, but the positioning of the origin of
108 // the filters is different between the two modes, since SAME positions the
109 // first filter off the edge of the input.
110 int filter_left_offset;
111 int filter_top_offset;
112 if (padding == VALID) {
113 filter_left_offset =
114 ((output_width - 1) * stride_cols + filter_width - input_width + 1) /
115 2;
116 filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
117 input_height + 1) /
118 2;
119 } else {
120 filter_left_offset =
121 ((output_width - 1) * stride_cols + filter_width - input_width) / 2;
122 filter_top_offset =
123 ((output_height - 1) * stride_rows + filter_height - input_height) /
124 2;
125 }
126
127 // If we've got multiple images in our input, work through each of them.
128 for (int batch = 0; batch < input_batches; ++batch) {
129 // Walk through all the output image values, sliding the filter to
130 // different positions in the input.
131 for (int out_y = 0; out_y < output_height; ++out_y) {
132 for (int out_x = 0; out_x < output_width; ++out_x) {
133 // Each filter kernel produces one output channel.
134 for (int out_channel = 0; out_channel < filter_count; ++out_channel) {
135 // We're going to calculate a single output value, which means we
136 // need to multiply a three dimensional kernel of weights against
137 // the current location within the input image.
138 /*
139 *-------------------------------...
140 |\ ^
141 | \in_depth
142 | \ v
143 | *-------------------------------...
144 | | ^
145 | | in_y_origin
146 | | v \
147 | |<in_x_origin>*---*^
148 | | \| |filter_height
149 . | *---*v
150 . | <--->
151 . filter_width
152 .
153 */
154 const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
155 const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
156 T3 total(0);
157 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
158 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
159 for (int in_channel = 0; in_channel < input_depth;
160 ++in_channel) {
161 const int in_x = in_x_origin + filter_x;
162 const int in_y = in_y_origin + filter_y;
163 T1 input_value;
164 // If the location is outside the bounds of the input image,
165 // use zero as a default value.
166 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
167 (in_y < input_height)) {
168 input_value =
169 input_data[(batch * input_height * input_width *
170 input_depth) +
171 (in_y * input_width * input_depth) +
172 (in_x * input_depth) + in_channel];
173 } else {
174 input_value = T1(0);
175 }
176 const T2 filter_value =
177 filter_data[(filter_y * filter_width * input_depth *
178 filter_count) +
179 (filter_x * input_depth * filter_count) +
180 (in_channel * filter_count) + out_channel];
181 total += (input_value * filter_value);
182 }
183 }
184 }
185 output_data[(batch * output_height * output_width * filter_count) +
186 (out_y * output_width * filter_count) +
187 (out_x * filter_count) + out_channel] = total;
188 }
189 }
190 }
191 }
192 }
193};
194
195// We don't want to allocate a buffer to hold all the patches if the size is
196// going to be extremely large, so break it into chunks if it's bigger than
197// a limit. Each chunk will be processed serially, so we can refill the
198// buffer for the next chunk and reuse it, keeping maximum memory size down.
199// In this case, we've picked 16 megabytes as a reasonable limit for Android and
200// other platforms using Eigen, and 1MB for Apple devices, from experimentation.
201#if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
202const size_t kMaxChunkSize = (1 * 1024 * 1024);
203#else
204const size_t kMaxChunkSize = (16 * 1024 * 1024);
205#endif
206
207// Implements convolution as a two stage process, first packing the patches of
208// the input image into columns (im2col) and then running GEMM to produce the
209// final result.
210template <class T1, class T2, class T3, class TGemmFunctor>
211class Im2ColConvFunctor {
212 public:
213 void operator()(OpKernelContext* context, const T1* input_data,
214 int input_batches, int input_height, int input_width,
215 int input_depth, const T2* filter_data, int filter_height,
216 int filter_width, int filter_count, int stride_rows,
217 int stride_cols, Padding padding, T3* output_data,
218 int output_height, int output_width) {
219 if ((input_batches <= 0) || (input_width <= 0) || (input_height <= 0) ||
220 (input_depth <= 0)) {
221 LOG(WARNING) << "Conv2D was called with bad input dimensions: "
222 << input_batches << ", " << input_height << ", "
223 << input_width << ", " << input_depth;
224 return;
225 }
226 if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
227 LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
228 << filter_width << ", " << filter_height << ", "
229 << filter_count;
230 return;
231 }
232 if ((output_width <= 0) || (output_height <= 0)) {
233 LOG(WARNING) << "Conv2D was called with bad output width or height: "
234 << output_width << ", " << output_height;
235 return;
236 }
237
238 // We can just use a GEMM if the im2col is the identity operator, e.g., if
239 // the kernel is 1x1 or the input data and filter have same height/width.
240 if (filter_height == 1 && filter_width == 1 && stride_rows == 1 &&
241 stride_cols == 1) {
242 // The kernel is 1x1.
243 const int m = input_batches * input_height * input_width;
244 const int n = filter_count;
245 const int k = input_depth;
246 const int lda = k;
247 const int ldb = filter_count;
248 const int ldc = filter_count;
249 TGemmFunctor gemm_functor;
250 gemm_functor(context, m, n, k, input_data, lda, filter_data, ldb,
251 output_data, ldc);
252 return;
253 } else if (filter_height == input_height && filter_width == input_width &&
254 padding == VALID) {
255 // The input data and filter have the same height/width.
256 const int m = input_batches;
257 const int n = filter_count;
258 const int k = input_height * input_width * input_depth;
259 const int lda = k;
260 const int ldb = filter_count;
261 const int ldc = filter_count;
262 TGemmFunctor gemm_functor;
263 gemm_functor(context, m, n, k, input_data, lda, filter_data, ldb,
264 output_data, ldc);
265 return;
266 }
267
268 // These calculations define how the patches will be positioned within the
269 // input image. The actual definitions are quite complex, and rely on the
270 // previously-calculated output size.
271 int filter_left_offset;
272 int filter_top_offset;
273 if (padding == VALID) {
274 filter_left_offset =
275 ((output_width - 1) * stride_cols + filter_width - input_width + 1) /
276 2;
277 filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
278 input_height + 1) /
279 2;
280 } else {
281 filter_left_offset =
282 ((output_width - 1) * stride_cols + filter_width - input_width) / 2;
283 filter_top_offset =
284 ((output_height - 1) * stride_rows + filter_height - input_height) /
285 2;
286 }
287
288 // The im2col buffer has # of patches rows, and # of filters cols.
289 // It's laid out like this, in row major order in memory:
290 // < filter value count >
291 // ^ +---------------------+
292 // patch | |
293 // count | |
294 // v +---------------------+
295 // Each patch row contains a filter_width x filter_height patch of the
296 // input, with the depth channel as the most contiguous in memory, followed
297 // by the width, then the height. This is the standard memory order in the
298 // image world if it helps to visualize it.
299 const int filter_value_count = filter_width * filter_height * input_depth;
300 OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize,
301 errors::InvalidArgument("Im2Col patch too large for buffer"));
302 const int64_t patches_per_chunk =
303 kMaxChunkSize / (filter_value_count * sizeof(T1));
304 const int64_t chunk_value_count =
305 (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1);
306 // Because memory allocation is very expensive on mobile platforms, try to
307 // allocate a persistent buffer that will be kept around between calls. We
308 // use TensorFlow's resource management to ensure that the memory will be
309 // released when the session is over.
310 Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource;
311 std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)>
312 creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) {
313 *resource = new Im2ColBufferResource<T1, chunk_value_count>();
314 return OkStatus();
315 };
316 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
317 "Conv2d", "im2col_buffer",
318 &im2col_buffer_resource, creator));
319 // This means that multiple ops can't be run simultaneously on different
320 // threads, because we have a single shared resource. The platforms this is
321 // aimed at have intra-op parallelism as their focus though, so it shouldn't
322 // be an issue.
323 mutex_lock lock_buffer(im2col_buffer_resource->mu);
324 core::ScopedUnref unref_buffer(im2col_buffer_resource);
325 T1* im2col_buffer = im2col_buffer_resource->data;
326
327 const int64_t patch_count = (input_batches * output_height * output_width);
328 const int64_t chunk_count =
329 (patch_count + (patches_per_chunk - 1)) / patches_per_chunk;
330 for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) {
331 const int64_t patch_index_start = chunk_index * patches_per_chunk;
332 const int64_t patch_index_end =
333 std::min(patch_index_start + patches_per_chunk, patch_count);
334 for (int64_t patch_index = patch_index_start;
335 patch_index < patch_index_end; ++patch_index) {
336 const int64_t batch = patch_index / (output_height * output_width);
337 const int64_t out_y = (patch_index / output_width) % output_height;
338 const int64_t out_x = patch_index % output_width;
339 const T1* input_batch_start =
340 input_data + (batch * input_height * input_width * input_depth);
341 const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
342 const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
343 const int patch_index_within_chunk = patch_index % patches_per_chunk;
344 T1* im2col_patch_start =
345 im2col_buffer + (patch_index_within_chunk * filter_value_count);
346 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
347 const int in_y = in_y_origin + filter_y;
348 T1* im2col_row_start =
349 im2col_patch_start + (filter_y * filter_width * input_depth);
350 // If we're off the top or the bottom of the input, fill the
351 // whole row with zeroes.
352 if ((in_y < 0) || (in_y >= input_height)) {
353 T1* im2col_row_end =
354 im2col_row_start + (filter_width * input_depth);
355 std::fill(im2col_row_start, im2col_row_end, T1(0));
356 } else {
357 // What we're doing here is trying to copy and fill the im2col
358 // buffer as efficiently as possible, using functions to set or
359 // duplicate values en masse. We know we don't have to worry about
360 // vertical edges because we dealt with that case above, so we
361 // just need to handle filters that overlap the left or right
362 // edges. Here's what that looks like:
363 //
364 // < left_zero_count > < center_copy_count > < right_zero_count >
365 // +------------------+---------------------+--------------------+
366 // | (filter) | (image) | (filter) |
367 // +------------------+---------------------+--------------------+
368 // in_x_origin 0 input_width in_x_end
369 //
370 // In reality it's unlikely that a filter patch will be wider
371 // than an input, but this shows all the edge cases.
372 // We use std::fill() to set the left and right sections to zeroes
373 // and std::copy() to copy over the input data for the center.
374 const int in_x_end = in_x_origin + filter_width;
375 const int left_zero_count = std::max(0, 0 - in_x_origin);
376 const int right_zero_count = std::max(0, in_x_end - input_width);
377 const int center_copy_count =
378 filter_width - (left_zero_count + right_zero_count);
379 if (left_zero_count > 0) {
380 T1* im2col_left_start = im2col_row_start;
381 T1* im2col_left_end =
382 im2col_left_start + (left_zero_count * input_depth);
383 std::fill(im2col_left_start, im2col_left_end, T1(0));
384 }
385 if (center_copy_count > 0) {
386 const T1* input_row_start =
387 input_batch_start + (in_y * input_width * input_depth) +
388 (std::max(0, in_x_origin) * input_depth);
389 const T1* input_row_end =
390 input_row_start + (center_copy_count * input_depth);
391 T1* im2col_center_start =
392 im2col_row_start + (left_zero_count * input_depth);
393 std::copy(input_row_start, input_row_end, im2col_center_start);
394 }
395 if (right_zero_count > 0) {
396 T1* im2col_right_start =
397 im2col_row_start +
398 ((left_zero_count + center_copy_count) * input_depth);
399 T1* im2col_right_end =
400 im2col_right_start + (right_zero_count * input_depth);
401 std::fill(im2col_right_start, im2col_right_end, T1(0));
402 }
403 }
404 }
405 }
406 // Now we've assembled a set of image patches into a matrix, apply a
407 // GEMM matrix multiply of the patches as rows, times the filter
408 // weights in columns, to get partial results in the output matrix.
409 const int how_many_patches = patch_index_end - patch_index_start;
410 const int m = how_many_patches;
411 const int n = filter_count;
412 const int k = filter_value_count;
413 const int lda = filter_value_count;
414 const int ldb = filter_count;
415 const int ldc = filter_count;
416 T3* chunk_output_data = output_data + (patch_index_start * filter_count);
417 TGemmFunctor gemm_functor;
418 gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb,
419 chunk_output_data, ldc);
420 }
421 }
422};
423
424} // namespace
425
426// This TensorFlow kernel class handles all of the IO and housekeeping for the
427// functors that actually implement the underlying algorithm. To swap in
428// different implementations of the main calculations, use a different
429// TConvFunctor parameter when instantiating the template.
430template <class T, class TConvFunctor>
431class Conv2DUsingGemmOp : public BinaryOp<T> {
432 public:
433 explicit Conv2DUsingGemmOp(OpKernelConstruction* context)
434 : BinaryOp<T>(context) {
435 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
436 string data_format;
437 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
438 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
439 errors::InvalidArgument("Invalid data format"));
440 OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
441 errors::InvalidArgument(
442 "Data format not supported by this kernel", data_format));
443 OP_REQUIRES(context, strides_.size() == 4,
444 errors::InvalidArgument("Sliding window strides field must "
445 "specify 4 dimensions"));
446 const int64_t stride_n = GetTensorDim(strides_, data_format_, 'N');
447 const int64_t stride_c = GetTensorDim(strides_, data_format_, 'C');
448 OP_REQUIRES(
449 context, stride_n == 1 && stride_c == 1,
450 errors::InvalidArgument("Current implementation does not yet support "
451 "strides in the batch and depth dimensions."));
452 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
453 }
454
455 void Compute(OpKernelContext* context) override {
456 // Input tensor is of the following dimensions:
457 // [ batch, in_rows, in_cols, in_depth ]
458 const Tensor& input = context->input(0);
459
460 // Input filter is of the following dimensions:
461 // [ filter_rows, filter_cols, in_depth, out_depth]
462 const Tensor& filter = context->input(1);
463
464 // For 2D convolution, there should be 4 dimensions.
465 OP_REQUIRES(context, input.dims() == 4,
466 errors::InvalidArgument("input must be 4-dimensional",
467 input.shape().DebugString()));
468 OP_REQUIRES(context, filter.dims() == 4,
469 errors::InvalidArgument("filter must be 4-dimensional: ",
470 filter.shape().DebugString()));
471
472 for (int i = 0; i < 3; i++) {
473 OP_REQUIRES(
474 context,
475 FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
476 errors::InvalidArgument("filter too large"));
477 }
478
479 // The last dimension for input is in_depth. It must be the same as the
480 // filter's in_depth.
481 const int64_t in_depth = GetTensorDim(input, data_format_, 'C');
482 OP_REQUIRES(context, in_depth == filter.dim_size(2),
483 errors::InvalidArgument(
484 "input and filter must have the same depth: ", in_depth,
485 " vs ", filter.dim_size(2)));
486
487 // The last dimension for filter is out_depth.
488 const int out_depth = static_cast<int>(filter.dim_size(3));
489
490 // The second dimension for input is rows/height.
491 // The first dimension for filter is rows/height.
492 const int64_t input_rows_raw = GetTensorDim(input, data_format_, 'H');
493 OP_REQUIRES(
494 context,
495 FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
496 errors::InvalidArgument("Input rows too large"));
497 const int input_rows = static_cast<int>(input_rows_raw);
498 const int filter_rows = static_cast<int>(filter.dim_size(0));
499
500 // The third dimension for input is columns/width.
501 // The second dimension for filter is columns/width.
502 const int64_t input_cols_raw = GetTensorDim(input, data_format_, 'W');
503 OP_REQUIRES(
504 context,
505 FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
506 errors::InvalidArgument("Input cols too large"));
507 const int input_cols = static_cast<int>(input_cols_raw);
508 const int filter_cols = static_cast<int>(filter.dim_size(1));
509
510 // The first dimension for input is batch.
511 const int64_t batch_raw = GetTensorDim(input, data_format_, 'N');
512 OP_REQUIRES(context,
513 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
514 errors::InvalidArgument("batch is too large"));
515 const int batch = static_cast<int>(batch_raw);
516
517 // For now we take the stride from the second and third dimensions only (we
518 // do not support striding on the batch or depth dimension).
519 const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
520 const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
521
522 int64_t out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
523 OP_REQUIRES_OK(context,
524 GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
525 padding_, &out_rows, &pad_rows));
526 OP_REQUIRES_OK(context,
527 GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
528 padding_, &out_cols, &pad_cols));
529 TensorShape out_shape =
530 ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
531
532 // Output tensor is of the following dimensions:
533 // [ in_batch, out_rows, out_cols, out_depth ]
534 Tensor* output = nullptr;
535 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
536
537 VLOG(2) << "Conv2D: in_depth = " << in_depth
538 << ", input_cols = " << input_cols
539 << ", filter_cols = " << filter_cols
540 << ", input_rows = " << input_rows
541 << ", filter_rows = " << filter_rows
542 << ", stride_rows = " << stride_rows
543 << ", stride_cols = " << stride_cols
544 << ", out_depth = " << out_depth;
545
546 // If there is nothing to compute, return.
547 if (out_shape.num_elements() == 0) {
548 return;
549 }
550 TConvFunctor conv_functor;
551 conv_functor(context, input.flat<T>().data(), batch, input_rows, input_cols,
552 in_depth, filter.flat<T>().data(), filter_rows, filter_cols,
553 out_depth, stride_rows, stride_cols, padding_,
554 output->flat<T>().data(), out_rows, out_cols);
555 }
556
557 private:
558 std::vector<int32> strides_;
559 Padding padding_;
560 TensorFormat data_format_;
561
562 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DUsingGemmOp);
563};
564
565#define REGISTER_CPU(T) \
566 REGISTER_KERNEL_BUILDER( \
567 Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
568 Conv2DUsingGemmOp< \
569 T, Im2ColConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>);
570
571// Only register this GEMM-based implementation of Conv2d if the compiler flags
572// request the implementation explicitly, since otherwise it will clash with the
573// default EigenTensor-based kernel.
574#if defined(USE_GEMM_FOR_CONV)
575TF_CALL_bfloat16(REGISTER_CPU);
576TF_CALL_half(REGISTER_CPU);
577TF_CALL_float(REGISTER_CPU);
578TF_CALL_double(REGISTER_CPU);
579TF_CALL_int32(REGISTER_CPU);
580#endif // USE_GEMM_FOR_CONV
581
582} // namespace tensorflow
583