1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_REF_MATMUL_HPP |
18 | #define GPU_OCL_REF_MATMUL_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | #include "common/utils.hpp" |
26 | #include "gpu/gpu_matmul_pd.hpp" |
27 | #include "gpu/gpu_primitive.hpp" |
28 | #include "gpu/gpu_resource.hpp" |
29 | #include "gpu/ocl/ocl_utils.hpp" |
30 | #include "gpu/primitive_conf.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace gpu { |
35 | namespace ocl { |
36 | |
37 | struct ref_matmul_t : public gpu_primitive_t { |
38 | using gpu_primitive_t::gpu_primitive_t; |
39 | struct pd_t : public gpu_matmul_pd_t { |
40 | using gpu_matmul_pd_t::gpu_matmul_pd_t; |
41 | |
42 | DECLARE_COMMON_PD_T("ocl:ref:any" , ref_matmul_t); |
43 | |
44 | status_t init(engine_t *engine) { |
45 | using namespace data_type; |
46 | using smask_t = primitive_attr_t::skip_mask_t; |
47 | |
48 | src_dt_ = src_md()->data_type; |
49 | dst_dt_ = dst_md()->data_type; |
50 | wei_dt_ = weights_md(0)->data_type; |
51 | bia_dt_ = with_bias() ? weights_md(1)->data_type : data_type::f32; |
52 | |
53 | bool ok = IMPLICATION(desc()->accum_data_type == s32, |
54 | attr()->zero_points_.common()) |
55 | && IMPLICATION(desc()->accum_data_type != s32, |
56 | attr()->zero_points_.has_default_values()) |
57 | && attr()->has_default_values(smask_t::scales_runtime |
58 | | smask_t::zero_points_runtime | smask_t::post_ops) |
59 | && attr_scales_ok() && set_default_formats() |
60 | && !has_blocks() |
61 | && ((utils::one_of(src_dt_, u8, s8) |
62 | && utils::one_of(wei_dt_, u8, s8) |
63 | && utils::one_of(dst_dt_, f32, s8, u8, s32, f16) |
64 | && IMPLICATION(with_bias(), |
65 | utils::one_of( |
66 | bia_dt_, f32, u8, s8, s32))) |
67 | || ((utils::everyone_is( |
68 | f32, src_dt_, wei_dt_, dst_dt_) |
69 | || (utils::everyone_is( |
70 | f16, src_dt_, wei_dt_) |
71 | && utils::one_of( |
72 | dst_dt_, u8, s8, f16)) |
73 | || (utils::everyone_is( |
74 | bf16, src_dt_, wei_dt_) |
75 | && utils::one_of( |
76 | dst_dt_, bf16, f32))) |
77 | && IMPLICATION(with_bias(), |
78 | utils::one_of(bia_dt_, f32)))) |
79 | && post_ops_with_binary_ok(attr(), dst_dt_, 6) |
80 | && attr_.set_default_formats(dst_md(0)) == status::success; |
81 | |
82 | if (!ok) return status::unimplemented; |
83 | |
84 | non_default_attrs_ = !attr()->has_default_values(); |
85 | attr_info_ = attr_info_t::create(attr()); |
86 | |
87 | return status::success; |
88 | } |
89 | |
90 | bool non_default_attrs_ = false; |
91 | data_type_t bia_dt_ = data_type::undef; |
92 | data_type_t src_dt_ = data_type::undef; |
93 | data_type_t dst_dt_ = data_type::undef; |
94 | data_type_t wei_dt_ = data_type::undef; |
95 | |
96 | attr_info_t attr_info_ = {}; |
97 | |
98 | private: |
99 | bool attr_scales_ok() const { |
100 | std::vector<int> supported_args |
101 | = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; |
102 | if (!attr()->scales_.has_default_values(supported_args)) |
103 | return false; |
104 | for (int arg : supported_args) { |
105 | auto &scales = attr()->scales_.get(arg); |
106 | if (scales.has_default_values()) continue; |
107 | int mask = scales.mask_; |
108 | if (arg == DNNL_ARG_WEIGHTS) { |
109 | if (!utils::one_of(mask, 0, 1 << (batched() + 1))) |
110 | return false; |
111 | } else { |
112 | if (mask != 0) return false; |
113 | } |
114 | } |
115 | return true; |
116 | } |
117 | }; |
118 | |
119 | status_t init(engine_t *engine) override { |
120 | compute::kernel_ctx_t kernel_ctx; |
121 | |
122 | kernel_ctx.define_int("DST_NDIMS" , pd()->dst_md()->ndims); |
123 | kernel_ctx.define_int("WITH_BIAS" , pd()->with_bias()); |
124 | kernel_ctx.define_int("NON_DEFAULT_ATTRS" , pd()->non_default_attrs_); |
125 | |
126 | kernel_ctx.set_data_type(pd()->dst_dt_); |
127 | def_attr_info(kernel_ctx, pd()->attr_info_, pd()->attr()->post_ops_); |
128 | |
129 | def_data_type(kernel_ctx, pd()->src_dt_, "SRC" ); |
130 | def_data_type(kernel_ctx, pd()->wei_dt_, "WEI" ); |
131 | def_data_type(kernel_ctx, pd()->dst_dt_, "DST" ); |
132 | def_data_type(kernel_ctx, pd()->bia_dt_, "BIA" ); |
133 | def_data_type(kernel_ctx, pd()->desc()->accum_data_type, "ACC" ); |
134 | create_kernel(engine, &kernel_, "ref_matmul" , kernel_ctx); |
135 | if (!kernel_) return status::runtime_error; |
136 | return status::success; |
137 | } |
138 | |
139 | status_t execute(const exec_ctx_t &ctx) const override { |
140 | return execute_ref(ctx); |
141 | } |
142 | |
143 | private: |
144 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
145 | status_t execute_ref(const exec_ctx_t &ctx) const; |
146 | compute::kernel_t kernel_; |
147 | }; |
148 | |
149 | } // namespace ocl |
150 | } // namespace gpu |
151 | } // namespace impl |
152 | } // namespace dnnl |
153 | |
154 | #endif |
155 | |