1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/barrier.hpp"
18
19#include "gpu/jit/ir/message.hpp"
20#include "gpu/jit/utils/trace.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27class barrier_optimizer_t : public ir_mutator_t {
28public:
29 object_t _mutate(const for_t &obj) override {
30 loop_level_++;
31 auto new_obj = ir_mutator_t::_mutate(obj);
32 loop_level_--;
33 return new_obj;
34 }
35
36 object_t _mutate(const func_call_t &obj) override {
37 if (is_func_call<send_t>(obj)) {
38 auto &send = obj.func.as<send_t>();
39 if (send.is_slm()) can_remove_barrier_ = false;
40 } else if (obj.func.is_same(funcs::barrier_func())) {
41 bool can_remove = can_remove_barrier_;
42 can_remove_barrier_ = false;
43
44 // If not in a loop and this is the first barrier -> can be removed.
45 if (loop_level_ == 0 && can_remove) return stmt_t();
46 return obj;
47 }
48
49 return obj;
50 }
51
52 // Store doesn't contain nested statements, return as is.
53 object_t _mutate(const store_t &obj) override { return obj; }
54
55private:
56 int loop_level_ = 0;
57 bool can_remove_barrier_ = true;
58};
59
60stmt_t optimize_barrier(const stmt_t &s, ir_context_t &ir_ctx) {
61 trace_start();
62 auto ret = barrier_optimizer_t().mutate(s);
63 trace_pass("optimize_barrier", ret, ir_ctx);
64 return ret;
65}
66
67} // namespace jit
68} // namespace gpu
69} // namespace impl
70} // namespace dnnl
71