1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_matmul.hpp"
18
19#include "common/c_types_map.hpp"
20#include "common/type_helpers.hpp"
21namespace dnnl {
22namespace impl {
23namespace gpu {
24namespace ocl {
25
26status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
27 const auto &a = CTX_IN_STORAGE(DNNL_ARG_SRC);
28 const auto &b = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS);
29 const auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS);
30
31 auto &c = CTX_OUT_STORAGE(DNNL_ARG_DST);
32
33 auto &src_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
34 auto &wei_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
35 auto &dst_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
36 const auto &a0 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
37 const auto &b0
38 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
39 const auto &c0 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
40
41 const auto a_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
42 const auto b_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md());
43 const auto c_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
44 const auto bia_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1));
45
46 // All tensors must have the same order.
47 // If order > 2D, all dimensions above 2 will be combined into a single
48 // batch dimension. For this reason block formats are not supported.
49
50 const int last = c_d.ndims() - 1;
51
52 dnnl_dims_t bia_stride {0};
53 if (bia_d.data_type() != data_type::undef) {
54 const auto &bia_strides = bia_d.blocking_desc().strides;
55 for (int i = 0; i < bia_d.ndims(); i++) {
56 if (bia_d.dims()[last - i] > 1) {
57 bia_stride[i] = bia_strides[last - i];
58 } else {
59 bia_stride[i] = 0;
60 }
61 }
62 }
63
64 dnnl_dims_t a_stride {0};
65 dnnl_dims_t b_stride {0};
66 dnnl_dims_t c_stride {0};
67 const auto &a_strides = a_d.blocking_desc().strides;
68 const auto &b_strides = b_d.blocking_desc().strides;
69 const auto &c_strides = c_d.blocking_desc().strides;
70 for (int i = 0; i < c_d.ndims(); i++) {
71 if (a_d.dims()[last - i] > 1) { a_stride[i] = a_strides[last - i]; }
72 if (b_d.dims()[last - i] > 1) { b_stride[i] = b_strides[last - i]; }
73 if (c_d.dims()[last - i] > 1) { c_stride[i] = c_strides[last - i]; }
74 }
75
76 const dim_t D3 = c_d.ndims() > 5 ? c_d.dims()[last - 5] : 1;
77 const dim_t D2 = c_d.ndims() > 4 ? c_d.dims()[last - 4] : 1;
78 const dim_t D1 = c_d.ndims() > 3 ? c_d.dims()[last - 3] : 1;
79 const dim_t D0 = c_d.ndims() > 2 ? c_d.dims()[last - 2] : 1;
80 const dim_t M = c_d.dims()[last - 1];
81 const dim_t N = c_d.dims()[last];
82 const dim_t K = a_d.dims()[last];
83 const dim_t wei_scale_stride
84 = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 ? 0 : 1;
85
86 compute::kernel_arg_list_t arg_list;
87 arg_list.set(0, a);
88 arg_list.set(1, b);
89 arg_list.set(2, c);
90 arg_list.set(3, bias);
91 arg_list.set(4, a0);
92 arg_list.set(5, b0);
93 arg_list.set(6, c0);
94 arg_list.set(7, src_scales);
95 arg_list.set(8, wei_scales);
96 arg_list.set(9, wei_scale_stride);
97 arg_list.set(10, dst_scales);
98 arg_list.set(11, K);
99 arg_list.set(12, N);
100 arg_list.set(13, M);
101 arg_list.set(14, D0);
102 arg_list.set(15, D1);
103 arg_list.set(16, D2);
104 arg_list.set(17, bia_stride[5]);
105 arg_list.set(18, bia_stride[4]);
106 arg_list.set(19, bia_stride[3]);
107 arg_list.set(20, bia_stride[2]);
108 arg_list.set(21, bia_stride[1]);
109 arg_list.set(22, bia_stride[0]);
110 arg_list.set(23, a_stride[5]);
111 arg_list.set(24, a_stride[4]);
112 arg_list.set(25, a_stride[3]);
113 arg_list.set(26, a_stride[2]);
114 arg_list.set(27, a_stride[1]);
115 arg_list.set(28, a_stride[0]);
116 arg_list.set(29, b_stride[5]);
117 arg_list.set(30, b_stride[4]);
118 arg_list.set(31, b_stride[3]);
119 arg_list.set(32, b_stride[2]);
120 arg_list.set(33, b_stride[1]);
121 arg_list.set(34, b_stride[0]);
122 arg_list.set(35, c_stride[5]);
123 arg_list.set(36, c_stride[4]);
124 arg_list.set(37, c_stride[3]);
125 arg_list.set(38, c_stride[2]);
126 arg_list.set(39, c_stride[1]);
127 arg_list.set(40, c_stride[0]);
128
129 append_post_ops_to_arg_list(ctx, arg_list, 41, pd()->attr()->post_ops_);
130
131 size_t gws[3] = {1, (size_t)N, (size_t)(D0 * D1 * D2 * D3)};
132 auto nd_range = compute::nd_range_t(gws);
133
134 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
135 return status;
136}
137
138} // namespace ocl
139} // namespace gpu
140} // namespace impl
141} // namespace dnnl
142