1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // This file contains a set of different implementations of the two-dimensional |
17 | // convolution operation. The standard TensorFlow Conv2d kernel uses EigenTensor |
18 | // to implement the computation, but this module has a variety of different ways |
19 | // of producing the same result. These methods are designed to be easier to |
20 | // understand and connect to other libraries, so that we can take advantage of |
21 | // platforms that have specialized implementations of GEMM for example. |
22 | // |
23 | // The basic interface is a Conv functor object that's templated by the types |
24 | // of the data it will be operating on, and is passed in the arguments needed to |
25 | // calculate the convolution. The simplest implementation of this functor is |
26 | // ReferenceConvFunctor, which is a readable but slow reference version. |
27 | // |
28 | // A faster version uses the approach of packing image patches into a matrix |
29 | // before calling a matrix multiply, the Im2ColConvFunctor. In turn, this can |
30 | // use a variety of different methods to calculate the matrix multiplication, |
31 | // or GEMM. The simplest but slowest is the ReferenceGemmFunctor, but the |
32 | // FastGemmFunctor will use whatever optimized libraries are available. By |
33 | // default it uses Eigen, but on Apple platforms it will take advantage of the |
34 | // system's Accelerate BLAS library to get better performance than the standard |
35 | // TensorFlow convolution kernel. |
36 | // |
37 | // The version actually used is defined at the bottom of this file using the |
38 | // REGISTER_KERNEL_BUILDER() macro. To try out different implementations (for |
39 | // example to switch to a reference one for easier debugging) you can swap out |
40 | // the default functors in that call. |
41 | // |
42 | // The registration itself is guarded with the USE_GEMM_FOR_CONV macro. The iOS |
43 | // makefile build defines this, but if you want to enable this implementation |
44 | // and disable the standard EigenTensor one in other build setups, you'll need |
45 | // to define it there too. |
46 | |
47 | #define EIGEN_USE_THREADS |
48 | |
49 | #include <string.h> |
50 | |
51 | #include <map> |
52 | #include <vector> |
53 | |
54 | #include "tensorflow/core/framework/bounds_check.h" |
55 | #include "tensorflow/core/framework/kernel_shape_util.h" |
56 | #include "tensorflow/core/framework/numeric_op.h" |
57 | #include "tensorflow/core/framework/op_kernel.h" |
58 | #include "tensorflow/core/framework/register_types.h" |
59 | #include "tensorflow/core/framework/resource_mgr.h" |
60 | #include "tensorflow/core/framework/tensor.h" |
61 | #include "tensorflow/core/framework/tensor_shape.h" |
62 | #include "tensorflow/core/framework/tensor_slice.h" |
63 | #include "tensorflow/core/kernels/conv_ops.h" |
64 | #include "tensorflow/core/kernels/gemm_functors.h" |
65 | #include "tensorflow/core/util/image_resizer_state.h" |
66 | #include "tensorflow/core/util/mirror_pad_mode.h" |
67 | #include "tensorflow/core/util/padding.h" |
68 | #include "tensorflow/core/util/tensor_format.h" |
69 | |
70 | namespace tensorflow { |
71 | |
72 | namespace { |
73 | // This function implements the convolution operation in as simple a form as |
74 | // possible. It won't give great performance, but it is very useful for |
75 | // stepping through and instrumenting for debugging, creating minimal benchmarks |
76 | // to prototype with, and sharing with teams that want to run this outside of |
77 | // our environment. |
78 | // With that in mind, I've avoided using anything except pretty standard C++ |
79 | // types. This is especially noticeable in the data access through raw array |
80 | // indexing. It's deliberate in this case though, since it makes the underlying |
81 | // memory order very explicit, which is important for both inspecting memory |
82 | // contents during debugging and for specifying what we expect to others. |
83 | // The memory layout of the data is, from biggest stride to smallest: |
84 | // input_data = [input_batches, input_height, input_width, input_depth] |
85 | // filter_data = [filter_height, filter_width, input_depth, filter_count] |
86 | // output_data = [input_batches, output_height, output_width, filter_count] |
87 | template <class T1, class T2, class T3> |
88 | class ReferenceConvFunctor { |
89 | public: |
90 | void operator()(OpKernelContext* context, const T1* input_data, |
91 | int input_batches, int input_height, int input_width, |
92 | int input_depth, const T2* filter_data, int filter_height, |
93 | int filter_width, int filter_count, int stride_rows, |
94 | int stride_cols, Padding padding, T3* output_data, |
95 | int output_height, int output_width) { |
96 | // The two different padding modes we support can be a bit confusing. SAME |
97 | // means we're trying to produce an output image that's the same size as the |
98 | // input. It's complicated by stride, which shrinks the output image by a |
99 | // a factor, but it means we end up sampling from outside the borders of the |
100 | // input. These out-of-bounds values are read as zeroes. VALID means only |
101 | // produce output values where the filters can read all their values from |
102 | // within the input image. It effectively removes the margins of the output |
103 | // image compared to the one produced by SAME. Stride complicates this |
104 | // definition though, because it can result in the right and bottom filter |
105 | // patches sampling from outside the borders if it's greater than 1. |
106 | // Most of the logic for sorting this all out is done before this function, |
107 | // when we calculate the output size, but the positioning of the origin of |
108 | // the filters is different between the two modes, since SAME positions the |
109 | // first filter off the edge of the input. |
110 | int filter_left_offset; |
111 | int filter_top_offset; |
112 | if (padding == VALID) { |
113 | filter_left_offset = |
114 | ((output_width - 1) * stride_cols + filter_width - input_width + 1) / |
115 | 2; |
116 | filter_top_offset = ((output_height - 1) * stride_rows + filter_height - |
117 | input_height + 1) / |
118 | 2; |
119 | } else { |
120 | filter_left_offset = |
121 | ((output_width - 1) * stride_cols + filter_width - input_width) / 2; |
122 | filter_top_offset = |
123 | ((output_height - 1) * stride_rows + filter_height - input_height) / |
124 | 2; |
125 | } |
126 | |
127 | // If we've got multiple images in our input, work through each of them. |
128 | for (int batch = 0; batch < input_batches; ++batch) { |
129 | // Walk through all the output image values, sliding the filter to |
130 | // different positions in the input. |
131 | for (int out_y = 0; out_y < output_height; ++out_y) { |
132 | for (int out_x = 0; out_x < output_width; ++out_x) { |
133 | // Each filter kernel produces one output channel. |
134 | for (int out_channel = 0; out_channel < filter_count; ++out_channel) { |
135 | // We're going to calculate a single output value, which means we |
136 | // need to multiply a three dimensional kernel of weights against |
137 | // the current location within the input image. |
138 | /* |
139 | *-------------------------------... |
140 | |\ ^ |
141 | | \in_depth |
142 | | \ v |
143 | | *-------------------------------... |
144 | | | ^ |
145 | | | in_y_origin |
146 | | | v \ |
147 | | |<in_x_origin>*---*^ |
148 | | | \| |filter_height |
149 | . | *---*v |
150 | . | <---> |
151 | . filter_width |
152 | . |
153 | */ |
154 | const int in_x_origin = (out_x * stride_cols) - filter_left_offset; |
155 | const int in_y_origin = (out_y * stride_rows) - filter_top_offset; |
156 | T3 total(0); |
157 | for (int filter_y = 0; filter_y < filter_height; ++filter_y) { |
158 | for (int filter_x = 0; filter_x < filter_width; ++filter_x) { |
159 | for (int in_channel = 0; in_channel < input_depth; |
160 | ++in_channel) { |
161 | const int in_x = in_x_origin + filter_x; |
162 | const int in_y = in_y_origin + filter_y; |
163 | T1 input_value; |
164 | // If the location is outside the bounds of the input image, |
165 | // use zero as a default value. |
166 | if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && |
167 | (in_y < input_height)) { |
168 | input_value = |
169 | input_data[(batch * input_height * input_width * |
170 | input_depth) + |
171 | (in_y * input_width * input_depth) + |
172 | (in_x * input_depth) + in_channel]; |
173 | } else { |
174 | input_value = T1(0); |
175 | } |
176 | const T2 filter_value = |
177 | filter_data[(filter_y * filter_width * input_depth * |
178 | filter_count) + |
179 | (filter_x * input_depth * filter_count) + |
180 | (in_channel * filter_count) + out_channel]; |
181 | total += (input_value * filter_value); |
182 | } |
183 | } |
184 | } |
185 | output_data[(batch * output_height * output_width * filter_count) + |
186 | (out_y * output_width * filter_count) + |
187 | (out_x * filter_count) + out_channel] = total; |
188 | } |
189 | } |
190 | } |
191 | } |
192 | } |
193 | }; |
194 | |
195 | // We don't want to allocate a buffer to hold all the patches if the size is |
196 | // going to be extremely large, so break it into chunks if it's bigger than |
197 | // a limit. Each chunk will be processed serially, so we can refill the |
198 | // buffer for the next chunk and reuse it, keeping maximum memory size down. |
199 | // In this case, we've picked 16 megabytes as a reasonable limit for Android and |
200 | // other platforms using Eigen, and 1MB for Apple devices, from experimentation. |
201 | #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM) |
202 | const size_t kMaxChunkSize = (1 * 1024 * 1024); |
203 | #else |
204 | const size_t kMaxChunkSize = (16 * 1024 * 1024); |
205 | #endif |
206 | |
207 | // Implements convolution as a two stage process, first packing the patches of |
208 | // the input image into columns (im2col) and then running GEMM to produce the |
209 | // final result. |
210 | template <class T1, class T2, class T3, class TGemmFunctor> |
211 | class Im2ColConvFunctor { |
212 | public: |
213 | void operator()(OpKernelContext* context, const T1* input_data, |
214 | int input_batches, int input_height, int input_width, |
215 | int input_depth, const T2* filter_data, int filter_height, |
216 | int filter_width, int filter_count, int stride_rows, |
217 | int stride_cols, Padding padding, T3* output_data, |
218 | int output_height, int output_width) { |
219 | if ((input_batches <= 0) || (input_width <= 0) || (input_height <= 0) || |
220 | (input_depth <= 0)) { |
221 | LOG(WARNING) << "Conv2D was called with bad input dimensions: " |
222 | << input_batches << ", " << input_height << ", " |
223 | << input_width << ", " << input_depth; |
224 | return; |
225 | } |
226 | if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) { |
227 | LOG(WARNING) << "Conv2D was called with bad filter dimensions: " |
228 | << filter_width << ", " << filter_height << ", " |
229 | << filter_count; |
230 | return; |
231 | } |
232 | if ((output_width <= 0) || (output_height <= 0)) { |
233 | LOG(WARNING) << "Conv2D was called with bad output width or height: " |
234 | << output_width << ", " << output_height; |
235 | return; |
236 | } |
237 | |
238 | // We can just use a GEMM if the im2col is the identity operator, e.g., if |
239 | // the kernel is 1x1 or the input data and filter have same height/width. |
240 | if (filter_height == 1 && filter_width == 1 && stride_rows == 1 && |
241 | stride_cols == 1) { |
242 | // The kernel is 1x1. |
243 | const int m = input_batches * input_height * input_width; |
244 | const int n = filter_count; |
245 | const int k = input_depth; |
246 | const int lda = k; |
247 | const int ldb = filter_count; |
248 | const int ldc = filter_count; |
249 | TGemmFunctor gemm_functor; |
250 | gemm_functor(context, m, n, k, input_data, lda, filter_data, ldb, |
251 | output_data, ldc); |
252 | return; |
253 | } else if (filter_height == input_height && filter_width == input_width && |
254 | padding == VALID) { |
255 | // The input data and filter have the same height/width. |
256 | const int m = input_batches; |
257 | const int n = filter_count; |
258 | const int k = input_height * input_width * input_depth; |
259 | const int lda = k; |
260 | const int ldb = filter_count; |
261 | const int ldc = filter_count; |
262 | TGemmFunctor gemm_functor; |
263 | gemm_functor(context, m, n, k, input_data, lda, filter_data, ldb, |
264 | output_data, ldc); |
265 | return; |
266 | } |
267 | |
268 | // These calculations define how the patches will be positioned within the |
269 | // input image. The actual definitions are quite complex, and rely on the |
270 | // previously-calculated output size. |
271 | int filter_left_offset; |
272 | int filter_top_offset; |
273 | if (padding == VALID) { |
274 | filter_left_offset = |
275 | ((output_width - 1) * stride_cols + filter_width - input_width + 1) / |
276 | 2; |
277 | filter_top_offset = ((output_height - 1) * stride_rows + filter_height - |
278 | input_height + 1) / |
279 | 2; |
280 | } else { |
281 | filter_left_offset = |
282 | ((output_width - 1) * stride_cols + filter_width - input_width) / 2; |
283 | filter_top_offset = |
284 | ((output_height - 1) * stride_rows + filter_height - input_height) / |
285 | 2; |
286 | } |
287 | |
288 | // The im2col buffer has # of patches rows, and # of filters cols. |
289 | // It's laid out like this, in row major order in memory: |
290 | // < filter value count > |
291 | // ^ +---------------------+ |
292 | // patch | | |
293 | // count | | |
294 | // v +---------------------+ |
295 | // Each patch row contains a filter_width x filter_height patch of the |
296 | // input, with the depth channel as the most contiguous in memory, followed |
297 | // by the width, then the height. This is the standard memory order in the |
298 | // image world if it helps to visualize it. |
299 | const int filter_value_count = filter_width * filter_height * input_depth; |
300 | OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize, |
301 | errors::InvalidArgument("Im2Col patch too large for buffer" )); |
302 | const int64_t patches_per_chunk = |
303 | kMaxChunkSize / (filter_value_count * sizeof(T1)); |
304 | const int64_t chunk_value_count = |
305 | (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); |
306 | // Because memory allocation is very expensive on mobile platforms, try to |
307 | // allocate a persistent buffer that will be kept around between calls. We |
308 | // use TensorFlow's resource management to ensure that the memory will be |
309 | // released when the session is over. |
310 | Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource; |
311 | std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)> |
312 | creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) { |
313 | *resource = new Im2ColBufferResource<T1, chunk_value_count>(); |
314 | return OkStatus(); |
315 | }; |
316 | OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( |
317 | "Conv2d" , "im2col_buffer" , |
318 | &im2col_buffer_resource, creator)); |
319 | // This means that multiple ops can't be run simultaneously on different |
320 | // threads, because we have a single shared resource. The platforms this is |
321 | // aimed at have intra-op parallelism as their focus though, so it shouldn't |
322 | // be an issue. |
323 | mutex_lock lock_buffer(im2col_buffer_resource->mu); |
324 | core::ScopedUnref unref_buffer(im2col_buffer_resource); |
325 | T1* im2col_buffer = im2col_buffer_resource->data; |
326 | |
327 | const int64_t patch_count = (input_batches * output_height * output_width); |
328 | const int64_t chunk_count = |
329 | (patch_count + (patches_per_chunk - 1)) / patches_per_chunk; |
330 | for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { |
331 | const int64_t patch_index_start = chunk_index * patches_per_chunk; |
332 | const int64_t patch_index_end = |
333 | std::min(patch_index_start + patches_per_chunk, patch_count); |
334 | for (int64_t patch_index = patch_index_start; |
335 | patch_index < patch_index_end; ++patch_index) { |
336 | const int64_t batch = patch_index / (output_height * output_width); |
337 | const int64_t out_y = (patch_index / output_width) % output_height; |
338 | const int64_t out_x = patch_index % output_width; |
339 | const T1* input_batch_start = |
340 | input_data + (batch * input_height * input_width * input_depth); |
341 | const int in_y_origin = (out_y * stride_rows) - filter_top_offset; |
342 | const int in_x_origin = (out_x * stride_cols) - filter_left_offset; |
343 | const int patch_index_within_chunk = patch_index % patches_per_chunk; |
344 | T1* im2col_patch_start = |
345 | im2col_buffer + (patch_index_within_chunk * filter_value_count); |
346 | for (int filter_y = 0; filter_y < filter_height; ++filter_y) { |
347 | const int in_y = in_y_origin + filter_y; |
348 | T1* im2col_row_start = |
349 | im2col_patch_start + (filter_y * filter_width * input_depth); |
350 | // If we're off the top or the bottom of the input, fill the |
351 | // whole row with zeroes. |
352 | if ((in_y < 0) || (in_y >= input_height)) { |
353 | T1* im2col_row_end = |
354 | im2col_row_start + (filter_width * input_depth); |
355 | std::fill(im2col_row_start, im2col_row_end, T1(0)); |
356 | } else { |
357 | // What we're doing here is trying to copy and fill the im2col |
358 | // buffer as efficiently as possible, using functions to set or |
359 | // duplicate values en masse. We know we don't have to worry about |
360 | // vertical edges because we dealt with that case above, so we |
361 | // just need to handle filters that overlap the left or right |
362 | // edges. Here's what that looks like: |
363 | // |
364 | // < left_zero_count > < center_copy_count > < right_zero_count > |
365 | // +------------------+---------------------+--------------------+ |
366 | // | (filter) | (image) | (filter) | |
367 | // +------------------+---------------------+--------------------+ |
368 | // in_x_origin 0 input_width in_x_end |
369 | // |
370 | // In reality it's unlikely that a filter patch will be wider |
371 | // than an input, but this shows all the edge cases. |
372 | // We use std::fill() to set the left and right sections to zeroes |
373 | // and std::copy() to copy over the input data for the center. |
374 | const int in_x_end = in_x_origin + filter_width; |
375 | const int left_zero_count = std::max(0, 0 - in_x_origin); |
376 | const int right_zero_count = std::max(0, in_x_end - input_width); |
377 | const int center_copy_count = |
378 | filter_width - (left_zero_count + right_zero_count); |
379 | if (left_zero_count > 0) { |
380 | T1* im2col_left_start = im2col_row_start; |
381 | T1* im2col_left_end = |
382 | im2col_left_start + (left_zero_count * input_depth); |
383 | std::fill(im2col_left_start, im2col_left_end, T1(0)); |
384 | } |
385 | if (center_copy_count > 0) { |
386 | const T1* input_row_start = |
387 | input_batch_start + (in_y * input_width * input_depth) + |
388 | (std::max(0, in_x_origin) * input_depth); |
389 | const T1* input_row_end = |
390 | input_row_start + (center_copy_count * input_depth); |
391 | T1* im2col_center_start = |
392 | im2col_row_start + (left_zero_count * input_depth); |
393 | std::copy(input_row_start, input_row_end, im2col_center_start); |
394 | } |
395 | if (right_zero_count > 0) { |
396 | T1* im2col_right_start = |
397 | im2col_row_start + |
398 | ((left_zero_count + center_copy_count) * input_depth); |
399 | T1* im2col_right_end = |
400 | im2col_right_start + (right_zero_count * input_depth); |
401 | std::fill(im2col_right_start, im2col_right_end, T1(0)); |
402 | } |
403 | } |
404 | } |
405 | } |
406 | // Now we've assembled a set of image patches into a matrix, apply a |
407 | // GEMM matrix multiply of the patches as rows, times the filter |
408 | // weights in columns, to get partial results in the output matrix. |
409 | const int how_many_patches = patch_index_end - patch_index_start; |
410 | const int m = how_many_patches; |
411 | const int n = filter_count; |
412 | const int k = filter_value_count; |
413 | const int lda = filter_value_count; |
414 | const int ldb = filter_count; |
415 | const int ldc = filter_count; |
416 | T3* chunk_output_data = output_data + (patch_index_start * filter_count); |
417 | TGemmFunctor gemm_functor; |
418 | gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb, |
419 | chunk_output_data, ldc); |
420 | } |
421 | } |
422 | }; |
423 | |
424 | } // namespace |
425 | |
426 | // This TensorFlow kernel class handles all of the IO and housekeeping for the |
427 | // functors that actually implement the underlying algorithm. To swap in |
428 | // different implementations of the main calculations, use a different |
429 | // TConvFunctor parameter when instantiating the template. |
430 | template <class T, class TConvFunctor> |
431 | class Conv2DUsingGemmOp : public BinaryOp<T> { |
432 | public: |
433 | explicit Conv2DUsingGemmOp(OpKernelConstruction* context) |
434 | : BinaryOp<T>(context) { |
435 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &strides_)); |
436 | string data_format; |
437 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
438 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
439 | errors::InvalidArgument("Invalid data format" )); |
440 | OP_REQUIRES(context, data_format_ == FORMAT_NHWC, |
441 | errors::InvalidArgument( |
442 | "Data format not supported by this kernel" , data_format)); |
443 | OP_REQUIRES(context, strides_.size() == 4, |
444 | errors::InvalidArgument("Sliding window strides field must " |
445 | "specify 4 dimensions" )); |
446 | const int64_t stride_n = GetTensorDim(strides_, data_format_, 'N'); |
447 | const int64_t stride_c = GetTensorDim(strides_, data_format_, 'C'); |
448 | OP_REQUIRES( |
449 | context, stride_n == 1 && stride_c == 1, |
450 | errors::InvalidArgument("Current implementation does not yet support " |
451 | "strides in the batch and depth dimensions." )); |
452 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
453 | } |
454 | |
455 | void Compute(OpKernelContext* context) override { |
456 | // Input tensor is of the following dimensions: |
457 | // [ batch, in_rows, in_cols, in_depth ] |
458 | const Tensor& input = context->input(0); |
459 | |
460 | // Input filter is of the following dimensions: |
461 | // [ filter_rows, filter_cols, in_depth, out_depth] |
462 | const Tensor& filter = context->input(1); |
463 | |
464 | // For 2D convolution, there should be 4 dimensions. |
465 | OP_REQUIRES(context, input.dims() == 4, |
466 | errors::InvalidArgument("input must be 4-dimensional" , |
467 | input.shape().DebugString())); |
468 | OP_REQUIRES(context, filter.dims() == 4, |
469 | errors::InvalidArgument("filter must be 4-dimensional: " , |
470 | filter.shape().DebugString())); |
471 | |
472 | for (int i = 0; i < 3; i++) { |
473 | OP_REQUIRES( |
474 | context, |
475 | FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), |
476 | errors::InvalidArgument("filter too large" )); |
477 | } |
478 | |
479 | // The last dimension for input is in_depth. It must be the same as the |
480 | // filter's in_depth. |
481 | const int64_t in_depth = GetTensorDim(input, data_format_, 'C'); |
482 | OP_REQUIRES(context, in_depth == filter.dim_size(2), |
483 | errors::InvalidArgument( |
484 | "input and filter must have the same depth: " , in_depth, |
485 | " vs " , filter.dim_size(2))); |
486 | |
487 | // The last dimension for filter is out_depth. |
488 | const int out_depth = static_cast<int>(filter.dim_size(3)); |
489 | |
490 | // The second dimension for input is rows/height. |
491 | // The first dimension for filter is rows/height. |
492 | const int64_t input_rows_raw = GetTensorDim(input, data_format_, 'H'); |
493 | OP_REQUIRES( |
494 | context, |
495 | FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()), |
496 | errors::InvalidArgument("Input rows too large" )); |
497 | const int input_rows = static_cast<int>(input_rows_raw); |
498 | const int filter_rows = static_cast<int>(filter.dim_size(0)); |
499 | |
500 | // The third dimension for input is columns/width. |
501 | // The second dimension for filter is columns/width. |
502 | const int64_t input_cols_raw = GetTensorDim(input, data_format_, 'W'); |
503 | OP_REQUIRES( |
504 | context, |
505 | FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()), |
506 | errors::InvalidArgument("Input cols too large" )); |
507 | const int input_cols = static_cast<int>(input_cols_raw); |
508 | const int filter_cols = static_cast<int>(filter.dim_size(1)); |
509 | |
510 | // The first dimension for input is batch. |
511 | const int64_t batch_raw = GetTensorDim(input, data_format_, 'N'); |
512 | OP_REQUIRES(context, |
513 | FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()), |
514 | errors::InvalidArgument("batch is too large" )); |
515 | const int batch = static_cast<int>(batch_raw); |
516 | |
517 | // For now we take the stride from the second and third dimensions only (we |
518 | // do not support striding on the batch or depth dimension). |
519 | const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); |
520 | const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); |
521 | |
522 | int64_t out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; |
523 | OP_REQUIRES_OK(context, |
524 | GetWindowedOutputSize(input_rows, filter_rows, stride_rows, |
525 | padding_, &out_rows, &pad_rows)); |
526 | OP_REQUIRES_OK(context, |
527 | GetWindowedOutputSize(input_cols, filter_cols, stride_cols, |
528 | padding_, &out_cols, &pad_cols)); |
529 | TensorShape out_shape = |
530 | ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); |
531 | |
532 | // Output tensor is of the following dimensions: |
533 | // [ in_batch, out_rows, out_cols, out_depth ] |
534 | Tensor* output = nullptr; |
535 | OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); |
536 | |
537 | VLOG(2) << "Conv2D: in_depth = " << in_depth |
538 | << ", input_cols = " << input_cols |
539 | << ", filter_cols = " << filter_cols |
540 | << ", input_rows = " << input_rows |
541 | << ", filter_rows = " << filter_rows |
542 | << ", stride_rows = " << stride_rows |
543 | << ", stride_cols = " << stride_cols |
544 | << ", out_depth = " << out_depth; |
545 | |
546 | // If there is nothing to compute, return. |
547 | if (out_shape.num_elements() == 0) { |
548 | return; |
549 | } |
550 | TConvFunctor conv_functor; |
551 | conv_functor(context, input.flat<T>().data(), batch, input_rows, input_cols, |
552 | in_depth, filter.flat<T>().data(), filter_rows, filter_cols, |
553 | out_depth, stride_rows, stride_cols, padding_, |
554 | output->flat<T>().data(), out_rows, out_cols); |
555 | } |
556 | |
557 | private: |
558 | std::vector<int32> strides_; |
559 | Padding padding_; |
560 | TensorFormat data_format_; |
561 | |
562 | TF_DISALLOW_COPY_AND_ASSIGN(Conv2DUsingGemmOp); |
563 | }; |
564 | |
565 | #define REGISTER_CPU(T) \ |
566 | REGISTER_KERNEL_BUILDER( \ |
567 | Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
568 | Conv2DUsingGemmOp< \ |
569 | T, Im2ColConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>); |
570 | |
571 | // Only register this GEMM-based implementation of Conv2d if the compiler flags |
572 | // request the implementation explicitly, since otherwise it will clash with the |
573 | // default EigenTensor-based kernel. |
574 | #if defined(USE_GEMM_FOR_CONV) |
575 | TF_CALL_bfloat16(REGISTER_CPU); |
576 | TF_CALL_half(REGISTER_CPU); |
577 | TF_CALL_float(REGISTER_CPU); |
578 | TF_CALL_double(REGISTER_CPU); |
579 | TF_CALL_int32(REGISTER_CPU); |
580 | #endif // USE_GEMM_FOR_CONV |
581 | |
582 | } // namespace tensorflow |
583 | |