1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEMM_MATMUL_HPP
18#define GPU_OCL_GEMM_MATMUL_HPP
19
20#include "common/gemm_utils.hpp"
21#include "common/primitive.hpp"
22#include "common/primitive_desc_iterator.hpp"
23#include "gpu/gemm/gpu_gemm.hpp"
24#include "gpu/gpu_matmul_pd.hpp"
25#include "gpu/gpu_primitive.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30namespace ocl {
31
32struct gemm_matmul_t : public gpu_primitive_t {
33 using gpu_primitive_t::gpu_primitive_t;
34 struct pd_t : public gpu_matmul_pd_t {
35 pd_t(const matmul_desc_t *adesc, const primitive_attr_t *attr,
36 const matmul_pd_t *hint_pd)
37 : gpu_matmul_pd_t(adesc, attr, hint_pd) {}
38
39 pd_t(const pd_t &other) = default;
40
41 DECLARE_COMMON_PD_T(gemm_pd_->name(), gemm_matmul_t);
42
43 status_t init(engine_t *engine) {
44 using namespace data_type;
45
46 primitive_attr_t gemm_attr;
47 if (!attr()->scales_.has_default_values()) {
48 gemm_attr.scales_ = attr()->scales_;
49 }
50
51 auto map_gemm_zp = [&](int arg, int gemm_arg) {
52 if (!attr()->zero_points_.has_default_values(arg)) {
53 int mask = 0;
54 attr()->zero_points_.get(arg, &mask);
55 gemm_attr.zero_points_.set(gemm_arg, mask);
56 }
57 };
58
59 if (!attr()->zero_points_.has_default_values()) {
60 map_gemm_zp(DNNL_ARG_SRC, DNNL_ARG_B);
61 map_gemm_zp(DNNL_ARG_WEIGHTS, DNNL_ARG_A);
62 map_gemm_zp(DNNL_ARG_DST, DNNL_ARG_C);
63 }
64
65 if (!attr()->post_ops_.has_default_values()) {
66 gemm_attr.post_ops_.copy_from(attr()->post_ops_);
67 }
68
69 gemm_attr.set_fpmath_mode(attr()->fpmath_mode_);
70
71 const auto acc_dt = desc()->accum_data_type;
72
73 // We create a gemm_pd and resolve 'any' desc by querying gemm_pd
74 bool ok = status::success
75 == create_gemm_pd(gemm_pd_, engine, src_md(),
76 weights_md(), dst_md(), weights_md(1),
77 acc_dt, &gemm_attr)
78 && status::success == set_default_params()
79 && attr_.set_default_formats(dst_md(0)) == status::success;
80 if (!ok) return status::unimplemented;
81
82 init_scratchpad();
83
84 return status::success;
85 }
86
87 std::shared_ptr<primitive_desc_t> gemm_pd_;
88
89 private:
90 status_t set_default_params() {
91 src_md_ = *gemm_pd_->arg_md(DNNL_ARG_SRC_0);
92 weights_md_ = *gemm_pd_->arg_md(DNNL_ARG_SRC_1);
93 bias_md_ = *gemm_pd_->arg_md(DNNL_ARG_BIAS);
94 dst_md_ = *gemm_pd_->arg_md(DNNL_ARG_DST);
95 return status::success;
96 }
97
98 void init_scratchpad() {
99 auto scratchpad = scratchpad_registry().registrar();
100 scratchpad.book(memory_tracking::names::key_nested,
101 gemm_pd_->scratchpad_registry());
102 }
103 };
104
105 status_t init(engine_t *engine) override {
106 return create_nested_primitive(gemm_, pd()->gemm_pd_, engine);
107 }
108
109 status_t execute(const exec_ctx_t &ctx) const override;
110
111private:
112 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
113 std::shared_ptr<primitive_t> gemm_;
114};
115
116} // namespace ocl
117} // namespace gpu
118} // namespace impl
119} // namespace dnnl
120
121#endif
122