1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/pass.hpp"
18
19#include "gpu/jit/ir/message.hpp"
20#include "gpu/jit/ir/reorder.hpp"
21#include "gpu/jit/utils/trace.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace gpu {
26namespace jit {
27
28class external_var_visitor_t : public scope_visitor_t {
29public:
30 void _visit(const var_t &obj) {
31 if (!is_expr_defined(obj)) external_vars.insert(obj);
32 }
33
34 object_eq_set_t<expr_t> external_vars;
35};
36
37stmt_t inject_external_var_let(const stmt_t &_stmt, ir_context_t &ir_ctx) {
38 trace_start();
39 auto stmt = _stmt;
40 external_var_visitor_t v;
41 v.visit(stmt);
42
43 for (auto &var : v.external_vars)
44 stmt = let_t::make(var, {}, stmt);
45
46 trace_pass("inject_external_var_let", stmt, ir_ctx);
47 return stmt;
48}
49
50class spurious_send_mask_cast_remover_t : public ir_mutator_t {
51public:
52 object_t _mutate(const cast_t &obj) override {
53 if (in_send_ && obj.is_bool_vec_u16() && obj.expr.type().is_bool())
54 return mutate(obj.expr);
55 return ir_mutator_t::_mutate(obj);
56 }
57
58 object_t _mutate(const func_call_t &obj) override {
59 if (!is_func_call<send_t>(obj)) return obj;
60
61 in_send_ = true;
62 auto new_obj = ir_mutator_t::_mutate(obj);
63 in_send_ = false;
64 return new_obj;
65 }
66
67private:
68 bool in_send_ = false;
69};
70
71stmt_t remove_spurious_send_mask_cast(const stmt_t &s, ir_context_t &ir_ctx) {
72 spurious_send_mask_cast_remover_t mutator;
73 trace_start();
74 auto ret = mutator.mutate(s);
75 trace_pass("remove_spurious_send_mask_cast", ret, ir_ctx);
76 return ret;
77}
78
79class store_splitter_t : public ir_mutator_t {
80public:
81 store_splitter_t(ngen::HW hw) : hw_(hw) {}
82
83 object_t _mutate(const store_t &obj) override {
84 int elems = obj.value.type().elems();
85 int elem_size = obj.value.type().scalar().size();
86 int stride = (obj.has_default_stride() ? 1 : obj.stride / elem_size);
87 int store_size = elem_size * stride * elems;
88 const auto grf_size = ngen::GRF::bytes(hw_);
89 if (store_size <= 2 * grf_size) return ir_mutator_t::_mutate(obj);
90
91 int step = 2 * grf_size / (stride * elem_size);
92 stmt_t new_stmt;
93 for (int i = 0; i < elems; i += step) {
94 int cur_elems = std::min(step, elems - i);
95 ir_assert(math::is_pow2(cur_elems));
96 int off = i * stride * elem_size;
97 auto store = store_t::make(obj.buf, obj.off + off,
98 split_expr(obj.value, i, i + cur_elems), obj.stride);
99 new_stmt = new_stmt.append(store);
100 }
101 return std::move(new_stmt);
102 }
103
104private:
105 static expr_t split_expr(const expr_t &e, int beg, int end) {
106 auto *shuffle = e.as_ptr<shuffle_t>();
107 if (shuffle) return shuffle_t::make(shuffle, beg, end);
108
109 auto *binary = e.as_ptr<binary_op_t>();
110 if (binary) {
111 auto a = split_expr(binary->a, beg, end);
112 auto b = split_expr(binary->b, beg, end);
113 return binary_op_t::make(binary->op_kind, a, b);
114 }
115 ir_error_not_expected();
116 return expr_t();
117 }
118
119 ngen::HW hw_;
120};
121
122stmt_t split_wide_stores(const stmt_t &s, ir_context_t &ir_ctx) {
123 trace_start();
124 auto ret = store_splitter_t(ir_ctx.hw_cfg().hw()).mutate(s);
125 trace_pass("split_wide_stores", ret, ir_ctx);
126 return ret;
127}
128
129class if_condition_fixer_t : public ir_mutator_t {
130public:
131 if_condition_fixer_t(int simd_size) : simd_size_(simd_size) {}
132
133 object_t _mutate(const if_t &obj) override {
134 auto _new_obj = ir_mutator_t::_mutate(obj);
135 auto &new_obj = _new_obj.as<if_t>();
136 auto cond = shuffle_t::make_broadcast(new_obj.cond, simd_size_);
137 return if_t::make(cond, new_obj.body, new_obj.else_body);
138 }
139
140private:
141 int simd_size_;
142};
143
144stmt_t fixup_if_conditions(const stmt_t &s, ir_context_t &ir_ctx) {
145 trace_start();
146 auto ret = if_condition_fixer_t(ir_ctx.exec_cfg().simd()).mutate(s);
147 trace_pass("fixup_if_conditions", ret, ir_ctx);
148 return ret;
149}
150
151stmt_t maybe_strip_prefetches(
152 const stmt_t &s, ir_context_t &ir_ctx, int reserved_regs) {
153 trace_start();
154 int ir_usage = get_peak_grf_usage(s, ir_ctx.hw_cfg().grf_size());
155 int grf_usage = ir_usage + reserved_regs;
156 auto ret = s;
157 //strip prefetches when they exceed available registers
158 if (grf_usage > ir_ctx.exec_cfg().regs()) {
159 ret = remove_stmt_group(s, stmt_label_t::prefetch());
160 ir_warning() << "Dropping prefetches due to too lack of available "
161 "registers.\n";
162 }
163 trace_pass("maybe_strip_prefetches", ret, ir_ctx);
164 return ret;
165}
166
167} // namespace jit
168} // namespace gpu
169} // namespace impl
170} // namespace dnnl
171