1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
17#define TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
18
19#include "tensorflow/core/kernels/deep_conv2d.h"
20
21namespace tensorflow {
22
23// Winograd DeepConv2DTransform implementation for 3x3 filters.
24// Details:
25// *) Arithmetic complexity of computations: Shmuel Winograd
26// *) Fast Algorithms for Convolutional Neural Networks: Lavin, Gray
27
28template <typename T>
29class WinogradTransform : public DeepConv2DTransform<T> {
30 public:
31 typedef typename DeepConv2DTransform<T>::Shape Shape;
32
33 WinogradTransform()
34 : filter_shape_(3, 3), input_shape_(4, 4), output_shape_(2, 2) {}
35
36 virtual void GetFilterTransformMatrix(const int64_t rows, const int64_t cols,
37 T* transform_matrix) const;
38
39 virtual void GetInputTransformMatrix(const int64_t rows, const int64_t cols,
40 T* transform_matrix) const;
41
42 virtual void GetOutputTransformMatrix(const int64_t rows, const int64_t cols,
43 T* transform_matrix) const;
44
45 virtual const Shape& filter_shape() const { return filter_shape_; }
46 virtual const Shape& input_shape() const { return input_shape_; }
47 virtual const Shape& output_shape() const { return output_shape_; }
48
49 private:
50 const Shape filter_shape_;
51 const Shape input_shape_;
52 const Shape output_shape_;
53};
54
55// The filter transform matrix is the kronecker product 'M * M' of the
56// following matrix 'M':
57//
58// [ 1 0 0 ]
59// [ 1/2 1/2 1/2 ]
60// [ 1/2 -1/2 1/2 ]
61// [ 0 0 1 ]
62//
63// The data layout of 'transform_matrix':
64// [input_tile_spatial_size, filter_spatial_size]
65//
66template <typename T>
67void WinogradTransform<T>::GetFilterTransformMatrix(const int64_t rows,
68 const int64_t cols,
69 T* transform_matrix) const {
70 CHECK_GT(rows, 0);
71 CHECK_GT(cols, 0);
72 memset(transform_matrix, 0, sizeof(T) * rows * cols);
73
74 // Sub matrix [0,0]
75 transform_matrix[0 * cols + 0] = T(1.0);
76
77 transform_matrix[1 * cols + 0] = T(0.5);
78 transform_matrix[1 * cols + 1] = T(0.5);
79 transform_matrix[1 * cols + 2] = T(0.5);
80
81 transform_matrix[2 * cols + 0] = T(0.5);
82 transform_matrix[2 * cols + 1] = T(-0.5);
83 transform_matrix[2 * cols + 2] = T(0.5);
84
85 transform_matrix[3 * cols + 2] = T(1.0);
86
87 // Sub matrix [1,0]
88 transform_matrix[4 * cols + 0] = T(0.5);
89
90 transform_matrix[5 * cols + 0] = T(0.25);
91 transform_matrix[5 * cols + 1] = T(0.25);
92 transform_matrix[5 * cols + 2] = T(0.25);
93
94 transform_matrix[6 * cols + 0] = T(0.25);
95 transform_matrix[6 * cols + 1] = T(-0.25);
96 transform_matrix[6 * cols + 2] = T(0.25);
97
98 transform_matrix[7 * cols + 2] = T(0.5);
99
100 // Sub matrix [1,1]
101 transform_matrix[4 * cols + 3] = T(0.5);
102
103 transform_matrix[5 * cols + 3] = T(0.25);
104 transform_matrix[5 * cols + 4] = T(0.25);
105 transform_matrix[5 * cols + 5] = T(0.25);
106
107 transform_matrix[6 * cols + 3] = T(0.25);
108 transform_matrix[6 * cols + 4] = T(-0.25);
109 transform_matrix[6 * cols + 5] = T(0.25);
110
111 transform_matrix[7 * cols + 5] = T(0.5);
112
113 // Sub matrix [1,2]
114 transform_matrix[4 * cols + 6] = T(0.5);
115
116 transform_matrix[5 * cols + 6] = T(0.25);
117 transform_matrix[5 * cols + 7] = T(0.25);
118 transform_matrix[5 * cols + 8] = T(0.25);
119
120 transform_matrix[6 * cols + 6] = T(0.25);
121 transform_matrix[6 * cols + 7] = T(-0.25);
122 transform_matrix[6 * cols + 8] = T(0.25);
123
124 transform_matrix[7 * cols + 8] = T(0.5);
125
126 // Sub matrix [2,0]
127 transform_matrix[8 * cols + 0] = T(0.5);
128
129 transform_matrix[9 * cols + 0] = T(0.25);
130 transform_matrix[9 * cols + 1] = T(0.25);
131 transform_matrix[9 * cols + 2] = T(0.25);
132
133 transform_matrix[10 * cols + 0] = T(0.25);
134 transform_matrix[10 * cols + 1] = T(-0.25);
135 transform_matrix[10 * cols + 2] = T(0.25);
136
137 transform_matrix[11 * cols + 2] = T(0.5);
138
139 // Sub matrix [2,1]
140 transform_matrix[8 * cols + 3] = T(-0.5);
141
142 transform_matrix[9 * cols + 3] = T(-0.25);
143 transform_matrix[9 * cols + 4] = T(-0.25);
144 transform_matrix[9 * cols + 5] = T(-0.25);
145
146 transform_matrix[10 * cols + 3] = T(-0.25);
147 transform_matrix[10 * cols + 4] = T(0.25);
148 transform_matrix[10 * cols + 5] = T(-0.25);
149
150 transform_matrix[11 * cols + 5] = T(-0.5);
151
152 // Sub matrix [2,2]
153 transform_matrix[8 * cols + 6] = T(0.5);
154
155 transform_matrix[9 * cols + 6] = T(0.25);
156 transform_matrix[9 * cols + 7] = T(0.25);
157 transform_matrix[9 * cols + 8] = T(0.25);
158
159 transform_matrix[10 * cols + 6] = T(0.25);
160 transform_matrix[10 * cols + 7] = T(-0.25);
161 transform_matrix[10 * cols + 8] = T(0.25);
162
163 transform_matrix[11 * cols + 8] = T(0.5);
164
165 // Sub matrix [3,2]
166 transform_matrix[12 * cols + 6] = T(1.0);
167
168 transform_matrix[13 * cols + 6] = T(0.5);
169 transform_matrix[13 * cols + 7] = T(0.5);
170 transform_matrix[13 * cols + 8] = T(0.5);
171
172 transform_matrix[14 * cols + 6] = T(0.5);
173 transform_matrix[14 * cols + 7] = T(-0.5);
174 transform_matrix[14 * cols + 8] = T(0.5);
175
176 transform_matrix[15 * cols + 8] = T(1.0);
177}
178
179// The input transform matrix is the kronecker product 'M * M' of the
180// following matrix 'M':
181//
182// [1 0 -1 0]
183// [0 1 1 0]
184// [0 -1 1 0]
185// [0 1 0 -1]
186//
187// Data layout of 'transform_matrix':
188// [tile_spatial_size, tile_spatial_size]
189//
190template <typename T>
191void WinogradTransform<T>::GetInputTransformMatrix(const int64_t rows,
192 const int64_t cols,
193 T* transform_matrix) const {
194 CHECK_GT(rows, 0);
195 CHECK_GT(cols, 0);
196 memset(transform_matrix, 0, sizeof(T) * rows * cols);
197
198 // Sub matrix [0,0]
199 transform_matrix[0 * cols + 0] = T(1.0);
200 transform_matrix[0 * cols + 2] = T(-1.0);
201
202 transform_matrix[1 * cols + 1] = T(1.0);
203 transform_matrix[1 * cols + 2] = T(1.0);
204
205 transform_matrix[2 * cols + 1] = T(-1.0);
206 transform_matrix[2 * cols + 2] = T(1.0);
207
208 transform_matrix[3 * cols + 1] = T(1.0);
209 transform_matrix[3 * cols + 3] = T(-1.0);
210
211 // Sub matrix [0,2]
212 transform_matrix[0 * cols + 8] = T(-1.0);
213 transform_matrix[0 * cols + 10] = T(1.0);
214
215 transform_matrix[1 * cols + 9] = T(-1.0);
216 transform_matrix[1 * cols + 10] = T(-1.0);
217
218 transform_matrix[2 * cols + 9] = T(1.0);
219 transform_matrix[2 * cols + 10] = T(-1.0);
220
221 transform_matrix[3 * cols + 9] = T(-1.0);
222 transform_matrix[3 * cols + 11] = T(1.0);
223
224 // Sub matrix [1,1]
225 transform_matrix[4 * cols + 4] = T(1.0);
226 transform_matrix[4 * cols + 6] = T(-1.0);
227
228 transform_matrix[5 * cols + 5] = T(1.0);
229 transform_matrix[5 * cols + 6] = T(1.0);
230
231 transform_matrix[6 * cols + 5] = T(-1.0);
232 transform_matrix[6 * cols + 6] = T(1.0);
233
234 transform_matrix[7 * cols + 5] = T(1.0);
235 transform_matrix[7 * cols + 7] = T(-1.0);
236
237 // Sub matrix [1,2]
238 transform_matrix[4 * cols + 8] = T(1.0);
239 transform_matrix[4 * cols + 10] = T(-1.0);
240
241 transform_matrix[5 * cols + 9] = T(1.0);
242 transform_matrix[5 * cols + 10] = T(1.0);
243
244 transform_matrix[6 * cols + 9] = T(-1.0);
245 transform_matrix[6 * cols + 10] = T(1.0);
246
247 transform_matrix[7 * cols + 9] = T(1.0);
248 transform_matrix[7 * cols + 11] = T(-1.0);
249
250 // Sub matrix [2,1]
251 transform_matrix[8 * cols + 4] = T(-1.0);
252 transform_matrix[8 * cols + 6] = T(1.0);
253
254 transform_matrix[9 * cols + 5] = T(-1.0);
255 transform_matrix[9 * cols + 6] = T(-1.0);
256
257 transform_matrix[10 * cols + 5] = T(1.0);
258 transform_matrix[10 * cols + 6] = T(-1.0);
259
260 transform_matrix[11 * cols + 5] = T(-1.0);
261 transform_matrix[11 * cols + 7] = T(1.0);
262
263 // Sub matrix [2,2]
264 transform_matrix[8 * cols + 8] = T(1.0);
265 transform_matrix[8 * cols + 10] = T(-1.0);
266
267 transform_matrix[9 * cols + 9] = T(1.0);
268 transform_matrix[9 * cols + 10] = T(1.0);
269
270 transform_matrix[10 * cols + 9] = T(-1.0);
271 transform_matrix[10 * cols + 10] = T(1.0);
272
273 transform_matrix[11 * cols + 9] = T(1.0);
274 transform_matrix[11 * cols + 11] = T(-1.0);
275
276 // Sub matrix [3,1]
277 transform_matrix[12 * cols + 4] = T(1.0);
278 transform_matrix[12 * cols + 6] = T(-1.0);
279
280 transform_matrix[13 * cols + 5] = T(1.0);
281 transform_matrix[13 * cols + 6] = T(1.0);
282
283 transform_matrix[14 * cols + 5] = T(-1.0);
284 transform_matrix[14 * cols + 6] = T(1.0);
285
286 transform_matrix[15 * cols + 5] = T(1.0);
287 transform_matrix[15 * cols + 7] = T(-1.0);
288
289 // Sub matrix [3,3]
290 transform_matrix[12 * cols + 12] = T(-1.0);
291 transform_matrix[12 * cols + 14] = T(1.0);
292
293 transform_matrix[13 * cols + 13] = T(-1.0);
294 transform_matrix[13 * cols + 14] = T(-1.0);
295
296 transform_matrix[14 * cols + 13] = T(1.0);
297 transform_matrix[14 * cols + 14] = T(-1.0);
298
299 transform_matrix[15 * cols + 13] = T(-1.0);
300 transform_matrix[15 * cols + 15] = T(1.0);
301};
302
303// The output transform matrix is the kronecker product 'M * M' of the
304// following matrix 'M':
305//
306// [1 1 1 0]
307// [0 1 -1 -1]
308//
309// Data layout of 'transform_matrix':
310// [out_tile_spatial_size, tile_spatial_size]
311//
312template <typename T>
313void WinogradTransform<T>::GetOutputTransformMatrix(const int64_t rows,
314 const int64_t cols,
315 T* transform_matrix) const {
316 CHECK_GT(rows, 0);
317 CHECK_GT(cols, 0);
318 memset(transform_matrix, 0, sizeof(T) * rows * cols);
319
320 // Sub matrix [0,0]
321 transform_matrix[0 * cols + 0] = T(1.0);
322 transform_matrix[0 * cols + 1] = T(1.0);
323 transform_matrix[0 * cols + 2] = T(1.0);
324
325 transform_matrix[1 * cols + 1] = T(1.0);
326 transform_matrix[1 * cols + 2] = T(-1.0);
327 transform_matrix[1 * cols + 3] = T(-1.0);
328
329 // Sub matrix [0,1]
330 transform_matrix[0 * cols + 4] = T(1.0);
331 transform_matrix[0 * cols + 5] = T(1.0);
332 transform_matrix[0 * cols + 6] = T(1.0);
333
334 transform_matrix[1 * cols + 5] = T(1.0);
335 transform_matrix[1 * cols + 6] = T(-1.0);
336 transform_matrix[1 * cols + 7] = T(-1.0);
337
338 // Sub matrix [0,2]
339 transform_matrix[0 * cols + 8] = T(1.0);
340 transform_matrix[0 * cols + 9] = T(1.0);
341 transform_matrix[0 * cols + 10] = T(1.0);
342
343 transform_matrix[1 * cols + 9] = T(1.0);
344 transform_matrix[1 * cols + 10] = T(-1.0);
345 transform_matrix[1 * cols + 11] = T(-1.0);
346
347 // Sub matrix [1,1]
348 transform_matrix[2 * cols + 4] = T(1.0);
349 transform_matrix[2 * cols + 5] = T(1.0);
350 transform_matrix[2 * cols + 6] = T(1.0);
351
352 transform_matrix[3 * cols + 5] = T(1.0);
353 transform_matrix[3 * cols + 6] = T(-1.0);
354 transform_matrix[3 * cols + 7] = T(-1.0);
355
356 // Sub matrix [1,2]
357 transform_matrix[2 * cols + 8] = T(-1.0);
358 transform_matrix[2 * cols + 9] = T(-1.0);
359 transform_matrix[2 * cols + 10] = T(-1.0);
360
361 transform_matrix[3 * cols + 9] = T(-1.0);
362 transform_matrix[3 * cols + 10] = T(1.0);
363 transform_matrix[3 * cols + 11] = T(1.0);
364
365 // Sub matrix [1,3]
366 transform_matrix[2 * cols + 12] = T(-1.0);
367 transform_matrix[2 * cols + 13] = T(-1.0);
368 transform_matrix[2 * cols + 14] = T(-1.0);
369
370 transform_matrix[3 * cols + 13] = T(-1.0);
371 transform_matrix[3 * cols + 14] = T(1.0);
372 transform_matrix[3 * cols + 15] = T(1.0);
373};
374
375} // namespace tensorflow
376
377#endif // TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_
378