1 | /******************************************************************************* |
2 | * Copyright 2018-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cstdint> |
18 | |
19 | #include "oneapi/dnnl/dnnl_types.h" |
20 | |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/utils.hpp" |
23 | |
24 | #include "cpu/platform.hpp" |
25 | #include "cpu/simple_q10n.hpp" |
26 | |
27 | #include "cpu/gemm/f32/ref_gemm_f32.hpp" |
28 | |
29 | #include "cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | |
35 | template <typename b_dt> |
36 | dnnl_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, |
37 | const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, |
38 | const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao, |
39 | const b_dt *B, const dim_t *LDB, const b_dt *bo, const float *beta, |
40 | int32_t *C, const dim_t *LDC, const int32_t *co) { |
41 | |
42 | if (*M == 0 || *N == 0 || *K == 0) return dnnl_success; |
43 | |
44 | if (!(utils::one_of(*transa, 'n', 'N', 't', 'T') |
45 | && utils::one_of(*transb, 'n', 'N', 't', 'T'))) |
46 | return dnnl_unimplemented; |
47 | |
48 | bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); |
49 | bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); |
50 | bool AisN = (*transa == 'N' || *transa == 'n'); |
51 | bool BisN = (*transb == 'N' || *transb == 'n'); |
52 | |
53 | dim_t m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC; |
54 | size_t sizeA = AisN ? lda * k : lda * m; |
55 | size_t sizeB = BisN ? ldb * n : ldb * k; |
56 | size_t sizeC = ldc * n; |
57 | |
58 | double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K); |
59 | double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K); |
60 | double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K); |
61 | |
62 | if (utils::any_null(dA, dB, dC)) { |
63 | free(dA); |
64 | free(dB); |
65 | free(dC); |
66 | return dnnl_out_of_memory; |
67 | } |
68 | |
69 | auto da_setter = [=](dim_t i, dim_t j, double v) { dA[j * lda + i] = v; }; |
70 | auto db_setter = [=](dim_t i, dim_t j, double v) { dB[j * ldb + i] = v; }; |
71 | |
72 | auto ia_accessor = [=](dim_t i, dim_t j) { return A[j * lda + i]; }; |
73 | auto ib_accessor = [=](dim_t i, dim_t j) { return B[j * ldb + i]; }; |
74 | |
75 | const int a_rows = AisN ? m : k; |
76 | const int a_cols = AisN ? k : m; |
77 | dnnl::impl::parallel_nd(a_cols, a_rows, [&](dim_t j, dim_t i) { |
78 | da_setter(i, j, |
79 | static_cast<double>(ia_accessor(i, j)) |
80 | - static_cast<double>(ao[0])); |
81 | }); |
82 | |
83 | const dim_t b_rows = BisN ? k : n; |
84 | const dim_t b_cols = BisN ? n : k; |
85 | dnnl::impl::parallel_nd(b_cols, b_rows, [&](dim_t j, dim_t i) { |
86 | db_setter(i, j, |
87 | static_cast<double>(ib_accessor(i, j)) |
88 | - static_cast<double>(bo[0])); |
89 | }); |
90 | double one = 1.0, zero = 0.0; |
91 | ref_gemm<double>(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero, dC, |
92 | LDC, nullptr); |
93 | |
94 | auto i2d = [=](int32_t v) { return static_cast<double>(v); }; |
95 | auto f2d = [=](float v) { return static_cast<double>(v); }; |
96 | |
97 | dnnl::impl::parallel_nd(n, m, [&](dim_t j, dim_t i) { |
98 | double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]); |
99 | double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc])) |
100 | + f2d(*alpha) * dC[i + j * ldc] + coffset; |
101 | C[i + j * ldc] = out_round<int32_t>(saturate<int32_t>(val)); |
102 | }); |
103 | |
104 | free(dA); |
105 | free(dB); |
106 | free(dC); |
107 | return dnnl_success; |
108 | } |
109 | |
110 | template dnnl_status_t ref_gemm_s8x8s32<uint8_t>(const char *transa, |
111 | const char *transb, const char *offsetc, const dim_t *M, const dim_t *N, |
112 | const dim_t *K, const float *alpha, const int8_t *A, const dim_t *LDA, |
113 | const int8_t *ao, const uint8_t *B, const dim_t *LDB, const uint8_t *bo, |
114 | const float *beta, int32_t *C, const dim_t *LDC, const int32_t *co); |
115 | |
116 | template dnnl_status_t ref_gemm_s8x8s32<int8_t>(const char *transa, |
117 | const char *transb, const char *offsetc, const dim_t *M, const dim_t *N, |
118 | const dim_t *K, const float *alpha, const int8_t *A, const dim_t *LDA, |
119 | const int8_t *ao, const int8_t *B, const dim_t *LDB, const int8_t *bo, |
120 | const float *beta, int32_t *C, const dim_t *LDC, const int32_t *co); |
121 | |
122 | } // namespace cpu |
123 | } // namespace impl |
124 | } // namespace dnnl |
125 | |