1 | /******************************************************************************* |
2 | * Copyright 2018-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cstdint> |
18 | |
19 | #include "oneapi/dnnl/dnnl_types.h" |
20 | |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/nstl.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/platform.hpp" |
26 | #include "cpu/simple_q10n.hpp" |
27 | |
28 | #include "cpu/gemm/gemm.hpp" |
29 | |
30 | #include "cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | |
36 | void compensation_init(const char *offsetC, int32_t *compensation, dim_t len, |
37 | const int32_t *oc) { |
38 | bool OCisC = (*offsetC == 'C' || *offsetC == 'c'); |
39 | bool OCisF = (*offsetC == 'F' || *offsetC == 'f'); |
40 | |
41 | if (OCisF && (*oc) != 0) { |
42 | for (dim_t i = 0; i < len; i++) |
43 | compensation[i] = *oc; |
44 | } else if (OCisC) { |
45 | for (dim_t i = 0; i < len; i++) |
46 | compensation[i] = oc[i]; |
47 | } else { |
48 | for (dim_t i = 0; i < len; i++) |
49 | compensation[i] = 0; |
50 | } |
51 | } |
52 | |
53 | void compensation_compute(bool transa, dim_t m, dim_t k, float alpha, |
54 | const int8_t *a, dim_t lda, int32_t *compensation) { |
55 | if (!transa) { |
56 | const int L2_cache_size = platform::get_per_core_cache_size(2); |
57 | const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1); |
58 | const dim_t npanels = k / blocking_factor; |
59 | const bool has_tile = k % blocking_factor > 0; |
60 | |
61 | parallel_nd(npanels, m, [&](dim_t j, dim_t i) { |
62 | int32_t val = 0; |
63 | for (dim_t jb = 0; jb < blocking_factor; jb++) { |
64 | val += a[(i + j * blocking_factor * lda) + jb * lda]; |
65 | } |
66 | if (alpha != 1.0f) { |
67 | val = out_round<int32_t>( |
68 | saturate<int32_t>((double)val * alpha * -128.0)); |
69 | } else { |
70 | val *= -128; |
71 | } |
72 | fetch_and_add(&compensation[i], val); |
73 | }); |
74 | |
75 | if (has_tile) { |
76 | parallel_nd(m, [=](dim_t i) { |
77 | int32_t val = 0; |
78 | for (dim_t j = npanels * blocking_factor; j < k; j++) { |
79 | val += a[i + j * lda]; |
80 | } |
81 | if (alpha != 1.0f) { |
82 | val = out_round<int32_t>( |
83 | saturate<int32_t>((double)val * alpha * -128.0)); |
84 | } else { |
85 | val *= -128; |
86 | } |
87 | fetch_and_add(&compensation[i], val); |
88 | }); |
89 | } |
90 | } else { |
91 | parallel_nd(m, [=](dim_t i) { |
92 | int32_t val = 0; |
93 | for (dim_t j = 0; j < k; j++) { |
94 | val += a[j + i * lda]; |
95 | } |
96 | if (alpha != 1.0f) { |
97 | val = out_round<int32_t>( |
98 | saturate<int32_t>((double)val * alpha * -128.0)); |
99 | } else { |
100 | val *= -128; |
101 | } |
102 | compensation[i] += val; |
103 | }); |
104 | } |
105 | } |
106 | |
107 | void copy_and_shift_b(bool transb, dim_t k, dim_t n, uint8_t *b_u8, |
108 | dim_t ldb_u8, const int8_t *b_s8, dim_t ldb_s8) { |
109 | const dim_t b_cols = transb ? k : n; |
110 | |
111 | parallel_nd(b_cols, [=](dim_t j) { |
112 | const dim_t b_rows = transb ? n : k; |
113 | |
114 | uint8_t *pb_u8 = b_u8 + j * ldb_u8; |
115 | const int8_t *pb_s8 = b_s8 + j * ldb_s8; |
116 | |
117 | for (dim_t i = 0; i < b_rows; i++) { |
118 | (*pb_u8) = (*pb_s8) + 128; |
119 | pb_u8++; |
120 | pb_s8++; |
121 | } |
122 | }); |
123 | } |
124 | |
125 | /** |
126 | * gemm_s8s8s32 operation is defined as follows: |
127 | * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset + compensation |
128 | * |
129 | * where |
130 | * - compensation is a vector of length m that contains computed compensation |
131 | * that may contain C_offset if applicable. The compensation is applied inside |
132 | * gemm_s8u8s32 as a C_offset |
133 | * - B_shift is a k-by-n matrix, every element of B_shift is equal to 128 |
134 | * |
135 | * What is the compensation: |
136 | * In order to prepare the matrix B for gemm_s8u8s32 call the B_shift is applied: |
137 | * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset = |
138 | * alpha * op(A) * op(B) + alpha * op(A) * B_shift + beta * C + C_offset |
139 | * compensation = -alpha * op(A) * B_shift |
140 | * Since B_shift is a matrix, every element of which is equal to 128 then |
141 | * - if op(A) = A: compensation contains sum of the elements in each row |
142 | * scaled by -128 * alpha |
143 | * - if op(A) = A**T: compensation contains sum of the elements in each column |
144 | * scaled by -128 * alpha |
145 | * |
146 | * The rest of parameters is described in dnnl.h |
147 | */ |
148 | dnnl_status_t simple_gemm_s8s8s32(const char *transA, const char *transB, |
149 | const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k, |
150 | const float *alpha, const int8_t *a, const dim_t *lda, const int8_t *oa, |
151 | const int8_t *b, const dim_t *ldb, const int8_t *ob, const float *beta, |
152 | int32_t *c, const dim_t *ldc, const int32_t *oc) { |
153 | if (*oa != 0 || *ob != 0) return dnnl_unimplemented; |
154 | |
155 | dim_t M = *m, N = *n, K = *k; |
156 | bool transa = (*transA == 'T' || *transA == 't'); |
157 | bool transb = (*transB == 'T' || *transB == 't'); |
158 | dim_t ld = transb ? N : K; |
159 | |
160 | uint8_t *b_u8 = (uint8_t *)malloc( |
161 | sizeof(uint8_t) * K * N, platform::get_cache_line_size()); |
162 | uint8_t ob_u8 = 0; |
163 | int32_t *compensation = (int32_t *)malloc( |
164 | sizeof(int32_t) * M, platform::get_cache_line_size()); |
165 | |
166 | if (utils::any_null(b_u8, compensation)) { |
167 | free(b_u8); |
168 | free(compensation); |
169 | return dnnl_out_of_memory; |
170 | } |
171 | |
172 | compensation_init(offsetC, compensation, M, oc); |
173 | compensation_compute(transa, M, K, *alpha, a, *lda, compensation); |
174 | copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb); |
175 | |
176 | status_t st = gemm_s8x8s32(transA, transB, "C" , m, n, k, alpha, a, lda, oa, |
177 | b_u8, &ld, &ob_u8, beta, c, ldc, compensation); |
178 | if (st != dnnl_success) return st; |
179 | |
180 | if ((*offsetC == 'R' || *offsetC == 'r')) |
181 | parallel_nd(M, N, [=](dim_t i, dim_t j) { c[i + j * *ldc] += oc[j]; }); |
182 | |
183 | free(b_u8); |
184 | free(compensation); |
185 | |
186 | return st; |
187 | } |
188 | } // namespace cpu |
189 | } // namespace impl |
190 | } // namespace dnnl |
191 | |