1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/dpas_atomic.hpp" |
18 | |
19 | #include "gpu/jit/ir/fma.hpp" |
20 | #include "gpu/jit/utils/trace.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace jit { |
26 | |
27 | stmt_t inject_atomic(const stmt_t &stmt) { |
28 | stmt_t ret = stmt; |
29 | auto stmt_vec = flatten_statements(stmt); |
30 | for (size_t i = 0; i < stmt_vec.size(); i++) { |
31 | bool ok = true; |
32 | ok &= is_func_call<dpas_t>(stmt_vec[i]) // No atomics for DP4As! |
33 | && !dpas_t::is_dp4a_call(stmt_vec[i]); |
34 | ok &= (i + 1 < stmt_vec.size()) && is_func_call<dpas_t>(stmt_vec[i + 1]) |
35 | && !dpas_t::is_dp4a_call(stmt_vec[i + 1]); |
36 | if (ok) { |
37 | auto &cur_src1 = dpas_t::arg_src1(stmt_vec[i]); |
38 | auto &next_src1 = dpas_t::arg_src1(stmt_vec[i + 1]); |
39 | // Compare src1, apply {Atomic} if they are equal. |
40 | if (cur_src1.is_equal(next_src1)) { |
41 | auto &s = stmt_vec[i]; |
42 | auto atomic_attr = instruction_modifier_attr_t::make( |
43 | ngen_proxy::InstructionModifier().with_atomic()); |
44 | ret = substitute(ret, s, atomic_attr.apply_to(s)); |
45 | } |
46 | } |
47 | } |
48 | return ret; |
49 | } |
50 | |
51 | } // namespace jit |
52 | } // namespace gpu |
53 | } // namespace impl |
54 | } // namespace dnnl |
55 | |