1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/dp4a.hpp"
18
19#include "gpu/jit/ir/fma.hpp"
20#include "gpu/jit/utils/trace.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27class dp4a_injector_t : public ir_mutator_t {
28public:
29 object_t _mutate(const func_call_t &obj) {
30 auto *dpas = obj.func.as_ptr<dpas_t>();
31 if (!dpas) return obj;
32
33 int M = dpas->exec_size;
34 int N = dpas->rcount;
35 int K = dpas->sdepth * 4;
36
37 auto &dst = dpas_t::arg_dst(obj);
38 auto &src0 = dpas_t::arg_src0(obj);
39 auto &src1 = dpas_t::arg_src1(obj);
40 auto &src2 = dpas_t::arg_src2(obj);
41 int dst_size = dpas->dst_type.size();
42 int src0_size = dpas->dst_type.size();
43 int src1_size = dpas->src1_type.size();
44 int src2_size = dpas->src2_type.size();
45 auto dst_type = to_dp4a_type(dpas->dst_type);
46 auto src1_type = to_dp4a_type(dpas->src1_type);
47 auto src2_type = to_dp4a_type(dpas->src2_type);
48 bool is_src0_zero = is_zero(src0);
49
50 stmt_t stmt;
51 auto _dp4a = dpas_t::make(
52 /*is_dpasw=*/false, M, 1, 1, dst_type, src1_type, src2_type);
53 auto &dp4a = _dp4a.as<dpas_t>();
54 auto zero = shuffle_t::make_broadcast(0, M);
55 int k0 = (is_src0_zero ? -4 : 0);
56 for (int k = k0; k < K; k += 4) {
57 for (int n = 0; n < N; n++) {
58 int dst_off = n * M * dst_size;
59 int src0_off = n * M * src0_size;
60 int src1_off = k * M * src1_size;
61 int src2_off = (n * K + k) * src2_size;
62 auto _dst = dst + dst_off;
63 auto _src0 = is_src0_zero ? _dst : (src0 + src0_off);
64 auto _src1 = src1 + src1_off;
65 auto _src2 = src2 + src2_off;
66 if (k < 0) {
67 stmt = stmt.append(store_t::make(_dst, 0, zero));
68 } else {
69 stmt = stmt.append(dp4a(_dst, _src0, _src1, _src2));
70 }
71 }
72 }
73 return std::move(stmt);
74 }
75
76private:
77 static type_t to_dp4a_type(const type_t &type) {
78 if (type.is_x32()) return type;
79 if (type.is_s8()) return type_t::s32();
80 if (type.is_u8()) return type_t::u32();
81 ir_error_not_expected();
82 return type_t();
83 };
84};
85
86stmt_t inject_dp4a(const stmt_t &s, ir_context_t &ir_ctx) {
87 trace_start();
88 auto ret = dp4a_injector_t().mutate(s);
89 trace_pass("inject_dp4a", ret, ir_ctx);
90 return ret;
91}
92
93} // namespace jit
94} // namespace gpu
95} // namespace impl
96} // namespace dnnl
97