1#pragma once
2
3#ifndef _TRITON_IR_FUNCTION_H_
4#define _TRITON_IR_FUNCTION_H_
5
6#include <string>
7#include <map>
8#include "value.h"
9#include "constant.h"
10
11namespace triton{
12namespace ir{
13
14class function;
15class function_type;
16class module;
17class basic_block;
18
19/* Argument */
20class argument: public value{
21 argument(type *ty, const std::string &name, function *parent, unsigned arg_no);
22
23public:
24 static argument* create(type *ty, const std::string &name,
25 function *parent = nullptr, unsigned arg_no = 0);
26 function* get_parent() const;
27 unsigned get_arg_no() const;
28
29 void accept(visitor *v);
30
31private:
32 function *parent_;
33 unsigned arg_no_;
34};
35
36/* Attribute */
37enum attribute_kind_t {
38 readonly = 0,
39 writeonly,
40 noalias,
41 aligned,
42 multiple_of,
43 retune,
44 not_implemented
45};
46
47class attribute {
48public:
49 attribute(attribute_kind_t kind, unsigned value = 0):
50 kind_(kind), value_(value){}
51
52 bool operator<(const attribute& other) const {
53 return std::make_pair(kind_, value_) < std::make_pair(other.kind_, other.value_);
54 }
55
56 attribute_kind_t get_kind() const {
57 return kind_;
58 }
59
60 unsigned get_value() const {
61 return value_;
62 }
63
64 bool is_llvm_attr() const {
65 return kind_ != multiple_of;
66 }
67
68 std::string repr() const {
69 switch(kind_){
70 case readonly: return ".readonly";
71 case writeonly: return ".writeonly";
72 case noalias: return ".noalias";
73 case aligned: return ".aligned(" + std::to_string(value_) + ")";
74 case multiple_of: return ".multipleof(" + std::to_string(value_) + ")";
75 case retune: return ".retunr";
76 default: break;
77 }
78 assert(false);
79 return "";
80 }
81
82private:
83 attribute_kind_t kind_;
84 unsigned value_;
85};
86
87/* Function */
88class function: public global_object{
89 typedef std::vector<argument*> args_t;
90 typedef args_t::iterator arg_iterator;
91 typedef args_t::const_iterator const_arg_iterator;
92
93 typedef std::vector<basic_block*> blocks_t;
94 typedef blocks_t::iterator block_iterator;
95 typedef blocks_t::const_iterator const_block_iterator;
96
97 typedef std::map<unsigned, std::set<attribute>> attr_map_t;
98
99private:
100 function(function_type *ty, linkage_types_t linkage,
101 const std::string &name = "", module *parent = nullptr);
102
103public:
104 // accessors
105 const args_t &args() const { return args_; }
106 function_type* get_fn_type() { return fn_ty_; }
107 const function_type* get_fn_type() const { return fn_ty_; }
108 module *get_parent() { return parent_; }
109 const module *get_parent() const { return parent_; }
110
111 // factory methods
112 static function *create(function_type *ty, linkage_types_t linkage,
113 const std::string &name, module *mod);
114 // blocks
115 blocks_t &blocks() { return blocks_; }
116 const blocks_t &blocks() const { return blocks_; }
117 void insert_block(basic_block* block, basic_block *next = nullptr);
118
119 // attributes
120 void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); }
121 const attr_map_t &attrs() { return attrs_; }
122 bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); }
123 std::set<attribute> get_attributes(const argument* arg) { return attrs_[arg->get_arg_no() + 1]; }
124 void set_is_kernel(bool new_val) { is_kernel_ = new_val; }
125 bool get_is_kernel() { return is_kernel_; }
126
127 void print(std::ostream &os);
128
129 // visitor
130 void accept(visitor *v) { v->visit_function(this); }
131
132private:
133 module *parent_;
134 bool init_;
135 function_type *fn_ty_;
136 args_t args_;
137 blocks_t blocks_;
138 attr_map_t attrs_;
139 bool is_kernel_;
140};
141
142}
143}
144
145#endif
146