1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_LSTM_CELL_POSTGEMM_HPP
18#define CPU_X64_RNN_JIT_UNI_LSTM_CELL_POSTGEMM_HPP
19
20#include "common/utils.hpp"
21#include "cpu/x64/jit_generator.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28template <cpu_isa_t isa>
29struct jit_uni_lstm_cell_postgemm_t {
30 jit_uni_lstm_cell_postgemm_t(
31 jit_generator *host, int tmp_id_begin, bool use_bf16_emu)
32 : host_(host)
33 , min_allowed_tmp_vmm_idx_(0)
34 , max_allowed_tmp_vmm_idx_(cpu_isa_traits<isa>::n_vregs - 1
35 - (is_superset(isa, avx512_core) && use_bf16_emu ? 4 : 0)) {
36 reset_tmp_vmm_idx_range(tmp_id_begin, max_allowed_tmp_vmm_idx_);
37 }
38
39protected:
40 using injector_t = typename utils::conditional<isa == avx512_core,
41 jit_uni_eltwise_injector_f32<avx512_core>,
42 jit_uni_eltwise_injector_f32<isa>>::type;
43 using Vmm = typename cpu_isa_traits<isa>::Vmm;
44 const size_t vlen_ = cpu_isa_traits<isa>::vlen;
45
46 Vmm get_next_tmp_vmm() {
47 const Vmm vmm {current_tmp_id_++};
48
49 if (current_tmp_id_ > tmp_id_last_) reset_vmm_cnt();
50
51 return vmm;
52 }
53
54 Vmm maybe_get_next_tmp_vmm_for_below_avx2_isa() {
55 if (!this->avx2_available_) return get_next_tmp_vmm();
56
57 return Vmm(0); // return 0th register as dummy
58 }
59
60 void reset_vmm_cnt() { current_tmp_id_ = tmp_id_first_; }
61 int get_min_allowed_tmp_vmm_allowed_idx() const {
62 return min_allowed_tmp_vmm_idx_;
63 }
64 int get_max_allowed_tmp_vmm_allowed_idx() const {
65 return max_allowed_tmp_vmm_idx_;
66 }
67 void reset_tmp_vmm_idx_range(int lower_idx, int upper_idx) {
68 assert(lower_idx >= get_min_allowed_tmp_vmm_allowed_idx()
69 && upper_idx <= get_max_allowed_tmp_vmm_allowed_idx()
70 && lower_idx <= upper_idx);
71 tmp_id_first_ = lower_idx;
72 tmp_id_last_ = upper_idx;
73 reset_vmm_cnt();
74 }
75
76 Xbyak::Xmm get_next_tmp_xmm() {
77 return Xbyak::Xmm(get_next_tmp_vmm().getIdx());
78 }
79
80 Vmm vmm_backup(const Vmm &vmm) {
81 auto tmp_vmm = vmm;
82 if (!this->avx2_available_) {
83 tmp_vmm = this->get_next_tmp_vmm();
84 host_->uni_vmovups(tmp_vmm, vmm);
85 }
86 return tmp_vmm;
87 };
88
89 Xbyak::Xmm xmm_backup(const Xbyak::Xmm &xmm) {
90 auto tmp_xmm = xmm;
91 if (!this->avx2_available_) {
92 tmp_xmm = this->get_next_tmp_xmm();
93 host_->uni_vmovss(tmp_xmm, xmm);
94 }
95 return tmp_xmm;
96 };
97
98 void vaddps_rhs_op_mem(
99 const Vmm &dst, const Vmm &lhs, const Xbyak::Address &rhs_addr) {
100
101 if (avx2_available_)
102 host_->uni_vaddps(dst, lhs, rhs_addr);
103 else {
104 const auto rhs = get_next_tmp_vmm();
105 host_->uni_vmovups(rhs, rhs_addr);
106 host_->uni_vaddps(dst, lhs, rhs);
107 }
108 }
109
110 void vfmadd231ps_rhs_op_mem(
111 const Vmm &dst, const Vmm &lhs, const Xbyak::Address &rhs_addr) {
112 if (avx2_available_)
113 host_->uni_vfmadd231ps(dst, lhs, rhs_addr);
114 else {
115 const auto tmp = get_next_tmp_vmm();
116 host_->uni_vmovups(tmp, rhs_addr);
117 const auto &rhs = lhs;
118 host_->uni_vfmadd231ps(dst, tmp, rhs);
119 }
120 }
121
122 void vmulps_rhs_op_mem(
123 const Vmm &dst, const Vmm &lhs, const Xbyak::Address &rhs_addr) {
124 if (avx2_available_)
125 host_->uni_vmulps(dst, lhs, rhs_addr);
126 else {
127 const auto rhs = get_next_tmp_vmm();
128 host_->uni_vmovups(rhs, rhs_addr);
129 host_->uni_vmulps(dst, lhs, rhs);
130 }
131 }
132
133 void vaddss_rhs_op_mem(const Xbyak::Xmm &dst, const Xbyak::Xmm &lhs,
134 const Xbyak::Address &rhs_addr) {
135 if (avx2_available_)
136 host_->uni_vaddss(dst, lhs, rhs_addr);
137 else {
138 const auto rhs = get_next_tmp_xmm();
139 host_->uni_vmovss(rhs, rhs_addr);
140 host_->uni_vaddss(dst, lhs, rhs);
141 }
142 }
143
144 void vfmadd231ss_rhs_op_mem(const Xbyak::Xmm &dst, const Xbyak::Xmm &lhs,
145 const Xbyak::Address &rhs_addr) {
146 if (avx2_available_)
147 host_->uni_vfmadd231ss(dst, lhs, rhs_addr);
148 else {
149 const auto tmp = get_next_tmp_xmm();
150 host_->uni_vmovss(tmp, rhs_addr);
151 const auto &rhs = lhs;
152 host_->uni_vfmadd231ss(dst, tmp, rhs);
153 }
154 }
155
156 void vmulss_rhs_op_mem(const Xbyak::Xmm &dst, const Xbyak::Xmm &lhs,
157 const Xbyak::Address &rhs_addr) {
158 if (avx2_available_)
159 host_->uni_vmulss(dst, lhs, rhs_addr);
160 else {
161 const auto rhs = get_next_tmp_xmm();
162 host_->uni_vmovss(rhs, rhs_addr);
163 host_->uni_vmulss(dst, lhs, rhs);
164 }
165 }
166
167protected:
168 const bool avx2_available_ = is_superset(isa, avx2);
169
170private:
171 jit_generator *host_;
172 const int min_allowed_tmp_vmm_idx_;
173 const int max_allowed_tmp_vmm_idx_;
174 int tmp_id_first_;
175 int current_tmp_id_;
176 int tmp_id_last_;
177};
178
179} // namespace x64
180} // namespace cpu
181} // namespace impl
182} // namespace dnnl
183
184#endif
185