1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/rnn/brgemm_cell_common_reorders.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "cpu/rnn/rnn_utils.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace cpu { |
24 | namespace x64 { |
25 | |
26 | src_layer_iter_transpose_t::src_layer_iter_transpose_t(const int src_ld, |
27 | const int dst_ld, const int rows, const int cols, |
28 | jit_brgemm_trans_src_t *const kernel_transpose) |
29 | : src_ld_(src_ld) |
30 | , dst_ld_(dst_ld) |
31 | , src_rows_(rows) |
32 | , src_cols_(cols) |
33 | , kernel_transpose_(kernel_transpose) {}; |
34 | |
35 | template <typename Dt> |
36 | void src_layer_iter_transpose_t::execute(const Dt *src, Dt *dst) const { |
37 | static constexpr int block_size = 16; |
38 | const auto rows_div = std::div(src_rows_, block_size); |
39 | const auto rows_tail = rows_div.rem; |
40 | const auto rows_blks = rows_div.quot + (rows_tail > 0 ? 1 : 0); |
41 | const auto cols_div = std::div(src_cols_, block_size); |
42 | const auto cols_tail = cols_div.rem; |
43 | const auto cols_blks = cols_div.quot + (cols_tail > 0 ? 1 : 0); |
44 | |
45 | parallel_nd(cols_blks, rows_blks, [&](dim_t c, dim_t r) { |
46 | const auto current_rows |
47 | = (rows_tail && r == rows_blks - 1) ? rows_tail : block_size; |
48 | const auto current_cols |
49 | = (cols_tail && c == cols_blks - 1) ? cols_tail : block_size; |
50 | |
51 | auto ctx = jit_brgemm_trans_src_t::ctx_t(); |
52 | ctx.src = (void *)(src + (r * src_ld_ + c) * block_size); |
53 | ctx.tr_src = (void *)(dst + (c * dst_ld_ + r) * block_size); |
54 | ctx.current_gemm_batch = 1; |
55 | ctx.current_M = current_cols; |
56 | ctx.current_K = current_rows; |
57 | |
58 | (*kernel_transpose_)(&ctx); |
59 | }); |
60 | } |
61 | |
62 | template void src_layer_iter_transpose_t::execute<float>( |
63 | const float *, float *) const; |
64 | template void src_layer_iter_transpose_t::execute<bfloat16_t>( |
65 | const bfloat16_t *, bfloat16_t *) const; |
66 | |
67 | } // namespace x64 |
68 | } // namespace cpu |
69 | } // namespace impl |
70 | } // namespace dnnl |
71 | |