1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implements quantized eight-bit versions of the convolution operations.
17
18#include <algorithm>
19#include <vector>
20
21#define EIGEN_USE_THREADS
22
23#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
24#include "public/gemmlowp.h"
25#include "tensorflow/core/framework/kernel_shape_util.h"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/kernels/conv_ops.h"
29#include "tensorflow/core/kernels/meta_support.h"
30#include "tensorflow/core/kernels/quantization_utils.h"
31#include "tensorflow/core/kernels/reference_gemm.h"
32#include "tensorflow/core/lib/core/errors.h"
33#include "tensorflow/core/platform/errors.h"
34#include "tensorflow/core/util/padding.h"
35
36namespace tensorflow {
37
38// This functor implements the convolution operation in as simple a form as
39// possible. It won't give great performance, but it is very useful for
40// stepping through and instrumenting for debugging, creating minimal benchmarks
41// to prototype with, and sharing with teams that want to run this outside of
42// our environment.
43// With that in mind, I've avoided using anything except pretty standard C++
44// types. This is especially noticeable in the data access through raw array
45// indexing. It's deliberate in this case though, since it makes the underlying
46// memory order very explicit, which is important for both inspecting memory
47// contents during debugging and for specifying what we expect to others.
48// The memory layout of the data is, from biggest stride to smallest:
49// input_data = [input_batches, input_height, input_width, input_depth]
50// filter_data = [filter_height, filter_width, input_depth, filter_count]
51// output_data = [input_batches, output_height, output_width, filter_count]
52template <class T1, class T2, class T3>
53class ReferenceConvFunctor {
54 public:
55 void operator()(OpKernelContext* context, const T1* input_data,
56 int input_batches, int input_height, int input_width,
57 int input_depth, int input_offset, const T2* filter_data,
58 int filter_height, int filter_width, int filter_count,
59 int filter_offset, int stride, Padding padding,
60 T3* output_data, int output_height, int output_width,
61 int output_shift, int output_offset, int output_mult) {
62 // Set up some constants we need for the output down-shifting and
63 // saturation.
64 const int32_t highest = static_cast<int32>(Eigen::NumTraits<T3>::highest());
65 const int32_t lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest());
66
67 // When we're converting the 32 bit accumulator to a lower bit depth, we
68 // need to add on 0.5 in fixed-point terms to make the operation round half
69 // up towards positive infinity, rather than a floor.
70 // We also need to watch out for the case when there's no down shift,
71 // because a left shift by a negative number gives undefined results.
72 const int32_t rounding = (output_shift < 1) ? 0 : (1 << (output_shift - 1));
73
74 // The two different padding modes we support can be a bit confusing. SAME
75 // means we're trying to produce an output image that's the same size as the
76 // input. It's complicated by stride, which shrinks the output image by a
77 // a factor, but it means we end up sampling from outside the borders of the
78 // input. These out-of-bounds values are read as zeroes. VALID means only
79 // produce output values where the filters can read all their values from
80 // within the input image. It effectively removes the margins of the output
81 // image compared to the one produced by SAME. Stride complicates this
82 // definition though, because it can result in the right and bottom filter
83 // patches sampling from outside the borders if it's greater than 1.
84 // Most of the logic for sorting this all out is done before this function,
85 // when we calculate the output size, but the positioning of the origin of
86 // the filters is different between the two modes, since SAME positions the
87 // first filter off the edge of the input.
88 int filter_left_offset;
89 int filter_top_offset;
90 if (padding == VALID) {
91 filter_left_offset =
92 ((output_width - 1) * stride + filter_width - input_width + 1) / 2;
93 filter_top_offset =
94 ((output_height - 1) * stride + filter_height - input_height + 1) / 2;
95 } else {
96 filter_left_offset =
97 ((output_width - 1) * stride + filter_width - input_width) / 2;
98 filter_top_offset =
99 ((output_height - 1) * stride + filter_height - input_height) / 2;
100 }
101
102 // If we've got multiple images in our input, work through each of them.
103 for (int batch = 0; batch < input_batches; ++batch) {
104 // Walk through all the output image values, sliding the filter to
105 // different
106 // positions in the input.
107 for (int out_y = 0; out_y < output_height; ++out_y) {
108 for (int out_x = 0; out_x < output_width; ++out_x) {
109 // Each filter kernel produces one output channel.
110 for (int out_channel = 0; out_channel < filter_count; ++out_channel) {
111 // We're going to calculate a single output value, which means we
112 // need to multiply a three dimensional kernel of weights against
113 // the current location within the input image.
114 /*
115 *-------------------------------...
116 |\ ^
117 | \in_depth
118 | \ v
119 | *-------------------------------...
120 | | ^
121 | | in_y_origin
122 | | v \
123 | |<in_x_origin>*---*^
124 | | \| |filter_height
125 . | *---*v
126 . | <--->
127 . filter_width
128 .
129 */
130 const int in_x_origin = (out_x * stride) - filter_left_offset;
131 const int in_y_origin = (out_y * stride) - filter_top_offset;
132 int32_t total = 0;
133 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
134 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
135 for (int in_channel = 0; in_channel < input_depth;
136 ++in_channel) {
137 const int in_x = in_x_origin + filter_x;
138 const int in_y = in_y_origin + filter_y;
139 int32_t input_value;
140 // If the location is outside the bounds of the input image,
141 // use zero as a default value.
142 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
143 (in_y < input_height)) {
144 const T1 input_source_value =
145 input_data[(batch * input_height * input_width *
146 input_depth) +
147 (in_y * input_width * input_depth) +
148 (in_x * input_depth) + in_channel];
149 // We're promoting the T1 type to a higher bit depth here as
150 // we do the subtraction.
151 input_value =
152 static_cast<int32>(input_source_value) - input_offset;
153 } else {
154 input_value = 0;
155 }
156 const T2 filter_source_value =
157 filter_data[(filter_y * filter_width * input_depth *
158 filter_count) +
159 (filter_x * input_depth * filter_count) +
160 (in_channel * filter_count) + out_channel];
161 // Another promotion to 32 bit, as above.
162 const int32_t filter_value =
163 static_cast<int32>(filter_source_value) - filter_offset;
164 total += (input_value * filter_value);
165 }
166 }
167 }
168 // Here we're applying scale factors to compress the 32 bit
169 // accumulated total to a potentially lower bit depth.
170 const int32_t output =
171 ((((total + output_offset) * output_mult) + rounding) >>
172 output_shift);
173 // We need to saturate the results against the largest and smallest
174 // values that can be represented in this type.
175 const int32_t top_clamped_output = std::min(output, highest);
176 const int32_t clamped_output = std::max(top_clamped_output, lowest);
177 output_data[(batch * output_height * output_width * filter_count) +
178 (out_y * output_width * filter_count) +
179 (out_x * filter_count) + out_channel] = clamped_output;
180 }
181 }
182 }
183 }
184 }
185};
186
187// We don't want to allocate a buffer to hold all the patches if the size is
188// going to be extremely large, so break it into chunks if it's bigger than
189// a limit. Each chunk will be processed serially, so we can refill the
190// buffer for the next chunk and reuse it, keeping maximum memory size down.
191// In this case, we've picked 1 megabyte as a reasonable limit, from
192// experimentation.
193const size_t kMaxChunkSize = (1 * 1024 * 1024);
194
195// Implements convolution as a two stage process, first packing the patches of
196// the input image into columns (im2col) and then running GEMM to produce the
197// final result.
198template <class T1, class T2, class T3>
199class Im2ColConvFunctor {
200 public:
201 void operator()(OpKernelContext* context, const T1* input_data,
202 int input_batches, int input_height, int input_width,
203 int input_depth, int input_offset, const T2* filter_data,
204 int filter_height, int filter_width, int filter_count,
205 int filter_offset, int stride, Padding padding,
206 T3* output_data, int output_height, int output_width,
207 int output_shift, int output_offset, int output_mult) {
208 if (input_offset < 0) {
209 // Only log the first few occurrences of this warning.
210 static int warning_count = 0;
211 if (warning_count < 10) {
212 ++warning_count;
213 LOG(WARNING)
214 << "For kernel '" << context->op_kernel().name() << "' from input '"
215 << context->op_kernel().requested_input(0)
216 << "': Zero is not representable in the quantized range used by the"
217 << " input. This means QuantizedConv2d has to fall back to a slow"
218 << " implementation, since the border of zero values can't be"
219 << " represented easily. You should try to construct graphs that"
220 << " avoid this situation.";
221 }
222 ReferenceConvFunctor<T1, T2, T3> conv_functor;
223 conv_functor(context, input_data, input_batches, input_height,
224 input_width, input_depth, input_offset, filter_data,
225 filter_height, filter_width, filter_count, filter_offset,
226 stride, padding, output_data, output_height, output_width,
227 output_shift, output_offset, output_mult);
228 return;
229 }
230
231 OP_REQUIRES(
232 context, output_width > 0,
233 errors::InvalidArgument("output_width must be strictly positive"));
234 OP_REQUIRES(
235 context, output_height > 0,
236 errors::InvalidArgument("output_height must be strictly positive"));
237 int filter_left_offset;
238 int filter_top_offset;
239 if (padding == VALID) {
240 filter_left_offset =
241 ((output_width - 1) * stride + filter_width - input_width + 1) / 2;
242 filter_top_offset =
243 ((output_height - 1) * stride + filter_height - input_height + 1) / 2;
244 } else {
245 filter_left_offset =
246 ((output_width - 1) * stride + filter_width - input_width) / 2;
247 filter_top_offset =
248 ((output_height - 1) * stride + filter_height - input_height) / 2;
249 }
250
251 // The im2col buffer has # of patches rows, and # of filters cols.
252 // It's laid out like this, in row major order in memory:
253 // < filter value count >
254 // ^ +---------------------+
255 // patch | |
256 // count | |
257 // v +---------------------+
258 // Each patch row contains a filter_width x filter_height patch of the
259 // input, with the depth channel as the most contiguous in memory, followed
260 // by the width, then the height. This is the standard memory order in the
261 // image world if it helps to visualize it.
262 const int filter_value_count = filter_width * filter_height * input_depth;
263 OP_REQUIRES(context, filter_value_count > 0,
264 errors::InvalidArgument(
265 "filter patch must contain at least one element"));
266 const int64_t patches_per_chunk =
267 kMaxChunkSize / (filter_value_count * sizeof(T1));
268 const int64_t chunk_value_count =
269 (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1);
270 // TODO(petewarden) - Memory allocation can be very slow on Android. Can we
271 // optimize this by keeping the scratch buffer around?
272 // Because memory allocation is very expensive on mobile platforms, try to
273 // allocate a persistent buffer that will be kept around between calls. We
274 // use TensorFlow's resource management to ensure that the memory will be
275 // released when the session is over.
276 Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource;
277 std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)>
278 creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) {
279#ifdef _MSC_VER
280 // MSVC complains about the capture of chunk_value_count which oddly
281 // works fine in conv_ops_using_gemm.cc for example.
282 // Define chunk_value_count inside the lambda for now.
283 const int64 chunk_value_count =
284 (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1);
285#endif
286 *resource = new Im2ColBufferResource<T1, chunk_value_count>();
287 return OkStatus();
288 };
289 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
290 "Conv2d", "im2col_buffer",
291 &im2col_buffer_resource, creator));
292 // This means that multiple ops can't be run simultaneously on different
293 // threads, because we have a single shared resource. The platforms this is
294 // aimed at have intra-op parallelism as their focus though, so it shouldn't
295 // be an issue.
296 mutex_lock lock_buffer(im2col_buffer_resource->mu);
297 core::ScopedUnref unref_buffer(im2col_buffer_resource);
298 T1* im2col_buffer = im2col_buffer_resource->data;
299
300 const int64_t patch_count = (input_batches * output_height * output_width);
301 const int64_t chunk_count =
302 (patch_count + (patches_per_chunk - 1)) / patches_per_chunk;
303
304 for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) {
305 const int64_t patch_index_start = chunk_index * patches_per_chunk;
306 const int64_t patch_index_end =
307 std::min(patch_index_start + patches_per_chunk, patch_count);
308 for (int64_t patch_index = patch_index_start;
309 patch_index < patch_index_end; ++patch_index) {
310 const int64_t batch = patch_index / (output_height * output_width);
311 const int64_t out_y = (patch_index / output_width) % output_height;
312 const int64_t out_x = patch_index % output_width;
313 const T1* input_batch_start =
314 input_data + (batch * input_height * input_width * input_depth);
315 const int in_y_origin = (out_y * stride) - filter_top_offset;
316 const int in_x_origin = (out_x * stride) - filter_left_offset;
317 const int patch_index_within_chunk = patch_index % patches_per_chunk;
318 T1* im2col_patch_start =
319 im2col_buffer + (patch_index_within_chunk * filter_value_count);
320 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
321 const int in_y = in_y_origin + filter_y;
322 T1* im2col_row_start =
323 im2col_patch_start + (filter_y * filter_width * input_depth);
324 // If we're off the top or the bottom of the input, fill the
325 // whole row with zeroes.
326 if ((in_y < 0) || (in_y >= input_height)) {
327 // On Android, memset and memcpy are significantly faster than the
328 // more modern std::set and std::copy equivalents.
329 memset(im2col_row_start, input_offset,
330 (filter_width * input_depth));
331 } else {
332 // What we're doing here is trying to copy and fill the im2col
333 // buffer as efficiently as possible, using functions to set or
334 // duplicate values en masse. We know we don't have to worry about
335 // vertical edges because we dealt with that case above, so we
336 // just need to handle filters that overlap the left or right
337 // edges. Here's what that looks like:
338 //
339 // < left_zero_count > < center_copy_count > < right_zero_count >
340 // +------------------+---------------------+--------------------+
341 // | (filter) | (image) | (filter) |
342 // +------------------+---------------------+--------------------+
343 // in_x_origin 0 input_width in_x_end
344 //
345 // In reality it's unlikely that a filter patch will be wider
346 // than an input, but this shows all the edge cases.
347 // We use memset() to set the left and right sections to zeroes
348 // and memcpy() to copy over the input data for the center. These
349 // are preferred to std::fill and std::copy because they're much
350 // faster on Android.
351 const int in_x_end = in_x_origin + filter_width;
352 const int left_zero_count = std::max(0, 0 - in_x_origin);
353 const int right_zero_count = std::max(0, in_x_end - input_width);
354 const int center_copy_count =
355 filter_width - (left_zero_count + right_zero_count);
356 if (left_zero_count > 0) {
357 T1* im2col_left_start = im2col_row_start;
358 memset(im2col_left_start, input_offset,
359 (left_zero_count * input_depth));
360 }
361 if (center_copy_count > 0) {
362 const T1* input_row_start =
363 input_batch_start + (in_y * input_width * input_depth) +
364 (std::max(0, in_x_origin) * input_depth);
365 T1* im2col_center_start =
366 im2col_row_start + (left_zero_count * input_depth);
367 memcpy(im2col_center_start, input_row_start,
368 (center_copy_count * input_depth));
369 }
370 if (right_zero_count > 0) {
371 T1* im2col_right_start =
372 im2col_row_start +
373 ((left_zero_count + center_copy_count) * input_depth);
374 memset(im2col_right_start, input_offset,
375 (right_zero_count * input_depth));
376 }
377 }
378 }
379 }
380 // Now we've assembled a set of image patches into a matrix, apply a
381 // GEMM matrix multiply of the patches as rows, times the filter
382 // weights in columns, to get partial results in the output matrix.
383 const int how_many_patches = patch_index_end - patch_index_start;
384 const bool transpose_a = false;
385 const bool transpose_b = false;
386 const bool transpose_c = false;
387 const int m = how_many_patches;
388 const int n = filter_count;
389 const int k = filter_value_count;
390 const int lda = filter_value_count;
391 const int ldb = filter_count;
392 const int ldc = filter_count;
393 T3* chunk_output_data = output_data + (patch_index_start * filter_count);
394
395 if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
396 std::is_same<T2, quint8>() && std::is_same<T3, qint32>() &&
397 (output_offset == 0) && (output_mult == 1) && (output_shift == 0) &&
398 (transpose_c == false) && (k <= 2048)) {
399 meta::QuantizedGemm(context, transpose_a, transpose_b, im2col_buffer,
400 filter_data, chunk_output_data, m, n, k,
401 -input_offset, -filter_offset, lda, ldb, ldc);
402 } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
403 std::is_same<T3, qint32>() && (output_offset == 0) &&
404 (output_mult == 1) && (output_shift == 0)) {
405 // The gemmlowp optimized library only works for a particular set of
406 // data types, so check if we meet those requirements and fall back to a
407 // slower reference implementation if not.
408 const uint8* im2col_data_as_uint8 = &(im2col_buffer->value);
409 const uint8* filter_data_as_uint8 = &(filter_data->value);
410 int32* output_data_as_int32 = &(chunk_output_data->value);
411 // All of the transpose_* variables are currently compile-time consts,
412 // so we could just hard-code these values too, but that would break if
413 // anybody changed those values in the future (e.g. to match the ability
414 // of MatMul to specify them as attributes). We're using a verbose
415 // approach of deriving the order values from the transpose variables to
416 // be able to catch any changes like that.
417 static const gemmlowp::MapOrder ResultOrder =
418 !transpose_c ? gemmlowp::MapOrder::RowMajor
419 : gemmlowp::MapOrder::ColMajor;
420 static const gemmlowp::MapOrder LhsOrder =
421 !transpose_a ? gemmlowp::MapOrder::RowMajor
422 : gemmlowp::MapOrder::ColMajor;
423 static const gemmlowp::MapOrder RhsOrder =
424 !transpose_b ? gemmlowp::MapOrder::RowMajor
425 : gemmlowp::MapOrder::ColMajor;
426 gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(
427 im2col_data_as_uint8, m, k, lda);
428 gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(
429 filter_data_as_uint8, k, n, ldb);
430 gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(
431 output_data_as_int32, m, n, ldc);
432 const std::tuple<> empty_pipeline = {};
433
434 auto& worker_threads =
435 *(context->device()->tensorflow_cpu_worker_threads());
436 TensorflowGemmContext context(worker_threads.num_threads,
437 worker_threads.workers);
438 gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
439 gemmlowp::DefaultL8R8BitDepthParams>(
440 &context, lhs, rhs, &result, -input_offset, -filter_offset,
441 empty_pipeline);
442 // Since gemmlowp uses assembly to write to the output, msan won't
443 // detect the output buffer as written to, so we mark it manually.
444 TF_ANNOTATE_MEMORY_IS_INITIALIZED(output_data_as_int32,
445 m * n * sizeof(int32));
446 } else {
447 ReferenceGemm<T1, T2, T3>(
448 transpose_a, transpose_b, transpose_c, m, n, k, im2col_buffer,
449 input_offset, lda, filter_data, filter_offset, ldb,
450 chunk_output_data, output_shift, output_offset, output_mult, ldc);
451 }
452 }
453 }
454};
455
456template <class T1, class T2, class T3,
457 template <class TF1, class TF2, class TF3> class ConvFunctor>
458class QuantizedConv2DOp : public OpKernel {
459 public:
460 explicit QuantizedConv2DOp(OpKernelConstruction* context)
461 : OpKernel(context) {
462 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
463 OP_REQUIRES(context, strides_.size() == 4,
464 errors::InvalidArgument("Sliding window strides field must "
465 "specify 4 dimensions"));
466 OP_REQUIRES(context, strides_[1] == strides_[2],
467 errors::InvalidArgument(
468 "Current implementation only supports equal length "
469 "strides in the row and column dimensions."));
470 OP_REQUIRES(
471 context, (strides_[0] == 1 && strides_[3] == 1),
472 errors::InvalidArgument("Current implementation does not yet support "
473 "strides in the batch and depth dimensions."));
474 std::vector<int32> dilations;
475 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations));
476 OP_REQUIRES(context, dilations.size() == 4,
477 errors::InvalidArgument("Dilations field must "
478 "specify 4 dimensions"));
479 OP_REQUIRES(context, dilations[1] == 1 && dilations[2] == 1,
480 errors::InvalidArgument(
481 "Current implementation only supports dilated rate as 1 "
482 "in the row and column dimensions."));
483 OP_REQUIRES(context, (dilations[0] == 1 && dilations[3] == 1),
484 errors::InvalidArgument(
485 "Current implementation does not yet support "
486 "dilations in the batch and depth dimensions."));
487 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
488 }
489
490 void Compute(OpKernelContext* context) override {
491 // Input tensor is of the following dimensions:
492 // [ batch, in_rows, in_cols, in_depth ]
493 const Tensor& input = context->input(0);
494
495 // Input filter is of the following dimensions:
496 // [ filter_rows, filter_cols, in_depth, out_depth]
497 const Tensor& filter = context->input(1);
498
499 // For 2D convolution, there should be 4 dimensions.
500 OP_REQUIRES(context, input.dims() == 4,
501 errors::InvalidArgument("input must be rank 4 but is rank ",
502 input.shape().dims()));
503 OP_REQUIRES(context, filter.dims() == 4,
504 errors::InvalidArgument("filter must be rank 4 but is rank ",
505 filter.shape().dims()));
506
507 OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(2).shape()),
508 errors::InvalidArgument("min_input must be rank 0 but is rank ",
509 context->input(2).shape().dims()));
510 OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(3).shape()),
511 errors::InvalidArgument("max_input must be rank 0 but is rank ",
512 context->input(3).shape().dims()));
513 OP_REQUIRES(
514 context, TensorShapeUtils::IsScalar(context->input(4).shape()),
515 errors::InvalidArgument("min_filter must be rank 0 but is rank ",
516 context->input(4).shape().dims()));
517 OP_REQUIRES(
518 context, TensorShapeUtils::IsScalar(context->input(5).shape()),
519 errors::InvalidArgument("max_filter must be rank 0 but is rank ",
520 context->input(5).shape().dims()));
521
522 const float min_input = context->input(2).flat<float>()(0);
523 const float max_input = context->input(3).flat<float>()(0);
524 const float min_filter = context->input(4).flat<float>()(0);
525 const float max_filter = context->input(5).flat<float>()(0);
526 const int32_t offset_input =
527 FloatToQuantizedUnclamped<T1>(0.0f, min_input, max_input);
528 const int32_t offset_filter =
529 FloatToQuantizedUnclamped<T2>(0.0f, min_filter, max_filter);
530 const int32_t offset_output = 0;
531 const int32_t mult_output = 1;
532 const int32_t shift_output = 0;
533
534 // The last dimension for input is in_depth. It must be the same as the
535 // filter's in_depth.
536 const int64_t in_depth = input.dim_size(3);
537 OP_REQUIRES(context, in_depth == filter.dim_size(2),
538 errors::InvalidArgument(
539 "input and filter must have the same depth: ", in_depth,
540 " vs ", filter.dim_size(2)));
541
542 // The last dimension for filter is out_depth.
543 const int64_t out_depth = filter.dim_size(3);
544
545 // The second dimension for input is rows/height.
546 // The first dimension for filter is rows/height.
547 const int64_t input_rows = input.dim_size(1);
548 const int64_t filter_rows = filter.dim_size(0);
549
550 // The third dimension for input is columns/width.
551 // The second dimension for filter is columns/width.
552 const int64_t input_cols = input.dim_size(2);
553 const int64_t filter_cols = filter.dim_size(1);
554
555 // The first dimension for input is batch.
556 const int64_t batch = input.dim_size(0);
557
558 // For now we take the stride from the second dimension only (we
559 // assume row = col stride, and do not support striding on the
560 // batch or depth dimension).
561 const int stride = strides_[1];
562
563 int64_t out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
564 OP_REQUIRES_OK(context,
565 GetWindowedOutputSize(input_rows, filter_rows, stride,
566 padding_, &out_rows, &pad_rows));
567 OP_REQUIRES_OK(context,
568 GetWindowedOutputSize(input_cols, filter_cols, stride,
569 padding_, &out_cols, &pad_cols));
570 CHECK_GT(batch, 0);
571 CHECK_GT(out_rows, 0);
572 CHECK_GT(out_cols, 0);
573 CHECK_GT(out_depth, 0);
574 TensorShape out_shape({batch, out_rows, out_cols, out_depth});
575
576 // Output tensor is of the following dimensions:
577 // [ in_batch, out_rows, out_cols, out_depth ]
578 Tensor* output = nullptr;
579 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
580
581 // This will call different implementations (e.g. reference or optimized)
582 // depending on the template parameter.
583 ConvFunctor<T1, T2, T3> conv_functor;
584 conv_functor(context, input.flat<T1>().data(), batch, input_rows,
585 input_cols, in_depth, offset_input, filter.flat<T2>().data(),
586 filter_rows, filter_cols, out_depth, offset_filter, stride,
587 padding_, output->flat<T3>().data(), out_rows, out_cols,
588 shift_output, offset_output, mult_output);
589
590 float min_output_value;
591 float max_output_value;
592 QuantizationRangeForMultiplication<T1, T2, T3>(
593 min_input, max_input, min_filter, max_filter, &min_output_value,
594 &max_output_value);
595
596 Tensor* output_min = nullptr;
597 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
598 output_min->flat<float>()(0) = min_output_value;
599
600 Tensor* output_max = nullptr;
601 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
602 output_max->flat<float>()(0) = max_output_value;
603 }
604
605 private:
606 std::vector<int32> strides_;
607 Padding padding_;
608};
609
610// Right now we only support taking two eight bit inputs, and returning the
611// results as signed 32-bit integers.
612REGISTER_KERNEL_BUILDER(
613 Name("QuantizedConv2D")
614 .Device(DEVICE_CPU)
615 .TypeConstraint<quint8>("Tinput")
616 .TypeConstraint<quint8>("Tfilter")
617 .TypeConstraint<qint32>("out_type"),
618 QuantizedConv2DOp<quint8, quint8, qint32, Im2ColConvFunctor>);
619
620} // namespace tensorflow
621