1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ |
18 | |
19 | #include "tensorflow/core/framework/types.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | class OpKernelContext; |
24 | |
25 | // DeepConv2D is a Conv2D implementation specialized for deep (i.e. large |
26 | // in_depth * out_depth product) convolutions (see deep_conv2d.cc for details). |
27 | |
28 | // DeepConv2DTransform is an interface for implementing transforms for |
29 | // DeepConv2D. Implementations must specify transform matrices and |
30 | // input/output/filter shapes. DeepConv2d computes: |
31 | // |
32 | // y = C[Ad * Bg] |
33 | // |
34 | // C: output transform matrix |
35 | // A: input data transform matrix |
36 | // B: filter transform matrix |
37 | // d: vectorized 2D data tile |
38 | // g: vectorized 2D filter tile |
39 | // y: vectorized 2D output tile |
40 | |
41 | template <typename T> |
42 | class DeepConv2DTransform { |
43 | public: |
44 | virtual ~DeepConv2DTransform() {} |
45 | |
46 | virtual void GetFilterTransformMatrix(const int64_t rows, const int64_t cols, |
47 | T* transform_matrix) const = 0; |
48 | |
49 | virtual void GetInputTransformMatrix(const int64_t rows, const int64_t cols, |
50 | T* transform_matrix) const = 0; |
51 | |
52 | virtual void GetOutputTransformMatrix(const int64_t rows, const int64_t cols, |
53 | T* transform_matrix) const = 0; |
54 | |
55 | struct Shape { |
56 | Shape(int64_t r, int64_t c) : rows(r), cols(c) {} |
57 | int64_t rows; |
58 | int64_t cols; |
59 | }; |
60 | |
61 | virtual const Shape& filter_shape() const = 0; |
62 | virtual const Shape& input_shape() const = 0; |
63 | virtual const Shape& output_shape() const = 0; |
64 | }; |
65 | |
66 | // Conv2D arguments used by DeepConv2D implementation. |
67 | struct Conv2DArgs { |
68 | // Input layer dimensions |
69 | int batch; |
70 | int in_rows; |
71 | int in_cols; |
72 | int in_depth; |
73 | int filter_rows; |
74 | int filter_cols; |
75 | int pad_rows; |
76 | int pad_cols; |
77 | |
78 | // Output layer dimensions |
79 | int out_rows; |
80 | int out_cols; |
81 | int out_depth; |
82 | |
83 | Conv2DArgs() |
84 | : batch(0), |
85 | in_rows(0), |
86 | in_cols(0), |
87 | in_depth(0), |
88 | filter_rows(0), |
89 | filter_cols(0), |
90 | pad_rows(0), |
91 | pad_cols(0), |
92 | out_rows(0), |
93 | out_cols(0), |
94 | out_depth(0) {} |
95 | }; |
96 | |
97 | // Returns true if convolution operation specified by function arguments |
98 | // can use DeepConv2D implementation, and false otherwise. |
99 | // May return false based on parameters, cost, or whether feature is disabled. |
100 | bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows, |
101 | int filter_cols, int in_depth, int out_depth, |
102 | int out_rows, int out_cols); |
103 | |
104 | namespace functor { |
105 | |
106 | // Calls DeepConv2D implementation (see deep_conv2d.cc for details). |
107 | template <typename Device, typename T> |
108 | struct DeepConv2D { |
109 | void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input, |
110 | const T* filter, T* output); |
111 | }; |
112 | |
113 | } // namespace functor |
114 | |
115 | } // namespace tensorflow |
116 | |
117 | #endif // TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ |
118 | |