1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ |
18 | |
19 | #include "tensorflow/core/kernels/deep_conv2d.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | // Winograd DeepConv2DTransform implementation for 3x3 filters. |
24 | // Details: |
25 | // *) Arithmetic complexity of computations: Shmuel Winograd |
26 | // *) Fast Algorithms for Convolutional Neural Networks: Lavin, Gray |
27 | |
28 | template <typename T> |
29 | class WinogradTransform : public DeepConv2DTransform<T> { |
30 | public: |
31 | typedef typename DeepConv2DTransform<T>::Shape Shape; |
32 | |
33 | WinogradTransform() |
34 | : filter_shape_(3, 3), input_shape_(4, 4), output_shape_(2, 2) {} |
35 | |
36 | virtual void GetFilterTransformMatrix(const int64_t rows, const int64_t cols, |
37 | T* transform_matrix) const; |
38 | |
39 | virtual void GetInputTransformMatrix(const int64_t rows, const int64_t cols, |
40 | T* transform_matrix) const; |
41 | |
42 | virtual void GetOutputTransformMatrix(const int64_t rows, const int64_t cols, |
43 | T* transform_matrix) const; |
44 | |
45 | virtual const Shape& filter_shape() const { return filter_shape_; } |
46 | virtual const Shape& input_shape() const { return input_shape_; } |
47 | virtual const Shape& output_shape() const { return output_shape_; } |
48 | |
49 | private: |
50 | const Shape filter_shape_; |
51 | const Shape input_shape_; |
52 | const Shape output_shape_; |
53 | }; |
54 | |
55 | // The filter transform matrix is the kronecker product 'M * M' of the |
56 | // following matrix 'M': |
57 | // |
58 | // [ 1 0 0 ] |
59 | // [ 1/2 1/2 1/2 ] |
60 | // [ 1/2 -1/2 1/2 ] |
61 | // [ 0 0 1 ] |
62 | // |
63 | // The data layout of 'transform_matrix': |
64 | // [input_tile_spatial_size, filter_spatial_size] |
65 | // |
66 | template <typename T> |
67 | void WinogradTransform<T>::GetFilterTransformMatrix(const int64_t rows, |
68 | const int64_t cols, |
69 | T* transform_matrix) const { |
70 | CHECK_GT(rows, 0); |
71 | CHECK_GT(cols, 0); |
72 | memset(transform_matrix, 0, sizeof(T) * rows * cols); |
73 | |
74 | // Sub matrix [0,0] |
75 | transform_matrix[0 * cols + 0] = T(1.0); |
76 | |
77 | transform_matrix[1 * cols + 0] = T(0.5); |
78 | transform_matrix[1 * cols + 1] = T(0.5); |
79 | transform_matrix[1 * cols + 2] = T(0.5); |
80 | |
81 | transform_matrix[2 * cols + 0] = T(0.5); |
82 | transform_matrix[2 * cols + 1] = T(-0.5); |
83 | transform_matrix[2 * cols + 2] = T(0.5); |
84 | |
85 | transform_matrix[3 * cols + 2] = T(1.0); |
86 | |
87 | // Sub matrix [1,0] |
88 | transform_matrix[4 * cols + 0] = T(0.5); |
89 | |
90 | transform_matrix[5 * cols + 0] = T(0.25); |
91 | transform_matrix[5 * cols + 1] = T(0.25); |
92 | transform_matrix[5 * cols + 2] = T(0.25); |
93 | |
94 | transform_matrix[6 * cols + 0] = T(0.25); |
95 | transform_matrix[6 * cols + 1] = T(-0.25); |
96 | transform_matrix[6 * cols + 2] = T(0.25); |
97 | |
98 | transform_matrix[7 * cols + 2] = T(0.5); |
99 | |
100 | // Sub matrix [1,1] |
101 | transform_matrix[4 * cols + 3] = T(0.5); |
102 | |
103 | transform_matrix[5 * cols + 3] = T(0.25); |
104 | transform_matrix[5 * cols + 4] = T(0.25); |
105 | transform_matrix[5 * cols + 5] = T(0.25); |
106 | |
107 | transform_matrix[6 * cols + 3] = T(0.25); |
108 | transform_matrix[6 * cols + 4] = T(-0.25); |
109 | transform_matrix[6 * cols + 5] = T(0.25); |
110 | |
111 | transform_matrix[7 * cols + 5] = T(0.5); |
112 | |
113 | // Sub matrix [1,2] |
114 | transform_matrix[4 * cols + 6] = T(0.5); |
115 | |
116 | transform_matrix[5 * cols + 6] = T(0.25); |
117 | transform_matrix[5 * cols + 7] = T(0.25); |
118 | transform_matrix[5 * cols + 8] = T(0.25); |
119 | |
120 | transform_matrix[6 * cols + 6] = T(0.25); |
121 | transform_matrix[6 * cols + 7] = T(-0.25); |
122 | transform_matrix[6 * cols + 8] = T(0.25); |
123 | |
124 | transform_matrix[7 * cols + 8] = T(0.5); |
125 | |
126 | // Sub matrix [2,0] |
127 | transform_matrix[8 * cols + 0] = T(0.5); |
128 | |
129 | transform_matrix[9 * cols + 0] = T(0.25); |
130 | transform_matrix[9 * cols + 1] = T(0.25); |
131 | transform_matrix[9 * cols + 2] = T(0.25); |
132 | |
133 | transform_matrix[10 * cols + 0] = T(0.25); |
134 | transform_matrix[10 * cols + 1] = T(-0.25); |
135 | transform_matrix[10 * cols + 2] = T(0.25); |
136 | |
137 | transform_matrix[11 * cols + 2] = T(0.5); |
138 | |
139 | // Sub matrix [2,1] |
140 | transform_matrix[8 * cols + 3] = T(-0.5); |
141 | |
142 | transform_matrix[9 * cols + 3] = T(-0.25); |
143 | transform_matrix[9 * cols + 4] = T(-0.25); |
144 | transform_matrix[9 * cols + 5] = T(-0.25); |
145 | |
146 | transform_matrix[10 * cols + 3] = T(-0.25); |
147 | transform_matrix[10 * cols + 4] = T(0.25); |
148 | transform_matrix[10 * cols + 5] = T(-0.25); |
149 | |
150 | transform_matrix[11 * cols + 5] = T(-0.5); |
151 | |
152 | // Sub matrix [2,2] |
153 | transform_matrix[8 * cols + 6] = T(0.5); |
154 | |
155 | transform_matrix[9 * cols + 6] = T(0.25); |
156 | transform_matrix[9 * cols + 7] = T(0.25); |
157 | transform_matrix[9 * cols + 8] = T(0.25); |
158 | |
159 | transform_matrix[10 * cols + 6] = T(0.25); |
160 | transform_matrix[10 * cols + 7] = T(-0.25); |
161 | transform_matrix[10 * cols + 8] = T(0.25); |
162 | |
163 | transform_matrix[11 * cols + 8] = T(0.5); |
164 | |
165 | // Sub matrix [3,2] |
166 | transform_matrix[12 * cols + 6] = T(1.0); |
167 | |
168 | transform_matrix[13 * cols + 6] = T(0.5); |
169 | transform_matrix[13 * cols + 7] = T(0.5); |
170 | transform_matrix[13 * cols + 8] = T(0.5); |
171 | |
172 | transform_matrix[14 * cols + 6] = T(0.5); |
173 | transform_matrix[14 * cols + 7] = T(-0.5); |
174 | transform_matrix[14 * cols + 8] = T(0.5); |
175 | |
176 | transform_matrix[15 * cols + 8] = T(1.0); |
177 | } |
178 | |
179 | // The input transform matrix is the kronecker product 'M * M' of the |
180 | // following matrix 'M': |
181 | // |
182 | // [1 0 -1 0] |
183 | // [0 1 1 0] |
184 | // [0 -1 1 0] |
185 | // [0 1 0 -1] |
186 | // |
187 | // Data layout of 'transform_matrix': |
188 | // [tile_spatial_size, tile_spatial_size] |
189 | // |
190 | template <typename T> |
191 | void WinogradTransform<T>::GetInputTransformMatrix(const int64_t rows, |
192 | const int64_t cols, |
193 | T* transform_matrix) const { |
194 | CHECK_GT(rows, 0); |
195 | CHECK_GT(cols, 0); |
196 | memset(transform_matrix, 0, sizeof(T) * rows * cols); |
197 | |
198 | // Sub matrix [0,0] |
199 | transform_matrix[0 * cols + 0] = T(1.0); |
200 | transform_matrix[0 * cols + 2] = T(-1.0); |
201 | |
202 | transform_matrix[1 * cols + 1] = T(1.0); |
203 | transform_matrix[1 * cols + 2] = T(1.0); |
204 | |
205 | transform_matrix[2 * cols + 1] = T(-1.0); |
206 | transform_matrix[2 * cols + 2] = T(1.0); |
207 | |
208 | transform_matrix[3 * cols + 1] = T(1.0); |
209 | transform_matrix[3 * cols + 3] = T(-1.0); |
210 | |
211 | // Sub matrix [0,2] |
212 | transform_matrix[0 * cols + 8] = T(-1.0); |
213 | transform_matrix[0 * cols + 10] = T(1.0); |
214 | |
215 | transform_matrix[1 * cols + 9] = T(-1.0); |
216 | transform_matrix[1 * cols + 10] = T(-1.0); |
217 | |
218 | transform_matrix[2 * cols + 9] = T(1.0); |
219 | transform_matrix[2 * cols + 10] = T(-1.0); |
220 | |
221 | transform_matrix[3 * cols + 9] = T(-1.0); |
222 | transform_matrix[3 * cols + 11] = T(1.0); |
223 | |
224 | // Sub matrix [1,1] |
225 | transform_matrix[4 * cols + 4] = T(1.0); |
226 | transform_matrix[4 * cols + 6] = T(-1.0); |
227 | |
228 | transform_matrix[5 * cols + 5] = T(1.0); |
229 | transform_matrix[5 * cols + 6] = T(1.0); |
230 | |
231 | transform_matrix[6 * cols + 5] = T(-1.0); |
232 | transform_matrix[6 * cols + 6] = T(1.0); |
233 | |
234 | transform_matrix[7 * cols + 5] = T(1.0); |
235 | transform_matrix[7 * cols + 7] = T(-1.0); |
236 | |
237 | // Sub matrix [1,2] |
238 | transform_matrix[4 * cols + 8] = T(1.0); |
239 | transform_matrix[4 * cols + 10] = T(-1.0); |
240 | |
241 | transform_matrix[5 * cols + 9] = T(1.0); |
242 | transform_matrix[5 * cols + 10] = T(1.0); |
243 | |
244 | transform_matrix[6 * cols + 9] = T(-1.0); |
245 | transform_matrix[6 * cols + 10] = T(1.0); |
246 | |
247 | transform_matrix[7 * cols + 9] = T(1.0); |
248 | transform_matrix[7 * cols + 11] = T(-1.0); |
249 | |
250 | // Sub matrix [2,1] |
251 | transform_matrix[8 * cols + 4] = T(-1.0); |
252 | transform_matrix[8 * cols + 6] = T(1.0); |
253 | |
254 | transform_matrix[9 * cols + 5] = T(-1.0); |
255 | transform_matrix[9 * cols + 6] = T(-1.0); |
256 | |
257 | transform_matrix[10 * cols + 5] = T(1.0); |
258 | transform_matrix[10 * cols + 6] = T(-1.0); |
259 | |
260 | transform_matrix[11 * cols + 5] = T(-1.0); |
261 | transform_matrix[11 * cols + 7] = T(1.0); |
262 | |
263 | // Sub matrix [2,2] |
264 | transform_matrix[8 * cols + 8] = T(1.0); |
265 | transform_matrix[8 * cols + 10] = T(-1.0); |
266 | |
267 | transform_matrix[9 * cols + 9] = T(1.0); |
268 | transform_matrix[9 * cols + 10] = T(1.0); |
269 | |
270 | transform_matrix[10 * cols + 9] = T(-1.0); |
271 | transform_matrix[10 * cols + 10] = T(1.0); |
272 | |
273 | transform_matrix[11 * cols + 9] = T(1.0); |
274 | transform_matrix[11 * cols + 11] = T(-1.0); |
275 | |
276 | // Sub matrix [3,1] |
277 | transform_matrix[12 * cols + 4] = T(1.0); |
278 | transform_matrix[12 * cols + 6] = T(-1.0); |
279 | |
280 | transform_matrix[13 * cols + 5] = T(1.0); |
281 | transform_matrix[13 * cols + 6] = T(1.0); |
282 | |
283 | transform_matrix[14 * cols + 5] = T(-1.0); |
284 | transform_matrix[14 * cols + 6] = T(1.0); |
285 | |
286 | transform_matrix[15 * cols + 5] = T(1.0); |
287 | transform_matrix[15 * cols + 7] = T(-1.0); |
288 | |
289 | // Sub matrix [3,3] |
290 | transform_matrix[12 * cols + 12] = T(-1.0); |
291 | transform_matrix[12 * cols + 14] = T(1.0); |
292 | |
293 | transform_matrix[13 * cols + 13] = T(-1.0); |
294 | transform_matrix[13 * cols + 14] = T(-1.0); |
295 | |
296 | transform_matrix[14 * cols + 13] = T(1.0); |
297 | transform_matrix[14 * cols + 14] = T(-1.0); |
298 | |
299 | transform_matrix[15 * cols + 13] = T(-1.0); |
300 | transform_matrix[15 * cols + 15] = T(1.0); |
301 | }; |
302 | |
303 | // The output transform matrix is the kronecker product 'M * M' of the |
304 | // following matrix 'M': |
305 | // |
306 | // [1 1 1 0] |
307 | // [0 1 -1 -1] |
308 | // |
309 | // Data layout of 'transform_matrix': |
310 | // [out_tile_spatial_size, tile_spatial_size] |
311 | // |
312 | template <typename T> |
313 | void WinogradTransform<T>::GetOutputTransformMatrix(const int64_t rows, |
314 | const int64_t cols, |
315 | T* transform_matrix) const { |
316 | CHECK_GT(rows, 0); |
317 | CHECK_GT(cols, 0); |
318 | memset(transform_matrix, 0, sizeof(T) * rows * cols); |
319 | |
320 | // Sub matrix [0,0] |
321 | transform_matrix[0 * cols + 0] = T(1.0); |
322 | transform_matrix[0 * cols + 1] = T(1.0); |
323 | transform_matrix[0 * cols + 2] = T(1.0); |
324 | |
325 | transform_matrix[1 * cols + 1] = T(1.0); |
326 | transform_matrix[1 * cols + 2] = T(-1.0); |
327 | transform_matrix[1 * cols + 3] = T(-1.0); |
328 | |
329 | // Sub matrix [0,1] |
330 | transform_matrix[0 * cols + 4] = T(1.0); |
331 | transform_matrix[0 * cols + 5] = T(1.0); |
332 | transform_matrix[0 * cols + 6] = T(1.0); |
333 | |
334 | transform_matrix[1 * cols + 5] = T(1.0); |
335 | transform_matrix[1 * cols + 6] = T(-1.0); |
336 | transform_matrix[1 * cols + 7] = T(-1.0); |
337 | |
338 | // Sub matrix [0,2] |
339 | transform_matrix[0 * cols + 8] = T(1.0); |
340 | transform_matrix[0 * cols + 9] = T(1.0); |
341 | transform_matrix[0 * cols + 10] = T(1.0); |
342 | |
343 | transform_matrix[1 * cols + 9] = T(1.0); |
344 | transform_matrix[1 * cols + 10] = T(-1.0); |
345 | transform_matrix[1 * cols + 11] = T(-1.0); |
346 | |
347 | // Sub matrix [1,1] |
348 | transform_matrix[2 * cols + 4] = T(1.0); |
349 | transform_matrix[2 * cols + 5] = T(1.0); |
350 | transform_matrix[2 * cols + 6] = T(1.0); |
351 | |
352 | transform_matrix[3 * cols + 5] = T(1.0); |
353 | transform_matrix[3 * cols + 6] = T(-1.0); |
354 | transform_matrix[3 * cols + 7] = T(-1.0); |
355 | |
356 | // Sub matrix [1,2] |
357 | transform_matrix[2 * cols + 8] = T(-1.0); |
358 | transform_matrix[2 * cols + 9] = T(-1.0); |
359 | transform_matrix[2 * cols + 10] = T(-1.0); |
360 | |
361 | transform_matrix[3 * cols + 9] = T(-1.0); |
362 | transform_matrix[3 * cols + 10] = T(1.0); |
363 | transform_matrix[3 * cols + 11] = T(1.0); |
364 | |
365 | // Sub matrix [1,3] |
366 | transform_matrix[2 * cols + 12] = T(-1.0); |
367 | transform_matrix[2 * cols + 13] = T(-1.0); |
368 | transform_matrix[2 * cols + 14] = T(-1.0); |
369 | |
370 | transform_matrix[3 * cols + 13] = T(-1.0); |
371 | transform_matrix[3 * cols + 14] = T(1.0); |
372 | transform_matrix[3 * cols + 15] = T(1.0); |
373 | }; |
374 | |
375 | } // namespace tensorflow |
376 | |
377 | #endif // TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ |
378 | |