1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_GEMM_MATMUL_HPP |
18 | #define GPU_OCL_GEMM_MATMUL_HPP |
19 | |
20 | #include "common/gemm_utils.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "common/primitive_desc_iterator.hpp" |
23 | #include "gpu/gemm/gpu_gemm.hpp" |
24 | #include "gpu/gpu_matmul_pd.hpp" |
25 | #include "gpu/gpu_primitive.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace gpu { |
30 | namespace ocl { |
31 | |
32 | struct gemm_matmul_t : public gpu_primitive_t { |
33 | using gpu_primitive_t::gpu_primitive_t; |
34 | struct pd_t : public gpu_matmul_pd_t { |
35 | pd_t(const matmul_desc_t *adesc, const primitive_attr_t *attr, |
36 | const matmul_pd_t *hint_pd) |
37 | : gpu_matmul_pd_t(adesc, attr, hint_pd) {} |
38 | |
39 | pd_t(const pd_t &other) = default; |
40 | |
41 | DECLARE_COMMON_PD_T(gemm_pd_->name(), gemm_matmul_t); |
42 | |
43 | status_t init(engine_t *engine) { |
44 | using namespace data_type; |
45 | |
46 | primitive_attr_t gemm_attr; |
47 | if (!attr()->scales_.has_default_values()) { |
48 | gemm_attr.scales_ = attr()->scales_; |
49 | } |
50 | |
51 | auto map_gemm_zp = [&](int arg, int gemm_arg) { |
52 | if (!attr()->zero_points_.has_default_values(arg)) { |
53 | int mask = 0; |
54 | attr()->zero_points_.get(arg, &mask); |
55 | gemm_attr.zero_points_.set(gemm_arg, mask); |
56 | } |
57 | }; |
58 | |
59 | if (!attr()->zero_points_.has_default_values()) { |
60 | map_gemm_zp(DNNL_ARG_SRC, DNNL_ARG_B); |
61 | map_gemm_zp(DNNL_ARG_WEIGHTS, DNNL_ARG_A); |
62 | map_gemm_zp(DNNL_ARG_DST, DNNL_ARG_C); |
63 | } |
64 | |
65 | if (!attr()->post_ops_.has_default_values()) { |
66 | gemm_attr.post_ops_.copy_from(attr()->post_ops_); |
67 | } |
68 | |
69 | gemm_attr.set_fpmath_mode(attr()->fpmath_mode_); |
70 | |
71 | const auto acc_dt = desc()->accum_data_type; |
72 | |
73 | // We create a gemm_pd and resolve 'any' desc by querying gemm_pd |
74 | bool ok = status::success |
75 | == create_gemm_pd(gemm_pd_, engine, src_md(), |
76 | weights_md(), dst_md(), weights_md(1), |
77 | acc_dt, &gemm_attr) |
78 | && status::success == set_default_params() |
79 | && attr_.set_default_formats(dst_md(0)) == status::success; |
80 | if (!ok) return status::unimplemented; |
81 | |
82 | init_scratchpad(); |
83 | |
84 | return status::success; |
85 | } |
86 | |
87 | std::shared_ptr<primitive_desc_t> gemm_pd_; |
88 | |
89 | private: |
90 | status_t set_default_params() { |
91 | src_md_ = *gemm_pd_->arg_md(DNNL_ARG_SRC_0); |
92 | weights_md_ = *gemm_pd_->arg_md(DNNL_ARG_SRC_1); |
93 | bias_md_ = *gemm_pd_->arg_md(DNNL_ARG_BIAS); |
94 | dst_md_ = *gemm_pd_->arg_md(DNNL_ARG_DST); |
95 | return status::success; |
96 | } |
97 | |
98 | void init_scratchpad() { |
99 | auto scratchpad = scratchpad_registry().registrar(); |
100 | scratchpad.book(memory_tracking::names::key_nested, |
101 | gemm_pd_->scratchpad_registry()); |
102 | } |
103 | }; |
104 | |
105 | status_t init(engine_t *engine) override { |
106 | return create_nested_primitive(gemm_, pd()->gemm_pd_, engine); |
107 | } |
108 | |
109 | status_t execute(const exec_ctx_t &ctx) const override; |
110 | |
111 | private: |
112 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
113 | std::shared_ptr<primitive_t> gemm_; |
114 | }; |
115 | |
116 | } // namespace ocl |
117 | } // namespace gpu |
118 | } // namespace impl |
119 | } // namespace dnnl |
120 | |
121 | #endif |
122 | |