1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_LSTM_CELL_POSTGEMM_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_LSTM_CELL_POSTGEMM_HPP |
19 | |
20 | #include "common/utils.hpp" |
21 | #include "cpu/x64/jit_generator.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace x64 { |
27 | |
28 | template <cpu_isa_t isa> |
29 | struct jit_uni_lstm_cell_postgemm_t { |
30 | jit_uni_lstm_cell_postgemm_t( |
31 | jit_generator *host, int tmp_id_begin, bool use_bf16_emu) |
32 | : host_(host) |
33 | , min_allowed_tmp_vmm_idx_(0) |
34 | , max_allowed_tmp_vmm_idx_(cpu_isa_traits<isa>::n_vregs - 1 |
35 | - (is_superset(isa, avx512_core) && use_bf16_emu ? 4 : 0)) { |
36 | reset_tmp_vmm_idx_range(tmp_id_begin, max_allowed_tmp_vmm_idx_); |
37 | } |
38 | |
39 | protected: |
40 | using injector_t = typename utils::conditional<isa == avx512_core, |
41 | jit_uni_eltwise_injector_f32<avx512_core>, |
42 | jit_uni_eltwise_injector_f32<isa>>::type; |
43 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
44 | const size_t vlen_ = cpu_isa_traits<isa>::vlen; |
45 | |
46 | Vmm get_next_tmp_vmm() { |
47 | const Vmm vmm {current_tmp_id_++}; |
48 | |
49 | if (current_tmp_id_ > tmp_id_last_) reset_vmm_cnt(); |
50 | |
51 | return vmm; |
52 | } |
53 | |
54 | Vmm maybe_get_next_tmp_vmm_for_below_avx2_isa() { |
55 | if (!this->avx2_available_) return get_next_tmp_vmm(); |
56 | |
57 | return Vmm(0); // return 0th register as dummy |
58 | } |
59 | |
60 | void reset_vmm_cnt() { current_tmp_id_ = tmp_id_first_; } |
61 | int get_min_allowed_tmp_vmm_allowed_idx() const { |
62 | return min_allowed_tmp_vmm_idx_; |
63 | } |
64 | int get_max_allowed_tmp_vmm_allowed_idx() const { |
65 | return max_allowed_tmp_vmm_idx_; |
66 | } |
67 | void reset_tmp_vmm_idx_range(int lower_idx, int upper_idx) { |
68 | assert(lower_idx >= get_min_allowed_tmp_vmm_allowed_idx() |
69 | && upper_idx <= get_max_allowed_tmp_vmm_allowed_idx() |
70 | && lower_idx <= upper_idx); |
71 | tmp_id_first_ = lower_idx; |
72 | tmp_id_last_ = upper_idx; |
73 | reset_vmm_cnt(); |
74 | } |
75 | |
76 | Xbyak::Xmm get_next_tmp_xmm() { |
77 | return Xbyak::Xmm(get_next_tmp_vmm().getIdx()); |
78 | } |
79 | |
80 | Vmm vmm_backup(const Vmm &vmm) { |
81 | auto tmp_vmm = vmm; |
82 | if (!this->avx2_available_) { |
83 | tmp_vmm = this->get_next_tmp_vmm(); |
84 | host_->uni_vmovups(tmp_vmm, vmm); |
85 | } |
86 | return tmp_vmm; |
87 | }; |
88 | |
89 | Xbyak::Xmm xmm_backup(const Xbyak::Xmm &xmm) { |
90 | auto tmp_xmm = xmm; |
91 | if (!this->avx2_available_) { |
92 | tmp_xmm = this->get_next_tmp_xmm(); |
93 | host_->uni_vmovss(tmp_xmm, xmm); |
94 | } |
95 | return tmp_xmm; |
96 | }; |
97 | |
98 | void vaddps_rhs_op_mem( |
99 | const Vmm &dst, const Vmm &lhs, const Xbyak::Address &rhs_addr) { |
100 | |
101 | if (avx2_available_) |
102 | host_->uni_vaddps(dst, lhs, rhs_addr); |
103 | else { |
104 | const auto rhs = get_next_tmp_vmm(); |
105 | host_->uni_vmovups(rhs, rhs_addr); |
106 | host_->uni_vaddps(dst, lhs, rhs); |
107 | } |
108 | } |
109 | |
110 | void vfmadd231ps_rhs_op_mem( |
111 | const Vmm &dst, const Vmm &lhs, const Xbyak::Address &rhs_addr) { |
112 | if (avx2_available_) |
113 | host_->uni_vfmadd231ps(dst, lhs, rhs_addr); |
114 | else { |
115 | const auto tmp = get_next_tmp_vmm(); |
116 | host_->uni_vmovups(tmp, rhs_addr); |
117 | const auto &rhs = lhs; |
118 | host_->uni_vfmadd231ps(dst, tmp, rhs); |
119 | } |
120 | } |
121 | |
122 | void vmulps_rhs_op_mem( |
123 | const Vmm &dst, const Vmm &lhs, const Xbyak::Address &rhs_addr) { |
124 | if (avx2_available_) |
125 | host_->uni_vmulps(dst, lhs, rhs_addr); |
126 | else { |
127 | const auto rhs = get_next_tmp_vmm(); |
128 | host_->uni_vmovups(rhs, rhs_addr); |
129 | host_->uni_vmulps(dst, lhs, rhs); |
130 | } |
131 | } |
132 | |
133 | void vaddss_rhs_op_mem(const Xbyak::Xmm &dst, const Xbyak::Xmm &lhs, |
134 | const Xbyak::Address &rhs_addr) { |
135 | if (avx2_available_) |
136 | host_->uni_vaddss(dst, lhs, rhs_addr); |
137 | else { |
138 | const auto rhs = get_next_tmp_xmm(); |
139 | host_->uni_vmovss(rhs, rhs_addr); |
140 | host_->uni_vaddss(dst, lhs, rhs); |
141 | } |
142 | } |
143 | |
144 | void vfmadd231ss_rhs_op_mem(const Xbyak::Xmm &dst, const Xbyak::Xmm &lhs, |
145 | const Xbyak::Address &rhs_addr) { |
146 | if (avx2_available_) |
147 | host_->uni_vfmadd231ss(dst, lhs, rhs_addr); |
148 | else { |
149 | const auto tmp = get_next_tmp_xmm(); |
150 | host_->uni_vmovss(tmp, rhs_addr); |
151 | const auto &rhs = lhs; |
152 | host_->uni_vfmadd231ss(dst, tmp, rhs); |
153 | } |
154 | } |
155 | |
156 | void vmulss_rhs_op_mem(const Xbyak::Xmm &dst, const Xbyak::Xmm &lhs, |
157 | const Xbyak::Address &rhs_addr) { |
158 | if (avx2_available_) |
159 | host_->uni_vmulss(dst, lhs, rhs_addr); |
160 | else { |
161 | const auto rhs = get_next_tmp_xmm(); |
162 | host_->uni_vmovss(rhs, rhs_addr); |
163 | host_->uni_vmulss(dst, lhs, rhs); |
164 | } |
165 | } |
166 | |
167 | protected: |
168 | const bool avx2_available_ = is_superset(isa, avx2); |
169 | |
170 | private: |
171 | jit_generator *host_; |
172 | const int min_allowed_tmp_vmm_idx_; |
173 | const int max_allowed_tmp_vmm_idx_; |
174 | int tmp_id_first_; |
175 | int current_tmp_id_; |
176 | int tmp_id_last_; |
177 | }; |
178 | |
179 | } // namespace x64 |
180 | } // namespace cpu |
181 | } // namespace impl |
182 | } // namespace dnnl |
183 | |
184 | #endif |
185 | |