1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define USE_EIGEN_TENSOR
17#define EIGEN_USE_THREADS
18
19#include "tensorflow/core/kernels/deep_conv2d.h"
20
21#include <stdlib.h>
22
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/kernels/winograd_transform.h"
25#include "tensorflow/core/util/work_sharder.h"
26
27namespace tensorflow {
28
29// DeepConv2D is a Conv2D implementation specialized for deep convolutions (i.e
30// large 'in_depth' and 'out_depth' product. See cost models below for details).
31//
32// DeepConv2D is implemented by computing the following equation:
33//
34// y = C[Ad * Bg]
35//
36// C: output transform matrix
37// A: input data transform matrix
38// B: filter transform matrix
39// d: vectorized data tile
40// g: vectorized filter tile
41// y: vectorized output tile
42//
43// The transform matrices and input, filter and output tile sizes are all
44// specified by the DeepConv2DTransform implementation selected at the
45// start of the DeepConv2D call, based on convolution parameters.
46
47// Approximate cost models for direct and deep convolutions.
48static int64_t GetDeepConvCost(int input_tile_rows, int input_tile_cols,
49 int out_tile_rows, int out_tile_cols,
50 int in_depth, int out_depth, int out_rows,
51 int out_cols) {
52 // Input transform cost.
53 const int64_t input_tile_spatial_size = input_tile_rows * input_tile_cols;
54 const int64_t input_transform_cost =
55 input_tile_spatial_size * input_tile_spatial_size * in_depth;
56
57 // Element-wise products (each product is a MatMul across depth).
58 const int64_t product_cost = input_tile_spatial_size * in_depth * out_depth;
59
60 // Output transform cost.
61 const int64_t output_tile_spatial_size = out_tile_rows * out_tile_cols;
62 const int64_t output_transform_cost =
63 output_tile_spatial_size * input_tile_spatial_size * out_depth;
64
65 // Calculate number of input tiles to process.
66 const int64_t row_tiles = (out_rows + out_tile_rows - 1) / out_tile_rows;
67 const int64_t col_tiles = (out_cols + out_tile_cols - 1) / out_tile_cols;
68 const int64_t num_tiles = row_tiles * col_tiles;
69
70 // Return total cost.
71 return num_tiles *
72 (input_transform_cost + product_cost + output_transform_cost);
73}
74
75static int64_t GetDirectConvCost(int filter_rows, int filter_cols, int in_depth,
76 int out_depth, int out_rows, int out_cols) {
77 return filter_rows * filter_cols * in_depth * out_depth * out_rows * out_cols;
78}
79
80// Reads environment variable 'env_var_name'.
81// Returns 'true' if environment variable is enabled, false otherwise.
82static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) {
83 const char* tf_env_var_val = getenv(env_var_name);
84 if (tf_env_var_val != nullptr) {
85 StringPiece tf_env_var_val_str(tf_env_var_val);
86 if (tf_env_var_val_str == "0") {
87 return false;
88 }
89 return true;
90 }
91 return default_val;
92}
93
94// Returns true if convolution can be computed efficiently by DeepConv2D,
95// returns false otherwise.
96// TODO(andydavis) Add support for other filter sizes and strides.
97// TODO(andydavis) Add support for autotuning.
98bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows,
99 int filter_cols, int in_depth, int out_depth,
100 int out_rows, int out_cols) {
101 // Check if convolution parameters are supported.
102 // TODO(andydavis) Add support for multiple filter sizes and strides.
103 if (stride_rows > 1 || stride_cols > 1 || filter_rows != 3 ||
104 filter_cols != 3) {
105 return false;
106 }
107
108 // Check if deep convolution is enabled by environment variable.
109 // NOTE: IF this environment variable name changes, update conv_ops_test.py.
110 if (!ReadBoolFromEnvVar("TF_USE_DEEP_CONV2D", false)) {
111 return false;
112 }
113
114 // Check if flop cost of deep convolution is less than direct convolution.
115 WinogradTransform<float> t;
116 const int64_t deep_conv_cost = GetDeepConvCost(
117 t.input_shape().rows, t.input_shape().cols, t.output_shape().rows,
118 t.output_shape().cols, in_depth, out_depth, out_rows, out_cols);
119 const int64_t direct_conv_cost = GetDirectConvCost(
120 filter_rows, filter_cols, in_depth, out_depth, out_rows, out_cols);
121
122 VLOG(2) << "CanUseDeepConv2D"
123 << " deep_conv_cost: " << deep_conv_cost
124 << " direct_conv_cost: " << direct_conv_cost << " deep_direct_ratio: "
125 << (static_cast<float>(deep_conv_cost) /
126 static_cast<float>(direct_conv_cost))
127 << " use_deep_conv: " << (deep_conv_cost < direct_conv_cost);
128 return deep_conv_cost < direct_conv_cost;
129}
130
131typedef Eigen::ThreadPoolDevice CPUDevice;
132
133// Copies data from 'filter_in' to 'filter_buf' along 'in_depth' dimension.
134//
135// filter_in:
136// [filter_rows, filter_cols, in_depth, out_depth]
137//
138// filter_buf:
139// [base_filter_rows, base_filter_cols, in_depth]
140//
141template <typename T>
142struct CopyFilterDepth {
143 void operator()(const Conv2DArgs& args, const T* filter_in, T* filter_buf) {
144 typedef typename Eigen::internal::packet_traits<T>::type Packet;
145 static constexpr int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
146
147 const int64_t vectorized_size = args.in_depth / kPacketSize;
148 const int64_t scalar_size = args.in_depth % kPacketSize;
149 const int64_t input_stride = args.out_depth * kPacketSize;
150
151 // Copy vectorized portion of depth dimension.
152 for (int64_t d = 0; d < vectorized_size; ++d) {
153 auto v = Eigen::internal::pgather<T, Packet>(filter_in + d * input_stride,
154 args.out_depth);
155 Eigen::internal::pstoreu<T>(filter_buf + d * kPacketSize, v);
156 }
157 // Copy scalar portion of inner dimension.
158 const int64_t in_scalar_base = vectorized_size * input_stride;
159 const int64_t buf_scalar_base = vectorized_size * kPacketSize;
160 for (int64_t d = 0; d < scalar_size; ++d) {
161 filter_buf[buf_scalar_base + d] =
162 filter_in[in_scalar_base + d * args.out_depth];
163 }
164 }
165};
166
167// Computes transform of 'num_filters' from 'filter_in' starting at 'od_start'.
168// Intermediate results (i.e. output of MatMul('transform_matrix', 'filter_in'))
169// are stored in 'out_buffer'. The final result is copied from 'out_buffer' to
170// 'filter_out' at the coordinate stride required by the transformed filter
171// data layout.
172//
173// filter_in:
174// [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
175// in_depth]
176//
177// filter_out:
178// [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
179//
180// transform_matrix:
181// [tile_spatial_size, base_filter_spatial_size]
182//
183// out_buffer:
184// [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
185
186template <typename T>
187struct ComputeFilterRangeTransform {
188 typedef typename Eigen::internal::packet_traits<T>::type Packet;
189 static constexpr int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
190
191 typedef Eigen::Map<
192 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
193 MatrixMap;
194 typedef Eigen::Map<
195 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
196 ConstMatrixMap;
197
198 void operator()(const Conv2DArgs& args,
199 const DeepConv2DTransform<T>* transform,
200 const int64_t od_start, const int64_t num_filters,
201 const int64_t shard_rows, const int64_t shard_cols,
202 const T* filter_in, const int64_t in_stride,
203 const int64_t out_stride, const T* transform_matrix,
204 T* out_buffer, T* filter_out) {
205 namespace ei = Eigen::internal;
206
207 const int64_t in_depth = args.in_depth;
208 const int64_t base_filter_rows = transform->filter_shape().rows;
209 const int64_t base_filter_cols = transform->filter_shape().cols;
210 const int64_t base_filter_spatial_size =
211 base_filter_rows * base_filter_cols;
212 const int64_t tile_rows = transform->input_shape().rows;
213 const int64_t tile_cols = transform->input_shape().cols;
214 const int64_t tile_spatial_size = tile_rows * tile_cols;
215
216 // Compute transform of 'num_filters' by 'transform_matrix'.
217 ConstMatrixMap A(transform_matrix, tile_spatial_size,
218 base_filter_spatial_size);
219 ConstMatrixMap B(filter_in, base_filter_spatial_size, in_stride);
220 MatrixMap C(out_buffer, tile_spatial_size, in_stride);
221
222 C.noalias() = A * B;
223
224 // Copy 'out_buffer' to 'filter_out' at required filter output stride.
225 const int64_t scalar_size = in_depth % kPacketSize;
226 const int64_t vectorized_size = in_depth / kPacketSize;
227
228 const int64_t shard_stride = args.in_depth;
229 const int64_t out_depth_stride = shard_rows * shard_cols * shard_stride;
230
231 for (int64_t od = 0; od < num_filters; ++od) {
232 const int64_t out_depth_buf_base = od * out_depth_stride;
233 const int64_t out_depth_base = (od_start + od) * out_depth_stride;
234
235 // TODO(andydavis) Shard filters that are multiples of base filter sizes.
236 for (int64_t s_r = 0; s_r < shard_rows; ++s_r) {
237 for (int64_t s_c = 0; s_c < shard_cols; ++s_c) {
238 const int64_t shard_base = shard_stride * (s_r * shard_cols + s_c);
239
240 for (int64_t i = 0; i < tile_spatial_size; ++i) {
241 const int64_t in_base =
242 i * in_stride + out_depth_buf_base + shard_base;
243 const int64_t out_base =
244 i * out_stride + out_depth_base + shard_base;
245 // Copy vectorized portion of 'in_depth'.
246 for (int64_t d = 0; d < vectorized_size; ++d) {
247 auto v =
248 ei::ploadu<Packet>(out_buffer + in_base + d * kPacketSize);
249 ei::pstoreu<T>(filter_out + out_base + d * kPacketSize, v);
250 }
251 // Transform scalar portion of 'in_depth'.
252 const int64_t scalar_base = vectorized_size * kPacketSize;
253 for (int64_t d = 0; d < scalar_size; ++d) {
254 filter_out[out_base + scalar_base + d] =
255 out_buffer[in_base + scalar_base + d];
256 }
257 }
258 }
259 }
260 }
261 }
262};
263
264// Transforms 'num_filters' from 'filter_in', starting at 'od_start'.
265// For each filter in 'num_filters', copies data for all filter shards from
266// 'filter_in' into 'filter_buf', adding zero-padding as needed.
267// Calls ComputeFilterRangeTransform to compute filter transform of data
268// in 'filter_buf' by 'transform_matrix', storing the result in 'filter_out'.
269//
270// filter_in:
271// [filter_rows, filter_cols, in_depth, out_depth]
272//
273// filter_out:
274// [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
275//
276// filter_buffer:
277// [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
278// in_depth]
279//
280// transform_matrix:
281// [tile_spatial_size, base_filter_spatial_size]
282//
283// out_buffer:
284// [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
285//
286
287template <typename T>
288struct TransformFilterRange {
289 void operator()(const Conv2DArgs& args,
290 const DeepConv2DTransform<T>* transform,
291 const int64_t od_start, const int64_t od_limit,
292 const T* filter_in, const T* transform_matrix, T* out_buffer,
293 T* filter_buf, T* filter_out) {
294 const int64_t num_filters = od_limit - od_start;
295 const int64_t base_filter_rows = transform->filter_shape().rows;
296 const int64_t base_filter_cols = transform->filter_shape().cols;
297 const int64_t base_filter_spatial_size =
298 base_filter_rows * base_filter_cols;
299
300 // Compute number of filter shards.
301 const int64_t residual_row =
302 std::max(int64_t{0}, args.filter_rows - base_filter_rows);
303 const int64_t shard_rows = 1 + (residual_row + 2 - 1) / 2;
304
305 const int64_t residual_col =
306 std::max(int64_t{0}, args.filter_cols - base_filter_cols);
307 const int64_t shard_cols = 1 + (residual_col + 2 - 1) / 2;
308
309 // Compute strides to be used for input and output IO.
310 const int64_t shard_stride = args.in_depth;
311 const int64_t out_depth_stride = shard_rows * shard_cols * shard_stride;
312 const int64_t coord_stride = out_depth_stride * args.out_depth;
313 const int64_t filter_buf_stride =
314 num_filters * shard_rows * shard_cols * args.in_depth;
315 const int64_t tile_stride_rows = transform->output_shape().rows;
316 const int64_t tile_stride_cols = transform->output_shape().cols;
317
318 const int64_t filter_buf_size = base_filter_spatial_size * num_filters *
319 shard_rows * shard_cols * args.in_depth;
320 memset(filter_buf, 0, sizeof(T) * filter_buf_size);
321
322 // Copy filter range into 'filter_buf'.
323 for (int64_t od = 0; od < num_filters; ++od) {
324 const int64_t out_depth_base = od * out_depth_stride;
325
326 // TODO(andydavis) Shard filters that are multiples of base filter sizes.
327 for (int64_t s_r = 0; s_r < shard_rows; ++s_r) {
328 const int64_t row_offset = s_r == 0 ? 0 : 1;
329
330 for (int64_t s_c = 0; s_c < shard_cols; ++s_c) {
331 const int64_t col_offset = s_c == 0 ? 0 : 1;
332 const int64_t f_r_start = s_r * tile_stride_rows;
333 const int64_t f_c_start = s_c * tile_stride_cols;
334
335 const int64_t shard_base = shard_stride * (s_r * shard_cols + s_c);
336
337 for (int64_t b_r = row_offset; b_r < base_filter_rows; ++b_r) {
338 const int64_t f_r = f_r_start + b_r;
339 if (f_r >= args.filter_rows) continue;
340
341 for (int64_t b_c = col_offset; b_c < base_filter_cols; ++b_c) {
342 const int64_t f_c = f_c_start + b_c;
343 if (f_c >= args.filter_cols) continue;
344
345 const int64_t in_index =
346 args.out_depth *
347 (args.in_depth * (f_r * args.filter_cols + f_c)) +
348 (od_start + od);
349
350 const int64_t buf_index =
351 filter_buf_stride * (b_r * base_filter_cols + b_c) +
352 out_depth_base + shard_base;
353
354 CopyFilterDepth<T>()(args, filter_in + in_index,
355 filter_buf + buf_index);
356 }
357 }
358 }
359 }
360 }
361
362 // Compute filter transform of data in 'filter_buf' by 'transform_matrix'.
363 // Intermediate results are stored in 'out_buffer'.
364 // Final results are stored in 'filter_out'.
365 ComputeFilterRangeTransform<T>()(args, transform, od_start, num_filters,
366 shard_rows, shard_cols, filter_buf,
367 filter_buf_stride, coord_stride,
368 transform_matrix, out_buffer, filter_out);
369 }
370};
371
372// Transforms all filters from 'filter_in', storing result in 'filter_out'.
373//
374// filter_in:
375// [filter_rows, filter_cols, in_depth, out_depth]
376//
377// filter_out:
378// [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
379//
380template <typename T>
381struct TransformFilters {
382 void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
383 const DeepConv2DTransform<T>* transform,
384 const int64_t filter_shards_row,
385 const int64_t filter_shards_col, const T* filter_in,
386 T* filter_out) {
387 const int64_t in_depth = args.in_depth;
388 const int64_t out_depth = args.out_depth;
389
390 const int64_t tile_rows = transform->input_shape().rows;
391 const int64_t tile_cols = transform->input_shape().cols;
392 const int64_t tile_spatial_size = tile_rows * tile_cols;
393
394 const int64_t base_filter_rows = transform->filter_shape().rows;
395 const int64_t base_filter_cols = transform->filter_shape().cols;
396 const int64_t base_filter_spatial_size =
397 base_filter_rows * base_filter_cols;
398
399 const int64_t filter_shards_total = filter_shards_row * filter_shards_col;
400
401 // Calculate filter transform batch based on cache/filter sizes.
402
403 // Cache budget (based on L2 cache size = 256KB).
404 // TODO(andydavis) Read cache size from system.
405 const int64_t cache_size = (256LL << 10) / sizeof(T);
406
407 // Fixed cost.
408 const int64_t filter_transform_matrix_size =
409 tile_spatial_size * base_filter_spatial_size;
410
411 // Per-filter costs.
412 const int64_t filter_total_size =
413 base_filter_spatial_size * in_depth * filter_shards_total;
414
415 const int64_t filter_transform_buffer_size =
416 base_filter_spatial_size * filter_shards_total * in_depth;
417
418 const int64_t filter_out_buf_size =
419 tile_spatial_size * filter_shards_total * in_depth;
420
421 // Total per-filter costs.
422 const int64_t per_filter_cost =
423 filter_total_size + filter_transform_buffer_size + filter_out_buf_size;
424
425 // Remove fixed cost and divide by per-filter cost.
426 const int64_t num_filters_cache =
427 std::max(int64_t{1},
428 (cache_size - filter_transform_matrix_size) / per_filter_cost);
429 const int64_t num_filters_transform =
430 std::min(out_depth, num_filters_cache);
431
432 // Allocate buffer for filter transform matrix:
433 // [tile_spatial_size, base_filter_spatial_size]
434 Tensor filter_transform_matrix;
435 OP_REQUIRES_OK(
436 ctx, ctx->allocate_temp(
437 DataTypeToEnum<T>::value,
438 TensorShape({tile_spatial_size, base_filter_spatial_size}),
439 &filter_transform_matrix));
440 T* transform_matrix = filter_transform_matrix.template flat<T>().data();
441 transform->GetFilterTransformMatrix(
442 tile_spatial_size, base_filter_spatial_size, transform_matrix);
443
444 auto shard = [&ctx, &args, &transform, &base_filter_rows, &base_filter_cols,
445 &num_filters_transform, &in_depth, &filter_shards_row,
446 &filter_shards_col, &tile_spatial_size, &filter_in,
447 &transform_matrix,
448 &filter_out](int64_t start, int64_t limit) {
449 // Allocate buffer for pre-processed filter:
450 // [base_filter_rows, base_filter_cols, num_filters_transform, in_depth]
451 //
452 Tensor filter_transform_buffer;
453 OP_REQUIRES_OK(ctx,
454 ctx->allocate_temp(
455 DataTypeToEnum<T>::value,
456 TensorShape({base_filter_rows, base_filter_cols,
457 num_filters_transform, filter_shards_row,
458 filter_shards_col, in_depth}),
459 &filter_transform_buffer));
460 T* filter_buf = filter_transform_buffer.template flat<T>().data();
461
462 // Allocate buffer for output filter transform matrix:
463 // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
464 Tensor filter_output_buffer;
465 OP_REQUIRES_OK(
466 ctx,
467 ctx->allocate_temp(
468 DataTypeToEnum<T>::value,
469 TensorShape({tile_spatial_size, num_filters_transform,
470 filter_shards_row, filter_shards_col, in_depth}),
471 &filter_output_buffer));
472 T* out_buffer = filter_output_buffer.template flat<T>().data();
473
474 const int64_t num_filters = limit - start;
475 const int64_t od_unroll = num_filters_transform;
476 const int64_t od_unroll_limit = (num_filters / od_unroll) * od_unroll;
477
478 for (int64_t od = start; od < od_unroll_limit; od += od_unroll) {
479 TransformFilterRange<T>()(args, transform, od, od + od_unroll,
480 filter_in, transform_matrix, out_buffer,
481 filter_buf, filter_out);
482 }
483
484 if (od_unroll_limit < limit) {
485 TransformFilterRange<T>()(args, transform, od_unroll_limit, limit,
486 filter_in, transform_matrix, out_buffer,
487 filter_buf, filter_out);
488 }
489 };
490 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
491
492 const int64_t shard_cost = args.filter_rows * args.filter_cols * in_depth *
493 filter_shards_total * tile_spatial_size;
494 // TODO(andydavis) Resolve performance of multi-threaded filter transforms.
495 Shard(1, worker_threads.workers, out_depth, shard_cost, shard);
496 }
497};
498
499// Packs transformed filters stored in 'lhs_input' into 'lhs_block' in a
500// gemm-kernel friendly data layout.
501//
502// Data layout for 'lhs_block':
503// [out_depth, shard_rows, shard_cols, in_depth].
504
505template <typename T>
506class GemmFilterPacker {
507 public:
508 typedef Eigen::internal::const_blas_data_mapper<T, int64_t, Eigen::RowMajor>
509 LhsMapper;
510 typedef Eigen::internal::gebp_traits<T, T> Traits;
511 Eigen::internal::gemm_pack_lhs<
512 T, int64_t, LhsMapper, Traits::mr, Traits::LhsProgress,
513 typename Traits::LhsPacket4Packing, Eigen::RowMajor>
514 pack_lhs;
515
516 GemmFilterPacker(const int64_t rows, const int64_t depth, const T* lhs_input,
517 T* lhs_block)
518 : rows_(rows),
519 depth_(depth),
520 lhs_block_(lhs_block),
521 lhs_mapper_(lhs_input, depth_) {}
522
523 void Run() { pack_lhs(lhs_block_, lhs_mapper_, depth_, rows_); }
524
525 private:
526 const int64_t rows_;
527 const int64_t depth_;
528 T* lhs_block_;
529 LhsMapper lhs_mapper_;
530};
531
532// Packs transformed filter stored in 'filter_transform_data' into
533// 'packed_filters' to be used by GemmState.
534template <typename T>
535struct PackFilters {
536 void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
537 const int64_t tile_spatial_size,
538 const int64_t filter_shards_row,
539 const int64_t filter_shards_col,
540 const T* filter_transform_data,
541 std::vector<Tensor>* packed_filters) {
542 const int64_t in_depth = args.in_depth;
543 const int64_t out_depth = args.out_depth;
544 const int64_t num_filters =
545 filter_shards_row * filter_shards_col * out_depth;
546
547 auto shard = [&ctx, &packed_filters, &filter_transform_data, &in_depth,
548 &out_depth, &filter_shards_row, &filter_shards_col,
549 &num_filters](int64_t start, int64_t limit) {
550 const int64_t filter_coord_stride = num_filters * in_depth;
551 for (int64_t i = start; i < limit; ++i) {
552 // Allocate filter buffer [out_depth, shard_rows, shard_cols, in_depth].
553 OP_REQUIRES_OK(
554 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
555 TensorShape({out_depth, filter_shards_row,
556 filter_shards_col, in_depth}),
557 &(*packed_filters)[i]));
558 T* packed_filter = (*packed_filters)[i].template flat<T>().data();
559 // Pack filters.
560 GemmFilterPacker<T> packer(
561 num_filters, in_depth,
562 filter_transform_data + i * filter_coord_stride, packed_filter);
563 packer.Run();
564 }
565 };
566 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
567 Shard(worker_threads.num_threads, worker_threads.workers, tile_spatial_size,
568 num_filters * in_depth, shard);
569 }
570};
571
572// Computes the product of filters stored in 'lhs_block' and input tiles
573// stored in 'rhs_block', storing output in 'out_buffer'.
574//
575// Data layout for 'lhs_block':
576// [out_depth, shard_rows, shard_cols, in_depth].
577//
578// Data layout for 'rhs_block':
579// [num_tiles, in_depth]
580//
581// Data layout for 'out_buffer':
582// [num_tiles, out_depth, shard_rows, shard_cols]
583
584template <typename T>
585class GemmState {
586 public:
587 typedef Eigen::internal::const_blas_data_mapper<T, int64_t, Eigen::ColMajor>
588 RhsMapper;
589 typedef Eigen::internal::blas_data_mapper<T, int64_t, Eigen::ColMajor>
590 OutputMapper;
591 typedef Eigen::internal::gebp_traits<T, T> Traits;
592
593 Eigen::internal::gemm_pack_rhs<T, int64_t, RhsMapper, Traits::nr,
594 Eigen::ColMajor>
595 pack_rhs;
596 Eigen::internal::gebp_kernel<T, T, int64_t, OutputMapper, Traits::mr,
597 Traits::nr, false, false>
598 gebp;
599
600 GemmState(const int64_t rows, const int64_t cols, const int64_t depth,
601 const int64_t out_buffer_size, const T* lhs_block,
602 const T* rhs_input, T* rhs_block, T* out_buffer)
603 : rows_(rows),
604 cols_(cols),
605 depth_(depth),
606 out_buffer_size_(out_buffer_size),
607 lhs_block_(lhs_block),
608 rhs_block_(rhs_block),
609 out_buffer_(out_buffer),
610 rhs_mapper_(rhs_input, depth_),
611 out_mapper_(out_buffer, rows_) {}
612
613 void PackRhs() { pack_rhs(rhs_block_, rhs_mapper_, depth_, cols_); }
614
615 void Compute() {
616 memset(out_buffer_, 0, sizeof(T) * out_buffer_size_);
617 gebp(out_mapper_, lhs_block_, rhs_block_, rows_, depth_, cols_, 1.0);
618 }
619
620 private:
621 const int64_t rows_;
622 const int64_t cols_;
623 const int64_t depth_;
624 const int64_t out_buffer_size_;
625 const T* lhs_block_;
626 T* rhs_block_;
627 T* out_buffer_;
628 RhsMapper rhs_mapper_;
629 OutputMapper out_mapper_;
630};
631
632// Copies an input tile from 'input' into 'tile_buffer'.
633//
634// input:
635// [in_rows, in_cols, in_depth]
636//
637// tile_buffer:
638// [tile_rows, tile_cols, num_tiles, in_depth]
639
640template <typename T>
641struct CopyInputTile {
642 void operator()(const Conv2DArgs& args,
643 const DeepConv2DTransform<T>* transform,
644 const int64_t num_tiles, const int64_t in_r_start,
645 const int64_t in_c_start, const T* input, T* tile_buffer) {
646 typedef typename Eigen::internal::packet_traits<T>::type Packet;
647 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
648
649 const int64_t tile_rows = transform->input_shape().rows;
650 const int64_t tile_cols = transform->input_shape().cols;
651 const int64_t coord_stride = num_tiles * args.in_depth;
652
653 // Calculate vectorized and scalar (residual) lengths for 'in_depth'.
654 const int64_t input_vectorized_size =
655 (args.in_depth / kPacketSize) * kPacketSize;
656 const int64_t input_scalar_size = args.in_depth % kPacketSize;
657
658 for (int64_t r = 0; r < tile_rows; ++r) {
659 const int64_t in_r = in_r_start + r;
660 if (in_r < 0 || in_r >= args.in_rows) continue;
661
662 for (int64_t c = 0; c < tile_cols; ++c) {
663 const int64_t in_c = in_c_start + c;
664 if (in_c < 0 || in_c >= args.in_cols) continue;
665
666 auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth;
667 auto* tile = tile_buffer + coord_stride * (r * tile_rows + c);
668 // Copy vectorized portion of depth dimension.
669 for (int64_t d = 0; d < input_vectorized_size; d += kPacketSize) {
670 auto v = Eigen::internal::ploadu<Packet>(in + d);
671 Eigen::internal::pstoreu<T>(tile, v);
672 tile += kPacketSize;
673 }
674 // Copy scalar portion of inner dimension.
675 for (int64_t d = 0; d < input_scalar_size; ++d) {
676 tile[d] = in[input_vectorized_size + d];
677 }
678 }
679 }
680 }
681};
682
683// Transforms 'num_tiles' tiles from 'input' by 'transform_matrix', storing the
684// final result in 'tile_transform'.
685// Intermediate results are stored in 'tile_buffer'.
686//
687// input:
688// [in_rows, in_cols, in_depth]
689// tile_buffer:
690// [tile_rows, tile_cols, num_tiles, in_depth]
691// tile_transform_matrix:
692// [tile_spatial_size, tile_spatial_size]
693// tile_transform:
694// [tile_rows, tile_cols, num_tiles, in_depth]
695
696template <typename T>
697struct TransformInputTiles {
698 typedef Eigen::Map<
699 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
700 MatrixMap;
701 typedef Eigen::Map<
702 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
703 ConstMatrixMap;
704
705 void operator()(const Conv2DArgs& args,
706 const DeepConv2DTransform<T>* transform,
707 const int64_t num_tiles, const int64_t in_r_start,
708 const int64_t in_c_start, const T* input,
709 const T* transform_matrix, T* tile_buffer,
710 T* tile_transform) {
711 const int64_t tile_rows = transform->input_shape().rows;
712 const int64_t tile_cols = transform->input_shape().cols;
713 const int64_t tile_spatial_size = tile_rows * tile_cols;
714 const int64_t tile_stride_cols = transform->output_shape().cols;
715 const int64_t coord_stride = num_tiles * args.in_depth;
716 const int64_t num_tiles_stride = args.in_depth;
717
718 memset(tile_buffer, 0, sizeof(T) * tile_spatial_size * coord_stride);
719 const int64_t in_r = in_r_start;
720 for (int64_t t = 0; t < num_tiles; ++t) {
721 const int64_t num_tiles_base = t * num_tiles_stride;
722 const int64_t in_c = in_c_start + t * tile_stride_cols;
723 CopyInputTile<T>()(args, transform, num_tiles, in_r, in_c, input,
724 tile_buffer + num_tiles_base);
725 }
726
727 ConstMatrixMap A(transform_matrix, tile_spatial_size, tile_spatial_size);
728 ConstMatrixMap B(tile_buffer, tile_spatial_size, coord_stride);
729 MatrixMap C(tile_transform, tile_spatial_size, coord_stride);
730
731 C.noalias() = A * B;
732 }
733};
734
735// Transforms output tiles from buffer by 'out_transform_matrix', storing
736// final result in 'output' (intermediate results stored in 'out_buffer').
737//
738// out_buffer:
739// [tile_rows, tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
740//
741// output transform buffer:
742// [out_tile_rows, out_tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
743//
744// output:
745// [out_rows, out_cols, out_depth]
746//
747
748template <typename T>
749struct TransformOutputTile {
750 typedef Eigen::Map<
751 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
752 MatrixMap;
753 typedef Eigen::Map<
754 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
755 ConstMatrixMap;
756
757 void operator()(const Conv2DArgs& args,
758 const DeepConv2DTransform<T>* transform,
759 const int64_t num_tiles, const int64_t in_r,
760 const int64_t in_c, const int64_t filter_shards_row,
761 const int64_t filter_shards_col,
762 const T* out_transform_matrix, const T* out_buffer,
763 T* out_transform_buffer, T* output) {
764 const int64_t tile_rows = transform->input_shape().rows;
765 const int64_t tile_cols = transform->input_shape().cols;
766 const int64_t tile_spatial_size = tile_rows * tile_cols;
767
768 const int64_t out_buf_stride =
769 num_tiles * args.out_depth * filter_shards_row * filter_shards_col;
770
771 const int64_t out_tile_rows = transform->output_shape().rows;
772 const int64_t out_tile_cols = transform->output_shape().cols;
773 const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
774
775 // Compute output transform.
776 ConstMatrixMap A(out_transform_matrix, out_tile_spatial_size,
777 tile_spatial_size);
778 ConstMatrixMap B(out_buffer, tile_spatial_size, out_buf_stride);
779 MatrixMap C(out_transform_buffer, out_tile_spatial_size, out_buf_stride);
780
781 C.noalias() = A * B;
782
783 const int64_t tile_stride_rows = transform->output_shape().rows;
784 const int64_t tile_stride_cols = transform->output_shape().cols;
785
786 const int64_t out_depth_stride = filter_shards_row * filter_shards_col;
787 const int64_t num_tiles_stride = args.out_depth * out_depth_stride;
788
789 // Copy transformed output from 'out_transform_buffer' to proper index
790 // in 'output'. Note that some outputs at boundaries can be discarded.
791 for (int64_t t = 0; t < num_tiles; ++t) {
792 const int64_t tile_base = t * num_tiles_stride;
793
794 for (int64_t od = 0; od < args.out_depth; ++od) {
795 const int64_t out_depth_base = od * out_depth_stride;
796
797 // TODO(andydavis) Update filter sharding scheme in the next CL.
798 for (int64_t sr = 0; sr < filter_shards_row; ++sr) {
799 for (int64_t sc = 0; sc < filter_shards_col; ++sc) {
800 const int64_t shard_base = sr * filter_shards_col + sc;
801 const int64_t out_buf_base =
802 tile_base + out_depth_base + shard_base;
803
804 // Calculate output indices and outputs to drop (if needed).
805 const int64_t out_r_start =
806 in_r + args.pad_rows - sr * tile_stride_rows;
807 // NOTE: The index 't' for 'num_tiles is used in index calculation
808 // for 'out_c_start' because we 'num_tiles' progresses along the
809 // column dimension.
810 const int64_t out_c_start = (in_c + t * tile_stride_cols) +
811 args.pad_cols - sc * tile_stride_cols;
812
813 if (out_r_start < 0 || out_r_start >= args.out_rows ||
814 out_c_start < 0 || out_c_start >= args.out_cols) {
815 continue; // Skip un-needed outputs.
816 }
817
818 // Increment output if not first filter shard.
819 const bool inc_output = (sr == 0 && sc == 0) ? false : true;
820
821 for (int64_t ot_row = 0; ot_row < out_tile_rows; ++ot_row) {
822 const int64_t out_r = out_r_start + ot_row;
823 if (out_r >= args.out_rows) continue;
824
825 for (int64_t ot_col = 0; ot_col < out_tile_cols; ++ot_col) {
826 const int64_t out_c = out_c_start + ot_col;
827 if (out_c >= args.out_cols) continue;
828
829 // Calculate out tile indexl
830 const int64_t out_buf_index = ot_row * out_tile_cols + ot_col;
831 // Read output value from buffer.
832 const T out_val =
833 out_transform_buffer[out_buf_base +
834 out_buf_index * out_buf_stride];
835 // Calculate output index.
836 const int64_t output_index =
837 args.out_depth * (out_r * args.out_cols + out_c) + od;
838 // Update output.
839 if (inc_output) {
840 output[output_index] += out_val;
841 } else {
842 output[output_index] = out_val;
843 }
844 }
845 }
846 }
847 }
848 }
849 }
850 }
851};
852
853template <typename T>
854struct Conv2DState {
855 Conv2DState(const int64_t tile_spatial_size, const int64_t filter_shards_row,
856 const int64_t filter_shards_col, const T* input,
857 const T* tile_transform_matrix, const T* output_transform_matrix,
858 T* buffer1, T* buffer2, T* packed_tile_buffer,
859 T* gemm_output_buffer)
860 : tile_spatial_size(tile_spatial_size),
861 filter_shards_row(filter_shards_row),
862 filter_shards_col(filter_shards_col),
863 input(input),
864 tile_transform_matrix(tile_transform_matrix),
865 output_transform_matrix(output_transform_matrix),
866 buffer1(buffer1),
867 buffer2(buffer2),
868 packed_tile_buffer(packed_tile_buffer),
869 gemm_output_buffer(gemm_output_buffer) {}
870
871 const int64_t tile_spatial_size;
872 const int64_t filter_shards_row;
873 const int64_t filter_shards_col;
874 const T* input;
875 const T* tile_transform_matrix;
876 const T* output_transform_matrix;
877 T* buffer1;
878 T* buffer2;
879 T* packed_tile_buffer;
880 T* gemm_output_buffer;
881};
882
883// Computes Conv2D for 'num_tiles' input tiles from 'input' starting at
884// (in_r, in_c), storing the results of the computation in 'output'.
885// Details:
886// *) Transforms 'num_tiles' input tiles into 'tile_transform_buffer'.
887// *) Computes point-wise MatMuls of 'num_tiles' input tiles with all filters.
888// *) Transforms output tiles, and stores result to 'output'.
889
890// TODO(andydavis) Maybe pass Conv2DState into TransformInput/Output functions.
891template <typename T>
892struct ComputeConv2D {
893 void operator()(const Conv2DArgs& args,
894 const DeepConv2DTransform<T>* transform,
895 const Conv2DState<T>& cs, const int64_t in_r,
896 const int64_t in_c, const int64_t num_tiles,
897 const std::vector<Tensor>& packed_filters, const T* input,
898 T* output) {
899 // Transform input tiles.
900 TransformInputTiles<T>()(args, transform, num_tiles, in_r, in_c, input,
901 cs.tile_transform_matrix, cs.buffer1, cs.buffer2);
902
903 // Compute element-wise product (each a MatMul): input tiles X filters.
904 const int64_t in_depth = args.in_depth;
905 const int64_t out_depth = args.out_depth;
906 const int64_t num_filters =
907 cs.filter_shards_row * cs.filter_shards_col * out_depth;
908 const int64_t tile_coord_stride = num_tiles * in_depth;
909 const int64_t gemm_out_buf_size = num_tiles * num_filters;
910 const int64_t gemm_out_buf_bytes = gemm_out_buf_size * sizeof(T);
911
912 for (int64_t i = 0; i < cs.tile_spatial_size; ++i) {
913 GemmState<T> gemm(num_filters, num_tiles, in_depth, gemm_out_buf_size,
914 packed_filters[i].template flat<T>().data(),
915 cs.buffer2 + i * tile_coord_stride,
916 cs.packed_tile_buffer, cs.gemm_output_buffer);
917 // Pack tile buffer.
918 gemm.PackRhs();
919 // Compute product.
920 gemm.Compute();
921 // Copy to larger output buffer without alignment requirements.
922 memcpy(cs.buffer1 + i * gemm_out_buf_size, cs.gemm_output_buffer,
923 gemm_out_buf_bytes);
924 }
925
926 // Transform output.
927 TransformOutputTile<T>()(args, transform, num_tiles, in_r, in_c,
928 cs.filter_shards_row, cs.filter_shards_col,
929 cs.output_transform_matrix, cs.buffer1, cs.buffer2,
930 output);
931 }
932};
933
934namespace functor {
935
936// Conv2D operation specialized for deep convolutions (i.e. large
937// in_depth * out_depth).
938// Details:
939// *) Transforms and packs filters from 'filter' in parallel.
940// *) Computes Conv2D parallelized across 'batch' dimension.
941// *) Each thread loops over images in its batch shard, copying 'num_tiles'
942// input tiles into a local buffer, and computing the Conv2D output of
943// these tiles by all filters.
944
945// TODO(andydavis) Improve the performance of boundary cases where the input
946// tile extends past the limit, and wasted outputs are computed. This overhead
947// is at most 2/n, where 'n' is the max(out_rows, out_cols), and so is worse
948// for smaller spatial sizes.
949// TODO(andydavis) Improve the performance of sharded filters.
950template <typename T>
951struct DeepConv2D<CPUDevice, T> {
952 void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input,
953 const T* filter, T* output) {
954 // TODO(andydavis) Add function to select transform based on conv params.
955 std::unique_ptr<DeepConv2DTransform<T>> transform(new WinogradTransform<T>);
956
957 const int64_t in_depth = args.in_depth;
958 const int64_t out_depth = args.out_depth;
959
960 const int64_t tile_rows = transform->input_shape().rows;
961 const int64_t tile_cols = transform->input_shape().cols;
962 const int64_t tile_spatial_size = tile_rows * tile_cols;
963
964 const int64_t out_tile_rows = transform->output_shape().rows;
965 const int64_t out_tile_cols = transform->output_shape().cols;
966 const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
967
968 const int64_t base_filter_rows = transform->filter_shape().rows;
969
970 const int64_t filter_residual_row =
971 std::max(int64_t{0}, args.filter_rows - base_filter_rows);
972 const int64_t filter_shards_row = 1 + (filter_residual_row + 2 - 1) / 2;
973
974 const int64_t filter_residual_col =
975 std::max(int64_t{0}, args.filter_cols - base_filter_rows);
976 const int64_t filter_shards_col = 1 + (filter_residual_col + 2 - 1) / 2;
977
978 // Allocate buffer for transformed filters.
979 Tensor filter_transform;
980 OP_REQUIRES_OK(
981 ctx, ctx->allocate_temp(
982 DataTypeToEnum<T>::value,
983 TensorShape({tile_rows, tile_cols, out_depth,
984 filter_shards_row, filter_shards_col, in_depth}),
985 &filter_transform));
986 T* filter_transform_data = filter_transform.template flat<T>().data();
987
988 // Transform filters.
989 TransformFilters<T>()(ctx, args, transform.get(), filter_shards_row,
990 filter_shards_col, filter, filter_transform_data);
991
992 // Pack filters.
993 std::vector<Tensor> packed_filters(tile_spatial_size);
994 PackFilters<T>()(ctx, args, tile_spatial_size, filter_shards_row,
995 filter_shards_col, filter_transform_data, &packed_filters);
996
997 // Allocate buffer for tile transform matrix.
998 Tensor tile_transform_matrix_tensor;
999 OP_REQUIRES_OK(ctx, ctx->allocate_temp(
1000 DataTypeToEnum<T>::value,
1001 TensorShape({tile_spatial_size, tile_spatial_size}),
1002 &tile_transform_matrix_tensor));
1003 T* tile_transform_matrix =
1004 tile_transform_matrix_tensor.template flat<T>().data();
1005 transform->GetInputTransformMatrix(tile_spatial_size, tile_spatial_size,
1006 tile_transform_matrix);
1007
1008 // Allocate buffer for output transform matrix.
1009 Tensor output_transform_matrix_tensor;
1010 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1011 TensorShape({out_tile_spatial_size,
1012 tile_spatial_size}),
1013 &output_transform_matrix_tensor));
1014 T* output_transform_matrix =
1015 output_transform_matrix_tensor.template flat<T>().data();
1016 transform->GetOutputTransformMatrix(
1017 out_tile_spatial_size, tile_spatial_size, output_transform_matrix);
1018
1019 auto shard = [&ctx, &args, &transform, &packed_filters, &in_depth,
1020 out_depth, out_tile_rows, out_tile_cols, filter_shards_row,
1021 filter_shards_col, tile_spatial_size, &input,
1022 &tile_transform_matrix, &output_transform_matrix,
1023 &output](int64_t batch_start, int64_t batch_limit) {
1024 const int64_t row_tiles =
1025 (args.out_rows + out_tile_rows - 1) / out_tile_rows +
1026 filter_shards_row - 1;
1027 const int64_t col_tiles =
1028 (args.out_cols + out_tile_cols - 1) / out_tile_cols +
1029 filter_shards_col - 1;
1030
1031 // Calculate number of tiles to process together.
1032 const int64_t filter_shard_size = filter_shards_row * filter_shards_col;
1033 const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
1034
1035 // Cache budget (based on L2 cache size = 256KB).
1036 // TODO(andydavis) Read cache size from the system.
1037 const int64_t cache_size = (256LL << 10) / sizeof(T);
1038
1039 // Fixed costs.
1040 const int64_t tile_transform_matrix_size =
1041 tile_spatial_size * tile_spatial_size;
1042 const int64_t output_transform_matrix_size =
1043 out_tile_spatial_size * tile_spatial_size;
1044 // Calculate cache reserve size.
1045 const int64_t filter_depth_size =
1046 in_depth * out_depth * filter_shard_size;
1047 const bool small_filter = ((filter_depth_size * 100) / cache_size) <= 25;
1048 const int64_t cache_reserve_size =
1049 small_filter ? filter_depth_size : 1024;
1050 // Calculate total fixed cost.
1051 const int64_t total_fixed_cost = tile_transform_matrix_size +
1052 output_transform_matrix_size +
1053 cache_reserve_size;
1054
1055 // Per-tile costs.
1056 const int64_t buffer1_per_tile_size =
1057 tile_spatial_size * std::max(in_depth, out_depth * filter_shard_size);
1058 const int64_t buffer2_per_tile_size =
1059 std::max(tile_spatial_size * in_depth,
1060 out_tile_spatial_size * out_depth * filter_shard_size);
1061 const int64_t packed_tile_per_tile_size = in_depth;
1062 const int64_t gemm_out_per_tile_size = out_depth * filter_shard_size;
1063 const int64_t total_per_tile_cost =
1064 buffer1_per_tile_size + buffer2_per_tile_size +
1065 packed_tile_per_tile_size + gemm_out_per_tile_size;
1066
1067 const int64_t num_tiles_cache = std::max(
1068 int64{4}, (cache_size - total_fixed_cost) / total_per_tile_cost);
1069 const int64_t num_tiles = std::min(num_tiles_cache, col_tiles);
1070
1071 // Allocate temporary buffer 'buffer1', which is first used for copying
1072 // input tiles, then re-used to buffer gemm output. Calculate the
1073 // required buffer size for 'buffer1', based on max buffer size required
1074 // between copying input tiles and buffering gemm product output.
1075 // buffer1: [max(buf1_tile_size, buf1_out_size)]
1076 const int64_t buffer1_tile_size =
1077 tile_spatial_size * num_tiles * in_depth;
1078 const int64_t buffer1_out_size =
1079 tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1080 const int64_t buffer1_size =
1081 std::max(buffer1_tile_size, buffer1_out_size);
1082 Tensor buffer1_tensor;
1083 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1084 TensorShape({buffer1_size}),
1085 &buffer1_tensor));
1086 T* buffer1 = buffer1_tensor.template flat<T>().data();
1087
1088 // Allocate temporary buffer 'buffer2', which is first used for
1089 // transformed input tiles, then re-used for transformed output tiles.
1090 // Calculate required buffer size for 'buffer2' as max required buffer
1091 // between input and output transform buffer sizes.
1092 const int64_t buffer2_tile_transform_size =
1093 tile_spatial_size * num_tiles * in_depth;
1094 const int64_t buffer2_out_transform_size =
1095 out_tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1096 const int64_t buffer2_size =
1097 std::max(buffer2_tile_transform_size, buffer2_out_transform_size);
1098 Tensor buffer2_tensor;
1099 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1100 TensorShape({buffer2_size}),
1101 &buffer2_tensor));
1102 T* buffer2 = buffer2_tensor.template flat<T>().data();
1103
1104 // Allocate temporary buffer to store packed tiles for one coordinate.
1105 // packed tile buffer: [num_tiles, in_depth].
1106 Tensor packed_tile_tensor;
1107 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1108 TensorShape({num_tiles, in_depth}),
1109 &packed_tile_tensor));
1110 T* packed_tile_buffer = packed_tile_tensor.template flat<T>().data();
1111
1112 // Allocate temporary buffer for gemm output.
1113 // gemm output buffer [num_tiles, out_depth, shard_rows, shard_cols].
1114 Tensor gemm_output_tensor;
1115 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1116 TensorShape({num_tiles, out_depth,
1117 filter_shards_row,
1118 filter_shards_col}),
1119 &gemm_output_tensor));
1120 T* gemm_output_buffer = gemm_output_tensor.template flat<T>().data();
1121
1122 // Capture state needed for ComputeConv2D inner loop.
1123 Conv2DState<T> conv_state(tile_spatial_size, filter_shards_row,
1124 filter_shards_col, input, tile_transform_matrix,
1125 output_transform_matrix, buffer1, buffer2,
1126 packed_tile_buffer, gemm_output_buffer);
1127
1128 const int64_t row_pad = args.pad_rows;
1129 const int64_t col_pad = args.pad_cols;
1130 const int64_t unroll_col_limit = (col_tiles / num_tiles) * num_tiles;
1131
1132 const int64_t input_image_size = args.in_rows * args.in_cols * in_depth;
1133 const int64_t output_image_size =
1134 args.out_rows * args.out_cols * out_depth;
1135
1136 const int64_t tile_stride_rows = transform->output_shape().rows;
1137 const int64_t tile_stride_cols = transform->output_shape().cols;
1138
1139 for (int64_t b = batch_start; b < batch_limit; ++b) {
1140 const int64_t in_base = b * input_image_size;
1141 const int64_t out_base = b * output_image_size;
1142
1143 for (int64_t tile_r = 0; tile_r < row_tiles; ++tile_r) {
1144 const int64_t in_r = tile_r * tile_stride_rows - row_pad;
1145
1146 // Process unrolled tiles.
1147 for (int64_t tile_c = 0; tile_c < unroll_col_limit;
1148 tile_c += num_tiles) {
1149 const int64_t in_c = tile_c * tile_stride_cols - col_pad;
1150 ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1151 num_tiles, packed_filters, input + in_base,
1152 output + out_base);
1153 }
1154 // Process remaining tiles.
1155 if (unroll_col_limit < col_tiles) {
1156 const int64_t rem_tiles = col_tiles - unroll_col_limit;
1157 const int64_t in_c = unroll_col_limit * tile_stride_cols - col_pad;
1158 ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1159 rem_tiles, packed_filters, input + in_base,
1160 output + out_base);
1161 }
1162 }
1163 }
1164 };
1165 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
1166 const int64_t shard_cost = args.out_rows * args.out_cols * args.out_depth *
1167 tile_spatial_size * args.in_depth;
1168 Shard(worker_threads.num_threads, worker_threads.workers, args.batch,
1169 shard_cost, shard);
1170 }
1171};
1172
1173} // namespace functor
1174
1175template struct functor::DeepConv2D<CPUDevice, float>;
1176
1177} // namespace tensorflow
1178