1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_matmul.hpp" |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace gpu { |
24 | namespace ocl { |
25 | |
26 | status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const { |
27 | const auto &a = CTX_IN_STORAGE(DNNL_ARG_SRC); |
28 | const auto &b = CTX_IN_STORAGE(DNNL_ARG_WEIGHTS); |
29 | const auto &bias = CTX_IN_STORAGE(DNNL_ARG_BIAS); |
30 | |
31 | auto &c = CTX_OUT_STORAGE(DNNL_ARG_DST); |
32 | |
33 | auto &src_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); |
34 | auto &wei_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); |
35 | auto &dst_scales = CTX_IN_STORAGE(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
36 | const auto &a0 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); |
37 | const auto &b0 |
38 | = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS); |
39 | const auto &c0 = CTX_IN_STORAGE(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); |
40 | |
41 | const auto a_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md()); |
42 | const auto b_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md()); |
43 | const auto c_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); |
44 | const auto bia_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1)); |
45 | |
46 | // All tensors must have the same order. |
47 | // If order > 2D, all dimensions above 2 will be combined into a single |
48 | // batch dimension. For this reason block formats are not supported. |
49 | |
50 | const int last = c_d.ndims() - 1; |
51 | |
52 | dnnl_dims_t bia_stride {0}; |
53 | if (bia_d.data_type() != data_type::undef) { |
54 | const auto &bia_strides = bia_d.blocking_desc().strides; |
55 | for (int i = 0; i < bia_d.ndims(); i++) { |
56 | if (bia_d.dims()[last - i] > 1) { |
57 | bia_stride[i] = bia_strides[last - i]; |
58 | } else { |
59 | bia_stride[i] = 0; |
60 | } |
61 | } |
62 | } |
63 | |
64 | dnnl_dims_t a_stride {0}; |
65 | dnnl_dims_t b_stride {0}; |
66 | dnnl_dims_t c_stride {0}; |
67 | const auto &a_strides = a_d.blocking_desc().strides; |
68 | const auto &b_strides = b_d.blocking_desc().strides; |
69 | const auto &c_strides = c_d.blocking_desc().strides; |
70 | for (int i = 0; i < c_d.ndims(); i++) { |
71 | if (a_d.dims()[last - i] > 1) { a_stride[i] = a_strides[last - i]; } |
72 | if (b_d.dims()[last - i] > 1) { b_stride[i] = b_strides[last - i]; } |
73 | if (c_d.dims()[last - i] > 1) { c_stride[i] = c_strides[last - i]; } |
74 | } |
75 | |
76 | const dim_t D3 = c_d.ndims() > 5 ? c_d.dims()[last - 5] : 1; |
77 | const dim_t D2 = c_d.ndims() > 4 ? c_d.dims()[last - 4] : 1; |
78 | const dim_t D1 = c_d.ndims() > 3 ? c_d.dims()[last - 3] : 1; |
79 | const dim_t D0 = c_d.ndims() > 2 ? c_d.dims()[last - 2] : 1; |
80 | const dim_t M = c_d.dims()[last - 1]; |
81 | const dim_t N = c_d.dims()[last]; |
82 | const dim_t K = a_d.dims()[last]; |
83 | const dim_t wei_scale_stride |
84 | = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ == 0 ? 0 : 1; |
85 | |
86 | compute::kernel_arg_list_t arg_list; |
87 | arg_list.set(0, a); |
88 | arg_list.set(1, b); |
89 | arg_list.set(2, c); |
90 | arg_list.set(3, bias); |
91 | arg_list.set(4, a0); |
92 | arg_list.set(5, b0); |
93 | arg_list.set(6, c0); |
94 | arg_list.set(7, src_scales); |
95 | arg_list.set(8, wei_scales); |
96 | arg_list.set(9, wei_scale_stride); |
97 | arg_list.set(10, dst_scales); |
98 | arg_list.set(11, K); |
99 | arg_list.set(12, N); |
100 | arg_list.set(13, M); |
101 | arg_list.set(14, D0); |
102 | arg_list.set(15, D1); |
103 | arg_list.set(16, D2); |
104 | arg_list.set(17, bia_stride[5]); |
105 | arg_list.set(18, bia_stride[4]); |
106 | arg_list.set(19, bia_stride[3]); |
107 | arg_list.set(20, bia_stride[2]); |
108 | arg_list.set(21, bia_stride[1]); |
109 | arg_list.set(22, bia_stride[0]); |
110 | arg_list.set(23, a_stride[5]); |
111 | arg_list.set(24, a_stride[4]); |
112 | arg_list.set(25, a_stride[3]); |
113 | arg_list.set(26, a_stride[2]); |
114 | arg_list.set(27, a_stride[1]); |
115 | arg_list.set(28, a_stride[0]); |
116 | arg_list.set(29, b_stride[5]); |
117 | arg_list.set(30, b_stride[4]); |
118 | arg_list.set(31, b_stride[3]); |
119 | arg_list.set(32, b_stride[2]); |
120 | arg_list.set(33, b_stride[1]); |
121 | arg_list.set(34, b_stride[0]); |
122 | arg_list.set(35, c_stride[5]); |
123 | arg_list.set(36, c_stride[4]); |
124 | arg_list.set(37, c_stride[3]); |
125 | arg_list.set(38, c_stride[2]); |
126 | arg_list.set(39, c_stride[1]); |
127 | arg_list.set(40, c_stride[0]); |
128 | |
129 | append_post_ops_to_arg_list(ctx, arg_list, 41, pd()->attr()->post_ops_); |
130 | |
131 | size_t gws[3] = {1, (size_t)N, (size_t)(D0 * D1 * D2 * D3)}; |
132 | auto nd_range = compute::nd_range_t(gws); |
133 | |
134 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
135 | return status; |
136 | } |
137 | |
138 | } // namespace ocl |
139 | } // namespace gpu |
140 | } // namespace impl |
141 | } // namespace dnnl |
142 | |