1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_MATMUL_BRGEMM_MATMUL_REORDERS_HPP
18#define CPU_X64_MATMUL_BRGEMM_MATMUL_REORDERS_HPP
19
20#include "cpu/reorder/cpu_reorder_pd.hpp"
21#include "cpu/x64/matmul/brgemm_matmul_copy_utils.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28struct brgemm_matmul_matrix_B_reorder_t : public primitive_t {
29 struct pd_t : public cpu_reorder_pd_t {
30 using cpu_reorder_pd_t::cpu_reorder_pd_t;
31
32 DECLARE_COMMON_PD_T("brgemm_matmul_matrix_B_reorder_t",
33 brgemm_matmul_matrix_B_reorder_t);
34
35 // required to re-use brgemm matmul copy_b jit kernels
36 matmul::brgemm_matmul_conf_t matmul_conf_for_reorder_;
37 status_t init(
38 engine_t *engine, engine_t *src_engine, engine_t *dst_engine);
39
40 private:
41 static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
42 const primitive_attr_t *attr, engine_t *src_engine,
43 const memory_desc_t *src_md, engine_t *dst_engine,
44 const memory_desc_t *dst_md);
45
46 void init_scratchpad() {}
47 friend dnnl::impl::impl_list_item_t;
48 };
49
50 brgemm_matmul_matrix_B_reorder_t(const pd_t *apd) : primitive_t(apd) {}
51 status_t init(engine_t *engine) override {
52 CHECK(matmul::create_brgemm_matmul_copy_b(
53 kernel_, &pd()->matmul_conf_for_reorder_));
54
55 return status::success;
56 }
57
58private:
59 status_t execute_body(const exec_ctx_t &ctx) const;
60 status_t execute(const exec_ctx_t &ctx) const override {
61 return execute_body(ctx);
62 }
63
64 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
65 std::unique_ptr<matmul::jit_brgemm_matmul_copy_b_t> kernel_;
66};
67
68} // namespace x64
69} // namespace cpu
70} // namespace impl
71} // namespace dnnl
72
73#endif
74