1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/dnnl_thread.hpp" |
18 | #include "common/math_utils.hpp" |
19 | #include "cpu/simple_q10n.hpp" |
20 | |
21 | #include "cpu/cpu_primitive.hpp" |
22 | #include "cpu/scale_utils.hpp" |
23 | |
24 | #include "cpu/binary_injector_utils.hpp" |
25 | #include "cpu/gemm/gemm.hpp" |
26 | #include "cpu/gemm_x8s8s32x_inner_product.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | |
32 | using namespace math; |
33 | using namespace format_tag; |
34 | using namespace memory_tracking::names; |
35 | |
36 | status_t gemm_x8s8s32x_inner_product_fwd_t::execute_forward( |
37 | const exec_ctx_t &ctx) const { |
38 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
39 | auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS); |
40 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
41 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
42 | const auto post_ops_binary_rhs_arg_vec |
43 | = binary_injector_utils::prepare_binary_args( |
44 | this->pd()->attr()->post_ops_, ctx); |
45 | |
46 | const dim_t MB = pd()->MB(); |
47 | const dim_t OC = pd()->OC(); |
48 | const dim_t IC = pd()->IC(); |
49 | |
50 | const auto &wmd = *pd()->weights_md(); |
51 | const auto &smd = *pd()->src_md(); |
52 | bool wei_tr = wmd.format_desc.blocking.strides[0] != 1; |
53 | // check if MB is the leading dimension |
54 | bool src_tr = smd.format_desc.blocking.strides[0] == 1 && IC > 1; |
55 | |
56 | const dim_t M = OC; |
57 | const dim_t N = MB; |
58 | const dim_t K = pd()->IC_total_padded(); |
59 | const int8_t off_a = 0; |
60 | const int32_t off_c = 0; |
61 | |
62 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
63 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
64 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
65 | |
66 | auto scratchpad = ctx.get_scratchpad_grantor(); |
67 | const float *scales = precompute_scales( |
68 | scratchpad, src_scales, wei_scales, OC, pd()->attr()); |
69 | |
70 | int32_t *acc = pd()->dst_is_acc_ |
71 | ? (int32_t *)dst |
72 | : ctx.get_scratchpad_grantor().template get<int32_t>( |
73 | key_iprod_int_dat_in_acc_dt); |
74 | |
75 | const float onef = 1.0, zerof = 0.0; |
76 | |
77 | if (smd.data_type == data_type::s8) { |
78 | const int8_t off_b = 0; |
79 | const int8_t *src_ = reinterpret_cast<const int8_t *>(src); |
80 | CHECK(gemm_s8x8s32(wei_tr ? "T" : "N" , src_tr ? "T" : "N" , "F" , &M, &N, |
81 | &K, &onef, weights, wei_tr ? &K : &M, &off_a, src_, |
82 | src_tr ? &N : &K, &off_b, &zerof, acc, &M, &off_c)); |
83 | } else if (smd.data_type == data_type::u8) { |
84 | const uint8_t off_b = 0; |
85 | const uint8_t *src_ = reinterpret_cast<const uint8_t *>(src); |
86 | CHECK(gemm_s8x8s32(wei_tr ? "T" : "N" , src_tr ? "T" : "N" , "F" , &M, &N, |
87 | &K, &onef, weights, wei_tr ? &K : &M, &off_a, src_, |
88 | src_tr ? &N : &K, &off_b, &zerof, acc, &M, &off_c)); |
89 | } else { |
90 | assert(!"unsupported data type!" ); |
91 | } |
92 | |
93 | if (!pd()->attr()->has_default_values() |
94 | || pd()->dst_md()->data_type != data_type::s32 |
95 | || pd()->with_bias()) { |
96 | const bool force_sequential |
97 | = pp_kernel_->sequential_kernel() || MB * OC < 2000; |
98 | parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { |
99 | size_t start, end; |
100 | balance211((size_t)(OC * MB), nthr, ithr, start, end); |
101 | const size_t dst_logical_off = start; |
102 | const size_t dim1_off = start % OC; |
103 | (*pp_kernel_)(dst, acc, bias, scales, dst_scales[0], start, |
104 | dst_logical_off, dim1_off, end, 0, 0, nullptr, |
105 | post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx, |
106 | *pd()->dst_md()); |
107 | }); |
108 | } |
109 | |
110 | return status::success; |
111 | } |
112 | |
113 | } // namespace cpu |
114 | } // namespace impl |
115 | } // namespace dnnl |
116 | |