1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | |
21 | namespace Eigen { |
22 | |
23 | // Noise mode used when padding. |
24 | enum { |
25 | UNIFORM = 0, |
26 | GAUSSIAN = 1, |
27 | ZERO = 2, |
28 | }; |
29 | |
30 | /** ExtractGlimpses |
31 | * \ingroup CXX11_NeuralNetworks_Module |
32 | * |
33 | * \brief Extract glimpses from an input tensor. |
34 | * |
35 | * The input parameter is expected to be a col-major tensor with a rank of 4 |
36 | * (depth, x, y, and batch). The width and height parameters specify the |
37 | * extension of the returned glimpses. The offsets parameter specifies the x, y |
38 | * locations of the center of the glimpses relative to the center of the input |
39 | * image. The vector is expected to contain one IndexPair for each image in the |
40 | * batch dimension. The normalized boolean indicates if incoming coordinates are |
41 | * normalized so that 0.0 and 1.0 correspond to the minimum and maximum of each |
42 | * height and width dimension. The centered boolean indicates if incoming |
43 | * coordinates are centered relative to the image, in which case -1.0 and 1.0 |
44 | * correspond to minimum and maximum of each dimension while 0.0 corresponds to |
45 | * the center. |
46 | * |
47 | * The result can be assigned to a tensor of rank equal to that of the input. |
48 | * The result will be laid out in col-major order (depth, x, y, batch). The |
49 | * dimensions of the result will be equal to the dimensions of the input except |
50 | * for width and height which will be equal to the requested glimpse size. |
51 | */ |
52 | namespace { |
53 | |
54 | template <typename Index> |
55 | struct { |
56 | (const Index width, const Index height, |
57 | const std::vector<IndexPair<float> >& offsets, |
58 | const bool normalized, const bool centered, |
59 | const ExtractGlimpsesNoiseMode noise, const int version) |
60 | : width_(width), |
61 | height_(height), |
62 | offsets_(offsets), |
63 | normalized_(normalized), |
64 | centered_(centered), |
65 | noise_(noise), |
66 | version_(version) {} |
67 | |
68 | template <typename Input> |
69 | DSizes<Index, 4> (const Input& input) const { |
70 | typedef typename internal::traits<Input>::Index IndexType; |
71 | typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4, |
72 | internal::traits<Input>::Layout, IndexType> > |
73 | Ref; |
74 | Ref in(input); |
75 | |
76 | DSizes<Index, 4> dims = in.dimensions(); |
77 | |
78 | dims[0] = in.dimension(0); |
79 | dims[1] = width_; |
80 | dims[2] = height_; |
81 | dims[3] = in.dimension(3); |
82 | return dims; |
83 | } |
84 | |
85 | template <typename Input, typename Output, typename Device> |
86 | EIGEN_DEVICE_FUNC void (const Input& input, Output& output, |
87 | const Device& device) const { |
88 | typedef typename internal::traits<Input>::Index IndexType; |
89 | typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4, |
90 | internal::traits<Input>::Layout, IndexType> > |
91 | Ref; |
92 | Ref in(input); |
93 | const Index num_channels = in.dimension(0); |
94 | const Index input_width = in.dimension(1); |
95 | const Index input_height = in.dimension(2); |
96 | const Index batch_size = in.dimension(3); |
97 | eigen_assert(input_width > 0); |
98 | eigen_assert(input_height > 0); |
99 | internal::NormalRandomGenerator<float> gen; |
100 | internal::UniformRandomGenerator<float> unigen; |
101 | |
102 | for (Index i = 0; i < batch_size; ++i) { |
103 | float x = offsets_[i].first, y = offsets_[i].second; |
104 | |
105 | if (version_ == 1) { |
106 | // Un-normalize coordinates back to pixel space if normalized. |
107 | if (normalized_) { |
108 | x *= input_width; |
109 | y *= input_height; |
110 | } |
111 | // Un-center if coordinates are centered on the image center. |
112 | if (centered_) { |
113 | x /= 2.0f; |
114 | y /= 2.0f; |
115 | x += input_width / 2.0f; |
116 | y += input_height / 2.0f; |
117 | } |
118 | // Remove half of the glimpse window. |
119 | x -= width_ / 2.0f; |
120 | y -= height_ / 2.0f; |
121 | } else { |
122 | if (normalized_) { |
123 | // Un-normalize coordinates back to pixel space if normalized. |
124 | x *= input_width; |
125 | y *= input_height; |
126 | if (centered_) { |
127 | // Un-center if coordinates are centered on the image center. |
128 | x /= 2.0f; |
129 | y /= 2.0f; |
130 | x += input_width / 2.0f; |
131 | y += input_height / 2.0f; |
132 | // Remove half of the glimpse window. |
133 | x -= width_ / 2.0f; |
134 | y -= height_ / 2.0f; |
135 | } |
136 | } else { |
137 | if (centered_) { |
138 | x += input_width / 2.0f; |
139 | y += input_height / 2.0f; |
140 | } |
141 | } |
142 | } |
143 | |
144 | const Index offset_x = (Index)x; |
145 | const Index offset_y = (Index)y; |
146 | Index glimpse_width = width_; |
147 | Index glimpse_height = height_; |
148 | bool partial_overlap = false; |
149 | DSizes<Index, 3> slice_offset(0, offset_x, offset_y); |
150 | DSizes<Index, 3> slice_extent(num_channels, width_, height_); |
151 | DSizes<Index, 3> base_offset(0, 0, 0); |
152 | |
153 | if (offset_x < 0) { |
154 | slice_offset[1] = 0; |
155 | glimpse_width = (std::max<Index>)(0, width_ + offset_x); |
156 | slice_extent[1] = glimpse_width; |
157 | base_offset[1] = width_ - glimpse_width; |
158 | partial_overlap = true; |
159 | } else if (offset_x + width_ >= input_width) { |
160 | glimpse_width = (std::max<Index>)(0, input_width - offset_x); |
161 | slice_extent[1] = glimpse_width; |
162 | partial_overlap = true; |
163 | } |
164 | if (offset_y < 0) { |
165 | slice_offset[2] = 0; |
166 | glimpse_height = (std::max<Index>)(0, height_ + offset_y); |
167 | slice_extent[2] = glimpse_height; |
168 | base_offset[2] = height_ - glimpse_height; |
169 | partial_overlap = true; |
170 | } else if (offset_y + height_ >= input_height) { |
171 | glimpse_height = (std::max<Index>)(0, input_height - offset_y); |
172 | slice_extent[2] = glimpse_height; |
173 | partial_overlap = true; |
174 | } |
175 | slice_extent[1] = std::min<Index>(input_width, slice_extent[1]); |
176 | slice_extent[2] = std::min<Index>(input_height, slice_extent[2]); |
177 | |
178 | if (partial_overlap) { |
179 | switch (noise_) { |
180 | case ZERO: { |
181 | // Initialize the glimpse with zero noise. |
182 | output.template chip<3>(i).device(device) = |
183 | output.template chip<3>(i).constant(0); |
184 | } break; |
185 | case UNIFORM: { |
186 | // Initialize the glimpse with uniform noise. |
187 | typedef std::remove_const_t< |
188 | typename internal::traits<Input>::Scalar> |
189 | Scalar; |
190 | TensorFixedSize<Scalar, Sizes<> > mini; |
191 | mini.device(device) = input.template chip<3>(i).minimum(); |
192 | TensorFixedSize<float, Sizes<> > range; |
193 | range.device(device) = (input.template chip<3>(i).maximum() - mini) |
194 | .template cast<float>(); |
195 | |
196 | DSizes<Index, 3> glimpse_size(num_channels, width_, height_); |
197 | TensorMap<Tensor<float, 3> > tmp(nullptr, glimpse_size); |
198 | output.template chip<3>(i).device(device) = |
199 | mini.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size) + |
200 | (tmp.random(unigen) * |
201 | range.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size)) |
202 | .template cast<Scalar>(); |
203 | } break; |
204 | case GAUSSIAN: { |
205 | // Initialize the glimpse with white noise: compute the mean and |
206 | // sigma |
207 | // of each channel, and use them to shape the gaussian. |
208 | DSizes<Index, 2> glimpse_size(width_, height_); |
209 | DSizes<Index, 2> input_size(input_width, input_height); |
210 | typedef std::remove_const_t< |
211 | typename internal::traits<Input>::Scalar> |
212 | Scalar; |
213 | |
214 | for (int j = 0; j < num_channels; ++j) { |
215 | TensorFixedSize<Scalar, Sizes<> > mean; |
216 | mean.device(device) = input.template chip<3>(i) |
217 | .template chip<0>(j) |
218 | .template cast<float>() |
219 | .mean(); |
220 | TensorFixedSize<float, Sizes<> > sigma; |
221 | sigma.device(device) = |
222 | (input.template chip<3>(i) |
223 | .template chip<0>(j) |
224 | .template cast<float>() - |
225 | mean.reshape(Sizes<1, 1>()).broadcast(input_size)) |
226 | .square() |
227 | .mean() |
228 | .sqrt(); |
229 | TensorFixedSize<Scalar, Sizes<> > mini; |
230 | mini.device(device) = |
231 | input.template chip<3>(i).template chip<0>(j).minimum(); |
232 | TensorFixedSize<float, Sizes<> > maxi; |
233 | maxi.device(device) = |
234 | input.template chip<3>(i).template chip<0>(j).maximum(); |
235 | |
236 | TensorMap<Tensor<float, 2> > tmp(nullptr, glimpse_size); |
237 | output.template chip<3>(i).template chip<0>(j).device(device) = |
238 | (mean.reshape(Sizes<1, 1>()).broadcast(glimpse_size) + |
239 | (tmp.random(gen) * |
240 | sigma.reshape(Sizes<1, 1>()).broadcast(glimpse_size)) |
241 | .template cast<Scalar>()) |
242 | .cwiseMin( |
243 | maxi.reshape(Sizes<1, 1>()).broadcast(glimpse_size)) |
244 | .cwiseMax( |
245 | mini.reshape(Sizes<1, 1>()).broadcast(glimpse_size)); |
246 | } |
247 | } break; |
248 | } |
249 | |
250 | // Copy the part of the glimpse that cover the input image if any. |
251 | if (glimpse_width == 0 || glimpse_height == 0) { |
252 | continue; |
253 | } |
254 | output.template chip<3>(i) |
255 | .slice(base_offset, slice_extent) |
256 | .device(device) = |
257 | input.template chip<3>(i).slice(slice_offset, slice_extent); |
258 | } else { |
259 | output.template chip<3>(i).device(device) = |
260 | input.template chip<3>(i).slice(slice_offset, slice_extent); |
261 | } |
262 | } |
263 | } |
264 | |
265 | private: |
266 | const Index ; |
267 | const Index ; |
268 | const std::vector<IndexPair<float> > ; |
269 | const bool ; |
270 | const bool ; |
271 | const ExtractGlimpsesNoiseMode ; |
272 | const int ; |
273 | }; |
274 | } // namespace |
275 | |
276 | template <typename Input> |
277 | EIGEN_ALWAYS_INLINE static const TensorCustomUnaryOp< |
278 | const GlimpseExtractionOp<typename internal::traits<Input>::Index>, |
279 | const Input> |
280 | ( |
281 | const Input& input, const typename internal::traits<Input>::Index width, |
282 | const typename internal::traits<Input>::Index height, |
283 | const std::vector<IndexPair<float> >& offsets, const bool normalized = true, |
284 | const bool centered = true, |
285 | const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM, |
286 | const int version = 2) { |
287 | EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor, |
288 | YOU_MADE_A_PROGRAMMING_MISTAKE); |
289 | EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4, |
290 | YOU_MADE_A_PROGRAMMING_MISTAKE); |
291 | |
292 | typedef typename internal::traits<Input>::Index Index; |
293 | const GlimpseExtractionOp<Index> op(width, height, offsets, normalized, |
294 | centered, noise, version); |
295 | return input.customOp(op); |
296 | } |
297 | |
298 | } // end namespace Eigen |
299 | |
300 | #endif // TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ |
301 | |