1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define USE_EIGEN_TENSOR |
17 | #define EIGEN_USE_THREADS |
18 | |
19 | #include "tensorflow/core/kernels/deep_conv2d.h" |
20 | |
21 | #include <stdlib.h> |
22 | |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/kernels/winograd_transform.h" |
25 | #include "tensorflow/core/util/work_sharder.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // DeepConv2D is a Conv2D implementation specialized for deep convolutions (i.e |
30 | // large 'in_depth' and 'out_depth' product. See cost models below for details). |
31 | // |
32 | // DeepConv2D is implemented by computing the following equation: |
33 | // |
34 | // y = C[Ad * Bg] |
35 | // |
36 | // C: output transform matrix |
37 | // A: input data transform matrix |
38 | // B: filter transform matrix |
39 | // d: vectorized data tile |
40 | // g: vectorized filter tile |
41 | // y: vectorized output tile |
42 | // |
43 | // The transform matrices and input, filter and output tile sizes are all |
44 | // specified by the DeepConv2DTransform implementation selected at the |
45 | // start of the DeepConv2D call, based on convolution parameters. |
46 | |
47 | // Approximate cost models for direct and deep convolutions. |
48 | static int64_t GetDeepConvCost(int input_tile_rows, int input_tile_cols, |
49 | int out_tile_rows, int out_tile_cols, |
50 | int in_depth, int out_depth, int out_rows, |
51 | int out_cols) { |
52 | // Input transform cost. |
53 | const int64_t input_tile_spatial_size = input_tile_rows * input_tile_cols; |
54 | const int64_t input_transform_cost = |
55 | input_tile_spatial_size * input_tile_spatial_size * in_depth; |
56 | |
57 | // Element-wise products (each product is a MatMul across depth). |
58 | const int64_t product_cost = input_tile_spatial_size * in_depth * out_depth; |
59 | |
60 | // Output transform cost. |
61 | const int64_t output_tile_spatial_size = out_tile_rows * out_tile_cols; |
62 | const int64_t output_transform_cost = |
63 | output_tile_spatial_size * input_tile_spatial_size * out_depth; |
64 | |
65 | // Calculate number of input tiles to process. |
66 | const int64_t row_tiles = (out_rows + out_tile_rows - 1) / out_tile_rows; |
67 | const int64_t col_tiles = (out_cols + out_tile_cols - 1) / out_tile_cols; |
68 | const int64_t num_tiles = row_tiles * col_tiles; |
69 | |
70 | // Return total cost. |
71 | return num_tiles * |
72 | (input_transform_cost + product_cost + output_transform_cost); |
73 | } |
74 | |
75 | static int64_t GetDirectConvCost(int filter_rows, int filter_cols, int in_depth, |
76 | int out_depth, int out_rows, int out_cols) { |
77 | return filter_rows * filter_cols * in_depth * out_depth * out_rows * out_cols; |
78 | } |
79 | |
80 | // Reads environment variable 'env_var_name'. |
81 | // Returns 'true' if environment variable is enabled, false otherwise. |
82 | static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) { |
83 | const char* tf_env_var_val = getenv(env_var_name); |
84 | if (tf_env_var_val != nullptr) { |
85 | StringPiece tf_env_var_val_str(tf_env_var_val); |
86 | if (tf_env_var_val_str == "0" ) { |
87 | return false; |
88 | } |
89 | return true; |
90 | } |
91 | return default_val; |
92 | } |
93 | |
94 | // Returns true if convolution can be computed efficiently by DeepConv2D, |
95 | // returns false otherwise. |
96 | // TODO(andydavis) Add support for other filter sizes and strides. |
97 | // TODO(andydavis) Add support for autotuning. |
98 | bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows, |
99 | int filter_cols, int in_depth, int out_depth, |
100 | int out_rows, int out_cols) { |
101 | // Check if convolution parameters are supported. |
102 | // TODO(andydavis) Add support for multiple filter sizes and strides. |
103 | if (stride_rows > 1 || stride_cols > 1 || filter_rows != 3 || |
104 | filter_cols != 3) { |
105 | return false; |
106 | } |
107 | |
108 | // Check if deep convolution is enabled by environment variable. |
109 | // NOTE: IF this environment variable name changes, update conv_ops_test.py. |
110 | if (!ReadBoolFromEnvVar("TF_USE_DEEP_CONV2D" , false)) { |
111 | return false; |
112 | } |
113 | |
114 | // Check if flop cost of deep convolution is less than direct convolution. |
115 | WinogradTransform<float> t; |
116 | const int64_t deep_conv_cost = GetDeepConvCost( |
117 | t.input_shape().rows, t.input_shape().cols, t.output_shape().rows, |
118 | t.output_shape().cols, in_depth, out_depth, out_rows, out_cols); |
119 | const int64_t direct_conv_cost = GetDirectConvCost( |
120 | filter_rows, filter_cols, in_depth, out_depth, out_rows, out_cols); |
121 | |
122 | VLOG(2) << "CanUseDeepConv2D" |
123 | << " deep_conv_cost: " << deep_conv_cost |
124 | << " direct_conv_cost: " << direct_conv_cost << " deep_direct_ratio: " |
125 | << (static_cast<float>(deep_conv_cost) / |
126 | static_cast<float>(direct_conv_cost)) |
127 | << " use_deep_conv: " << (deep_conv_cost < direct_conv_cost); |
128 | return deep_conv_cost < direct_conv_cost; |
129 | } |
130 | |
131 | typedef Eigen::ThreadPoolDevice CPUDevice; |
132 | |
133 | // Copies data from 'filter_in' to 'filter_buf' along 'in_depth' dimension. |
134 | // |
135 | // filter_in: |
136 | // [filter_rows, filter_cols, in_depth, out_depth] |
137 | // |
138 | // filter_buf: |
139 | // [base_filter_rows, base_filter_cols, in_depth] |
140 | // |
141 | template <typename T> |
142 | struct CopyFilterDepth { |
143 | void operator()(const Conv2DArgs& args, const T* filter_in, T* filter_buf) { |
144 | typedef typename Eigen::internal::packet_traits<T>::type Packet; |
145 | static constexpr int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); |
146 | |
147 | const int64_t vectorized_size = args.in_depth / kPacketSize; |
148 | const int64_t scalar_size = args.in_depth % kPacketSize; |
149 | const int64_t input_stride = args.out_depth * kPacketSize; |
150 | |
151 | // Copy vectorized portion of depth dimension. |
152 | for (int64_t d = 0; d < vectorized_size; ++d) { |
153 | auto v = Eigen::internal::pgather<T, Packet>(filter_in + d * input_stride, |
154 | args.out_depth); |
155 | Eigen::internal::pstoreu<T>(filter_buf + d * kPacketSize, v); |
156 | } |
157 | // Copy scalar portion of inner dimension. |
158 | const int64_t in_scalar_base = vectorized_size * input_stride; |
159 | const int64_t buf_scalar_base = vectorized_size * kPacketSize; |
160 | for (int64_t d = 0; d < scalar_size; ++d) { |
161 | filter_buf[buf_scalar_base + d] = |
162 | filter_in[in_scalar_base + d * args.out_depth]; |
163 | } |
164 | } |
165 | }; |
166 | |
167 | // Computes transform of 'num_filters' from 'filter_in' starting at 'od_start'. |
168 | // Intermediate results (i.e. output of MatMul('transform_matrix', 'filter_in')) |
169 | // are stored in 'out_buffer'. The final result is copied from 'out_buffer' to |
170 | // 'filter_out' at the coordinate stride required by the transformed filter |
171 | // data layout. |
172 | // |
173 | // filter_in: |
174 | // [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols, |
175 | // in_depth] |
176 | // |
177 | // filter_out: |
178 | // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth] |
179 | // |
180 | // transform_matrix: |
181 | // [tile_spatial_size, base_filter_spatial_size] |
182 | // |
183 | // out_buffer: |
184 | // [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth] |
185 | |
186 | template <typename T> |
187 | struct ComputeFilterRangeTransform { |
188 | typedef typename Eigen::internal::packet_traits<T>::type Packet; |
189 | static constexpr int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); |
190 | |
191 | typedef Eigen::Map< |
192 | Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> |
193 | MatrixMap; |
194 | typedef Eigen::Map< |
195 | const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> |
196 | ConstMatrixMap; |
197 | |
198 | void operator()(const Conv2DArgs& args, |
199 | const DeepConv2DTransform<T>* transform, |
200 | const int64_t od_start, const int64_t num_filters, |
201 | const int64_t shard_rows, const int64_t shard_cols, |
202 | const T* filter_in, const int64_t in_stride, |
203 | const int64_t out_stride, const T* transform_matrix, |
204 | T* out_buffer, T* filter_out) { |
205 | namespace ei = Eigen::internal; |
206 | |
207 | const int64_t in_depth = args.in_depth; |
208 | const int64_t base_filter_rows = transform->filter_shape().rows; |
209 | const int64_t base_filter_cols = transform->filter_shape().cols; |
210 | const int64_t base_filter_spatial_size = |
211 | base_filter_rows * base_filter_cols; |
212 | const int64_t tile_rows = transform->input_shape().rows; |
213 | const int64_t tile_cols = transform->input_shape().cols; |
214 | const int64_t tile_spatial_size = tile_rows * tile_cols; |
215 | |
216 | // Compute transform of 'num_filters' by 'transform_matrix'. |
217 | ConstMatrixMap A(transform_matrix, tile_spatial_size, |
218 | base_filter_spatial_size); |
219 | ConstMatrixMap B(filter_in, base_filter_spatial_size, in_stride); |
220 | MatrixMap C(out_buffer, tile_spatial_size, in_stride); |
221 | |
222 | C.noalias() = A * B; |
223 | |
224 | // Copy 'out_buffer' to 'filter_out' at required filter output stride. |
225 | const int64_t scalar_size = in_depth % kPacketSize; |
226 | const int64_t vectorized_size = in_depth / kPacketSize; |
227 | |
228 | const int64_t shard_stride = args.in_depth; |
229 | const int64_t out_depth_stride = shard_rows * shard_cols * shard_stride; |
230 | |
231 | for (int64_t od = 0; od < num_filters; ++od) { |
232 | const int64_t out_depth_buf_base = od * out_depth_stride; |
233 | const int64_t out_depth_base = (od_start + od) * out_depth_stride; |
234 | |
235 | // TODO(andydavis) Shard filters that are multiples of base filter sizes. |
236 | for (int64_t s_r = 0; s_r < shard_rows; ++s_r) { |
237 | for (int64_t s_c = 0; s_c < shard_cols; ++s_c) { |
238 | const int64_t shard_base = shard_stride * (s_r * shard_cols + s_c); |
239 | |
240 | for (int64_t i = 0; i < tile_spatial_size; ++i) { |
241 | const int64_t in_base = |
242 | i * in_stride + out_depth_buf_base + shard_base; |
243 | const int64_t out_base = |
244 | i * out_stride + out_depth_base + shard_base; |
245 | // Copy vectorized portion of 'in_depth'. |
246 | for (int64_t d = 0; d < vectorized_size; ++d) { |
247 | auto v = |
248 | ei::ploadu<Packet>(out_buffer + in_base + d * kPacketSize); |
249 | ei::pstoreu<T>(filter_out + out_base + d * kPacketSize, v); |
250 | } |
251 | // Transform scalar portion of 'in_depth'. |
252 | const int64_t scalar_base = vectorized_size * kPacketSize; |
253 | for (int64_t d = 0; d < scalar_size; ++d) { |
254 | filter_out[out_base + scalar_base + d] = |
255 | out_buffer[in_base + scalar_base + d]; |
256 | } |
257 | } |
258 | } |
259 | } |
260 | } |
261 | } |
262 | }; |
263 | |
264 | // Transforms 'num_filters' from 'filter_in', starting at 'od_start'. |
265 | // For each filter in 'num_filters', copies data for all filter shards from |
266 | // 'filter_in' into 'filter_buf', adding zero-padding as needed. |
267 | // Calls ComputeFilterRangeTransform to compute filter transform of data |
268 | // in 'filter_buf' by 'transform_matrix', storing the result in 'filter_out'. |
269 | // |
270 | // filter_in: |
271 | // [filter_rows, filter_cols, in_depth, out_depth] |
272 | // |
273 | // filter_out: |
274 | // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth] |
275 | // |
276 | // filter_buffer: |
277 | // [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols, |
278 | // in_depth] |
279 | // |
280 | // transform_matrix: |
281 | // [tile_spatial_size, base_filter_spatial_size] |
282 | // |
283 | // out_buffer: |
284 | // [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth] |
285 | // |
286 | |
287 | template <typename T> |
288 | struct TransformFilterRange { |
289 | void operator()(const Conv2DArgs& args, |
290 | const DeepConv2DTransform<T>* transform, |
291 | const int64_t od_start, const int64_t od_limit, |
292 | const T* filter_in, const T* transform_matrix, T* out_buffer, |
293 | T* filter_buf, T* filter_out) { |
294 | const int64_t num_filters = od_limit - od_start; |
295 | const int64_t base_filter_rows = transform->filter_shape().rows; |
296 | const int64_t base_filter_cols = transform->filter_shape().cols; |
297 | const int64_t base_filter_spatial_size = |
298 | base_filter_rows * base_filter_cols; |
299 | |
300 | // Compute number of filter shards. |
301 | const int64_t residual_row = |
302 | std::max(int64_t{0}, args.filter_rows - base_filter_rows); |
303 | const int64_t shard_rows = 1 + (residual_row + 2 - 1) / 2; |
304 | |
305 | const int64_t residual_col = |
306 | std::max(int64_t{0}, args.filter_cols - base_filter_cols); |
307 | const int64_t shard_cols = 1 + (residual_col + 2 - 1) / 2; |
308 | |
309 | // Compute strides to be used for input and output IO. |
310 | const int64_t shard_stride = args.in_depth; |
311 | const int64_t out_depth_stride = shard_rows * shard_cols * shard_stride; |
312 | const int64_t coord_stride = out_depth_stride * args.out_depth; |
313 | const int64_t filter_buf_stride = |
314 | num_filters * shard_rows * shard_cols * args.in_depth; |
315 | const int64_t tile_stride_rows = transform->output_shape().rows; |
316 | const int64_t tile_stride_cols = transform->output_shape().cols; |
317 | |
318 | const int64_t filter_buf_size = base_filter_spatial_size * num_filters * |
319 | shard_rows * shard_cols * args.in_depth; |
320 | memset(filter_buf, 0, sizeof(T) * filter_buf_size); |
321 | |
322 | // Copy filter range into 'filter_buf'. |
323 | for (int64_t od = 0; od < num_filters; ++od) { |
324 | const int64_t out_depth_base = od * out_depth_stride; |
325 | |
326 | // TODO(andydavis) Shard filters that are multiples of base filter sizes. |
327 | for (int64_t s_r = 0; s_r < shard_rows; ++s_r) { |
328 | const int64_t row_offset = s_r == 0 ? 0 : 1; |
329 | |
330 | for (int64_t s_c = 0; s_c < shard_cols; ++s_c) { |
331 | const int64_t col_offset = s_c == 0 ? 0 : 1; |
332 | const int64_t f_r_start = s_r * tile_stride_rows; |
333 | const int64_t f_c_start = s_c * tile_stride_cols; |
334 | |
335 | const int64_t shard_base = shard_stride * (s_r * shard_cols + s_c); |
336 | |
337 | for (int64_t b_r = row_offset; b_r < base_filter_rows; ++b_r) { |
338 | const int64_t f_r = f_r_start + b_r; |
339 | if (f_r >= args.filter_rows) continue; |
340 | |
341 | for (int64_t b_c = col_offset; b_c < base_filter_cols; ++b_c) { |
342 | const int64_t f_c = f_c_start + b_c; |
343 | if (f_c >= args.filter_cols) continue; |
344 | |
345 | const int64_t in_index = |
346 | args.out_depth * |
347 | (args.in_depth * (f_r * args.filter_cols + f_c)) + |
348 | (od_start + od); |
349 | |
350 | const int64_t buf_index = |
351 | filter_buf_stride * (b_r * base_filter_cols + b_c) + |
352 | out_depth_base + shard_base; |
353 | |
354 | CopyFilterDepth<T>()(args, filter_in + in_index, |
355 | filter_buf + buf_index); |
356 | } |
357 | } |
358 | } |
359 | } |
360 | } |
361 | |
362 | // Compute filter transform of data in 'filter_buf' by 'transform_matrix'. |
363 | // Intermediate results are stored in 'out_buffer'. |
364 | // Final results are stored in 'filter_out'. |
365 | ComputeFilterRangeTransform<T>()(args, transform, od_start, num_filters, |
366 | shard_rows, shard_cols, filter_buf, |
367 | filter_buf_stride, coord_stride, |
368 | transform_matrix, out_buffer, filter_out); |
369 | } |
370 | }; |
371 | |
372 | // Transforms all filters from 'filter_in', storing result in 'filter_out'. |
373 | // |
374 | // filter_in: |
375 | // [filter_rows, filter_cols, in_depth, out_depth] |
376 | // |
377 | // filter_out: |
378 | // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth] |
379 | // |
380 | template <typename T> |
381 | struct TransformFilters { |
382 | void operator()(OpKernelContext* ctx, const Conv2DArgs& args, |
383 | const DeepConv2DTransform<T>* transform, |
384 | const int64_t filter_shards_row, |
385 | const int64_t filter_shards_col, const T* filter_in, |
386 | T* filter_out) { |
387 | const int64_t in_depth = args.in_depth; |
388 | const int64_t out_depth = args.out_depth; |
389 | |
390 | const int64_t tile_rows = transform->input_shape().rows; |
391 | const int64_t tile_cols = transform->input_shape().cols; |
392 | const int64_t tile_spatial_size = tile_rows * tile_cols; |
393 | |
394 | const int64_t base_filter_rows = transform->filter_shape().rows; |
395 | const int64_t base_filter_cols = transform->filter_shape().cols; |
396 | const int64_t base_filter_spatial_size = |
397 | base_filter_rows * base_filter_cols; |
398 | |
399 | const int64_t filter_shards_total = filter_shards_row * filter_shards_col; |
400 | |
401 | // Calculate filter transform batch based on cache/filter sizes. |
402 | |
403 | // Cache budget (based on L2 cache size = 256KB). |
404 | // TODO(andydavis) Read cache size from system. |
405 | const int64_t cache_size = (256LL << 10) / sizeof(T); |
406 | |
407 | // Fixed cost. |
408 | const int64_t filter_transform_matrix_size = |
409 | tile_spatial_size * base_filter_spatial_size; |
410 | |
411 | // Per-filter costs. |
412 | const int64_t filter_total_size = |
413 | base_filter_spatial_size * in_depth * filter_shards_total; |
414 | |
415 | const int64_t filter_transform_buffer_size = |
416 | base_filter_spatial_size * filter_shards_total * in_depth; |
417 | |
418 | const int64_t filter_out_buf_size = |
419 | tile_spatial_size * filter_shards_total * in_depth; |
420 | |
421 | // Total per-filter costs. |
422 | const int64_t per_filter_cost = |
423 | filter_total_size + filter_transform_buffer_size + filter_out_buf_size; |
424 | |
425 | // Remove fixed cost and divide by per-filter cost. |
426 | const int64_t num_filters_cache = |
427 | std::max(int64_t{1}, |
428 | (cache_size - filter_transform_matrix_size) / per_filter_cost); |
429 | const int64_t num_filters_transform = |
430 | std::min(out_depth, num_filters_cache); |
431 | |
432 | // Allocate buffer for filter transform matrix: |
433 | // [tile_spatial_size, base_filter_spatial_size] |
434 | Tensor filter_transform_matrix; |
435 | OP_REQUIRES_OK( |
436 | ctx, ctx->allocate_temp( |
437 | DataTypeToEnum<T>::value, |
438 | TensorShape({tile_spatial_size, base_filter_spatial_size}), |
439 | &filter_transform_matrix)); |
440 | T* transform_matrix = filter_transform_matrix.template flat<T>().data(); |
441 | transform->GetFilterTransformMatrix( |
442 | tile_spatial_size, base_filter_spatial_size, transform_matrix); |
443 | |
444 | auto shard = [&ctx, &args, &transform, &base_filter_rows, &base_filter_cols, |
445 | &num_filters_transform, &in_depth, &filter_shards_row, |
446 | &filter_shards_col, &tile_spatial_size, &filter_in, |
447 | &transform_matrix, |
448 | &filter_out](int64_t start, int64_t limit) { |
449 | // Allocate buffer for pre-processed filter: |
450 | // [base_filter_rows, base_filter_cols, num_filters_transform, in_depth] |
451 | // |
452 | Tensor filter_transform_buffer; |
453 | OP_REQUIRES_OK(ctx, |
454 | ctx->allocate_temp( |
455 | DataTypeToEnum<T>::value, |
456 | TensorShape({base_filter_rows, base_filter_cols, |
457 | num_filters_transform, filter_shards_row, |
458 | filter_shards_col, in_depth}), |
459 | &filter_transform_buffer)); |
460 | T* filter_buf = filter_transform_buffer.template flat<T>().data(); |
461 | |
462 | // Allocate buffer for output filter transform matrix: |
463 | // [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth] |
464 | Tensor filter_output_buffer; |
465 | OP_REQUIRES_OK( |
466 | ctx, |
467 | ctx->allocate_temp( |
468 | DataTypeToEnum<T>::value, |
469 | TensorShape({tile_spatial_size, num_filters_transform, |
470 | filter_shards_row, filter_shards_col, in_depth}), |
471 | &filter_output_buffer)); |
472 | T* out_buffer = filter_output_buffer.template flat<T>().data(); |
473 | |
474 | const int64_t num_filters = limit - start; |
475 | const int64_t od_unroll = num_filters_transform; |
476 | const int64_t od_unroll_limit = (num_filters / od_unroll) * od_unroll; |
477 | |
478 | for (int64_t od = start; od < od_unroll_limit; od += od_unroll) { |
479 | TransformFilterRange<T>()(args, transform, od, od + od_unroll, |
480 | filter_in, transform_matrix, out_buffer, |
481 | filter_buf, filter_out); |
482 | } |
483 | |
484 | if (od_unroll_limit < limit) { |
485 | TransformFilterRange<T>()(args, transform, od_unroll_limit, limit, |
486 | filter_in, transform_matrix, out_buffer, |
487 | filter_buf, filter_out); |
488 | } |
489 | }; |
490 | auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); |
491 | |
492 | const int64_t shard_cost = args.filter_rows * args.filter_cols * in_depth * |
493 | filter_shards_total * tile_spatial_size; |
494 | // TODO(andydavis) Resolve performance of multi-threaded filter transforms. |
495 | Shard(1, worker_threads.workers, out_depth, shard_cost, shard); |
496 | } |
497 | }; |
498 | |
499 | // Packs transformed filters stored in 'lhs_input' into 'lhs_block' in a |
500 | // gemm-kernel friendly data layout. |
501 | // |
502 | // Data layout for 'lhs_block': |
503 | // [out_depth, shard_rows, shard_cols, in_depth]. |
504 | |
505 | template <typename T> |
506 | class GemmFilterPacker { |
507 | public: |
508 | typedef Eigen::internal::const_blas_data_mapper<T, int64_t, Eigen::RowMajor> |
509 | LhsMapper; |
510 | typedef Eigen::internal::gebp_traits<T, T> Traits; |
511 | Eigen::internal::gemm_pack_lhs< |
512 | T, int64_t, LhsMapper, Traits::mr, Traits::LhsProgress, |
513 | typename Traits::LhsPacket4Packing, Eigen::RowMajor> |
514 | pack_lhs; |
515 | |
516 | GemmFilterPacker(const int64_t rows, const int64_t depth, const T* lhs_input, |
517 | T* lhs_block) |
518 | : rows_(rows), |
519 | depth_(depth), |
520 | lhs_block_(lhs_block), |
521 | lhs_mapper_(lhs_input, depth_) {} |
522 | |
523 | void Run() { pack_lhs(lhs_block_, lhs_mapper_, depth_, rows_); } |
524 | |
525 | private: |
526 | const int64_t rows_; |
527 | const int64_t depth_; |
528 | T* lhs_block_; |
529 | LhsMapper lhs_mapper_; |
530 | }; |
531 | |
532 | // Packs transformed filter stored in 'filter_transform_data' into |
533 | // 'packed_filters' to be used by GemmState. |
534 | template <typename T> |
535 | struct PackFilters { |
536 | void operator()(OpKernelContext* ctx, const Conv2DArgs& args, |
537 | const int64_t tile_spatial_size, |
538 | const int64_t filter_shards_row, |
539 | const int64_t filter_shards_col, |
540 | const T* filter_transform_data, |
541 | std::vector<Tensor>* packed_filters) { |
542 | const int64_t in_depth = args.in_depth; |
543 | const int64_t out_depth = args.out_depth; |
544 | const int64_t num_filters = |
545 | filter_shards_row * filter_shards_col * out_depth; |
546 | |
547 | auto shard = [&ctx, &packed_filters, &filter_transform_data, &in_depth, |
548 | &out_depth, &filter_shards_row, &filter_shards_col, |
549 | &num_filters](int64_t start, int64_t limit) { |
550 | const int64_t filter_coord_stride = num_filters * in_depth; |
551 | for (int64_t i = start; i < limit; ++i) { |
552 | // Allocate filter buffer [out_depth, shard_rows, shard_cols, in_depth]. |
553 | OP_REQUIRES_OK( |
554 | ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, |
555 | TensorShape({out_depth, filter_shards_row, |
556 | filter_shards_col, in_depth}), |
557 | &(*packed_filters)[i])); |
558 | T* packed_filter = (*packed_filters)[i].template flat<T>().data(); |
559 | // Pack filters. |
560 | GemmFilterPacker<T> packer( |
561 | num_filters, in_depth, |
562 | filter_transform_data + i * filter_coord_stride, packed_filter); |
563 | packer.Run(); |
564 | } |
565 | }; |
566 | auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); |
567 | Shard(worker_threads.num_threads, worker_threads.workers, tile_spatial_size, |
568 | num_filters * in_depth, shard); |
569 | } |
570 | }; |
571 | |
572 | // Computes the product of filters stored in 'lhs_block' and input tiles |
573 | // stored in 'rhs_block', storing output in 'out_buffer'. |
574 | // |
575 | // Data layout for 'lhs_block': |
576 | // [out_depth, shard_rows, shard_cols, in_depth]. |
577 | // |
578 | // Data layout for 'rhs_block': |
579 | // [num_tiles, in_depth] |
580 | // |
581 | // Data layout for 'out_buffer': |
582 | // [num_tiles, out_depth, shard_rows, shard_cols] |
583 | |
584 | template <typename T> |
585 | class GemmState { |
586 | public: |
587 | typedef Eigen::internal::const_blas_data_mapper<T, int64_t, Eigen::ColMajor> |
588 | RhsMapper; |
589 | typedef Eigen::internal::blas_data_mapper<T, int64_t, Eigen::ColMajor> |
590 | OutputMapper; |
591 | typedef Eigen::internal::gebp_traits<T, T> Traits; |
592 | |
593 | Eigen::internal::gemm_pack_rhs<T, int64_t, RhsMapper, Traits::nr, |
594 | Eigen::ColMajor> |
595 | pack_rhs; |
596 | Eigen::internal::gebp_kernel<T, T, int64_t, OutputMapper, Traits::mr, |
597 | Traits::nr, false, false> |
598 | gebp; |
599 | |
600 | GemmState(const int64_t rows, const int64_t cols, const int64_t depth, |
601 | const int64_t out_buffer_size, const T* lhs_block, |
602 | const T* rhs_input, T* rhs_block, T* out_buffer) |
603 | : rows_(rows), |
604 | cols_(cols), |
605 | depth_(depth), |
606 | out_buffer_size_(out_buffer_size), |
607 | lhs_block_(lhs_block), |
608 | rhs_block_(rhs_block), |
609 | out_buffer_(out_buffer), |
610 | rhs_mapper_(rhs_input, depth_), |
611 | out_mapper_(out_buffer, rows_) {} |
612 | |
613 | void PackRhs() { pack_rhs(rhs_block_, rhs_mapper_, depth_, cols_); } |
614 | |
615 | void Compute() { |
616 | memset(out_buffer_, 0, sizeof(T) * out_buffer_size_); |
617 | gebp(out_mapper_, lhs_block_, rhs_block_, rows_, depth_, cols_, 1.0); |
618 | } |
619 | |
620 | private: |
621 | const int64_t rows_; |
622 | const int64_t cols_; |
623 | const int64_t depth_; |
624 | const int64_t out_buffer_size_; |
625 | const T* lhs_block_; |
626 | T* rhs_block_; |
627 | T* out_buffer_; |
628 | RhsMapper rhs_mapper_; |
629 | OutputMapper out_mapper_; |
630 | }; |
631 | |
632 | // Copies an input tile from 'input' into 'tile_buffer'. |
633 | // |
634 | // input: |
635 | // [in_rows, in_cols, in_depth] |
636 | // |
637 | // tile_buffer: |
638 | // [tile_rows, tile_cols, num_tiles, in_depth] |
639 | |
640 | template <typename T> |
641 | struct CopyInputTile { |
642 | void operator()(const Conv2DArgs& args, |
643 | const DeepConv2DTransform<T>* transform, |
644 | const int64_t num_tiles, const int64_t in_r_start, |
645 | const int64_t in_c_start, const T* input, T* tile_buffer) { |
646 | typedef typename Eigen::internal::packet_traits<T>::type Packet; |
647 | static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); |
648 | |
649 | const int64_t tile_rows = transform->input_shape().rows; |
650 | const int64_t tile_cols = transform->input_shape().cols; |
651 | const int64_t coord_stride = num_tiles * args.in_depth; |
652 | |
653 | // Calculate vectorized and scalar (residual) lengths for 'in_depth'. |
654 | const int64_t input_vectorized_size = |
655 | (args.in_depth / kPacketSize) * kPacketSize; |
656 | const int64_t input_scalar_size = args.in_depth % kPacketSize; |
657 | |
658 | for (int64_t r = 0; r < tile_rows; ++r) { |
659 | const int64_t in_r = in_r_start + r; |
660 | if (in_r < 0 || in_r >= args.in_rows) continue; |
661 | |
662 | for (int64_t c = 0; c < tile_cols; ++c) { |
663 | const int64_t in_c = in_c_start + c; |
664 | if (in_c < 0 || in_c >= args.in_cols) continue; |
665 | |
666 | auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth; |
667 | auto* tile = tile_buffer + coord_stride * (r * tile_rows + c); |
668 | // Copy vectorized portion of depth dimension. |
669 | for (int64_t d = 0; d < input_vectorized_size; d += kPacketSize) { |
670 | auto v = Eigen::internal::ploadu<Packet>(in + d); |
671 | Eigen::internal::pstoreu<T>(tile, v); |
672 | tile += kPacketSize; |
673 | } |
674 | // Copy scalar portion of inner dimension. |
675 | for (int64_t d = 0; d < input_scalar_size; ++d) { |
676 | tile[d] = in[input_vectorized_size + d]; |
677 | } |
678 | } |
679 | } |
680 | } |
681 | }; |
682 | |
683 | // Transforms 'num_tiles' tiles from 'input' by 'transform_matrix', storing the |
684 | // final result in 'tile_transform'. |
685 | // Intermediate results are stored in 'tile_buffer'. |
686 | // |
687 | // input: |
688 | // [in_rows, in_cols, in_depth] |
689 | // tile_buffer: |
690 | // [tile_rows, tile_cols, num_tiles, in_depth] |
691 | // tile_transform_matrix: |
692 | // [tile_spatial_size, tile_spatial_size] |
693 | // tile_transform: |
694 | // [tile_rows, tile_cols, num_tiles, in_depth] |
695 | |
696 | template <typename T> |
697 | struct TransformInputTiles { |
698 | typedef Eigen::Map< |
699 | Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> |
700 | MatrixMap; |
701 | typedef Eigen::Map< |
702 | const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> |
703 | ConstMatrixMap; |
704 | |
705 | void operator()(const Conv2DArgs& args, |
706 | const DeepConv2DTransform<T>* transform, |
707 | const int64_t num_tiles, const int64_t in_r_start, |
708 | const int64_t in_c_start, const T* input, |
709 | const T* transform_matrix, T* tile_buffer, |
710 | T* tile_transform) { |
711 | const int64_t tile_rows = transform->input_shape().rows; |
712 | const int64_t tile_cols = transform->input_shape().cols; |
713 | const int64_t tile_spatial_size = tile_rows * tile_cols; |
714 | const int64_t tile_stride_cols = transform->output_shape().cols; |
715 | const int64_t coord_stride = num_tiles * args.in_depth; |
716 | const int64_t num_tiles_stride = args.in_depth; |
717 | |
718 | memset(tile_buffer, 0, sizeof(T) * tile_spatial_size * coord_stride); |
719 | const int64_t in_r = in_r_start; |
720 | for (int64_t t = 0; t < num_tiles; ++t) { |
721 | const int64_t num_tiles_base = t * num_tiles_stride; |
722 | const int64_t in_c = in_c_start + t * tile_stride_cols; |
723 | CopyInputTile<T>()(args, transform, num_tiles, in_r, in_c, input, |
724 | tile_buffer + num_tiles_base); |
725 | } |
726 | |
727 | ConstMatrixMap A(transform_matrix, tile_spatial_size, tile_spatial_size); |
728 | ConstMatrixMap B(tile_buffer, tile_spatial_size, coord_stride); |
729 | MatrixMap C(tile_transform, tile_spatial_size, coord_stride); |
730 | |
731 | C.noalias() = A * B; |
732 | } |
733 | }; |
734 | |
735 | // Transforms output tiles from buffer by 'out_transform_matrix', storing |
736 | // final result in 'output' (intermediate results stored in 'out_buffer'). |
737 | // |
738 | // out_buffer: |
739 | // [tile_rows, tile_cols, num_tiles, out_depth, shard_rows, shard_cols] |
740 | // |
741 | // output transform buffer: |
742 | // [out_tile_rows, out_tile_cols, num_tiles, out_depth, shard_rows, shard_cols] |
743 | // |
744 | // output: |
745 | // [out_rows, out_cols, out_depth] |
746 | // |
747 | |
748 | template <typename T> |
749 | struct TransformOutputTile { |
750 | typedef Eigen::Map< |
751 | Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> |
752 | MatrixMap; |
753 | typedef Eigen::Map< |
754 | const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> |
755 | ConstMatrixMap; |
756 | |
757 | void operator()(const Conv2DArgs& args, |
758 | const DeepConv2DTransform<T>* transform, |
759 | const int64_t num_tiles, const int64_t in_r, |
760 | const int64_t in_c, const int64_t filter_shards_row, |
761 | const int64_t filter_shards_col, |
762 | const T* out_transform_matrix, const T* out_buffer, |
763 | T* out_transform_buffer, T* output) { |
764 | const int64_t tile_rows = transform->input_shape().rows; |
765 | const int64_t tile_cols = transform->input_shape().cols; |
766 | const int64_t tile_spatial_size = tile_rows * tile_cols; |
767 | |
768 | const int64_t out_buf_stride = |
769 | num_tiles * args.out_depth * filter_shards_row * filter_shards_col; |
770 | |
771 | const int64_t out_tile_rows = transform->output_shape().rows; |
772 | const int64_t out_tile_cols = transform->output_shape().cols; |
773 | const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols; |
774 | |
775 | // Compute output transform. |
776 | ConstMatrixMap A(out_transform_matrix, out_tile_spatial_size, |
777 | tile_spatial_size); |
778 | ConstMatrixMap B(out_buffer, tile_spatial_size, out_buf_stride); |
779 | MatrixMap C(out_transform_buffer, out_tile_spatial_size, out_buf_stride); |
780 | |
781 | C.noalias() = A * B; |
782 | |
783 | const int64_t tile_stride_rows = transform->output_shape().rows; |
784 | const int64_t tile_stride_cols = transform->output_shape().cols; |
785 | |
786 | const int64_t out_depth_stride = filter_shards_row * filter_shards_col; |
787 | const int64_t num_tiles_stride = args.out_depth * out_depth_stride; |
788 | |
789 | // Copy transformed output from 'out_transform_buffer' to proper index |
790 | // in 'output'. Note that some outputs at boundaries can be discarded. |
791 | for (int64_t t = 0; t < num_tiles; ++t) { |
792 | const int64_t tile_base = t * num_tiles_stride; |
793 | |
794 | for (int64_t od = 0; od < args.out_depth; ++od) { |
795 | const int64_t out_depth_base = od * out_depth_stride; |
796 | |
797 | // TODO(andydavis) Update filter sharding scheme in the next CL. |
798 | for (int64_t sr = 0; sr < filter_shards_row; ++sr) { |
799 | for (int64_t sc = 0; sc < filter_shards_col; ++sc) { |
800 | const int64_t shard_base = sr * filter_shards_col + sc; |
801 | const int64_t out_buf_base = |
802 | tile_base + out_depth_base + shard_base; |
803 | |
804 | // Calculate output indices and outputs to drop (if needed). |
805 | const int64_t out_r_start = |
806 | in_r + args.pad_rows - sr * tile_stride_rows; |
807 | // NOTE: The index 't' for 'num_tiles is used in index calculation |
808 | // for 'out_c_start' because we 'num_tiles' progresses along the |
809 | // column dimension. |
810 | const int64_t out_c_start = (in_c + t * tile_stride_cols) + |
811 | args.pad_cols - sc * tile_stride_cols; |
812 | |
813 | if (out_r_start < 0 || out_r_start >= args.out_rows || |
814 | out_c_start < 0 || out_c_start >= args.out_cols) { |
815 | continue; // Skip un-needed outputs. |
816 | } |
817 | |
818 | // Increment output if not first filter shard. |
819 | const bool inc_output = (sr == 0 && sc == 0) ? false : true; |
820 | |
821 | for (int64_t ot_row = 0; ot_row < out_tile_rows; ++ot_row) { |
822 | const int64_t out_r = out_r_start + ot_row; |
823 | if (out_r >= args.out_rows) continue; |
824 | |
825 | for (int64_t ot_col = 0; ot_col < out_tile_cols; ++ot_col) { |
826 | const int64_t out_c = out_c_start + ot_col; |
827 | if (out_c >= args.out_cols) continue; |
828 | |
829 | // Calculate out tile indexl |
830 | const int64_t out_buf_index = ot_row * out_tile_cols + ot_col; |
831 | // Read output value from buffer. |
832 | const T out_val = |
833 | out_transform_buffer[out_buf_base + |
834 | out_buf_index * out_buf_stride]; |
835 | // Calculate output index. |
836 | const int64_t output_index = |
837 | args.out_depth * (out_r * args.out_cols + out_c) + od; |
838 | // Update output. |
839 | if (inc_output) { |
840 | output[output_index] += out_val; |
841 | } else { |
842 | output[output_index] = out_val; |
843 | } |
844 | } |
845 | } |
846 | } |
847 | } |
848 | } |
849 | } |
850 | } |
851 | }; |
852 | |
853 | template <typename T> |
854 | struct Conv2DState { |
855 | Conv2DState(const int64_t tile_spatial_size, const int64_t filter_shards_row, |
856 | const int64_t filter_shards_col, const T* input, |
857 | const T* tile_transform_matrix, const T* output_transform_matrix, |
858 | T* buffer1, T* buffer2, T* packed_tile_buffer, |
859 | T* gemm_output_buffer) |
860 | : tile_spatial_size(tile_spatial_size), |
861 | filter_shards_row(filter_shards_row), |
862 | filter_shards_col(filter_shards_col), |
863 | input(input), |
864 | tile_transform_matrix(tile_transform_matrix), |
865 | output_transform_matrix(output_transform_matrix), |
866 | buffer1(buffer1), |
867 | buffer2(buffer2), |
868 | packed_tile_buffer(packed_tile_buffer), |
869 | gemm_output_buffer(gemm_output_buffer) {} |
870 | |
871 | const int64_t tile_spatial_size; |
872 | const int64_t filter_shards_row; |
873 | const int64_t filter_shards_col; |
874 | const T* input; |
875 | const T* tile_transform_matrix; |
876 | const T* output_transform_matrix; |
877 | T* buffer1; |
878 | T* buffer2; |
879 | T* packed_tile_buffer; |
880 | T* gemm_output_buffer; |
881 | }; |
882 | |
883 | // Computes Conv2D for 'num_tiles' input tiles from 'input' starting at |
884 | // (in_r, in_c), storing the results of the computation in 'output'. |
885 | // Details: |
886 | // *) Transforms 'num_tiles' input tiles into 'tile_transform_buffer'. |
887 | // *) Computes point-wise MatMuls of 'num_tiles' input tiles with all filters. |
888 | // *) Transforms output tiles, and stores result to 'output'. |
889 | |
890 | // TODO(andydavis) Maybe pass Conv2DState into TransformInput/Output functions. |
891 | template <typename T> |
892 | struct ComputeConv2D { |
893 | void operator()(const Conv2DArgs& args, |
894 | const DeepConv2DTransform<T>* transform, |
895 | const Conv2DState<T>& cs, const int64_t in_r, |
896 | const int64_t in_c, const int64_t num_tiles, |
897 | const std::vector<Tensor>& packed_filters, const T* input, |
898 | T* output) { |
899 | // Transform input tiles. |
900 | TransformInputTiles<T>()(args, transform, num_tiles, in_r, in_c, input, |
901 | cs.tile_transform_matrix, cs.buffer1, cs.buffer2); |
902 | |
903 | // Compute element-wise product (each a MatMul): input tiles X filters. |
904 | const int64_t in_depth = args.in_depth; |
905 | const int64_t out_depth = args.out_depth; |
906 | const int64_t num_filters = |
907 | cs.filter_shards_row * cs.filter_shards_col * out_depth; |
908 | const int64_t tile_coord_stride = num_tiles * in_depth; |
909 | const int64_t gemm_out_buf_size = num_tiles * num_filters; |
910 | const int64_t gemm_out_buf_bytes = gemm_out_buf_size * sizeof(T); |
911 | |
912 | for (int64_t i = 0; i < cs.tile_spatial_size; ++i) { |
913 | GemmState<T> gemm(num_filters, num_tiles, in_depth, gemm_out_buf_size, |
914 | packed_filters[i].template flat<T>().data(), |
915 | cs.buffer2 + i * tile_coord_stride, |
916 | cs.packed_tile_buffer, cs.gemm_output_buffer); |
917 | // Pack tile buffer. |
918 | gemm.PackRhs(); |
919 | // Compute product. |
920 | gemm.Compute(); |
921 | // Copy to larger output buffer without alignment requirements. |
922 | memcpy(cs.buffer1 + i * gemm_out_buf_size, cs.gemm_output_buffer, |
923 | gemm_out_buf_bytes); |
924 | } |
925 | |
926 | // Transform output. |
927 | TransformOutputTile<T>()(args, transform, num_tiles, in_r, in_c, |
928 | cs.filter_shards_row, cs.filter_shards_col, |
929 | cs.output_transform_matrix, cs.buffer1, cs.buffer2, |
930 | output); |
931 | } |
932 | }; |
933 | |
934 | namespace functor { |
935 | |
936 | // Conv2D operation specialized for deep convolutions (i.e. large |
937 | // in_depth * out_depth). |
938 | // Details: |
939 | // *) Transforms and packs filters from 'filter' in parallel. |
940 | // *) Computes Conv2D parallelized across 'batch' dimension. |
941 | // *) Each thread loops over images in its batch shard, copying 'num_tiles' |
942 | // input tiles into a local buffer, and computing the Conv2D output of |
943 | // these tiles by all filters. |
944 | |
945 | // TODO(andydavis) Improve the performance of boundary cases where the input |
946 | // tile extends past the limit, and wasted outputs are computed. This overhead |
947 | // is at most 2/n, where 'n' is the max(out_rows, out_cols), and so is worse |
948 | // for smaller spatial sizes. |
949 | // TODO(andydavis) Improve the performance of sharded filters. |
950 | template <typename T> |
951 | struct DeepConv2D<CPUDevice, T> { |
952 | void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input, |
953 | const T* filter, T* output) { |
954 | // TODO(andydavis) Add function to select transform based on conv params. |
955 | std::unique_ptr<DeepConv2DTransform<T>> transform(new WinogradTransform<T>); |
956 | |
957 | const int64_t in_depth = args.in_depth; |
958 | const int64_t out_depth = args.out_depth; |
959 | |
960 | const int64_t tile_rows = transform->input_shape().rows; |
961 | const int64_t tile_cols = transform->input_shape().cols; |
962 | const int64_t tile_spatial_size = tile_rows * tile_cols; |
963 | |
964 | const int64_t out_tile_rows = transform->output_shape().rows; |
965 | const int64_t out_tile_cols = transform->output_shape().cols; |
966 | const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols; |
967 | |
968 | const int64_t base_filter_rows = transform->filter_shape().rows; |
969 | |
970 | const int64_t filter_residual_row = |
971 | std::max(int64_t{0}, args.filter_rows - base_filter_rows); |
972 | const int64_t filter_shards_row = 1 + (filter_residual_row + 2 - 1) / 2; |
973 | |
974 | const int64_t filter_residual_col = |
975 | std::max(int64_t{0}, args.filter_cols - base_filter_rows); |
976 | const int64_t filter_shards_col = 1 + (filter_residual_col + 2 - 1) / 2; |
977 | |
978 | // Allocate buffer for transformed filters. |
979 | Tensor filter_transform; |
980 | OP_REQUIRES_OK( |
981 | ctx, ctx->allocate_temp( |
982 | DataTypeToEnum<T>::value, |
983 | TensorShape({tile_rows, tile_cols, out_depth, |
984 | filter_shards_row, filter_shards_col, in_depth}), |
985 | &filter_transform)); |
986 | T* filter_transform_data = filter_transform.template flat<T>().data(); |
987 | |
988 | // Transform filters. |
989 | TransformFilters<T>()(ctx, args, transform.get(), filter_shards_row, |
990 | filter_shards_col, filter, filter_transform_data); |
991 | |
992 | // Pack filters. |
993 | std::vector<Tensor> packed_filters(tile_spatial_size); |
994 | PackFilters<T>()(ctx, args, tile_spatial_size, filter_shards_row, |
995 | filter_shards_col, filter_transform_data, &packed_filters); |
996 | |
997 | // Allocate buffer for tile transform matrix. |
998 | Tensor tile_transform_matrix_tensor; |
999 | OP_REQUIRES_OK(ctx, ctx->allocate_temp( |
1000 | DataTypeToEnum<T>::value, |
1001 | TensorShape({tile_spatial_size, tile_spatial_size}), |
1002 | &tile_transform_matrix_tensor)); |
1003 | T* tile_transform_matrix = |
1004 | tile_transform_matrix_tensor.template flat<T>().data(); |
1005 | transform->GetInputTransformMatrix(tile_spatial_size, tile_spatial_size, |
1006 | tile_transform_matrix); |
1007 | |
1008 | // Allocate buffer for output transform matrix. |
1009 | Tensor output_transform_matrix_tensor; |
1010 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, |
1011 | TensorShape({out_tile_spatial_size, |
1012 | tile_spatial_size}), |
1013 | &output_transform_matrix_tensor)); |
1014 | T* output_transform_matrix = |
1015 | output_transform_matrix_tensor.template flat<T>().data(); |
1016 | transform->GetOutputTransformMatrix( |
1017 | out_tile_spatial_size, tile_spatial_size, output_transform_matrix); |
1018 | |
1019 | auto shard = [&ctx, &args, &transform, &packed_filters, &in_depth, |
1020 | out_depth, out_tile_rows, out_tile_cols, filter_shards_row, |
1021 | filter_shards_col, tile_spatial_size, &input, |
1022 | &tile_transform_matrix, &output_transform_matrix, |
1023 | &output](int64_t batch_start, int64_t batch_limit) { |
1024 | const int64_t row_tiles = |
1025 | (args.out_rows + out_tile_rows - 1) / out_tile_rows + |
1026 | filter_shards_row - 1; |
1027 | const int64_t col_tiles = |
1028 | (args.out_cols + out_tile_cols - 1) / out_tile_cols + |
1029 | filter_shards_col - 1; |
1030 | |
1031 | // Calculate number of tiles to process together. |
1032 | const int64_t filter_shard_size = filter_shards_row * filter_shards_col; |
1033 | const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols; |
1034 | |
1035 | // Cache budget (based on L2 cache size = 256KB). |
1036 | // TODO(andydavis) Read cache size from the system. |
1037 | const int64_t cache_size = (256LL << 10) / sizeof(T); |
1038 | |
1039 | // Fixed costs. |
1040 | const int64_t tile_transform_matrix_size = |
1041 | tile_spatial_size * tile_spatial_size; |
1042 | const int64_t output_transform_matrix_size = |
1043 | out_tile_spatial_size * tile_spatial_size; |
1044 | // Calculate cache reserve size. |
1045 | const int64_t filter_depth_size = |
1046 | in_depth * out_depth * filter_shard_size; |
1047 | const bool small_filter = ((filter_depth_size * 100) / cache_size) <= 25; |
1048 | const int64_t cache_reserve_size = |
1049 | small_filter ? filter_depth_size : 1024; |
1050 | // Calculate total fixed cost. |
1051 | const int64_t total_fixed_cost = tile_transform_matrix_size + |
1052 | output_transform_matrix_size + |
1053 | cache_reserve_size; |
1054 | |
1055 | // Per-tile costs. |
1056 | const int64_t buffer1_per_tile_size = |
1057 | tile_spatial_size * std::max(in_depth, out_depth * filter_shard_size); |
1058 | const int64_t buffer2_per_tile_size = |
1059 | std::max(tile_spatial_size * in_depth, |
1060 | out_tile_spatial_size * out_depth * filter_shard_size); |
1061 | const int64_t packed_tile_per_tile_size = in_depth; |
1062 | const int64_t gemm_out_per_tile_size = out_depth * filter_shard_size; |
1063 | const int64_t total_per_tile_cost = |
1064 | buffer1_per_tile_size + buffer2_per_tile_size + |
1065 | packed_tile_per_tile_size + gemm_out_per_tile_size; |
1066 | |
1067 | const int64_t num_tiles_cache = std::max( |
1068 | int64{4}, (cache_size - total_fixed_cost) / total_per_tile_cost); |
1069 | const int64_t num_tiles = std::min(num_tiles_cache, col_tiles); |
1070 | |
1071 | // Allocate temporary buffer 'buffer1', which is first used for copying |
1072 | // input tiles, then re-used to buffer gemm output. Calculate the |
1073 | // required buffer size for 'buffer1', based on max buffer size required |
1074 | // between copying input tiles and buffering gemm product output. |
1075 | // buffer1: [max(buf1_tile_size, buf1_out_size)] |
1076 | const int64_t buffer1_tile_size = |
1077 | tile_spatial_size * num_tiles * in_depth; |
1078 | const int64_t buffer1_out_size = |
1079 | tile_spatial_size * num_tiles * out_depth * filter_shard_size; |
1080 | const int64_t buffer1_size = |
1081 | std::max(buffer1_tile_size, buffer1_out_size); |
1082 | Tensor buffer1_tensor; |
1083 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, |
1084 | TensorShape({buffer1_size}), |
1085 | &buffer1_tensor)); |
1086 | T* buffer1 = buffer1_tensor.template flat<T>().data(); |
1087 | |
1088 | // Allocate temporary buffer 'buffer2', which is first used for |
1089 | // transformed input tiles, then re-used for transformed output tiles. |
1090 | // Calculate required buffer size for 'buffer2' as max required buffer |
1091 | // between input and output transform buffer sizes. |
1092 | const int64_t buffer2_tile_transform_size = |
1093 | tile_spatial_size * num_tiles * in_depth; |
1094 | const int64_t buffer2_out_transform_size = |
1095 | out_tile_spatial_size * num_tiles * out_depth * filter_shard_size; |
1096 | const int64_t buffer2_size = |
1097 | std::max(buffer2_tile_transform_size, buffer2_out_transform_size); |
1098 | Tensor buffer2_tensor; |
1099 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, |
1100 | TensorShape({buffer2_size}), |
1101 | &buffer2_tensor)); |
1102 | T* buffer2 = buffer2_tensor.template flat<T>().data(); |
1103 | |
1104 | // Allocate temporary buffer to store packed tiles for one coordinate. |
1105 | // packed tile buffer: [num_tiles, in_depth]. |
1106 | Tensor packed_tile_tensor; |
1107 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, |
1108 | TensorShape({num_tiles, in_depth}), |
1109 | &packed_tile_tensor)); |
1110 | T* packed_tile_buffer = packed_tile_tensor.template flat<T>().data(); |
1111 | |
1112 | // Allocate temporary buffer for gemm output. |
1113 | // gemm output buffer [num_tiles, out_depth, shard_rows, shard_cols]. |
1114 | Tensor gemm_output_tensor; |
1115 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, |
1116 | TensorShape({num_tiles, out_depth, |
1117 | filter_shards_row, |
1118 | filter_shards_col}), |
1119 | &gemm_output_tensor)); |
1120 | T* gemm_output_buffer = gemm_output_tensor.template flat<T>().data(); |
1121 | |
1122 | // Capture state needed for ComputeConv2D inner loop. |
1123 | Conv2DState<T> conv_state(tile_spatial_size, filter_shards_row, |
1124 | filter_shards_col, input, tile_transform_matrix, |
1125 | output_transform_matrix, buffer1, buffer2, |
1126 | packed_tile_buffer, gemm_output_buffer); |
1127 | |
1128 | const int64_t row_pad = args.pad_rows; |
1129 | const int64_t col_pad = args.pad_cols; |
1130 | const int64_t unroll_col_limit = (col_tiles / num_tiles) * num_tiles; |
1131 | |
1132 | const int64_t input_image_size = args.in_rows * args.in_cols * in_depth; |
1133 | const int64_t output_image_size = |
1134 | args.out_rows * args.out_cols * out_depth; |
1135 | |
1136 | const int64_t tile_stride_rows = transform->output_shape().rows; |
1137 | const int64_t tile_stride_cols = transform->output_shape().cols; |
1138 | |
1139 | for (int64_t b = batch_start; b < batch_limit; ++b) { |
1140 | const int64_t in_base = b * input_image_size; |
1141 | const int64_t out_base = b * output_image_size; |
1142 | |
1143 | for (int64_t tile_r = 0; tile_r < row_tiles; ++tile_r) { |
1144 | const int64_t in_r = tile_r * tile_stride_rows - row_pad; |
1145 | |
1146 | // Process unrolled tiles. |
1147 | for (int64_t tile_c = 0; tile_c < unroll_col_limit; |
1148 | tile_c += num_tiles) { |
1149 | const int64_t in_c = tile_c * tile_stride_cols - col_pad; |
1150 | ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c, |
1151 | num_tiles, packed_filters, input + in_base, |
1152 | output + out_base); |
1153 | } |
1154 | // Process remaining tiles. |
1155 | if (unroll_col_limit < col_tiles) { |
1156 | const int64_t rem_tiles = col_tiles - unroll_col_limit; |
1157 | const int64_t in_c = unroll_col_limit * tile_stride_cols - col_pad; |
1158 | ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c, |
1159 | rem_tiles, packed_filters, input + in_base, |
1160 | output + out_base); |
1161 | } |
1162 | } |
1163 | } |
1164 | }; |
1165 | auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); |
1166 | const int64_t shard_cost = args.out_rows * args.out_cols * args.out_depth * |
1167 | tile_spatial_size * args.in_depth; |
1168 | Shard(worker_threads.num_threads, worker_threads.workers, args.batch, |
1169 | shard_cost, shard); |
1170 | } |
1171 | }; |
1172 | |
1173 | } // namespace functor |
1174 | |
1175 | template struct functor::DeepConv2D<CPUDevice, float>; |
1176 | |
1177 | } // namespace tensorflow |
1178 | |