1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_CONV_OPS_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/framework/resource_mgr.h" |
21 | #include "tensorflow/core/platform/mem.h" |
22 | #include "tensorflow/core/util/tensor_format.h" |
23 | |
24 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
25 | #include "tensorflow/core/kernels/conv_ops_gpu.h" |
26 | #include "tensorflow/core/platform/stream_executor.h" |
27 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
28 | |
29 | namespace tensorflow { |
30 | |
31 | // Forward declaration. |
32 | class OpKernelContext; |
33 | |
34 | template <typename Device, typename T> |
35 | struct LaunchConv2DOp { |
36 | void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, |
37 | const Tensor& input, const Tensor& filter, int row_dilation, |
38 | int col_dilation, int row_stride, int col_stride, |
39 | const Padding& padding, |
40 | const std::vector<int64_t>& explicit_paddings, Tensor* output, |
41 | TensorFormat data_format); |
42 | }; |
43 | |
44 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
45 | template <typename T> |
46 | struct LaunchConv2DOp<Eigen::GpuDevice, T> { |
47 | void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, |
48 | const Tensor& input, const Tensor& filter, int row_dilation, |
49 | int col_dilation, int row_stride, int col_stride, |
50 | const Padding& padding, |
51 | const std::vector<int64_t>& explicit_paddings, Tensor* output, |
52 | TensorFormat data_format); |
53 | }; |
54 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
55 | |
56 | // Used to keep track of persistent memory buffers used within the op. |
57 | // It uses malloc and free to avoid the time cost of initializing the memory. |
58 | template <class T, size_t size> |
59 | struct Im2ColBufferResource : public ResourceBase { |
60 | Im2ColBufferResource<T, size>() { |
61 | data = static_cast<T*>(port::Malloc(size * sizeof(T))); |
62 | } |
63 | ~Im2ColBufferResource<T, size>() { port::Free(data); } |
64 | // This mutex ensures that only a single operation at a time is able to use |
65 | // the buffer memory held by this resource. |
66 | mutex mu; |
67 | T* data; |
68 | string DebugString() const { return "Im2ColBufferResource" ; } |
69 | }; |
70 | |
71 | // Convolution parameters specified by Op attributes. |
72 | struct Conv2DParameters { |
73 | std::vector<int32> dilations; |
74 | std::vector<int32> strides; |
75 | Padding padding; |
76 | TensorFormat data_format; |
77 | std::vector<int64_t> explicit_paddings; |
78 | }; |
79 | |
80 | // Convolution dimensions inferred from parameters, input and filter tensors. |
81 | struct Conv2DDimensions { |
82 | int batch; |
83 | int input_rows; |
84 | int input_cols; |
85 | int in_depth; |
86 | |
87 | int filter_rows; |
88 | int filter_cols; |
89 | int patch_depth; |
90 | int out_depth; |
91 | |
92 | int stride_rows; |
93 | int stride_cols; |
94 | |
95 | int dilation_rows; |
96 | int dilation_cols; |
97 | |
98 | int64_t out_rows; |
99 | int64_t out_cols; |
100 | int64_t pad_rows_before; |
101 | int64_t pad_rows_after; |
102 | int64_t pad_cols_before; |
103 | int64_t pad_cols_after; |
104 | }; |
105 | |
106 | // Initializes and validates Conv2D parameters configured by OpKernel |
107 | // attributes. |
108 | Status InitConv2DParameters(const OpKernelConstruction* context, |
109 | Conv2DParameters* params); |
110 | |
111 | // Computes and validates convolutions dimensions from Conv2D parameters. If |
112 | // parameters are valid, dimensions will be updated with derived convolution |
113 | // dimensions, otherwise an error will be returned. |
114 | Status ComputeConv2DDimension(const Conv2DParameters& params, |
115 | const Tensor& input, const Tensor& filter, |
116 | Conv2DDimensions* dimensions); |
117 | |
118 | } // namespace tensorflow |
119 | |
120 | #endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_ |
121 | |