1/*******************************************************************************
2* Copyright 2018-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cstdint>
18
19#include "oneapi/dnnl/dnnl_types.h"
20
21#include "common/dnnl_thread.hpp"
22#include "common/nstl.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/platform.hpp"
26#include "cpu/simple_q10n.hpp"
27
28#include "cpu/gemm/gemm.hpp"
29
30#include "cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35
36void compensation_init(const char *offsetC, int32_t *compensation, dim_t len,
37 const int32_t *oc) {
38 bool OCisC = (*offsetC == 'C' || *offsetC == 'c');
39 bool OCisF = (*offsetC == 'F' || *offsetC == 'f');
40
41 if (OCisF && (*oc) != 0) {
42 for (dim_t i = 0; i < len; i++)
43 compensation[i] = *oc;
44 } else if (OCisC) {
45 for (dim_t i = 0; i < len; i++)
46 compensation[i] = oc[i];
47 } else {
48 for (dim_t i = 0; i < len; i++)
49 compensation[i] = 0;
50 }
51}
52
53void compensation_compute(bool transa, dim_t m, dim_t k, float alpha,
54 const int8_t *a, dim_t lda, int32_t *compensation) {
55 if (!transa) {
56 const int L2_cache_size = platform::get_per_core_cache_size(2);
57 const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1);
58 const dim_t npanels = k / blocking_factor;
59 const bool has_tile = k % blocking_factor > 0;
60
61 parallel_nd(npanels, m, [&](dim_t j, dim_t i) {
62 int32_t val = 0;
63 for (dim_t jb = 0; jb < blocking_factor; jb++) {
64 val += a[(i + j * blocking_factor * lda) + jb * lda];
65 }
66 if (alpha != 1.0f) {
67 val = out_round<int32_t>(
68 saturate<int32_t>((double)val * alpha * -128.0));
69 } else {
70 val *= -128;
71 }
72 fetch_and_add(&compensation[i], val);
73 });
74
75 if (has_tile) {
76 parallel_nd(m, [=](dim_t i) {
77 int32_t val = 0;
78 for (dim_t j = npanels * blocking_factor; j < k; j++) {
79 val += a[i + j * lda];
80 }
81 if (alpha != 1.0f) {
82 val = out_round<int32_t>(
83 saturate<int32_t>((double)val * alpha * -128.0));
84 } else {
85 val *= -128;
86 }
87 fetch_and_add(&compensation[i], val);
88 });
89 }
90 } else {
91 parallel_nd(m, [=](dim_t i) {
92 int32_t val = 0;
93 for (dim_t j = 0; j < k; j++) {
94 val += a[j + i * lda];
95 }
96 if (alpha != 1.0f) {
97 val = out_round<int32_t>(
98 saturate<int32_t>((double)val * alpha * -128.0));
99 } else {
100 val *= -128;
101 }
102 compensation[i] += val;
103 });
104 }
105}
106
107void copy_and_shift_b(bool transb, dim_t k, dim_t n, uint8_t *b_u8,
108 dim_t ldb_u8, const int8_t *b_s8, dim_t ldb_s8) {
109 const dim_t b_cols = transb ? k : n;
110
111 parallel_nd(b_cols, [=](dim_t j) {
112 const dim_t b_rows = transb ? n : k;
113
114 uint8_t *pb_u8 = b_u8 + j * ldb_u8;
115 const int8_t *pb_s8 = b_s8 + j * ldb_s8;
116
117 for (dim_t i = 0; i < b_rows; i++) {
118 (*pb_u8) = (*pb_s8) + 128;
119 pb_u8++;
120 pb_s8++;
121 }
122 });
123}
124
125/**
126 * gemm_s8s8s32 operation is defined as follows:
127 * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset + compensation
128 *
129 * where
130 * - compensation is a vector of length m that contains computed compensation
131 * that may contain C_offset if applicable. The compensation is applied inside
132 * gemm_s8u8s32 as a C_offset
133 * - B_shift is a k-by-n matrix, every element of B_shift is equal to 128
134 *
135 * What is the compensation:
136 * In order to prepare the matrix B for gemm_s8u8s32 call the B_shift is applied:
137 * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset =
138 * alpha * op(A) * op(B) + alpha * op(A) * B_shift + beta * C + C_offset
139 * compensation = -alpha * op(A) * B_shift
140 * Since B_shift is a matrix, every element of which is equal to 128 then
141 * - if op(A) = A: compensation contains sum of the elements in each row
142 * scaled by -128 * alpha
143 * - if op(A) = A**T: compensation contains sum of the elements in each column
144 * scaled by -128 * alpha
145 *
146 * The rest of parameters is described in dnnl.h
147 */
148dnnl_status_t simple_gemm_s8s8s32(const char *transA, const char *transB,
149 const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k,
150 const float *alpha, const int8_t *a, const dim_t *lda, const int8_t *oa,
151 const int8_t *b, const dim_t *ldb, const int8_t *ob, const float *beta,
152 int32_t *c, const dim_t *ldc, const int32_t *oc) {
153 if (*oa != 0 || *ob != 0) return dnnl_unimplemented;
154
155 dim_t M = *m, N = *n, K = *k;
156 bool transa = (*transA == 'T' || *transA == 't');
157 bool transb = (*transB == 'T' || *transB == 't');
158 dim_t ld = transb ? N : K;
159
160 uint8_t *b_u8 = (uint8_t *)malloc(
161 sizeof(uint8_t) * K * N, platform::get_cache_line_size());
162 uint8_t ob_u8 = 0;
163 int32_t *compensation = (int32_t *)malloc(
164 sizeof(int32_t) * M, platform::get_cache_line_size());
165
166 if (utils::any_null(b_u8, compensation)) {
167 free(b_u8);
168 free(compensation);
169 return dnnl_out_of_memory;
170 }
171
172 compensation_init(offsetC, compensation, M, oc);
173 compensation_compute(transa, M, K, *alpha, a, *lda, compensation);
174 copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb);
175
176 status_t st = gemm_s8x8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa,
177 b_u8, &ld, &ob_u8, beta, c, ldc, compensation);
178 if (st != dnnl_success) return st;
179
180 if ((*offsetC == 'R' || *offsetC == 'r'))
181 parallel_nd(M, N, [=](dim_t i, dim_t j) { c[i + j * *ldc] += oc[j]; });
182
183 free(b_u8);
184 free(compensation);
185
186 return st;
187}
188} // namespace cpu
189} // namespace impl
190} // namespace dnnl
191