1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
17#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
18
19#include "ruy/matrix.h" // from @ruy
20#include "ruy/mul_params.h" // from @ruy
21#include "ruy/ruy.h" // from @ruy
22#include "tensorflow/lite/kernels/cpu_backend_context.h"
23#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
24#include "tensorflow/lite/kernels/internal/compatibility.h"
25
26namespace tflite {
27namespace cpu_backend_gemm {
28namespace detail {
29
30inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) {
31 switch (cache_policy) {
32 case CachePolicy::kNeverCache:
33 return ruy::CachePolicy::kNeverCache;
34 case CachePolicy::kCacheIfLargeSpeedup:
35 return ruy::CachePolicy::kCacheIfLargeSpeedup;
36 case CachePolicy::kAlwaysCache:
37 return ruy::CachePolicy::kAlwaysCache;
38 default:
39 TFLITE_DCHECK(false);
40 return ruy::CachePolicy::kNeverCache;
41 }
42}
43
44template <typename Scalar, typename DataPointer>
45void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
46 ruy::Matrix<Scalar>* dst, bool use_caching = false) {
47 ruy::Order ruy_order = params.order == Order::kColMajor
48 ? ruy::Order::kColMajor
49 : ruy::Order::kRowMajor;
50 ruy::MakeSimpleLayout(params.rows, params.cols, ruy_order,
51 dst->mutable_layout());
52 // Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer.
53 // It does care whether we assign to it a Scalar* or a const Scalar*.
54 dst->set_data(data_ptr);
55 dst->set_zero_point(params.zero_point);
56 if (use_caching) {
57 dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
58 }
59}
60
61// Floating-point case.
62template <typename AccumScalar, typename DstScalar,
63 QuantizationFlavor quantization_flavor>
64struct MakeRuyMulParamsImpl final {
65 static void Run(
66 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
67 ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
68 static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint,
69 "");
70 ruy_mul_params->set_bias(params.bias);
71 ruy_mul_params->set_clamp_min(params.clamp_min);
72 ruy_mul_params->set_clamp_max(params.clamp_max);
73 }
74};
75
76// Integer-quantized case with destination type narrower than int32
77template <typename DstScalar, QuantizationFlavor quantization_flavor>
78struct MakeRuyMulParamsImpl<std::int32_t, DstScalar, quantization_flavor>
79 final {
80 static void Run(
81 const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params,
82 ruy::MulParams<std::int32_t, DstScalar>* ruy_mul_params) {
83 static_assert(sizeof(DstScalar) < sizeof(std::int32_t), "");
84 if (quantization_flavor ==
85 QuantizationFlavor::kIntegerWithUniformMultiplier) {
86 ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
87 ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
88 }
89 if (quantization_flavor ==
90 QuantizationFlavor::kIntegerWithPerRowMultiplier) {
91 ruy_mul_params->set_multiplier_fixedpoint_perchannel(
92 params.multiplier_fixedpoint_perchannel);
93 ruy_mul_params->set_multiplier_exponent_perchannel(
94 params.multiplier_exponent_perchannel);
95 }
96 ruy_mul_params->set_bias(params.bias);
97 ruy_mul_params->set_clamp_min(params.clamp_min);
98 ruy_mul_params->set_clamp_max(params.clamp_max);
99 }
100};
101
102// Raw-integer case with destination type int32.
103template <QuantizationFlavor quantization_flavor>
104struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t, quantization_flavor>
105 final {
106 static void Run(
107 const GemmParams<std::int32_t, std::int32_t, quantization_flavor>& params,
108 ruy::MulParams<std::int32_t, std::int32_t>* ruy_mul_params) {
109 ruy_mul_params->set_bias(params.bias);
110 }
111};
112
113template <typename AccumScalar, typename DstScalar,
114 QuantizationFlavor quantization_flavor>
115void MakeRuyMulParams(
116 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
117 ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
118 MakeRuyMulParamsImpl<AccumScalar, DstScalar, quantization_flavor>::Run(
119 params, ruy_mul_params);
120}
121
122template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
123 typename DstScalar, QuantizationFlavor quantization_flavor>
124struct GemmImplUsingRuy {
125 static void Run(
126 const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
127 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
128 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
129 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
130 CpuBackendContext* context) {
131 ruy::Matrix<LhsScalar> ruy_lhs;
132 ruy::Matrix<RhsScalar> ruy_rhs;
133 ruy::Matrix<DstScalar> ruy_dst;
134 MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs, context->use_caching());
135 MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs, context->use_caching());
136 MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
137
138 ruy::MulParams<AccumScalar, DstScalar> ruy_mul_params;
139 MakeRuyMulParams(params, &ruy_mul_params);
140
141 ruy::Mul(ruy_lhs, ruy_rhs, ruy_mul_params, context->ruy_context(),
142 &ruy_dst);
143 }
144};
145
146} // namespace detail
147} // namespace cpu_backend_gemm
148} // namespace tflite
149
150#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
151