1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_GEMM_X8S8S32X_INNER_PRODUCT_HPP |
18 | #define CPU_GEMM_X8S8S32X_INNER_PRODUCT_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include <memory> |
23 | |
24 | #include "common/c_types_map.hpp" |
25 | #include "common/memory_tracking.hpp" |
26 | #include "common/primitive.hpp" |
27 | #include "common/type_helpers.hpp" |
28 | #include "common/utils.hpp" |
29 | |
30 | #include "cpu/gemm/gemm.hpp" |
31 | #include "cpu/gemm_inner_product_utils.hpp" |
32 | |
33 | #include "cpu/cpu_inner_product_pd.hpp" |
34 | #include "cpu/scale_utils.hpp" |
35 | #if DNNL_X64 |
36 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
37 | #endif |
38 | |
39 | namespace dnnl { |
40 | namespace impl { |
41 | namespace cpu { |
42 | |
43 | struct gemm_x8s8s32x_inner_product_fwd_t : public primitive_t { |
44 | struct pd_t : public cpu_inner_product_fwd_pd_t { |
45 | using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; |
46 | |
47 | DECLARE_COMMON_PD_T(src_md()->data_type == data_type::u8 |
48 | ? IGEMM_S8U8S32_IMPL_STR |
49 | : IGEMM_S8S8S32_IMPL_STR, |
50 | gemm_x8s8s32x_inner_product_fwd_t, USE_GLOBAL_SCRATCHPAD); |
51 | |
52 | status_t init(engine_t *engine) { |
53 | using namespace data_type; |
54 | |
55 | const bool ok = is_fwd() && !has_zero_dim_memory() |
56 | && utils::one_of(src_md()->data_type, s8, u8) |
57 | && weights_md()->data_type == s8 |
58 | && utils::one_of(dst_md()->data_type, f32, s32, s8, u8) |
59 | && IMPLICATION(with_bias(), |
60 | utils::one_of( |
61 | weights_md(1)->data_type, f32, s32, s8, u8)) |
62 | && attr()->has_default_values( |
63 | primitive_attr_t::skip_mask_t::scales_runtime |
64 | | primitive_attr_t::skip_mask_t::post_ops, |
65 | dst_md()->data_type) |
66 | && attr()->post_ops_.check_sum_consistent_dt( |
67 | dst_md()->data_type) |
68 | && scales_mask_ok() |
69 | && set_default_params() == status::success |
70 | && dense_gemm_consitency_check( |
71 | src_md(), weights_md(), dst_md()) |
72 | && attr_.set_default_formats(dst_md(0)) == status::success |
73 | && inner_product_utils::post_ops_ok( |
74 | attr()->post_ops_, &dst_md_); |
75 | |
76 | if (!ok) return status::unimplemented; |
77 | |
78 | bool do_sum = attr()->post_ops_.find(primitive_kind::sum) >= 0; |
79 | dst_is_acc_ |
80 | = utils::one_of(dst_md()->data_type, s32, f32) && !do_sum; |
81 | |
82 | init_scratchpad(); |
83 | |
84 | return status::success; |
85 | } |
86 | |
87 | bool dst_is_acc_; |
88 | |
89 | protected: |
90 | bool scales_mask_ok() const { |
91 | using namespace data_type; |
92 | const std::vector<int> supported_args |
93 | = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; |
94 | bool ok = attr()->scales_.has_default_values(supported_args); |
95 | for (int arg : supported_args) { |
96 | const auto &mask = attr()->scales_.get(arg).mask_; |
97 | if (arg == DNNL_ARG_WEIGHTS) |
98 | ok = ok && (mask == 0 || mask == (1 << 0)); |
99 | else |
100 | ok = ok && (mask == 0); |
101 | } |
102 | return ok; |
103 | } |
104 | |
105 | private: |
106 | void init_scratchpad() { |
107 | auto scratchpad = scratchpad_registry().registrar(); |
108 | if (!dst_is_acc_) { |
109 | scratchpad.template book<int32_t>( |
110 | memory_tracking::names::key_iprod_int_dat_in_acc_dt, |
111 | MB() * OC()); |
112 | } |
113 | |
114 | book_precomputed_scales(scratchpad, attr()->scales_, OC()); |
115 | } |
116 | }; |
117 | |
118 | gemm_x8s8s32x_inner_product_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
119 | |
120 | status_t init(engine_t *engine) override { |
121 | CHECK(safe_ptr_assign(pp_kernel_, |
122 | inner_product_utils::pp_kernel_t::create(pd(), false))); |
123 | return pp_kernel_->create_kernel(); |
124 | } |
125 | |
126 | status_t execute(const exec_ctx_t &ctx) const override { |
127 | return execute_forward(ctx); |
128 | } |
129 | |
130 | private: |
131 | status_t execute_forward(const exec_ctx_t &ctx) const; |
132 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
133 | |
134 | std::unique_ptr<inner_product_utils::pp_kernel_t> pp_kernel_; |
135 | }; |
136 | |
137 | } // namespace cpu |
138 | } // namespace impl |
139 | } // namespace dnnl |
140 | |
141 | #endif |
142 | |
143 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
144 | |