1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_MATMUL_HPP
18#define GPU_OCL_REF_MATMUL_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26#include "gpu/gpu_matmul_pd.hpp"
27#include "gpu/gpu_primitive.hpp"
28#include "gpu/gpu_resource.hpp"
29#include "gpu/ocl/ocl_utils.hpp"
30#include "gpu/primitive_conf.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace ocl {
36
37struct ref_matmul_t : public gpu_primitive_t {
38 using gpu_primitive_t::gpu_primitive_t;
39 struct pd_t : public gpu_matmul_pd_t {
40 using gpu_matmul_pd_t::gpu_matmul_pd_t;
41
42 DECLARE_COMMON_PD_T("ocl:ref:any", ref_matmul_t);
43
44 status_t init(engine_t *engine) {
45 using namespace data_type;
46 using smask_t = primitive_attr_t::skip_mask_t;
47
48 src_dt_ = src_md()->data_type;
49 dst_dt_ = dst_md()->data_type;
50 wei_dt_ = weights_md(0)->data_type;
51 bia_dt_ = with_bias() ? weights_md(1)->data_type : data_type::f32;
52
53 bool ok = IMPLICATION(desc()->accum_data_type == s32,
54 attr()->zero_points_.common())
55 && IMPLICATION(desc()->accum_data_type != s32,
56 attr()->zero_points_.has_default_values())
57 && attr()->has_default_values(smask_t::scales_runtime
58 | smask_t::zero_points_runtime | smask_t::post_ops)
59 && attr_scales_ok() && set_default_formats()
60 && !has_blocks()
61 && ((utils::one_of(src_dt_, u8, s8)
62 && utils::one_of(wei_dt_, u8, s8)
63 && utils::one_of(dst_dt_, f32, s8, u8, s32, f16)
64 && IMPLICATION(with_bias(),
65 utils::one_of(
66 bia_dt_, f32, u8, s8, s32)))
67 || ((utils::everyone_is(
68 f32, src_dt_, wei_dt_, dst_dt_)
69 || (utils::everyone_is(
70 f16, src_dt_, wei_dt_)
71 && utils::one_of(
72 dst_dt_, u8, s8, f16))
73 || (utils::everyone_is(
74 bf16, src_dt_, wei_dt_)
75 && utils::one_of(
76 dst_dt_, bf16, f32)))
77 && IMPLICATION(with_bias(),
78 utils::one_of(bia_dt_, f32))))
79 && post_ops_with_binary_ok(attr(), dst_dt_, 6)
80 && attr_.set_default_formats(dst_md(0)) == status::success;
81
82 if (!ok) return status::unimplemented;
83
84 non_default_attrs_ = !attr()->has_default_values();
85 attr_info_ = attr_info_t::create(attr());
86
87 return status::success;
88 }
89
90 bool non_default_attrs_ = false;
91 data_type_t bia_dt_ = data_type::undef;
92 data_type_t src_dt_ = data_type::undef;
93 data_type_t dst_dt_ = data_type::undef;
94 data_type_t wei_dt_ = data_type::undef;
95
96 attr_info_t attr_info_ = {};
97
98 private:
99 bool attr_scales_ok() const {
100 std::vector<int> supported_args
101 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
102 if (!attr()->scales_.has_default_values(supported_args))
103 return false;
104 for (int arg : supported_args) {
105 auto &scales = attr()->scales_.get(arg);
106 if (scales.has_default_values()) continue;
107 int mask = scales.mask_;
108 if (arg == DNNL_ARG_WEIGHTS) {
109 if (!utils::one_of(mask, 0, 1 << (batched() + 1)))
110 return false;
111 } else {
112 if (mask != 0) return false;
113 }
114 }
115 return true;
116 }
117 };
118
119 status_t init(engine_t *engine) override {
120 compute::kernel_ctx_t kernel_ctx;
121
122 kernel_ctx.define_int("DST_NDIMS", pd()->dst_md()->ndims);
123 kernel_ctx.define_int("WITH_BIAS", pd()->with_bias());
124 kernel_ctx.define_int("NON_DEFAULT_ATTRS", pd()->non_default_attrs_);
125
126 kernel_ctx.set_data_type(pd()->dst_dt_);
127 def_attr_info(kernel_ctx, pd()->attr_info_, pd()->attr()->post_ops_);
128
129 def_data_type(kernel_ctx, pd()->src_dt_, "SRC");
130 def_data_type(kernel_ctx, pd()->wei_dt_, "WEI");
131 def_data_type(kernel_ctx, pd()->dst_dt_, "DST");
132 def_data_type(kernel_ctx, pd()->bia_dt_, "BIA");
133 def_data_type(kernel_ctx, pd()->desc()->accum_data_type, "ACC");
134 create_kernel(engine, &kernel_, "ref_matmul", kernel_ctx);
135 if (!kernel_) return status::runtime_error;
136 return status::success;
137 }
138
139 status_t execute(const exec_ctx_t &ctx) const override {
140 return execute_ref(ctx);
141 }
142
143private:
144 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
145 status_t execute_ref(const exec_ctx_t &ctx) const;
146 compute::kernel_t kernel_;
147};
148
149} // namespace ocl
150} // namespace gpu
151} // namespace impl
152} // namespace dnnl
153
154#endif
155