1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_GEMM_X8S8S32X_INNER_PRODUCT_HPP
18#define CPU_GEMM_X8S8S32X_INNER_PRODUCT_HPP
19
20#include <assert.h>
21
22#include <memory>
23
24#include "common/c_types_map.hpp"
25#include "common/memory_tracking.hpp"
26#include "common/primitive.hpp"
27#include "common/type_helpers.hpp"
28#include "common/utils.hpp"
29
30#include "cpu/gemm/gemm.hpp"
31#include "cpu/gemm_inner_product_utils.hpp"
32
33#include "cpu/cpu_inner_product_pd.hpp"
34#include "cpu/scale_utils.hpp"
35#if DNNL_X64
36#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
37#endif
38
39namespace dnnl {
40namespace impl {
41namespace cpu {
42
43struct gemm_x8s8s32x_inner_product_fwd_t : public primitive_t {
44 struct pd_t : public cpu_inner_product_fwd_pd_t {
45 using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t;
46
47 DECLARE_COMMON_PD_T(src_md()->data_type == data_type::u8
48 ? IGEMM_S8U8S32_IMPL_STR
49 : IGEMM_S8S8S32_IMPL_STR,
50 gemm_x8s8s32x_inner_product_fwd_t, USE_GLOBAL_SCRATCHPAD);
51
52 status_t init(engine_t *engine) {
53 using namespace data_type;
54
55 const bool ok = is_fwd() && !has_zero_dim_memory()
56 && utils::one_of(src_md()->data_type, s8, u8)
57 && weights_md()->data_type == s8
58 && utils::one_of(dst_md()->data_type, f32, s32, s8, u8)
59 && IMPLICATION(with_bias(),
60 utils::one_of(
61 weights_md(1)->data_type, f32, s32, s8, u8))
62 && attr()->has_default_values(
63 primitive_attr_t::skip_mask_t::scales_runtime
64 | primitive_attr_t::skip_mask_t::post_ops,
65 dst_md()->data_type)
66 && attr()->post_ops_.check_sum_consistent_dt(
67 dst_md()->data_type)
68 && scales_mask_ok()
69 && set_default_params() == status::success
70 && dense_gemm_consitency_check(
71 src_md(), weights_md(), dst_md())
72 && attr_.set_default_formats(dst_md(0)) == status::success
73 && inner_product_utils::post_ops_ok(
74 attr()->post_ops_, &dst_md_);
75
76 if (!ok) return status::unimplemented;
77
78 bool do_sum = attr()->post_ops_.find(primitive_kind::sum) >= 0;
79 dst_is_acc_
80 = utils::one_of(dst_md()->data_type, s32, f32) && !do_sum;
81
82 init_scratchpad();
83
84 return status::success;
85 }
86
87 bool dst_is_acc_;
88
89 protected:
90 bool scales_mask_ok() const {
91 using namespace data_type;
92 const std::vector<int> supported_args
93 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
94 bool ok = attr()->scales_.has_default_values(supported_args);
95 for (int arg : supported_args) {
96 const auto &mask = attr()->scales_.get(arg).mask_;
97 if (arg == DNNL_ARG_WEIGHTS)
98 ok = ok && (mask == 0 || mask == (1 << 0));
99 else
100 ok = ok && (mask == 0);
101 }
102 return ok;
103 }
104
105 private:
106 void init_scratchpad() {
107 auto scratchpad = scratchpad_registry().registrar();
108 if (!dst_is_acc_) {
109 scratchpad.template book<int32_t>(
110 memory_tracking::names::key_iprod_int_dat_in_acc_dt,
111 MB() * OC());
112 }
113
114 book_precomputed_scales(scratchpad, attr()->scales_, OC());
115 }
116 };
117
118 gemm_x8s8s32x_inner_product_fwd_t(const pd_t *apd) : primitive_t(apd) {}
119
120 status_t init(engine_t *engine) override {
121 CHECK(safe_ptr_assign(pp_kernel_,
122 inner_product_utils::pp_kernel_t::create(pd(), false)));
123 return pp_kernel_->create_kernel();
124 }
125
126 status_t execute(const exec_ctx_t &ctx) const override {
127 return execute_forward(ctx);
128 }
129
130private:
131 status_t execute_forward(const exec_ctx_t &ctx) const;
132 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
133
134 std::unique_ptr<inner_product_utils::pp_kernel_t> pp_kernel_;
135};
136
137} // namespace cpu
138} // namespace impl
139} // namespace dnnl
140
141#endif
142
143// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
144