1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#ifndef CPU_X64_JIT_UNI_POSTOPS_INJECTOR_HPP
17#define CPU_X64_JIT_UNI_POSTOPS_INJECTOR_HPP
18
19#include <functional>
20#include <map>
21#include <memory>
22
23#include "common/c_types_map.hpp"
24#include "common/primitive_attr.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27#include "cpu/x64/injectors/injector_utils.hpp"
28#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
29#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
30#include "cpu/x64/jit_generator.hpp"
31#include <initializer_list>
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36namespace x64 {
37namespace injector {
38
39/*
40 * Allows specifying custom injector function for given post-op type - one
41 * function per primitive. There are post-ops type (example: sum) that don't
42 * have specialized injector. They heavily rely on kernel specific intrnals,
43 * which makes the generalization unreasonable. As so user can prepare internal
44 * kernel lambda and pass it explicitly to injector.
45 */
46using lambda_jit_injectors_t
47 = std::map<dnnl_primitive_kind_t, std::function<void()>>;
48
49struct post_ops_ok_args_t;
50/*
51 * Checks if postops injection for given args is supported.
52 */
53bool is_supported(const post_ops_ok_args_t &post_ops_ok_args);
54
55/*
56 * Main mechanism of handling various post-ops types. It utilizes internally
57 * specialized injectors to generate post-ops code to host primitive. Random
58 * order of post-ops is supported.
59 */
60template <cpu_isa_t isa, typename Vmm = typename cpu_isa_traits<isa>::Vmm>
61class jit_uni_postops_injector_t {
62public:
63 /*
64 * @param host <required> - user primitive where post-ops generated code is
65 * injected
66 * @param post_ops <required> - struct representing requested post-ops chain
67 * @binary_static_params <reguired> - static params needed for binary_injector.
68 * see: jit_uni_binary_injector.hpp for more info.
69 * @param eltwise_static_params <optional> - allows user specify non default
70 * params for eltwise_injector
71 * @param lambda_jit_injectors <optional> - allows user specify custom injector
72 * function for given post-op type
73 */
74 jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops,
75 const binary_injector::static_params_t &binary_static_params);
76 jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops,
77 const binary_injector::static_params_t &binary_static_params,
78 const lambda_jit_injectors_t &lambda_jit_injectors);
79 jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops,
80 const binary_injector::static_params_t &binary_static_params,
81 const eltwise_injector::static_params_t &eltwise_static_params);
82 jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops,
83 const binary_injector::static_params_t &binary_static_params,
84 const eltwise_injector::static_params_t &eltwise_static_params,
85 const lambda_jit_injectors_t &lambda_jit_injectors);
86
87 /*
88 * Generates code of post_ops chain injected to host primitive. Applied to
89 * ordered set of vector registers' indexes.
90 *
91 * @rhs_arg_params: see jit_uni_binary_injector description
92 */
93 void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs,
94 const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params);
95
96 void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs);
97
98 /*
99 * Generates code of post_ops chain injected to host primitive. Applied to
100 * range <start_idx, end_idx) of vector registers' indexes.
101 *
102 * @rhs_arg_params: see jit_uni_binary_injector description
103 */
104 void compute_vector_range(size_t start_idx, size_t end_idx,
105 const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params);
106
107 void compute_vector_range(size_t start_idx, size_t end_idx);
108
109 /*
110 * Generates code of post_ops chain injected to host primitive. Applied to
111 * a single vector register index.
112 *
113 * @rhs_arg_params: see jit_uni_binary_injector description
114 */
115 void compute_vector(size_t idx,
116 const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params);
117 void compute_vector(size_t idx);
118
119 /*
120 * Thin wrapper for eltwise injector specific function
121 */
122 void prepare_table(bool gen_table = true);
123 void set_lambda_injector(lambda_jit_injectors_t::key_type,
124 const lambda_jit_injectors_t::mapped_type &jit_injector);
125
126private:
127 post_ops_t post_ops_;
128 jit_generator *host_;
129 // Key is a numerical order of a post-op in attributes.
130 std::map<int, jit_uni_eltwise_injector_f32<isa, Vmm>>
131 alg_to_eltwise_injector_;
132 std::unique_ptr<binary_injector::jit_uni_binary_injector_t<isa, Vmm>>
133 binary_injector_;
134 lambda_jit_injectors_t lambda_jit_injectors_;
135};
136
137enum post_op_type { sum = 0, eltwise, binary };
138
139struct post_ops_ok_args_t {
140 post_ops_ok_args_t(const cpu_isa_t isa,
141 const std::vector<post_op_type> &accepted_post_op_types,
142 const post_ops_t &post_ops,
143 const memory_desc_wrapper *dst_d = nullptr,
144 const bool sum_at_pos_0_only = false,
145 const bool sum_requires_scale_one = false,
146 const bool sum_requires_zp_zero = true,
147 const bcast_set_t &enabled_bcast_strategy = default_strategies());
148
149 const cpu_isa_t isa;
150 const std::vector<post_op_type> &accepted_post_op_types;
151 const post_ops_t &post_ops;
152 const memory_desc_wrapper *dst_d;
153 const bool sum_at_pos_0_only;
154 const bool sum_requires_scale_one;
155 const bool sum_requires_zp_zero;
156 const bcast_set_t enabled_bcast_strategy;
157};
158
159bool post_ops_ok(const post_ops_ok_args_t &args);
160
161} // namespace injector
162} // namespace x64
163} // namespace cpu
164} // namespace impl
165} // namespace dnnl
166
167#endif
168