1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/bank_conflict.hpp" |
18 | |
19 | #include "gpu/jit/ir/fma.hpp" |
20 | #include "gpu/jit/ir/message.hpp" |
21 | #include "gpu/jit/utils/trace.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace gpu { |
26 | namespace jit { |
27 | |
28 | class bank_conflict_attribute_injector_t : public ir_mutator_t { |
29 | public: |
30 | object_t _mutate(const alloc_t &obj) override { |
31 | all_buf_sizes_.emplace(obj.buf, obj.size); |
32 | |
33 | auto new_obj = ir_mutator_t::_mutate(obj); |
34 | if (bufs_.count(obj.buf) == 0) return new_obj; |
35 | |
36 | init_attr(); |
37 | |
38 | auto new_attrs = obj.attrs; |
39 | new_attrs.push_back(attr_); |
40 | auto &body = new_obj.as<alloc_t>().body; |
41 | return alloc_t::make(obj.buf, obj.size, obj.kind, new_attrs, body); |
42 | } |
43 | |
44 | object_t _mutate(const func_call_t &obj) override { |
45 | if (is_frozen) return ir_mutator_t::_mutate(obj); |
46 | |
47 | bool is_mad = obj.func.is<mad_t>(); |
48 | bool is_dpas = obj.func.is<dpas_t>(); |
49 | auto *send = obj.func.as_ptr<send_t>(); |
50 | bool is_load = send && (send->is_load() || send->is_load_2d()); |
51 | |
52 | if (is_mad || is_dpas) { |
53 | auto dst_buf = ptr_base(obj.args[0]); |
54 | auto src0_buf = ptr_base(obj.args[1]); |
55 | auto src1_buf = ptr_base(obj.args[2]); |
56 | auto src2_buf = ptr_base(obj.args[3]); |
57 | |
58 | // src0 may be null in some cases, skip it. |
59 | if (!src0_buf.is_empty()) bufs_.insert(src0_buf); |
60 | bufs_.insert(src1_buf); |
61 | bufs_.insert(src2_buf); |
62 | |
63 | instructions_.insert(obj); |
64 | } else if (is_load) { |
65 | // Returns minimal 2^B so that there is x such that: |
66 | // x * 2^B <= a <= b < (x + 1) * 2^B |
67 | auto min_pow2_span = [](int a, int b) { |
68 | int same_left_bits = 0; |
69 | for (int i = 31; i >= 0; i--) { |
70 | int b0 = ((uint32_t)a >> i) & 0x1; |
71 | int b1 = ((uint32_t)b >> i) & 0x1; |
72 | if (b0 != b1) break; |
73 | same_left_bits++; |
74 | } |
75 | return 1 << (32 - same_left_bits); |
76 | }; |
77 | auto &buf = send_t::arg_reg_buf(obj); |
78 | auto &base = (is_var(buf) ? buf : buf.as<ptr_t>().base); |
79 | int off = (is_var(buf) ? 0 : to_cpp<int>(buf.as<ptr_t>().off)); |
80 | int size = send->payload_size(); |
81 | int span = min_pow2_span(off, off + size - 1); |
82 | int &min_block_size = all_buf_min_block_sizes[base]; |
83 | min_block_size = std::max(min_block_size, span); |
84 | } |
85 | return ir_mutator_t::_mutate(obj); |
86 | } |
87 | |
88 | private: |
89 | void init_attr() { |
90 | if (!attr_.is_empty()) return; |
91 | |
92 | is_frozen = true; |
93 | std::vector<stmt_t> instructions; |
94 | for (auto &s : instructions_) |
95 | instructions.push_back(s); |
96 | |
97 | std::vector<expr_t> buf_vec; |
98 | std::vector<int> buf_sizes; |
99 | std::vector<int> buf_min_block_sizes; |
100 | for (auto &buf : bufs_) { |
101 | buf_vec.push_back(buf); |
102 | buf_sizes.push_back(all_buf_sizes_.at(buf)); |
103 | auto it = all_buf_min_block_sizes.find(buf); |
104 | int min_block_size |
105 | = (it == all_buf_min_block_sizes.end() ? 0 : it->second); |
106 | buf_min_block_sizes.push_back(min_block_size); |
107 | } |
108 | attr_ = bank_conflict_attr_t::make( |
109 | buf_vec, buf_sizes, buf_min_block_sizes, instructions); |
110 | } |
111 | |
112 | static expr_t ptr_base(const expr_t &e) { |
113 | if (e.is<var_t>()) return e; |
114 | auto *ptr = e.as_ptr<ptr_t>(); |
115 | if (ptr) return e.as<ptr_t>().base; |
116 | return expr_t(); |
117 | } |
118 | |
119 | object_map_t<expr_t, int> all_buf_sizes_; |
120 | object_map_t<expr_t, int> all_buf_min_block_sizes; |
121 | object_eq_set_t<expr_t> bufs_; |
122 | object_eq_set_t<stmt_t> instructions_; |
123 | bool is_frozen = false; |
124 | |
125 | alloc_attr_t attr_; |
126 | }; |
127 | |
128 | stmt_t inject_bank_conflict_attribute(const stmt_t &s, ir_context_t &ir_ctx) { |
129 | trace_start(); |
130 | auto ret = bank_conflict_attribute_injector_t().mutate(s); |
131 | trace_pass("inject_bank_conflict_attribute" , ret, ir_ctx); |
132 | return ret; |
133 | } |
134 | |
135 | } // namespace jit |
136 | } // namespace gpu |
137 | } // namespace impl |
138 | } // namespace dnnl |
139 | |