1 | #include <cassert> |
2 | #include <stdexcept> |
3 | #include "triton/ir/constant.h" |
4 | #include "triton/ir/type.h" |
5 | #include "triton/ir/context.h" |
6 | #include "triton/ir/context_impl.h" |
7 | |
8 | namespace triton{ |
9 | namespace ir{ |
10 | |
11 | |
12 | // constant |
13 | |
14 | constant *constant::get_null_value(type *ty) { |
15 | context &ctx = ty->get_context(); |
16 | switch (ty->get_scalar_ty()->get_type_id()) { |
17 | case type::IntegerTyID: |
18 | return constant_int::get(ty, 0); |
19 | case type::FP16TyID: |
20 | return constant_fp::get(type::get_fp16_ty(ctx), 0); |
21 | case type::BF16TyID: |
22 | return constant_fp::get(type::get_bf16_ty(ctx), 0); |
23 | case type::FP32TyID: |
24 | return constant_fp::get(type::get_fp32_ty(ctx), 0); |
25 | case type::FP64TyID: |
26 | return constant_fp::get(type::get_fp64_ty(ctx), 0); |
27 | default: |
28 | throw std::runtime_error("Cannot create a null constant of that type!" ); |
29 | } |
30 | } |
31 | |
32 | // FIXME |
33 | |
34 | constant *constant::get_all_ones_value(type *ty) { |
35 | if(ty->is_integer_ty()) |
36 | return constant_int::get(ty, 0xFFFFFFFFFFFFFFFF); |
37 | if(ty->is_floating_point_ty()) |
38 | return constant_fp::get(ty, 0xFFFFFFFFFFFFFFFF); |
39 | throw std::runtime_error("Cannot create all ones value for that type!" ); |
40 | } |
41 | |
42 | // constant_int |
43 | // FIXME use something like APInt |
44 | |
45 | constant_int::constant_int(type *ty, uint64_t value) |
46 | : constant(ty, 0), value_(value){ } |
47 | |
48 | constant_int *constant_int::get(type *ty, uint64_t value) { |
49 | if (!ty->is_integer_ty()) |
50 | throw std::runtime_error("Cannot create constant_int with non integer ty" ); |
51 | context_impl *impl = ty->get_context().p_impl.get(); |
52 | std::unique_ptr<constant_int> &cst = impl->int_constants_[std::make_pair(ty, value)]; |
53 | if(!cst) |
54 | cst.reset(new constant_int(ty, value)); |
55 | return cst.get(); |
56 | } |
57 | |
58 | |
59 | // constant_fp |
60 | // FIXME use something like APFloat |
61 | |
62 | constant_fp::constant_fp(type *ty, double value) |
63 | : constant(ty, 0), value_(value){ } |
64 | |
65 | constant *constant_fp::get_negative_zero(type *ty){ |
66 | double neg_zero = 0; |
67 | return get(ty, neg_zero); |
68 | } |
69 | |
70 | constant *constant_fp::get_zero_value_for_negation(type *ty) { |
71 | if(ty->get_scalar_ty()->is_floating_point_ty()) |
72 | return constant_fp::get(ty, 0); |
73 | return constant::get_null_value(ty); |
74 | } |
75 | |
76 | constant *constant_fp::get(type *ty, double v){ |
77 | context_impl *impl = ty->get_context().p_impl.get(); |
78 | std::unique_ptr<constant_fp> &result = impl->fp_constants_[std::make_pair(ty, v)]; |
79 | if(!result) |
80 | result.reset(new constant_fp(ty, v)); |
81 | return result.get(); |
82 | } |
83 | |
84 | |
85 | // undef value |
86 | undef_value::undef_value(type *ty) |
87 | : constant(ty, 0) { } |
88 | |
89 | undef_value *undef_value::get(type *ty) { |
90 | context_impl *impl = ty->get_context().p_impl.get(); |
91 | std::unique_ptr<undef_value> &result = impl->uv_constants_[ty]; |
92 | if(!result) |
93 | result.reset(new undef_value(ty)); |
94 | return result.get(); |
95 | } |
96 | |
97 | /* global value */ |
98 | global_value::global_value(type *ty, unsigned num_ops, |
99 | linkage_types_t linkage, |
100 | const std::string &name, unsigned addr_space) |
101 | : constant(pointer_type::get(ty, addr_space), num_ops, name), |
102 | linkage_(linkage) { } |
103 | |
104 | |
105 | /* global object */ |
106 | global_object::global_object(type *ty, unsigned num_ops, |
107 | linkage_types_t linkage, |
108 | const std::string &name, unsigned addr_space) |
109 | : global_value(ty, num_ops, linkage, name, addr_space) { } |
110 | |
111 | |
112 | /* alloc const */ |
113 | alloc_const::alloc_const(type *ty, constant_int *size, const std::string &name) |
114 | : global_object(ty, 1, global_value::external, name, 4) { |
115 | set_operand(0, size); |
116 | } |
117 | |
118 | |
119 | } |
120 | } |
121 | |