1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/dp4a.hpp" |
18 | |
19 | #include "gpu/jit/ir/fma.hpp" |
20 | #include "gpu/jit/utils/trace.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace jit { |
26 | |
27 | class dp4a_injector_t : public ir_mutator_t { |
28 | public: |
29 | object_t _mutate(const func_call_t &obj) { |
30 | auto *dpas = obj.func.as_ptr<dpas_t>(); |
31 | if (!dpas) return obj; |
32 | |
33 | int M = dpas->exec_size; |
34 | int N = dpas->rcount; |
35 | int K = dpas->sdepth * 4; |
36 | |
37 | auto &dst = dpas_t::arg_dst(obj); |
38 | auto &src0 = dpas_t::arg_src0(obj); |
39 | auto &src1 = dpas_t::arg_src1(obj); |
40 | auto &src2 = dpas_t::arg_src2(obj); |
41 | int dst_size = dpas->dst_type.size(); |
42 | int src0_size = dpas->dst_type.size(); |
43 | int src1_size = dpas->src1_type.size(); |
44 | int src2_size = dpas->src2_type.size(); |
45 | auto dst_type = to_dp4a_type(dpas->dst_type); |
46 | auto src1_type = to_dp4a_type(dpas->src1_type); |
47 | auto src2_type = to_dp4a_type(dpas->src2_type); |
48 | bool is_src0_zero = is_zero(src0); |
49 | |
50 | stmt_t stmt; |
51 | auto _dp4a = dpas_t::make( |
52 | /*is_dpasw=*/false, M, 1, 1, dst_type, src1_type, src2_type); |
53 | auto &dp4a = _dp4a.as<dpas_t>(); |
54 | auto zero = shuffle_t::make_broadcast(0, M); |
55 | int k0 = (is_src0_zero ? -4 : 0); |
56 | for (int k = k0; k < K; k += 4) { |
57 | for (int n = 0; n < N; n++) { |
58 | int dst_off = n * M * dst_size; |
59 | int src0_off = n * M * src0_size; |
60 | int src1_off = k * M * src1_size; |
61 | int src2_off = (n * K + k) * src2_size; |
62 | auto _dst = dst + dst_off; |
63 | auto _src0 = is_src0_zero ? _dst : (src0 + src0_off); |
64 | auto _src1 = src1 + src1_off; |
65 | auto _src2 = src2 + src2_off; |
66 | if (k < 0) { |
67 | stmt = stmt.append(store_t::make(_dst, 0, zero)); |
68 | } else { |
69 | stmt = stmt.append(dp4a(_dst, _src0, _src1, _src2)); |
70 | } |
71 | } |
72 | } |
73 | return std::move(stmt); |
74 | } |
75 | |
76 | private: |
77 | static type_t to_dp4a_type(const type_t &type) { |
78 | if (type.is_x32()) return type; |
79 | if (type.is_s8()) return type_t::s32(); |
80 | if (type.is_u8()) return type_t::u32(); |
81 | ir_error_not_expected(); |
82 | return type_t(); |
83 | }; |
84 | }; |
85 | |
86 | stmt_t inject_dp4a(const stmt_t &s, ir_context_t &ir_ctx) { |
87 | trace_start(); |
88 | auto ret = dp4a_injector_t().mutate(s); |
89 | trace_pass("inject_dp4a" , ret, ir_ctx); |
90 | return ret; |
91 | } |
92 | |
93 | } // namespace jit |
94 | } // namespace gpu |
95 | } // namespace impl |
96 | } // namespace dnnl |
97 | |