1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
17#define TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20
21namespace Eigen {
22
23// Noise mode used when padding.
24enum ExtractGlimpsesNoiseMode {
25 UNIFORM = 0,
26 GAUSSIAN = 1,
27 ZERO = 2,
28};
29
30/** ExtractGlimpses
31 * \ingroup CXX11_NeuralNetworks_Module
32 *
33 * \brief Extract glimpses from an input tensor.
34 *
35 * The input parameter is expected to be a col-major tensor with a rank of 4
36 * (depth, x, y, and batch). The width and height parameters specify the
37 * extension of the returned glimpses. The offsets parameter specifies the x, y
38 * locations of the center of the glimpses relative to the center of the input
39 * image. The vector is expected to contain one IndexPair for each image in the
40 * batch dimension. The normalized boolean indicates if incoming coordinates are
41 * normalized so that 0.0 and 1.0 correspond to the minimum and maximum of each
42 * height and width dimension. The centered boolean indicates if incoming
43 * coordinates are centered relative to the image, in which case -1.0 and 1.0
44 * correspond to minimum and maximum of each dimension while 0.0 corresponds to
45 * the center.
46 *
47 * The result can be assigned to a tensor of rank equal to that of the input.
48 * The result will be laid out in col-major order (depth, x, y, batch). The
49 * dimensions of the result will be equal to the dimensions of the input except
50 * for width and height which will be equal to the requested glimpse size.
51 */
52namespace {
53
54template <typename Index>
55struct GlimpseExtractionOp {
56 GlimpseExtractionOp(const Index width, const Index height,
57 const std::vector<IndexPair<float> >& offsets,
58 const bool normalized, const bool centered,
59 const ExtractGlimpsesNoiseMode noise, const int version)
60 : width_(width),
61 height_(height),
62 offsets_(offsets),
63 normalized_(normalized),
64 centered_(centered),
65 noise_(noise),
66 version_(version) {}
67
68 template <typename Input>
69 DSizes<Index, 4> dimensions(const Input& input) const {
70 typedef typename internal::traits<Input>::Index IndexType;
71 typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
72 internal::traits<Input>::Layout, IndexType> >
73 Ref;
74 Ref in(input);
75
76 DSizes<Index, 4> dims = in.dimensions();
77
78 dims[0] = in.dimension(0);
79 dims[1] = width_;
80 dims[2] = height_;
81 dims[3] = in.dimension(3);
82 return dims;
83 }
84
85 template <typename Input, typename Output, typename Device>
86 EIGEN_DEVICE_FUNC void eval(const Input& input, Output& output,
87 const Device& device) const {
88 typedef typename internal::traits<Input>::Index IndexType;
89 typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
90 internal::traits<Input>::Layout, IndexType> >
91 Ref;
92 Ref in(input);
93 const Index num_channels = in.dimension(0);
94 const Index input_width = in.dimension(1);
95 const Index input_height = in.dimension(2);
96 const Index batch_size = in.dimension(3);
97 eigen_assert(input_width > 0);
98 eigen_assert(input_height > 0);
99 internal::NormalRandomGenerator<float> gen;
100 internal::UniformRandomGenerator<float> unigen;
101
102 for (Index i = 0; i < batch_size; ++i) {
103 float x = offsets_[i].first, y = offsets_[i].second;
104
105 if (version_ == 1) {
106 // Un-normalize coordinates back to pixel space if normalized.
107 if (normalized_) {
108 x *= input_width;
109 y *= input_height;
110 }
111 // Un-center if coordinates are centered on the image center.
112 if (centered_) {
113 x /= 2.0f;
114 y /= 2.0f;
115 x += input_width / 2.0f;
116 y += input_height / 2.0f;
117 }
118 // Remove half of the glimpse window.
119 x -= width_ / 2.0f;
120 y -= height_ / 2.0f;
121 } else {
122 if (normalized_) {
123 // Un-normalize coordinates back to pixel space if normalized.
124 x *= input_width;
125 y *= input_height;
126 if (centered_) {
127 // Un-center if coordinates are centered on the image center.
128 x /= 2.0f;
129 y /= 2.0f;
130 x += input_width / 2.0f;
131 y += input_height / 2.0f;
132 // Remove half of the glimpse window.
133 x -= width_ / 2.0f;
134 y -= height_ / 2.0f;
135 }
136 } else {
137 if (centered_) {
138 x += input_width / 2.0f;
139 y += input_height / 2.0f;
140 }
141 }
142 }
143
144 const Index offset_x = (Index)x;
145 const Index offset_y = (Index)y;
146 Index glimpse_width = width_;
147 Index glimpse_height = height_;
148 bool partial_overlap = false;
149 DSizes<Index, 3> slice_offset(0, offset_x, offset_y);
150 DSizes<Index, 3> slice_extent(num_channels, width_, height_);
151 DSizes<Index, 3> base_offset(0, 0, 0);
152
153 if (offset_x < 0) {
154 slice_offset[1] = 0;
155 glimpse_width = (std::max<Index>)(0, width_ + offset_x);
156 slice_extent[1] = glimpse_width;
157 base_offset[1] = width_ - glimpse_width;
158 partial_overlap = true;
159 } else if (offset_x + width_ >= input_width) {
160 glimpse_width = (std::max<Index>)(0, input_width - offset_x);
161 slice_extent[1] = glimpse_width;
162 partial_overlap = true;
163 }
164 if (offset_y < 0) {
165 slice_offset[2] = 0;
166 glimpse_height = (std::max<Index>)(0, height_ + offset_y);
167 slice_extent[2] = glimpse_height;
168 base_offset[2] = height_ - glimpse_height;
169 partial_overlap = true;
170 } else if (offset_y + height_ >= input_height) {
171 glimpse_height = (std::max<Index>)(0, input_height - offset_y);
172 slice_extent[2] = glimpse_height;
173 partial_overlap = true;
174 }
175 slice_extent[1] = std::min<Index>(input_width, slice_extent[1]);
176 slice_extent[2] = std::min<Index>(input_height, slice_extent[2]);
177
178 if (partial_overlap) {
179 switch (noise_) {
180 case ZERO: {
181 // Initialize the glimpse with zero noise.
182 output.template chip<3>(i).device(device) =
183 output.template chip<3>(i).constant(0);
184 } break;
185 case UNIFORM: {
186 // Initialize the glimpse with uniform noise.
187 typedef std::remove_const_t<
188 typename internal::traits<Input>::Scalar>
189 Scalar;
190 TensorFixedSize<Scalar, Sizes<> > mini;
191 mini.device(device) = input.template chip<3>(i).minimum();
192 TensorFixedSize<float, Sizes<> > range;
193 range.device(device) = (input.template chip<3>(i).maximum() - mini)
194 .template cast<float>();
195
196 DSizes<Index, 3> glimpse_size(num_channels, width_, height_);
197 TensorMap<Tensor<float, 3> > tmp(nullptr, glimpse_size);
198 output.template chip<3>(i).device(device) =
199 mini.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size) +
200 (tmp.random(unigen) *
201 range.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size))
202 .template cast<Scalar>();
203 } break;
204 case GAUSSIAN: {
205 // Initialize the glimpse with white noise: compute the mean and
206 // sigma
207 // of each channel, and use them to shape the gaussian.
208 DSizes<Index, 2> glimpse_size(width_, height_);
209 DSizes<Index, 2> input_size(input_width, input_height);
210 typedef std::remove_const_t<
211 typename internal::traits<Input>::Scalar>
212 Scalar;
213
214 for (int j = 0; j < num_channels; ++j) {
215 TensorFixedSize<Scalar, Sizes<> > mean;
216 mean.device(device) = input.template chip<3>(i)
217 .template chip<0>(j)
218 .template cast<float>()
219 .mean();
220 TensorFixedSize<float, Sizes<> > sigma;
221 sigma.device(device) =
222 (input.template chip<3>(i)
223 .template chip<0>(j)
224 .template cast<float>() -
225 mean.reshape(Sizes<1, 1>()).broadcast(input_size))
226 .square()
227 .mean()
228 .sqrt();
229 TensorFixedSize<Scalar, Sizes<> > mini;
230 mini.device(device) =
231 input.template chip<3>(i).template chip<0>(j).minimum();
232 TensorFixedSize<float, Sizes<> > maxi;
233 maxi.device(device) =
234 input.template chip<3>(i).template chip<0>(j).maximum();
235
236 TensorMap<Tensor<float, 2> > tmp(nullptr, glimpse_size);
237 output.template chip<3>(i).template chip<0>(j).device(device) =
238 (mean.reshape(Sizes<1, 1>()).broadcast(glimpse_size) +
239 (tmp.random(gen) *
240 sigma.reshape(Sizes<1, 1>()).broadcast(glimpse_size))
241 .template cast<Scalar>())
242 .cwiseMin(
243 maxi.reshape(Sizes<1, 1>()).broadcast(glimpse_size))
244 .cwiseMax(
245 mini.reshape(Sizes<1, 1>()).broadcast(glimpse_size));
246 }
247 } break;
248 }
249
250 // Copy the part of the glimpse that cover the input image if any.
251 if (glimpse_width == 0 || glimpse_height == 0) {
252 continue;
253 }
254 output.template chip<3>(i)
255 .slice(base_offset, slice_extent)
256 .device(device) =
257 input.template chip<3>(i).slice(slice_offset, slice_extent);
258 } else {
259 output.template chip<3>(i).device(device) =
260 input.template chip<3>(i).slice(slice_offset, slice_extent);
261 }
262 }
263 }
264
265 private:
266 const Index width_;
267 const Index height_;
268 const std::vector<IndexPair<float> > offsets_;
269 const bool normalized_;
270 const bool centered_;
271 const ExtractGlimpsesNoiseMode noise_;
272 const int version_;
273};
274} // namespace
275
276template <typename Input>
277EIGEN_ALWAYS_INLINE static const TensorCustomUnaryOp<
278 const GlimpseExtractionOp<typename internal::traits<Input>::Index>,
279 const Input>
280ExtractGlimpses(
281 const Input& input, const typename internal::traits<Input>::Index width,
282 const typename internal::traits<Input>::Index height,
283 const std::vector<IndexPair<float> >& offsets, const bool normalized = true,
284 const bool centered = true,
285 const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM,
286 const int version = 2) {
287 EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor,
288 YOU_MADE_A_PROGRAMMING_MISTAKE);
289 EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4,
290 YOU_MADE_A_PROGRAMMING_MISTAKE);
291
292 typedef typename internal::traits<Input>::Index Index;
293 const GlimpseExtractionOp<Index> op(width, height, offsets, normalized,
294 centered, noise, version);
295 return input.customOp(op);
296}
297
298} // end namespace Eigen
299
300#endif // TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_
301