1 | #ifndef TDL_INCLUDE_IR_CODEGEN_TARGET_H |
2 | #define TDL_INCLUDE_IR_CODEGEN_TARGET_H |
3 | |
4 | namespace llvm{ |
5 | class Type; |
6 | class Value; |
7 | class Instruction; |
8 | class Constant; |
9 | class LLVMContext; |
10 | class Module; |
11 | class ConstantFolder; |
12 | class IRBuilderDefaultInserter; |
13 | template <typename T, typename Inserter> |
14 | class IRBuilder; |
15 | class ArrayType; |
16 | class Function; |
17 | } |
18 | |
19 | // typedefs |
20 | namespace triton{ |
21 | namespace codegen{ |
22 | typedef llvm::IRBuilder<llvm::ConstantFolder, |
23 | llvm::IRBuilderDefaultInserter> Builder; |
24 | typedef llvm::LLVMContext LLVMContext; |
25 | typedef llvm::Type Type; |
26 | typedef llvm::Value Value; |
27 | typedef llvm::Module Module; |
28 | typedef llvm::Instruction Instruction; |
29 | typedef llvm::Constant Constant; |
30 | typedef llvm::ArrayType ArrayType; |
31 | typedef llvm::Function Function; |
32 | } |
33 | } |
34 | |
35 | namespace triton{ |
36 | namespace codegen{ |
37 | |
38 | class nvidia_cu_target; |
39 | |
40 | class target { |
41 | public: |
42 | target(bool is_gpu): is_gpu_(is_gpu){} |
43 | virtual ~target() {} |
44 | virtual void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn) = 0; |
45 | virtual Instruction* add_barrier(Module *module, Builder& builder) = 0; |
46 | virtual Instruction* add_memfence(Module *module, Builder& builder) = 0; |
47 | virtual Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax) = 0; |
48 | virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0; |
49 | virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0; |
50 | virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0; |
51 | virtual unsigned guaranteed_alignment() = 0; |
52 | nvidia_cu_target* as_nvidia(); |
53 | bool is_gpu() const; |
54 | |
55 | private: |
56 | bool is_gpu_; |
57 | }; |
58 | |
59 | class amd_cl_target: public target { |
60 | public: |
61 | amd_cl_target(): target(true){} |
62 | void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn); |
63 | Instruction* add_barrier(Module *module, Builder& builder); |
64 | Instruction* add_memfence(Module *module, Builder& builder); |
65 | Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax); |
66 | Value* get_local_id(Module *module, Builder& builder, unsigned ax); |
67 | Value* get_block_id(Module *module, Builder& builder, unsigned ax); |
68 | Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); |
69 | unsigned guaranteed_alignment() { return 16; } |
70 | }; |
71 | |
72 | class nvidia_cu_target: public target { |
73 | public: |
74 | nvidia_cu_target(int sm): target(true), sm_(sm){} |
75 | void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn); |
76 | Instruction* add_barrier(Module *module, Builder& builder); |
77 | Instruction* add_memfence(Module *module, Builder& builder); |
78 | Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax); |
79 | Value* get_local_id(Module *module, Builder& builder, unsigned ax); |
80 | Value* get_block_id(Module *module, Builder& builder, unsigned ax); |
81 | Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); |
82 | int sm() { return sm_; } |
83 | unsigned guaranteed_alignment() { return 16; } |
84 | |
85 | private: |
86 | int sm_; |
87 | }; |
88 | |
89 | class cpu_target: public target { |
90 | public: |
91 | cpu_target(): target(false){} |
92 | void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn); |
93 | Instruction* add_barrier(Module *module, Builder& builder); |
94 | Instruction* add_memfence(Module *module, Builder& builder); |
95 | Value* get_global_offset(Module *module, Builder& builder, unsigned stride, unsigned ax); |
96 | Value* get_local_id(Module *module, Builder& builder, unsigned ax); |
97 | Value* get_block_id(Module *module, Builder& builder, unsigned ax); |
98 | Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); |
99 | unsigned guaranteed_alignment() { return 1; } |
100 | }; |
101 | |
102 | } |
103 | } |
104 | |
105 | #endif |
106 | |