1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/send.hpp"
18
19#include "gpu/jit/ir/message.hpp"
20#include "gpu/jit/utils/trace.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27class buffer_offset_lifter_t : public ir_mutator_t {
28public:
29 object_t _mutate(const func_call_t &obj) {
30 if (!obj.func.is<send_t>()) return ir_mutator_t::_mutate(obj);
31
32 auto &mem_buf = send_t::arg_mem_buf(obj);
33 if (!mem_buf.is<ptr_t>()) return ir_mutator_t::_mutate(obj);
34
35 auto &base = mem_buf.as<ptr_t>().base;
36 auto &off = mem_buf.as<ptr_t>().off;
37
38 std::vector<expr_t> new_args = obj.args;
39 send_t::arg_mem_buf(new_args) = base;
40 send_t::arg_mem_off(new_args) += off;
41 return obj.func.call(new_args, obj.attr);
42 }
43};
44
45stmt_t lift_buffer_offsets_in_send(const stmt_t &s, ir_context_t &ir_ctx) {
46 trace_start();
47 buffer_offset_lifter_t lifter;
48 auto ret = lifter.mutate(s);
49 trace_pass("lift_buffer_offsets_in_send", ret, ir_ctx);
50 return ret;
51}
52
53class send_injector_t : public ir_mutator_t {
54public:
55 send_injector_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {}
56
57 object_t _mutate(const func_call_t &obj) {
58 auto *send = obj.func.as_ptr<send_t>();
59 if (!send) return ir_mutator_t::_mutate(obj);
60
61 auto &mem_buf = send_t::arg_mem_buf(obj);
62 auto &mem_off = send_t::arg_mem_off(obj);
63 auto &reg_buf = send_t::arg_reg_buf(obj);
64 auto &mask = send_t::arg_mask(obj);
65
66 ir_assert(is_var(mem_buf)) << mem_buf;
67
68 auto header_buf = ir_ctx_.create_tmp_var(type_t::byte_ptr(), "h");
69 auto off_store = simplify_store(
70 send->create_offset_store(header_buf, mem_buf, mem_off));
71
72 if (send->is_2d()) {
73 auto emit_store = [&](const expr_t &e, int off) {
74 auto store = store_t::make(header_buf, off, e);
75 off_store = off_store.append(store);
76 };
77 auto emit_store_s32 = [&](int value, int off) {
78 emit_store(cast(value, type_t::s32()), off);
79 };
80 auto &info = send->block_2d_info;
81 int type_size = send->type.size();
82 emit_store_s32(info.surface_width * type_size - 1,
83 send_t::header_2d_off_surface_width());
84 emit_store_s32(info.surface_height - 1,
85 send_t::header_2d_off_surface_height());
86 emit_store_s32(info.surface_pitch * type_size - 1,
87 send_t::header_2d_off_surface_pitch());
88 emit_store(send_t::arg_x(obj), send_t::header_2d_off_x());
89 emit_store(send_t::arg_y(obj), send_t::header_2d_off_y());
90 uint32_t w_enc = info.width - 1;
91 uint32_t h_enc = info.height - 1;
92 uint32_t count_enc = info.count - 1;
93 emit_store_s32((count_enc << 16) + (h_enc << 8) + w_enc,
94 send_t::header_2d_off_whc());
95 }
96
97 auto new_call = func_call_t::make(
98 obj.func, {mem_buf, header_buf, reg_buf, mask}, obj.attr);
99 auto body = stmt_seq_t::make(off_store, new_call);
100
101 // Allocate header.
102 return alloc_t::make(
103 header_buf, send->header_size(), alloc_kind_t::grf, body);
104 }
105
106private:
107 stmt_t simplify_store(const stmt_t &_store) const {
108 auto &store = _store.as<store_t>();
109
110 auto value = store.value;
111 value = simplify(value, ir_ctx_.cset());
112
113 // Convert to N-ary form and back to expand multiplications. This
114 // helps to find more common subexpressions during the pass.
115 value = nary_op_canonicalize(value);
116 value = nary_op_back_transform(value);
117
118 return store_t::make(store.buf, store.off, value, store.stride);
119 }
120
121 ir_context_t &ir_ctx_;
122};
123
124stmt_t inject_send(const stmt_t &s, ir_context_t &ir_ctx) {
125 trace_start();
126 auto ret = send_injector_t(ir_ctx).mutate(s);
127 trace_pass("inject_send", ret, ir_ctx);
128 return ret;
129}
130
131class send_2d_header_store_lifter_t : public ir_mutator_t {
132public:
133 send_2d_header_store_lifter_t(const stmt_t &root) {
134 auto calls = find_objects<func_call_t>(root);
135 for (auto &c : calls) {
136 if (!is_func_call<send_t>(c)) continue;
137 if (!c.as<func_call_t>().func.as<send_t>().is_2d()) continue;
138 auto header_buf = send_t::arg_mem_off(c);
139 ir_assert(is_var(header_buf)) << header_buf;
140 header_bufs_.insert(header_buf);
141 }
142 }
143
144 object_t _mutate(const alloc_t &obj) override {
145 auto new_obj = ir_mutator_t::_mutate(obj);
146 auto it = stores_.find(obj.buf);
147 if (it == stores_.end()) return new_obj;
148
149 auto &alloc = new_obj.as<alloc_t>();
150 stmt_t header_store;
151 for (auto &s : it->second)
152 header_store = header_store.append(s);
153 it->second.clear();
154
155 auto new_body = header_store.append(alloc.body);
156 return alloc_t::make(
157 alloc.buf, alloc.size, alloc.kind, alloc.attrs, new_body);
158 }
159
160 object_t _mutate(const store_t &obj) override {
161 if (header_bufs_.count(obj.buf) == 0) return obj;
162 // Do not lift address assignments and non-const x and y.
163 int off = to_cpp<int>(obj.off);
164 if (off == 0) return obj;
165 if (utils::one_of(
166 off, send_t::header_2d_off_x(), send_t::header_2d_off_y())
167 && !is_const(obj.value))
168 return obj;
169 stores_[obj.buf].push_back(obj);
170 return stmt_t();
171 }
172
173private:
174 object_set_t<expr_t> header_bufs_;
175 object_map_t<expr_t, std::vector<stmt_t>> stores_;
176};
177
178stmt_t lift_send_2d_header_store(const stmt_t &s, ir_context_t &ir_ctx) {
179 trace_start();
180 auto ret = send_2d_header_store_lifter_t(s).mutate(s);
181 trace_pass("lift_send_2d_header_store", ret, ir_ctx);
182 return ret;
183}
184
185} // namespace jit
186} // namespace gpu
187} // namespace impl
188} // namespace dnnl
189