1#include "triton/codegen/transform/prefetch.h"
2#include "triton/codegen/target.h"
3#include "triton/ir/module.h"
4#include "triton/ir/function.h"
5#include "triton/ir/basic_block.h"
6#include "triton/ir/instructions.h"
7#include "triton/ir/utils.h"
8#include "triton/ir/print.h"
9#include <iostream>
10#include <vector>
11#include <algorithm>
12
13namespace triton::codegen::transform {
14
15/// find defs till phis
16static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector<ir::instruction*> &ret) {
17 ir::instruction *i = dynamic_cast<ir::instruction*>(v);
18 if (!i || i->get_parent() != bb)
19 return;
20 if (i->get_id() == ir::INST_PHI)
21 return;
22 ret.push_back(i);
23 for (ir::value *op : i->ops())
24 recursive_defs(op, bb, ret);
25}
26
27void prefetch::run(ir::module &mod) {
28 // 1. collect dots that can be prefethced
29 std::vector<ir::dot_inst*> to_prefetch;
30 ir::for_each_instruction(mod, [&](ir::instruction *i) {
31 if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
32 // Now only do prefetching when dot is using tensor cores
33 if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
34 dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
35 (dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
36 && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
37 (dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
38 && dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
39 && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
40 )
41 )
42 return;
43 auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
44 auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
45 if (a && a->get_incoming_block(1) == a->get_parent() &&
46 b && b->get_incoming_block(1) == b->get_parent())
47 to_prefetch.push_back(dot);
48 }
49 });
50
51 assert(to_prefetch.size() <=1 && "Don't know what to do with multiple dots");
52 ir::builder &builder = mod.get_builder();
53 // 2. do the prefetching
54 for (ir::dot_inst* dot : to_prefetch) {
55 auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
56 auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));
57 assert(a->get_incoming_block(0) == b->get_incoming_block(0));
58 ir::basic_block *loop_header = a->get_incoming_block(0);
59 ir::basic_block *loop_body = a->get_parent();
60
61 // mark as prefetched
62 dot->set_prefetched(true);
63
64 // 1. in the loop header (first iteration)
65 builder.set_insert_point(loop_header->get_inst_list().back());
66 assert(a && b);
67 builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0);
68 builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0);
69
70 // 2. at the end of the loop body (next iteration)
71 builder.set_insert_point(loop_body->get_inst_list().back());
72 builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1);
73 builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1);
74
75 prefetched_vals_.insert(a->get_incoming_value(0));
76 prefetched_vals_.insert(b->get_incoming_value(0));
77 // nested phis
78 ir::value* next_a = a->get_incoming_value(1);
79 while (auto* next_a_phi = dynamic_cast<ir::phi_node*>(next_a)) {
80 prefetched_vals_.insert(next_a_phi->get_incoming_value(0));
81 next_a = next_a_phi->get_incoming_value(1);
82 }
83 prefetched_vals_.insert(next_a);
84
85 ir::value* next_b = b->get_incoming_value(1);
86 while (auto* next_b_phi = dynamic_cast<ir::phi_node*>(next_b)) {
87 prefetched_vals_.insert(next_b_phi->get_incoming_value(0));
88 next_b = next_b_phi->get_incoming_value(1);
89 }
90 prefetched_vals_.insert(next_b);
91 }
92
93 // move loads to the beginning of the loop
94 if (tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) {
95 for (ir::function *fn : mod.get_function_list())
96 for (ir::basic_block *bb : fn->blocks()) {
97 // only apply to loop body
98 if (bb->get_predecessors().size() != 2 || bb->get_predecessors()[1] != bb)
99 continue;
100 // record loads (& dependency) to move
101 std::vector<ir::instruction*> loads;
102 // record original inst order
103 std::map<ir::instruction*, size_t> idx_map;
104 size_t idx = 0;
105 for (ir::instruction *inst : bb->get_inst_list()) {
106 if (auto *i = dynamic_cast<ir::masked_load_inst*>(inst))
107 recursive_defs(i, bb, loads);
108 idx_map[inst] = idx;
109 idx++;
110 }
111
112 // remove duplicates & keep the original input order
113 std::sort(loads.begin(), loads.end());
114 loads.erase(std::unique(loads.begin(), loads.end()), loads.end());
115 std::sort(loads.begin(), loads.end(), [&idx_map](ir::instruction *a, ir::instruction *b) {
116 return idx_map[a] < idx_map[b];
117 });
118
119 builder.set_insert_point(bb->get_first_non_phi());
120 auto& inst_list = bb->get_inst_list();
121 for (ir::instruction *i : loads){
122 auto it = std::find(inst_list.begin(), inst_list.end(), i);
123 // make sure we don't invalidate insert point
124 // in case instruction already at the top
125 if(it == builder.get_insert_point())
126 continue;
127 bb->erase(i);
128 builder.insert(i);
129 }
130 }
131 }
132}
133} // namespace triton::codegen::transform
134