1/* Copyright 2020 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Front-end validation code, see the Validate function.
17
18#ifndef RUY_RUY_VALIDATE_H_
19#define RUY_RUY_VALIDATE_H_
20
21#include <cstdint>
22#include <limits>
23#include <type_traits>
24
25#include "ruy/check_macros.h"
26#include "ruy/mat.h"
27#include "ruy/mul_params.h"
28#include "ruy/side_pair.h"
29
30namespace ruy {
31namespace detail {
32
33template <typename Scalar>
34void CheckZeroPoint(Scalar zero_point) {
35 if (std::is_floating_point<Scalar>::value) {
36 RUY_DCHECK(!zero_point);
37 }
38}
39
40template <typename LhsScalar, typename RhsScalar, typename DstScalar>
41void ValidateZeroPoints(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point,
42 DstScalar dst_zero_point) {
43 CheckZeroPoint(lhs_zero_point);
44 CheckZeroPoint(rhs_zero_point);
45 CheckZeroPoint(dst_zero_point);
46
47 // For now, support for int16 source types is limited to the
48 // symmetric case (zero_point==0) because that appears to be
49 // the case in the initial use cases, and that limits complexity
50 // in thinking about accumulator overflows.
51 const bool has_16bit_input = std::is_same<LhsScalar, std::int16_t>::value ||
52 std::is_same<RhsScalar, std::int16_t>::value;
53 if (has_16bit_input) {
54 RUY_DCHECK(!lhs_zero_point);
55 RUY_DCHECK(!rhs_zero_point);
56 RUY_DCHECK(!dst_zero_point);
57 }
58
59 // Guard against the case when both LHS and RHS zero_point's are equal to
60 // the minimum representable value. In that case, padding with zero_point
61 // values will generate the bad case for fast int8 kernels on NEON
62 // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8
63 // into a int16: this is safe except in the bad case -128*-128 + -128*-128.
64 // See b/131609283. This only affects the kNeon path but we ban this for all
65 // paths in order for ruy to have the same supported parameter space
66 // on all paths.
67 // We disable this check for now for the case of LhsScalar==RhsScalar==uint8
68 // for backwards compatability with gemmlowp. The issue is still relevant
69 // because we convert from uint8 to int8 for the backend kernels.
70 if (!std::is_same<LhsScalar, uint8_t>::value ||
71 !std::is_same<RhsScalar, uint8_t>::value) {
72 RUY_DCHECK(lhs_zero_point != std::numeric_limits<LhsScalar>::lowest() ||
73 rhs_zero_point != std::numeric_limits<RhsScalar>::lowest());
74 }
75}
76
77} // namespace detail
78
79template <typename LhsScalar, typename RhsScalar, typename DstScalar>
80void Validate(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
81 const Mat<DstScalar>& dst) {
82 detail::ValidateZeroPoints(lhs.zero_point, rhs.zero_point, dst.zero_point);
83}
84
85} // namespace ruy
86
87#endif // RUY_RUY_VALIDATE_H_
88