1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/rnn/jit_diff_weights_peephole.hpp" |
18 | #include "common/c_types_map.hpp" |
19 | #include "cpu/rnn/rnn_utils.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace cpu { |
24 | namespace x64 { |
25 | |
26 | jit_diff_weights_peephole_t::jit_diff_weights_peephole_t( |
27 | const rnn_utils::rnn_conf_t &rnn, const dim_t dhc_block_size) |
28 | : jit_generator(jit_name()) |
29 | , c_states_dt_(rnn.src_iter_c_dt) |
30 | , scratch_dt_(rnn.is_bf16_conf() ? data_type::bf16 : data_type::f32) |
31 | , dst_dt_(data_type::f32) |
32 | , compute_block_size_(dhc_block_size) |
33 | , tail_size_(dhc_block_size % simd_w_) |
34 | , io_(this, mayiuse(avx512_core_bf16) ? avx512_core_bf16 : avx512_core, |
35 | {c_states_dt_, scratch_dt_, dst_dt_}, {}, |
36 | io::io_tail_conf_t {static_cast<std::size_t>(simd_w_), |
37 | static_cast<std::size_t>(tail_size_), tail_opmask_, 0, |
38 | reg_tmp_}) {} |
39 | |
40 | void jit_diff_weights_peephole_t::generate() { |
41 | preamble(); |
42 | load_addresses(); |
43 | init(); |
44 | compute_loop(); |
45 | postamble(); |
46 | } |
47 | |
48 | #define PARAM_OFF(x) offsetof(jit_diff_weights_peephole_t::call_params_t, x) |
49 | |
50 | void jit_diff_weights_peephole_t::load_addresses() { |
51 | mov(reg_c_states_, ptr[abi_param1 + PARAM_OFF(c_states)]); |
52 | mov(reg_scratch_gates_, ptr[abi_param1 + PARAM_OFF(scratch_gates)]); |
53 | mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst)]); |
54 | } |
55 | |
56 | #undef PARAM_OFF |
57 | |
58 | void jit_diff_weights_peephole_t::init() { |
59 | if (tail_size_) { io_.prepare_tail_mask(); } |
60 | } |
61 | |
62 | void jit_diff_weights_peephole_t::compute_loop() { |
63 | |
64 | Xbyak::Label unroll_loop, unroll_loop_tail; |
65 | |
66 | mov(loop_cnt_, compute_block_size_); |
67 | xor_(reg_offset_, reg_offset_); |
68 | |
69 | const size_t offt_max = max_unrolling * simd_w_; |
70 | const size_t full_unroling_steps = compute_block_size_ / offt_max; |
71 | |
72 | if (full_unroling_steps) { |
73 | L(unroll_loop); |
74 | { |
75 | cmp(loop_cnt_, offt_max); |
76 | jl(unroll_loop_tail, T_NEAR); |
77 | |
78 | compute_dst(max_unrolling, false /*tail*/); |
79 | sub(loop_cnt_, offt_max); |
80 | add(reg_offset_, offt_max); |
81 | jmp(unroll_loop); |
82 | } |
83 | } |
84 | |
85 | const size_t full_blocks_left = (compute_block_size_ - tail_size_ |
86 | - (full_unroling_steps * offt_max)) |
87 | / simd_w_; |
88 | |
89 | L(unroll_loop_tail); |
90 | { |
91 | if (full_blocks_left) { |
92 | compute_dst(full_blocks_left, false /*tail*/); |
93 | if (tail_size_) { |
94 | const size_t offt = full_blocks_left * simd_w_; |
95 | add(reg_offset_, offt); |
96 | } |
97 | } |
98 | if (tail_size_) { compute_dst(1u /*unrolling factor*/, true /*tail*/); } |
99 | } |
100 | } |
101 | |
102 | void jit_diff_weights_peephole_t::compute_dst( |
103 | size_t unrolling_factor, bool tail) { |
104 | |
105 | static constexpr dim_t number_vmm_single_compute = 3; |
106 | |
107 | const auto get_compute_zmm = [=](size_t base_idx, size_t unroll_group) { |
108 | return Xbyak::Zmm(base_idx + unroll_group * number_vmm_single_compute); |
109 | }; |
110 | |
111 | const auto get_addr = [&](const Xbyak::Reg64 ®_base, const dim_t offt, |
112 | const data_type_t dt) { |
113 | const auto dt_size = types::data_type_size(dt); |
114 | return ptr[reg_base + reg_offset_ * dt_size + offt * dt_size]; |
115 | }; |
116 | |
117 | static constexpr size_t dst_idx = 0; |
118 | static constexpr size_t scratch_idx = 1; |
119 | static constexpr size_t c_states_idx = 2; |
120 | |
121 | const auto io_dst = io_.at(dst_dt_); |
122 | const auto io_scratch = io_.at(scratch_dt_); |
123 | const auto io_c_states = io_.at(c_states_dt_); |
124 | |
125 | for (size_t unroll_group = 0; unroll_group < unrolling_factor; |
126 | ++unroll_group) { |
127 | |
128 | const auto dst_zmm = get_compute_zmm(dst_idx, unroll_group); |
129 | const auto scratch_zmm = get_compute_zmm(scratch_idx, unroll_group); |
130 | const auto c_states_zmm = get_compute_zmm(c_states_idx, unroll_group); |
131 | |
132 | const auto unroll_offset = unroll_group * simd_w_; |
133 | const auto dst_addr = get_addr(reg_dst_, unroll_offset, dst_dt_); |
134 | io_dst->load(dst_addr, dst_zmm, tail); |
135 | io_scratch->load( |
136 | get_addr(reg_scratch_gates_, unroll_offset, scratch_dt_), |
137 | scratch_zmm, tail); |
138 | io_c_states->load(get_addr(reg_c_states_, unroll_offset, c_states_dt_), |
139 | c_states_zmm, tail); |
140 | const auto dst_zmm_masked = tail ? dst_zmm | tail_opmask_ : dst_zmm; |
141 | uni_vfmadd231ps(dst_zmm_masked, scratch_zmm, c_states_zmm); |
142 | io_dst->store(dst_zmm, dst_addr, tail); |
143 | } |
144 | } |
145 | |
146 | } // namespace x64 |
147 | } // namespace cpu |
148 | } // namespace impl |
149 | } // namespace dnnl |
150 | |