1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/bank_conflict.hpp"
18
19#include "gpu/jit/ir/fma.hpp"
20#include "gpu/jit/ir/message.hpp"
21#include "gpu/jit/utils/trace.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace gpu {
26namespace jit {
27
28class bank_conflict_attribute_injector_t : public ir_mutator_t {
29public:
30 object_t _mutate(const alloc_t &obj) override {
31 all_buf_sizes_.emplace(obj.buf, obj.size);
32
33 auto new_obj = ir_mutator_t::_mutate(obj);
34 if (bufs_.count(obj.buf) == 0) return new_obj;
35
36 init_attr();
37
38 auto new_attrs = obj.attrs;
39 new_attrs.push_back(attr_);
40 auto &body = new_obj.as<alloc_t>().body;
41 return alloc_t::make(obj.buf, obj.size, obj.kind, new_attrs, body);
42 }
43
44 object_t _mutate(const func_call_t &obj) override {
45 if (is_frozen) return ir_mutator_t::_mutate(obj);
46
47 bool is_mad = obj.func.is<mad_t>();
48 bool is_dpas = obj.func.is<dpas_t>();
49 auto *send = obj.func.as_ptr<send_t>();
50 bool is_load = send && (send->is_load() || send->is_load_2d());
51
52 if (is_mad || is_dpas) {
53 auto dst_buf = ptr_base(obj.args[0]);
54 auto src0_buf = ptr_base(obj.args[1]);
55 auto src1_buf = ptr_base(obj.args[2]);
56 auto src2_buf = ptr_base(obj.args[3]);
57
58 // src0 may be null in some cases, skip it.
59 if (!src0_buf.is_empty()) bufs_.insert(src0_buf);
60 bufs_.insert(src1_buf);
61 bufs_.insert(src2_buf);
62
63 instructions_.insert(obj);
64 } else if (is_load) {
65 // Returns minimal 2^B so that there is x such that:
66 // x * 2^B <= a <= b < (x + 1) * 2^B
67 auto min_pow2_span = [](int a, int b) {
68 int same_left_bits = 0;
69 for (int i = 31; i >= 0; i--) {
70 int b0 = ((uint32_t)a >> i) & 0x1;
71 int b1 = ((uint32_t)b >> i) & 0x1;
72 if (b0 != b1) break;
73 same_left_bits++;
74 }
75 return 1 << (32 - same_left_bits);
76 };
77 auto &buf = send_t::arg_reg_buf(obj);
78 auto &base = (is_var(buf) ? buf : buf.as<ptr_t>().base);
79 int off = (is_var(buf) ? 0 : to_cpp<int>(buf.as<ptr_t>().off));
80 int size = send->payload_size();
81 int span = min_pow2_span(off, off + size - 1);
82 int &min_block_size = all_buf_min_block_sizes[base];
83 min_block_size = std::max(min_block_size, span);
84 }
85 return ir_mutator_t::_mutate(obj);
86 }
87
88private:
89 void init_attr() {
90 if (!attr_.is_empty()) return;
91
92 is_frozen = true;
93 std::vector<stmt_t> instructions;
94 for (auto &s : instructions_)
95 instructions.push_back(s);
96
97 std::vector<expr_t> buf_vec;
98 std::vector<int> buf_sizes;
99 std::vector<int> buf_min_block_sizes;
100 for (auto &buf : bufs_) {
101 buf_vec.push_back(buf);
102 buf_sizes.push_back(all_buf_sizes_.at(buf));
103 auto it = all_buf_min_block_sizes.find(buf);
104 int min_block_size
105 = (it == all_buf_min_block_sizes.end() ? 0 : it->second);
106 buf_min_block_sizes.push_back(min_block_size);
107 }
108 attr_ = bank_conflict_attr_t::make(
109 buf_vec, buf_sizes, buf_min_block_sizes, instructions);
110 }
111
112 static expr_t ptr_base(const expr_t &e) {
113 if (e.is<var_t>()) return e;
114 auto *ptr = e.as_ptr<ptr_t>();
115 if (ptr) return e.as<ptr_t>().base;
116 return expr_t();
117 }
118
119 object_map_t<expr_t, int> all_buf_sizes_;
120 object_map_t<expr_t, int> all_buf_min_block_sizes;
121 object_eq_set_t<expr_t> bufs_;
122 object_eq_set_t<stmt_t> instructions_;
123 bool is_frozen = false;
124
125 alloc_attr_t attr_;
126};
127
128stmt_t inject_bank_conflict_attribute(const stmt_t &s, ir_context_t &ir_ctx) {
129 trace_start();
130 auto ret = bank_conflict_attribute_injector_t().mutate(s);
131 trace_pass("inject_bank_conflict_attribute", ret, ir_ctx);
132 return ret;
133}
134
135} // namespace jit
136} // namespace gpu
137} // namespace impl
138} // namespace dnnl
139