1 | #include "triton/codegen/transform/prefetch.h" |
2 | #include "triton/codegen/target.h" |
3 | #include "triton/ir/module.h" |
4 | #include "triton/ir/function.h" |
5 | #include "triton/ir/basic_block.h" |
6 | #include "triton/ir/instructions.h" |
7 | #include "triton/ir/utils.h" |
8 | #include "triton/ir/print.h" |
9 | #include <iostream> |
10 | #include <vector> |
11 | #include <algorithm> |
12 | |
13 | namespace triton::codegen::transform { |
14 | |
15 | /// find defs till phis |
16 | static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::instruction*> &ret) { |
17 | ir::instruction *i = dynamic_cast<ir::instruction*>(v); |
18 | if (!i || i->get_parent() != bb) |
19 | return; |
20 | if (i->get_id() == ir::INST_PHI) |
21 | return; |
22 | ret.push_back(i); |
23 | for (ir::value *op : i->ops()) |
24 | recursive_defs(op, bb, ret); |
25 | } |
26 | |
27 | void prefetch::run(ir::module &mod) { |
28 | // 1. collect dots that can be prefethced |
29 | std::vector<ir::dot_inst*> to_prefetch; |
30 | ir::for_each_instruction(mod, [&](ir::instruction *i) { |
31 | if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) { |
32 | // Now only do prefetching when dot is using tensor cores |
33 | if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() || |
34 | dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() || |
35 | (dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32() |
36 | && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) || |
37 | (dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8) |
38 | && dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8) |
39 | && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) |
40 | ) |
41 | ) |
42 | return; |
43 | auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0)); |
44 | auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1)); |
45 | if (a && a->get_incoming_block(1) == a->get_parent() && |
46 | b && b->get_incoming_block(1) == b->get_parent()) |
47 | to_prefetch.push_back(dot); |
48 | } |
49 | }); |
50 | |
51 | assert(to_prefetch.size() <=1 && "Don't know what to do with multiple dots" ); |
52 | ir::builder &builder = mod.get_builder(); |
53 | // 2. do the prefetching |
54 | for (ir::dot_inst* dot : to_prefetch) { |
55 | auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0)); |
56 | auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1)); |
57 | assert(a->get_incoming_block(0) == b->get_incoming_block(0)); |
58 | ir::basic_block * = a->get_incoming_block(0); |
59 | ir::basic_block *loop_body = a->get_parent(); |
60 | |
61 | // mark as prefetched |
62 | dot->set_prefetched(true); |
63 | |
64 | // 1. in the loop header (first iteration) |
65 | builder.set_insert_point(loop_header->get_inst_list().back()); |
66 | assert(a && b); |
67 | builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0); |
68 | builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0); |
69 | |
70 | // 2. at the end of the loop body (next iteration) |
71 | builder.set_insert_point(loop_body->get_inst_list().back()); |
72 | builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1); |
73 | builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1); |
74 | |
75 | prefetched_vals_.insert(a->get_incoming_value(0)); |
76 | prefetched_vals_.insert(b->get_incoming_value(0)); |
77 | // nested phis |
78 | ir::value* next_a = a->get_incoming_value(1); |
79 | while (auto* next_a_phi = dynamic_cast<ir::phi_node*>(next_a)) { |
80 | prefetched_vals_.insert(next_a_phi->get_incoming_value(0)); |
81 | next_a = next_a_phi->get_incoming_value(1); |
82 | } |
83 | prefetched_vals_.insert(next_a); |
84 | |
85 | ir::value* next_b = b->get_incoming_value(1); |
86 | while (auto* next_b_phi = dynamic_cast<ir::phi_node*>(next_b)) { |
87 | prefetched_vals_.insert(next_b_phi->get_incoming_value(0)); |
88 | next_b = next_b_phi->get_incoming_value(1); |
89 | } |
90 | prefetched_vals_.insert(next_b); |
91 | } |
92 | |
93 | // move loads to the beginning of the loop |
94 | if (tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) { |
95 | for (ir::function *fn : mod.get_function_list()) |
96 | for (ir::basic_block *bb : fn->blocks()) { |
97 | // only apply to loop body |
98 | if (bb->get_predecessors().size() != 2 || bb->get_predecessors()[1] != bb) |
99 | continue; |
100 | // record loads (& dependency) to move |
101 | std::vector<ir::instruction*> loads; |
102 | // record original inst order |
103 | std::map<ir::instruction*, size_t> idx_map; |
104 | size_t idx = 0; |
105 | for (ir::instruction *inst : bb->get_inst_list()) { |
106 | if (auto *i = dynamic_cast<ir::masked_load_inst*>(inst)) |
107 | recursive_defs(i, bb, loads); |
108 | idx_map[inst] = idx; |
109 | idx++; |
110 | } |
111 | |
112 | // remove duplicates & keep the original input order |
113 | std::sort(loads.begin(), loads.end()); |
114 | loads.erase(std::unique(loads.begin(), loads.end()), loads.end()); |
115 | std::sort(loads.begin(), loads.end(), [&idx_map](ir::instruction *a, ir::instruction *b) { |
116 | return idx_map[a] < idx_map[b]; |
117 | }); |
118 | |
119 | builder.set_insert_point(bb->get_first_non_phi()); |
120 | auto& inst_list = bb->get_inst_list(); |
121 | for (ir::instruction *i : loads){ |
122 | auto it = std::find(inst_list.begin(), inst_list.end(), i); |
123 | // make sure we don't invalidate insert point |
124 | // in case instruction already at the top |
125 | if(it == builder.get_insert_point()) |
126 | continue; |
127 | bb->erase(i); |
128 | builder.insert(i); |
129 | } |
130 | } |
131 | } |
132 | } |
133 | } // namespace triton::codegen::transform |
134 | |