1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ |
17 | #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ |
18 | |
19 | #include "ruy/matrix.h" // from @ruy |
20 | #include "ruy/mul_params.h" // from @ruy |
21 | #include "ruy/ruy.h" // from @ruy |
22 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
23 | #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" |
24 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
25 | |
26 | namespace tflite { |
27 | namespace cpu_backend_gemm { |
28 | namespace detail { |
29 | |
30 | inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) { |
31 | switch (cache_policy) { |
32 | case CachePolicy::kNeverCache: |
33 | return ruy::CachePolicy::kNeverCache; |
34 | case CachePolicy::kCacheIfLargeSpeedup: |
35 | return ruy::CachePolicy::kCacheIfLargeSpeedup; |
36 | case CachePolicy::kAlwaysCache: |
37 | return ruy::CachePolicy::kAlwaysCache; |
38 | default: |
39 | TFLITE_DCHECK(false); |
40 | return ruy::CachePolicy::kNeverCache; |
41 | } |
42 | } |
43 | |
44 | template <typename Scalar, typename DataPointer> |
45 | void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr, |
46 | ruy::Matrix<Scalar>* dst, bool use_caching = false) { |
47 | ruy::Order ruy_order = params.order == Order::kColMajor |
48 | ? ruy::Order::kColMajor |
49 | : ruy::Order::kRowMajor; |
50 | ruy::MakeSimpleLayout(params.rows, params.cols, ruy_order, |
51 | dst->mutable_layout()); |
52 | // Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer. |
53 | // It does care whether we assign to it a Scalar* or a const Scalar*. |
54 | dst->set_data(data_ptr); |
55 | dst->set_zero_point(params.zero_point); |
56 | if (use_caching) { |
57 | dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy)); |
58 | } |
59 | } |
60 | |
61 | // Floating-point case. |
62 | template <typename AccumScalar, typename DstScalar, |
63 | QuantizationFlavor quantization_flavor> |
64 | struct MakeRuyMulParamsImpl final { |
65 | static void Run( |
66 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, |
67 | ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) { |
68 | static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint, |
69 | "" ); |
70 | ruy_mul_params->set_bias(params.bias); |
71 | ruy_mul_params->set_clamp_min(params.clamp_min); |
72 | ruy_mul_params->set_clamp_max(params.clamp_max); |
73 | } |
74 | }; |
75 | |
76 | // Integer-quantized case with destination type narrower than int32 |
77 | template <typename DstScalar, QuantizationFlavor quantization_flavor> |
78 | struct MakeRuyMulParamsImpl<std::int32_t, DstScalar, quantization_flavor> |
79 | final { |
80 | static void Run( |
81 | const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params, |
82 | ruy::MulParams<std::int32_t, DstScalar>* ruy_mul_params) { |
83 | static_assert(sizeof(DstScalar) < sizeof(std::int32_t), "" ); |
84 | if (quantization_flavor == |
85 | QuantizationFlavor::kIntegerWithUniformMultiplier) { |
86 | ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint); |
87 | ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent); |
88 | } |
89 | if (quantization_flavor == |
90 | QuantizationFlavor::kIntegerWithPerRowMultiplier) { |
91 | ruy_mul_params->set_multiplier_fixedpoint_perchannel( |
92 | params.multiplier_fixedpoint_perchannel); |
93 | ruy_mul_params->set_multiplier_exponent_perchannel( |
94 | params.multiplier_exponent_perchannel); |
95 | } |
96 | ruy_mul_params->set_bias(params.bias); |
97 | ruy_mul_params->set_clamp_min(params.clamp_min); |
98 | ruy_mul_params->set_clamp_max(params.clamp_max); |
99 | } |
100 | }; |
101 | |
102 | // Raw-integer case with destination type int32. |
103 | template <QuantizationFlavor quantization_flavor> |
104 | struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t, quantization_flavor> |
105 | final { |
106 | static void Run( |
107 | const GemmParams<std::int32_t, std::int32_t, quantization_flavor>& params, |
108 | ruy::MulParams<std::int32_t, std::int32_t>* ruy_mul_params) { |
109 | ruy_mul_params->set_bias(params.bias); |
110 | } |
111 | }; |
112 | |
113 | template <typename AccumScalar, typename DstScalar, |
114 | QuantizationFlavor quantization_flavor> |
115 | void MakeRuyMulParams( |
116 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, |
117 | ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) { |
118 | MakeRuyMulParamsImpl<AccumScalar, DstScalar, quantization_flavor>::Run( |
119 | params, ruy_mul_params); |
120 | } |
121 | |
122 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
123 | typename DstScalar, QuantizationFlavor quantization_flavor> |
124 | struct GemmImplUsingRuy { |
125 | static void Run( |
126 | const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data, |
127 | const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data, |
128 | const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, |
129 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, |
130 | CpuBackendContext* context) { |
131 | ruy::Matrix<LhsScalar> ruy_lhs; |
132 | ruy::Matrix<RhsScalar> ruy_rhs; |
133 | ruy::Matrix<DstScalar> ruy_dst; |
134 | MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs, context->use_caching()); |
135 | MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs, context->use_caching()); |
136 | MakeRuyMatrix(dst_params, dst_data, &ruy_dst); |
137 | |
138 | ruy::MulParams<AccumScalar, DstScalar> ruy_mul_params; |
139 | MakeRuyMulParams(params, &ruy_mul_params); |
140 | |
141 | ruy::Mul(ruy_lhs, ruy_rhs, ruy_mul_params, context->ruy_context(), |
142 | &ruy_dst); |
143 | } |
144 | }; |
145 | |
146 | } // namespace detail |
147 | } // namespace cpu_backend_gemm |
148 | } // namespace tflite |
149 | |
150 | #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ |
151 | |