1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/rnn/brgemm_cell_common_reorders.hpp"
18#include "common/dnnl_thread.hpp"
19#include "cpu/rnn/rnn_utils.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace cpu {
24namespace x64 {
25
26src_layer_iter_transpose_t::src_layer_iter_transpose_t(const int src_ld,
27 const int dst_ld, const int rows, const int cols,
28 jit_brgemm_trans_src_t *const kernel_transpose)
29 : src_ld_(src_ld)
30 , dst_ld_(dst_ld)
31 , src_rows_(rows)
32 , src_cols_(cols)
33 , kernel_transpose_(kernel_transpose) {};
34
35template <typename Dt>
36void src_layer_iter_transpose_t::execute(const Dt *src, Dt *dst) const {
37 static constexpr int block_size = 16;
38 const auto rows_div = std::div(src_rows_, block_size);
39 const auto rows_tail = rows_div.rem;
40 const auto rows_blks = rows_div.quot + (rows_tail > 0 ? 1 : 0);
41 const auto cols_div = std::div(src_cols_, block_size);
42 const auto cols_tail = cols_div.rem;
43 const auto cols_blks = cols_div.quot + (cols_tail > 0 ? 1 : 0);
44
45 parallel_nd(cols_blks, rows_blks, [&](dim_t c, dim_t r) {
46 const auto current_rows
47 = (rows_tail && r == rows_blks - 1) ? rows_tail : block_size;
48 const auto current_cols
49 = (cols_tail && c == cols_blks - 1) ? cols_tail : block_size;
50
51 auto ctx = jit_brgemm_trans_src_t::ctx_t();
52 ctx.src = (void *)(src + (r * src_ld_ + c) * block_size);
53 ctx.tr_src = (void *)(dst + (c * dst_ld_ + r) * block_size);
54 ctx.current_gemm_batch = 1;
55 ctx.current_M = current_cols;
56 ctx.current_K = current_rows;
57
58 (*kernel_transpose_)(&ctx);
59 });
60}
61
62template void src_layer_iter_transpose_t::execute<float>(
63 const float *, float *) const;
64template void src_layer_iter_transpose_t::execute<bfloat16_t>(
65 const bfloat16_t *, bfloat16_t *) const;
66
67} // namespace x64
68} // namespace cpu
69} // namespace impl
70} // namespace dnnl
71