1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/rnn/jit_diff_weights_peephole.hpp"
18#include "common/c_types_map.hpp"
19#include "cpu/rnn/rnn_utils.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace cpu {
24namespace x64 {
25
26jit_diff_weights_peephole_t::jit_diff_weights_peephole_t(
27 const rnn_utils::rnn_conf_t &rnn, const dim_t dhc_block_size)
28 : jit_generator(jit_name())
29 , c_states_dt_(rnn.src_iter_c_dt)
30 , scratch_dt_(rnn.is_bf16_conf() ? data_type::bf16 : data_type::f32)
31 , dst_dt_(data_type::f32)
32 , compute_block_size_(dhc_block_size)
33 , tail_size_(dhc_block_size % simd_w_)
34 , io_(this, mayiuse(avx512_core_bf16) ? avx512_core_bf16 : avx512_core,
35 {c_states_dt_, scratch_dt_, dst_dt_}, {},
36 io::io_tail_conf_t {static_cast<std::size_t>(simd_w_),
37 static_cast<std::size_t>(tail_size_), tail_opmask_, 0,
38 reg_tmp_}) {}
39
40void jit_diff_weights_peephole_t::generate() {
41 preamble();
42 load_addresses();
43 init();
44 compute_loop();
45 postamble();
46}
47
48#define PARAM_OFF(x) offsetof(jit_diff_weights_peephole_t::call_params_t, x)
49
50void jit_diff_weights_peephole_t::load_addresses() {
51 mov(reg_c_states_, ptr[abi_param1 + PARAM_OFF(c_states)]);
52 mov(reg_scratch_gates_, ptr[abi_param1 + PARAM_OFF(scratch_gates)]);
53 mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst)]);
54}
55
56#undef PARAM_OFF
57
58void jit_diff_weights_peephole_t::init() {
59 if (tail_size_) { io_.prepare_tail_mask(); }
60}
61
62void jit_diff_weights_peephole_t::compute_loop() {
63
64 Xbyak::Label unroll_loop, unroll_loop_tail;
65
66 mov(loop_cnt_, compute_block_size_);
67 xor_(reg_offset_, reg_offset_);
68
69 const size_t offt_max = max_unrolling * simd_w_;
70 const size_t full_unroling_steps = compute_block_size_ / offt_max;
71
72 if (full_unroling_steps) {
73 L(unroll_loop);
74 {
75 cmp(loop_cnt_, offt_max);
76 jl(unroll_loop_tail, T_NEAR);
77
78 compute_dst(max_unrolling, false /*tail*/);
79 sub(loop_cnt_, offt_max);
80 add(reg_offset_, offt_max);
81 jmp(unroll_loop);
82 }
83 }
84
85 const size_t full_blocks_left = (compute_block_size_ - tail_size_
86 - (full_unroling_steps * offt_max))
87 / simd_w_;
88
89 L(unroll_loop_tail);
90 {
91 if (full_blocks_left) {
92 compute_dst(full_blocks_left, false /*tail*/);
93 if (tail_size_) {
94 const size_t offt = full_blocks_left * simd_w_;
95 add(reg_offset_, offt);
96 }
97 }
98 if (tail_size_) { compute_dst(1u /*unrolling factor*/, true /*tail*/); }
99 }
100}
101
102void jit_diff_weights_peephole_t::compute_dst(
103 size_t unrolling_factor, bool tail) {
104
105 static constexpr dim_t number_vmm_single_compute = 3;
106
107 const auto get_compute_zmm = [=](size_t base_idx, size_t unroll_group) {
108 return Xbyak::Zmm(base_idx + unroll_group * number_vmm_single_compute);
109 };
110
111 const auto get_addr = [&](const Xbyak::Reg64 &reg_base, const dim_t offt,
112 const data_type_t dt) {
113 const auto dt_size = types::data_type_size(dt);
114 return ptr[reg_base + reg_offset_ * dt_size + offt * dt_size];
115 };
116
117 static constexpr size_t dst_idx = 0;
118 static constexpr size_t scratch_idx = 1;
119 static constexpr size_t c_states_idx = 2;
120
121 const auto io_dst = io_.at(dst_dt_);
122 const auto io_scratch = io_.at(scratch_dt_);
123 const auto io_c_states = io_.at(c_states_dt_);
124
125 for (size_t unroll_group = 0; unroll_group < unrolling_factor;
126 ++unroll_group) {
127
128 const auto dst_zmm = get_compute_zmm(dst_idx, unroll_group);
129 const auto scratch_zmm = get_compute_zmm(scratch_idx, unroll_group);
130 const auto c_states_zmm = get_compute_zmm(c_states_idx, unroll_group);
131
132 const auto unroll_offset = unroll_group * simd_w_;
133 const auto dst_addr = get_addr(reg_dst_, unroll_offset, dst_dt_);
134 io_dst->load(dst_addr, dst_zmm, tail);
135 io_scratch->load(
136 get_addr(reg_scratch_gates_, unroll_offset, scratch_dt_),
137 scratch_zmm, tail);
138 io_c_states->load(get_addr(reg_c_states_, unroll_offset, c_states_dt_),
139 c_states_zmm, tail);
140 const auto dst_zmm_masked = tail ? dst_zmm | tail_opmask_ : dst_zmm;
141 uni_vfmadd231ps(dst_zmm_masked, scratch_zmm, c_states_zmm);
142 io_dst->store(dst_zmm, dst_addr, tail);
143 }
144}
145
146} // namespace x64
147} // namespace cpu
148} // namespace impl
149} // namespace dnnl
150