1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Implements quantized eight-bit versions of the convolution operations. |
17 | |
18 | #include <algorithm> |
19 | #include <vector> |
20 | |
21 | #define EIGEN_USE_THREADS |
22 | |
23 | #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK |
24 | #include "public/gemmlowp.h" |
25 | #include "tensorflow/core/framework/kernel_shape_util.h" |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/kernels/conv_ops.h" |
29 | #include "tensorflow/core/kernels/meta_support.h" |
30 | #include "tensorflow/core/kernels/quantization_utils.h" |
31 | #include "tensorflow/core/kernels/reference_gemm.h" |
32 | #include "tensorflow/core/lib/core/errors.h" |
33 | #include "tensorflow/core/platform/errors.h" |
34 | #include "tensorflow/core/util/padding.h" |
35 | |
36 | namespace tensorflow { |
37 | |
38 | // This functor implements the convolution operation in as simple a form as |
39 | // possible. It won't give great performance, but it is very useful for |
40 | // stepping through and instrumenting for debugging, creating minimal benchmarks |
41 | // to prototype with, and sharing with teams that want to run this outside of |
42 | // our environment. |
43 | // With that in mind, I've avoided using anything except pretty standard C++ |
44 | // types. This is especially noticeable in the data access through raw array |
45 | // indexing. It's deliberate in this case though, since it makes the underlying |
46 | // memory order very explicit, which is important for both inspecting memory |
47 | // contents during debugging and for specifying what we expect to others. |
48 | // The memory layout of the data is, from biggest stride to smallest: |
49 | // input_data = [input_batches, input_height, input_width, input_depth] |
50 | // filter_data = [filter_height, filter_width, input_depth, filter_count] |
51 | // output_data = [input_batches, output_height, output_width, filter_count] |
52 | template <class T1, class T2, class T3> |
53 | class ReferenceConvFunctor { |
54 | public: |
55 | void operator()(OpKernelContext* context, const T1* input_data, |
56 | int input_batches, int input_height, int input_width, |
57 | int input_depth, int input_offset, const T2* filter_data, |
58 | int filter_height, int filter_width, int filter_count, |
59 | int filter_offset, int stride, Padding padding, |
60 | T3* output_data, int output_height, int output_width, |
61 | int output_shift, int output_offset, int output_mult) { |
62 | // Set up some constants we need for the output down-shifting and |
63 | // saturation. |
64 | const int32_t highest = static_cast<int32>(Eigen::NumTraits<T3>::highest()); |
65 | const int32_t lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest()); |
66 | |
67 | // When we're converting the 32 bit accumulator to a lower bit depth, we |
68 | // need to add on 0.5 in fixed-point terms to make the operation round half |
69 | // up towards positive infinity, rather than a floor. |
70 | // We also need to watch out for the case when there's no down shift, |
71 | // because a left shift by a negative number gives undefined results. |
72 | const int32_t rounding = (output_shift < 1) ? 0 : (1 << (output_shift - 1)); |
73 | |
74 | // The two different padding modes we support can be a bit confusing. SAME |
75 | // means we're trying to produce an output image that's the same size as the |
76 | // input. It's complicated by stride, which shrinks the output image by a |
77 | // a factor, but it means we end up sampling from outside the borders of the |
78 | // input. These out-of-bounds values are read as zeroes. VALID means only |
79 | // produce output values where the filters can read all their values from |
80 | // within the input image. It effectively removes the margins of the output |
81 | // image compared to the one produced by SAME. Stride complicates this |
82 | // definition though, because it can result in the right and bottom filter |
83 | // patches sampling from outside the borders if it's greater than 1. |
84 | // Most of the logic for sorting this all out is done before this function, |
85 | // when we calculate the output size, but the positioning of the origin of |
86 | // the filters is different between the two modes, since SAME positions the |
87 | // first filter off the edge of the input. |
88 | int filter_left_offset; |
89 | int filter_top_offset; |
90 | if (padding == VALID) { |
91 | filter_left_offset = |
92 | ((output_width - 1) * stride + filter_width - input_width + 1) / 2; |
93 | filter_top_offset = |
94 | ((output_height - 1) * stride + filter_height - input_height + 1) / 2; |
95 | } else { |
96 | filter_left_offset = |
97 | ((output_width - 1) * stride + filter_width - input_width) / 2; |
98 | filter_top_offset = |
99 | ((output_height - 1) * stride + filter_height - input_height) / 2; |
100 | } |
101 | |
102 | // If we've got multiple images in our input, work through each of them. |
103 | for (int batch = 0; batch < input_batches; ++batch) { |
104 | // Walk through all the output image values, sliding the filter to |
105 | // different |
106 | // positions in the input. |
107 | for (int out_y = 0; out_y < output_height; ++out_y) { |
108 | for (int out_x = 0; out_x < output_width; ++out_x) { |
109 | // Each filter kernel produces one output channel. |
110 | for (int out_channel = 0; out_channel < filter_count; ++out_channel) { |
111 | // We're going to calculate a single output value, which means we |
112 | // need to multiply a three dimensional kernel of weights against |
113 | // the current location within the input image. |
114 | /* |
115 | *-------------------------------... |
116 | |\ ^ |
117 | | \in_depth |
118 | | \ v |
119 | | *-------------------------------... |
120 | | | ^ |
121 | | | in_y_origin |
122 | | | v \ |
123 | | |<in_x_origin>*---*^ |
124 | | | \| |filter_height |
125 | . | *---*v |
126 | . | <---> |
127 | . filter_width |
128 | . |
129 | */ |
130 | const int in_x_origin = (out_x * stride) - filter_left_offset; |
131 | const int in_y_origin = (out_y * stride) - filter_top_offset; |
132 | int32_t total = 0; |
133 | for (int filter_y = 0; filter_y < filter_height; ++filter_y) { |
134 | for (int filter_x = 0; filter_x < filter_width; ++filter_x) { |
135 | for (int in_channel = 0; in_channel < input_depth; |
136 | ++in_channel) { |
137 | const int in_x = in_x_origin + filter_x; |
138 | const int in_y = in_y_origin + filter_y; |
139 | int32_t input_value; |
140 | // If the location is outside the bounds of the input image, |
141 | // use zero as a default value. |
142 | if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && |
143 | (in_y < input_height)) { |
144 | const T1 input_source_value = |
145 | input_data[(batch * input_height * input_width * |
146 | input_depth) + |
147 | (in_y * input_width * input_depth) + |
148 | (in_x * input_depth) + in_channel]; |
149 | // We're promoting the T1 type to a higher bit depth here as |
150 | // we do the subtraction. |
151 | input_value = |
152 | static_cast<int32>(input_source_value) - input_offset; |
153 | } else { |
154 | input_value = 0; |
155 | } |
156 | const T2 filter_source_value = |
157 | filter_data[(filter_y * filter_width * input_depth * |
158 | filter_count) + |
159 | (filter_x * input_depth * filter_count) + |
160 | (in_channel * filter_count) + out_channel]; |
161 | // Another promotion to 32 bit, as above. |
162 | const int32_t filter_value = |
163 | static_cast<int32>(filter_source_value) - filter_offset; |
164 | total += (input_value * filter_value); |
165 | } |
166 | } |
167 | } |
168 | // Here we're applying scale factors to compress the 32 bit |
169 | // accumulated total to a potentially lower bit depth. |
170 | const int32_t output = |
171 | ((((total + output_offset) * output_mult) + rounding) >> |
172 | output_shift); |
173 | // We need to saturate the results against the largest and smallest |
174 | // values that can be represented in this type. |
175 | const int32_t top_clamped_output = std::min(output, highest); |
176 | const int32_t clamped_output = std::max(top_clamped_output, lowest); |
177 | output_data[(batch * output_height * output_width * filter_count) + |
178 | (out_y * output_width * filter_count) + |
179 | (out_x * filter_count) + out_channel] = clamped_output; |
180 | } |
181 | } |
182 | } |
183 | } |
184 | } |
185 | }; |
186 | |
187 | // We don't want to allocate a buffer to hold all the patches if the size is |
188 | // going to be extremely large, so break it into chunks if it's bigger than |
189 | // a limit. Each chunk will be processed serially, so we can refill the |
190 | // buffer for the next chunk and reuse it, keeping maximum memory size down. |
191 | // In this case, we've picked 1 megabyte as a reasonable limit, from |
192 | // experimentation. |
193 | const size_t kMaxChunkSize = (1 * 1024 * 1024); |
194 | |
195 | // Implements convolution as a two stage process, first packing the patches of |
196 | // the input image into columns (im2col) and then running GEMM to produce the |
197 | // final result. |
198 | template <class T1, class T2, class T3> |
199 | class Im2ColConvFunctor { |
200 | public: |
201 | void operator()(OpKernelContext* context, const T1* input_data, |
202 | int input_batches, int input_height, int input_width, |
203 | int input_depth, int input_offset, const T2* filter_data, |
204 | int filter_height, int filter_width, int filter_count, |
205 | int filter_offset, int stride, Padding padding, |
206 | T3* output_data, int output_height, int output_width, |
207 | int output_shift, int output_offset, int output_mult) { |
208 | if (input_offset < 0) { |
209 | // Only log the first few occurrences of this warning. |
210 | static int warning_count = 0; |
211 | if (warning_count < 10) { |
212 | ++warning_count; |
213 | LOG(WARNING) |
214 | << "For kernel '" << context->op_kernel().name() << "' from input '" |
215 | << context->op_kernel().requested_input(0) |
216 | << "': Zero is not representable in the quantized range used by the" |
217 | << " input. This means QuantizedConv2d has to fall back to a slow" |
218 | << " implementation, since the border of zero values can't be" |
219 | << " represented easily. You should try to construct graphs that" |
220 | << " avoid this situation." ; |
221 | } |
222 | ReferenceConvFunctor<T1, T2, T3> conv_functor; |
223 | conv_functor(context, input_data, input_batches, input_height, |
224 | input_width, input_depth, input_offset, filter_data, |
225 | filter_height, filter_width, filter_count, filter_offset, |
226 | stride, padding, output_data, output_height, output_width, |
227 | output_shift, output_offset, output_mult); |
228 | return; |
229 | } |
230 | |
231 | OP_REQUIRES( |
232 | context, output_width > 0, |
233 | errors::InvalidArgument("output_width must be strictly positive" )); |
234 | OP_REQUIRES( |
235 | context, output_height > 0, |
236 | errors::InvalidArgument("output_height must be strictly positive" )); |
237 | int filter_left_offset; |
238 | int filter_top_offset; |
239 | if (padding == VALID) { |
240 | filter_left_offset = |
241 | ((output_width - 1) * stride + filter_width - input_width + 1) / 2; |
242 | filter_top_offset = |
243 | ((output_height - 1) * stride + filter_height - input_height + 1) / 2; |
244 | } else { |
245 | filter_left_offset = |
246 | ((output_width - 1) * stride + filter_width - input_width) / 2; |
247 | filter_top_offset = |
248 | ((output_height - 1) * stride + filter_height - input_height) / 2; |
249 | } |
250 | |
251 | // The im2col buffer has # of patches rows, and # of filters cols. |
252 | // It's laid out like this, in row major order in memory: |
253 | // < filter value count > |
254 | // ^ +---------------------+ |
255 | // patch | | |
256 | // count | | |
257 | // v +---------------------+ |
258 | // Each patch row contains a filter_width x filter_height patch of the |
259 | // input, with the depth channel as the most contiguous in memory, followed |
260 | // by the width, then the height. This is the standard memory order in the |
261 | // image world if it helps to visualize it. |
262 | const int filter_value_count = filter_width * filter_height * input_depth; |
263 | OP_REQUIRES(context, filter_value_count > 0, |
264 | errors::InvalidArgument( |
265 | "filter patch must contain at least one element" )); |
266 | const int64_t patches_per_chunk = |
267 | kMaxChunkSize / (filter_value_count * sizeof(T1)); |
268 | const int64_t chunk_value_count = |
269 | (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); |
270 | // TODO(petewarden) - Memory allocation can be very slow on Android. Can we |
271 | // optimize this by keeping the scratch buffer around? |
272 | // Because memory allocation is very expensive on mobile platforms, try to |
273 | // allocate a persistent buffer that will be kept around between calls. We |
274 | // use TensorFlow's resource management to ensure that the memory will be |
275 | // released when the session is over. |
276 | Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource; |
277 | std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)> |
278 | creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) { |
279 | #ifdef _MSC_VER |
280 | // MSVC complains about the capture of chunk_value_count which oddly |
281 | // works fine in conv_ops_using_gemm.cc for example. |
282 | // Define chunk_value_count inside the lambda for now. |
283 | const int64 chunk_value_count = |
284 | (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); |
285 | #endif |
286 | *resource = new Im2ColBufferResource<T1, chunk_value_count>(); |
287 | return OkStatus(); |
288 | }; |
289 | OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( |
290 | "Conv2d" , "im2col_buffer" , |
291 | &im2col_buffer_resource, creator)); |
292 | // This means that multiple ops can't be run simultaneously on different |
293 | // threads, because we have a single shared resource. The platforms this is |
294 | // aimed at have intra-op parallelism as their focus though, so it shouldn't |
295 | // be an issue. |
296 | mutex_lock lock_buffer(im2col_buffer_resource->mu); |
297 | core::ScopedUnref unref_buffer(im2col_buffer_resource); |
298 | T1* im2col_buffer = im2col_buffer_resource->data; |
299 | |
300 | const int64_t patch_count = (input_batches * output_height * output_width); |
301 | const int64_t chunk_count = |
302 | (patch_count + (patches_per_chunk - 1)) / patches_per_chunk; |
303 | |
304 | for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { |
305 | const int64_t patch_index_start = chunk_index * patches_per_chunk; |
306 | const int64_t patch_index_end = |
307 | std::min(patch_index_start + patches_per_chunk, patch_count); |
308 | for (int64_t patch_index = patch_index_start; |
309 | patch_index < patch_index_end; ++patch_index) { |
310 | const int64_t batch = patch_index / (output_height * output_width); |
311 | const int64_t out_y = (patch_index / output_width) % output_height; |
312 | const int64_t out_x = patch_index % output_width; |
313 | const T1* input_batch_start = |
314 | input_data + (batch * input_height * input_width * input_depth); |
315 | const int in_y_origin = (out_y * stride) - filter_top_offset; |
316 | const int in_x_origin = (out_x * stride) - filter_left_offset; |
317 | const int patch_index_within_chunk = patch_index % patches_per_chunk; |
318 | T1* im2col_patch_start = |
319 | im2col_buffer + (patch_index_within_chunk * filter_value_count); |
320 | for (int filter_y = 0; filter_y < filter_height; ++filter_y) { |
321 | const int in_y = in_y_origin + filter_y; |
322 | T1* im2col_row_start = |
323 | im2col_patch_start + (filter_y * filter_width * input_depth); |
324 | // If we're off the top or the bottom of the input, fill the |
325 | // whole row with zeroes. |
326 | if ((in_y < 0) || (in_y >= input_height)) { |
327 | // On Android, memset and memcpy are significantly faster than the |
328 | // more modern std::set and std::copy equivalents. |
329 | memset(im2col_row_start, input_offset, |
330 | (filter_width * input_depth)); |
331 | } else { |
332 | // What we're doing here is trying to copy and fill the im2col |
333 | // buffer as efficiently as possible, using functions to set or |
334 | // duplicate values en masse. We know we don't have to worry about |
335 | // vertical edges because we dealt with that case above, so we |
336 | // just need to handle filters that overlap the left or right |
337 | // edges. Here's what that looks like: |
338 | // |
339 | // < left_zero_count > < center_copy_count > < right_zero_count > |
340 | // +------------------+---------------------+--------------------+ |
341 | // | (filter) | (image) | (filter) | |
342 | // +------------------+---------------------+--------------------+ |
343 | // in_x_origin 0 input_width in_x_end |
344 | // |
345 | // In reality it's unlikely that a filter patch will be wider |
346 | // than an input, but this shows all the edge cases. |
347 | // We use memset() to set the left and right sections to zeroes |
348 | // and memcpy() to copy over the input data for the center. These |
349 | // are preferred to std::fill and std::copy because they're much |
350 | // faster on Android. |
351 | const int in_x_end = in_x_origin + filter_width; |
352 | const int left_zero_count = std::max(0, 0 - in_x_origin); |
353 | const int right_zero_count = std::max(0, in_x_end - input_width); |
354 | const int center_copy_count = |
355 | filter_width - (left_zero_count + right_zero_count); |
356 | if (left_zero_count > 0) { |
357 | T1* im2col_left_start = im2col_row_start; |
358 | memset(im2col_left_start, input_offset, |
359 | (left_zero_count * input_depth)); |
360 | } |
361 | if (center_copy_count > 0) { |
362 | const T1* input_row_start = |
363 | input_batch_start + (in_y * input_width * input_depth) + |
364 | (std::max(0, in_x_origin) * input_depth); |
365 | T1* im2col_center_start = |
366 | im2col_row_start + (left_zero_count * input_depth); |
367 | memcpy(im2col_center_start, input_row_start, |
368 | (center_copy_count * input_depth)); |
369 | } |
370 | if (right_zero_count > 0) { |
371 | T1* im2col_right_start = |
372 | im2col_row_start + |
373 | ((left_zero_count + center_copy_count) * input_depth); |
374 | memset(im2col_right_start, input_offset, |
375 | (right_zero_count * input_depth)); |
376 | } |
377 | } |
378 | } |
379 | } |
380 | // Now we've assembled a set of image patches into a matrix, apply a |
381 | // GEMM matrix multiply of the patches as rows, times the filter |
382 | // weights in columns, to get partial results in the output matrix. |
383 | const int how_many_patches = patch_index_end - patch_index_start; |
384 | const bool transpose_a = false; |
385 | const bool transpose_b = false; |
386 | const bool transpose_c = false; |
387 | const int m = how_many_patches; |
388 | const int n = filter_count; |
389 | const int k = filter_value_count; |
390 | const int lda = filter_value_count; |
391 | const int ldb = filter_count; |
392 | const int ldc = filter_count; |
393 | T3* chunk_output_data = output_data + (patch_index_start * filter_count); |
394 | |
395 | if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() && |
396 | std::is_same<T2, quint8>() && std::is_same<T3, qint32>() && |
397 | (output_offset == 0) && (output_mult == 1) && (output_shift == 0) && |
398 | (transpose_c == false) && (k <= 2048)) { |
399 | meta::QuantizedGemm(context, transpose_a, transpose_b, im2col_buffer, |
400 | filter_data, chunk_output_data, m, n, k, |
401 | -input_offset, -filter_offset, lda, ldb, ldc); |
402 | } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && |
403 | std::is_same<T3, qint32>() && (output_offset == 0) && |
404 | (output_mult == 1) && (output_shift == 0)) { |
405 | // The gemmlowp optimized library only works for a particular set of |
406 | // data types, so check if we meet those requirements and fall back to a |
407 | // slower reference implementation if not. |
408 | const uint8* im2col_data_as_uint8 = &(im2col_buffer->value); |
409 | const uint8* filter_data_as_uint8 = &(filter_data->value); |
410 | int32* output_data_as_int32 = &(chunk_output_data->value); |
411 | // All of the transpose_* variables are currently compile-time consts, |
412 | // so we could just hard-code these values too, but that would break if |
413 | // anybody changed those values in the future (e.g. to match the ability |
414 | // of MatMul to specify them as attributes). We're using a verbose |
415 | // approach of deriving the order values from the transpose variables to |
416 | // be able to catch any changes like that. |
417 | static const gemmlowp::MapOrder ResultOrder = |
418 | !transpose_c ? gemmlowp::MapOrder::RowMajor |
419 | : gemmlowp::MapOrder::ColMajor; |
420 | static const gemmlowp::MapOrder LhsOrder = |
421 | !transpose_a ? gemmlowp::MapOrder::RowMajor |
422 | : gemmlowp::MapOrder::ColMajor; |
423 | static const gemmlowp::MapOrder RhsOrder = |
424 | !transpose_b ? gemmlowp::MapOrder::RowMajor |
425 | : gemmlowp::MapOrder::ColMajor; |
426 | gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs( |
427 | im2col_data_as_uint8, m, k, lda); |
428 | gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs( |
429 | filter_data_as_uint8, k, n, ldb); |
430 | gemmlowp::MatrixMap<std::int32_t, ResultOrder> result( |
431 | output_data_as_int32, m, n, ldc); |
432 | const std::tuple<> empty_pipeline = {}; |
433 | |
434 | auto& worker_threads = |
435 | *(context->device()->tensorflow_cpu_worker_threads()); |
436 | TensorflowGemmContext context(worker_threads.num_threads, |
437 | worker_threads.workers); |
438 | gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, |
439 | gemmlowp::DefaultL8R8BitDepthParams>( |
440 | &context, lhs, rhs, &result, -input_offset, -filter_offset, |
441 | empty_pipeline); |
442 | // Since gemmlowp uses assembly to write to the output, msan won't |
443 | // detect the output buffer as written to, so we mark it manually. |
444 | TF_ANNOTATE_MEMORY_IS_INITIALIZED(output_data_as_int32, |
445 | m * n * sizeof(int32)); |
446 | } else { |
447 | ReferenceGemm<T1, T2, T3>( |
448 | transpose_a, transpose_b, transpose_c, m, n, k, im2col_buffer, |
449 | input_offset, lda, filter_data, filter_offset, ldb, |
450 | chunk_output_data, output_shift, output_offset, output_mult, ldc); |
451 | } |
452 | } |
453 | } |
454 | }; |
455 | |
456 | template <class T1, class T2, class T3, |
457 | template <class TF1, class TF2, class TF3> class ConvFunctor> |
458 | class QuantizedConv2DOp : public OpKernel { |
459 | public: |
460 | explicit QuantizedConv2DOp(OpKernelConstruction* context) |
461 | : OpKernel(context) { |
462 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &strides_)); |
463 | OP_REQUIRES(context, strides_.size() == 4, |
464 | errors::InvalidArgument("Sliding window strides field must " |
465 | "specify 4 dimensions" )); |
466 | OP_REQUIRES(context, strides_[1] == strides_[2], |
467 | errors::InvalidArgument( |
468 | "Current implementation only supports equal length " |
469 | "strides in the row and column dimensions." )); |
470 | OP_REQUIRES( |
471 | context, (strides_[0] == 1 && strides_[3] == 1), |
472 | errors::InvalidArgument("Current implementation does not yet support " |
473 | "strides in the batch and depth dimensions." )); |
474 | std::vector<int32> dilations; |
475 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilations)); |
476 | OP_REQUIRES(context, dilations.size() == 4, |
477 | errors::InvalidArgument("Dilations field must " |
478 | "specify 4 dimensions" )); |
479 | OP_REQUIRES(context, dilations[1] == 1 && dilations[2] == 1, |
480 | errors::InvalidArgument( |
481 | "Current implementation only supports dilated rate as 1 " |
482 | "in the row and column dimensions." )); |
483 | OP_REQUIRES(context, (dilations[0] == 1 && dilations[3] == 1), |
484 | errors::InvalidArgument( |
485 | "Current implementation does not yet support " |
486 | "dilations in the batch and depth dimensions." )); |
487 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
488 | } |
489 | |
490 | void Compute(OpKernelContext* context) override { |
491 | // Input tensor is of the following dimensions: |
492 | // [ batch, in_rows, in_cols, in_depth ] |
493 | const Tensor& input = context->input(0); |
494 | |
495 | // Input filter is of the following dimensions: |
496 | // [ filter_rows, filter_cols, in_depth, out_depth] |
497 | const Tensor& filter = context->input(1); |
498 | |
499 | // For 2D convolution, there should be 4 dimensions. |
500 | OP_REQUIRES(context, input.dims() == 4, |
501 | errors::InvalidArgument("input must be rank 4 but is rank " , |
502 | input.shape().dims())); |
503 | OP_REQUIRES(context, filter.dims() == 4, |
504 | errors::InvalidArgument("filter must be rank 4 but is rank " , |
505 | filter.shape().dims())); |
506 | |
507 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(2).shape()), |
508 | errors::InvalidArgument("min_input must be rank 0 but is rank " , |
509 | context->input(2).shape().dims())); |
510 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(3).shape()), |
511 | errors::InvalidArgument("max_input must be rank 0 but is rank " , |
512 | context->input(3).shape().dims())); |
513 | OP_REQUIRES( |
514 | context, TensorShapeUtils::IsScalar(context->input(4).shape()), |
515 | errors::InvalidArgument("min_filter must be rank 0 but is rank " , |
516 | context->input(4).shape().dims())); |
517 | OP_REQUIRES( |
518 | context, TensorShapeUtils::IsScalar(context->input(5).shape()), |
519 | errors::InvalidArgument("max_filter must be rank 0 but is rank " , |
520 | context->input(5).shape().dims())); |
521 | |
522 | const float min_input = context->input(2).flat<float>()(0); |
523 | const float max_input = context->input(3).flat<float>()(0); |
524 | const float min_filter = context->input(4).flat<float>()(0); |
525 | const float max_filter = context->input(5).flat<float>()(0); |
526 | const int32_t offset_input = |
527 | FloatToQuantizedUnclamped<T1>(0.0f, min_input, max_input); |
528 | const int32_t offset_filter = |
529 | FloatToQuantizedUnclamped<T2>(0.0f, min_filter, max_filter); |
530 | const int32_t offset_output = 0; |
531 | const int32_t mult_output = 1; |
532 | const int32_t shift_output = 0; |
533 | |
534 | // The last dimension for input is in_depth. It must be the same as the |
535 | // filter's in_depth. |
536 | const int64_t in_depth = input.dim_size(3); |
537 | OP_REQUIRES(context, in_depth == filter.dim_size(2), |
538 | errors::InvalidArgument( |
539 | "input and filter must have the same depth: " , in_depth, |
540 | " vs " , filter.dim_size(2))); |
541 | |
542 | // The last dimension for filter is out_depth. |
543 | const int64_t out_depth = filter.dim_size(3); |
544 | |
545 | // The second dimension for input is rows/height. |
546 | // The first dimension for filter is rows/height. |
547 | const int64_t input_rows = input.dim_size(1); |
548 | const int64_t filter_rows = filter.dim_size(0); |
549 | |
550 | // The third dimension for input is columns/width. |
551 | // The second dimension for filter is columns/width. |
552 | const int64_t input_cols = input.dim_size(2); |
553 | const int64_t filter_cols = filter.dim_size(1); |
554 | |
555 | // The first dimension for input is batch. |
556 | const int64_t batch = input.dim_size(0); |
557 | |
558 | // For now we take the stride from the second dimension only (we |
559 | // assume row = col stride, and do not support striding on the |
560 | // batch or depth dimension). |
561 | const int stride = strides_[1]; |
562 | |
563 | int64_t out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; |
564 | OP_REQUIRES_OK(context, |
565 | GetWindowedOutputSize(input_rows, filter_rows, stride, |
566 | padding_, &out_rows, &pad_rows)); |
567 | OP_REQUIRES_OK(context, |
568 | GetWindowedOutputSize(input_cols, filter_cols, stride, |
569 | padding_, &out_cols, &pad_cols)); |
570 | CHECK_GT(batch, 0); |
571 | CHECK_GT(out_rows, 0); |
572 | CHECK_GT(out_cols, 0); |
573 | CHECK_GT(out_depth, 0); |
574 | TensorShape out_shape({batch, out_rows, out_cols, out_depth}); |
575 | |
576 | // Output tensor is of the following dimensions: |
577 | // [ in_batch, out_rows, out_cols, out_depth ] |
578 | Tensor* output = nullptr; |
579 | OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); |
580 | |
581 | // This will call different implementations (e.g. reference or optimized) |
582 | // depending on the template parameter. |
583 | ConvFunctor<T1, T2, T3> conv_functor; |
584 | conv_functor(context, input.flat<T1>().data(), batch, input_rows, |
585 | input_cols, in_depth, offset_input, filter.flat<T2>().data(), |
586 | filter_rows, filter_cols, out_depth, offset_filter, stride, |
587 | padding_, output->flat<T3>().data(), out_rows, out_cols, |
588 | shift_output, offset_output, mult_output); |
589 | |
590 | float min_output_value; |
591 | float max_output_value; |
592 | QuantizationRangeForMultiplication<T1, T2, T3>( |
593 | min_input, max_input, min_filter, max_filter, &min_output_value, |
594 | &max_output_value); |
595 | |
596 | Tensor* output_min = nullptr; |
597 | OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min)); |
598 | output_min->flat<float>()(0) = min_output_value; |
599 | |
600 | Tensor* output_max = nullptr; |
601 | OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max)); |
602 | output_max->flat<float>()(0) = max_output_value; |
603 | } |
604 | |
605 | private: |
606 | std::vector<int32> strides_; |
607 | Padding padding_; |
608 | }; |
609 | |
610 | // Right now we only support taking two eight bit inputs, and returning the |
611 | // results as signed 32-bit integers. |
612 | REGISTER_KERNEL_BUILDER( |
613 | Name("QuantizedConv2D" ) |
614 | .Device(DEVICE_CPU) |
615 | .TypeConstraint<quint8>("Tinput" ) |
616 | .TypeConstraint<quint8>("Tfilter" ) |
617 | .TypeConstraint<qint32>("out_type" ), |
618 | QuantizedConv2DOp<quint8, quint8, qint32, Im2ColConvFunctor>); |
619 | |
620 | } // namespace tensorflow |
621 | |