1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/unroll.hpp"
18
19#include "gpu/jit/utils/trace.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace gpu {
24namespace jit {
25
26class unrolling_updater_t : public ir_mutator_t {
27public:
28 object_t _mutate(const let_t &obj) override {
29 if (level_ == 0) {
30 // Skip top-level let statements.
31 return ir_mutator_t::_mutate(obj);
32 }
33 lets_.push_back(&obj);
34 auto new_body = mutate(obj.body);
35 if (!lets_.back()) {
36 // Let was moved to the innermost loop.
37 lets_.pop_back();
38 return new_body;
39 }
40 lets_.pop_back();
41 if (new_body.is_same(obj.body)) return obj;
42 return let_t::make(obj.var, obj.value, new_body);
43 }
44
45 object_t _mutate(const for_t &obj) override {
46 if (in_compute_loop_) level_++;
47 found_loop_ = false;
48 auto new_obj = ir_mutator_t::_mutate(obj);
49 if (in_compute_loop_) level_--;
50 if (!found_loop_) {
51 // Innermost loop, inject let statements.
52 auto body = get_stmt_body(new_obj);
53 for (auto it = lets_.rbegin(); it != lets_.rend(); ++it) {
54 body = let_t::make((*it)->var, (*it)->value, body);
55 *it = nullptr;
56 }
57 new_obj = replace_stmt_body(new_obj, body);
58 }
59 found_loop_ = true;
60 return new_obj;
61 }
62
63 object_t _mutate(const stmt_group_t &obj) override {
64 if (obj.label == stmt_label_t::compute_loop()) {
65 in_compute_loop_ = true;
66 }
67 auto new_obj = ir_mutator_t::_mutate(obj);
68 if (obj.label == stmt_label_t::compute_loop()) {
69 in_compute_loop_ = false;
70 }
71 return new_obj;
72 }
73
74private:
75 bool found_loop_ = false;
76 bool in_compute_loop_ = false;
77 int level_ = 0;
78 std::vector<const let_t *> lets_;
79};
80
81stmt_t update_loops_for_unrolling(const stmt_t &s, ir_context_t &ir_ctx) {
82 trace_start();
83 auto ret = unrolling_updater_t().mutate(s);
84 trace_pass("update_loops_for_unrolling", ret, ir_ctx);
85 return ret;
86}
87
88class loop_unroller_t : public ir_mutator_t {
89public:
90 loop_unroller_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {}
91
92 object_t _mutate(const for_t &obj) override {
93 auto new_obj = ir_mutator_t::_mutate(obj);
94 auto &_for = new_obj.as<for_t>();
95 // No unrolling.
96 if (_for.unroll == 1) return new_obj;
97
98 ir_assert(is_const(_for.init))
99 << "Can't unroll loop with non-const bound: " << _for.init;
100 ir_assert(is_const(_for.bound))
101 << "Can't unroll loop with non-const bound: " << _for.bound;
102
103 auto init = to_cpp<int>(_for.init);
104 auto bound = to_cpp<int>(_for.bound);
105
106 ir_assert(_for.unroll == (bound - init))
107 << "Only full loop unroll is supported.";
108
109 stmt_t ret;
110 for (int i = init; i < bound; i++) {
111 auto iter_stmt = substitute(
112 _for.body, _for.var, to_expr(i, _for.var.type()));
113 iter_stmt = rename_let_alloc(iter_stmt, i - init);
114 ret = ret.append(iter_stmt);
115 }
116 return std::move(ret);
117 }
118
119private:
120 stmt_t rename_let_alloc(const stmt_t &s, int idx) {
121 auto lets = find_objects<let_t>(s);
122 auto ret = s;
123 for (auto &_let : lets) {
124 auto &let = _let.as<let_t>();
125 auto &var = let.var.as<var_t>();
126 auto new_var = ir_ctx_.create_tmp_var(var.type, var.name);
127 ret = substitute(ret, let.var, new_var);
128 }
129 auto allocs = find_objects<alloc_t>(s);
130 for (auto &_alloc : allocs) {
131 auto &alloc = _alloc.as<alloc_t>();
132 auto &buf = alloc.buf.as<var_t>();
133 auto new_buf = ir_ctx_.create_tmp_var(buf.type, buf.name);
134 ret = substitute(ret, alloc.buf, new_buf);
135 }
136 return ret;
137 }
138
139 ir_context_t &ir_ctx_;
140};
141
142stmt_t unroll_loops(const stmt_t &s, ir_context_t &ir_ctx) {
143 trace_start();
144 auto ret = loop_unroller_t(ir_ctx).mutate(s);
145 trace_pass("unroll_loops", ret, ir_ctx);
146 return ret;
147}
148
149} // namespace jit
150} // namespace gpu
151} // namespace impl
152} // namespace dnnl
153