1 | #pragma once |
2 | |
3 | #include "taichi/util/lang_util.h" |
4 | |
5 | namespace taichi::lang { |
6 | |
7 | class Program; |
8 | class IRNode; |
9 | class FrontendContext; |
10 | |
11 | class TI_DLL_EXPORT Callable { |
12 | public: |
13 | Program *program{nullptr}; |
14 | std::unique_ptr<IRNode> ir{nullptr}; |
15 | std::unique_ptr<FrontendContext> context{nullptr}; |
16 | |
17 | struct Parameter { |
18 | bool is_array{ |
19 | false}; // This is true for both ndarray and external array args. |
20 | std::size_t total_dim{0}; // total dim of array |
21 | |
22 | /* [arguments with TensorType] |
23 | |
24 | Taichi used to represent TensorType with the combination of "PrimitiveType" |
25 | & "element_shape" and there are a bunch of interfaces designed like this (it |
26 | allows creating TensorType by passing in PrimitiveType + element_shape) |
27 | |
28 | Here we removed the "element_shape" member in the underlying objects (class |
29 | Arg, class ExternalTensorExpression, ...), and forced them to use TensorType |
30 | in their "dtype" member. |
31 | |
32 | However we kept the interfaces unchanged temporarily, so as to minimize |
33 | possible regressions. |
34 | */ |
35 | explicit Parameter(const DataType &dt = PrimitiveType::unknown, |
36 | bool is_array = false, |
37 | std::size_t size_unused = 0, |
38 | int total_dim = 0, |
39 | std::vector<int> element_shape = {}) { |
40 | if (dt->is<PrimitiveType>() && element_shape.size() > 0) { |
41 | this->dt_ = |
42 | taichi::lang::TypeFactory::get_instance().create_tensor_type( |
43 | element_shape, dt); |
44 | } else { |
45 | this->dt_ = dt; |
46 | } |
47 | |
48 | this->is_array = is_array; |
49 | this->total_dim = total_dim; |
50 | } |
51 | |
52 | std::vector<int> get_element_shape() const { |
53 | return dt_.get_shape(); |
54 | } |
55 | |
56 | DataType get_element_type() const { |
57 | return dt_.get_element_type(); |
58 | } |
59 | |
60 | int get_element_size() const { |
61 | return data_type_size(dt_); |
62 | } |
63 | |
64 | DataType get_dtype() const { |
65 | return dt_; |
66 | } |
67 | |
68 | private: |
69 | DataType dt_; |
70 | }; |
71 | |
72 | struct Ret { |
73 | DataType dt; |
74 | |
75 | explicit Ret(const DataType &dt = PrimitiveType::unknown) : dt(dt) { |
76 | } |
77 | }; |
78 | |
79 | std::vector<Parameter> parameter_list; |
80 | std::vector<Ret> rets; |
81 | |
82 | const StructType *ret_type = nullptr; |
83 | |
84 | Callable(); |
85 | virtual ~Callable(); |
86 | |
87 | int insert_scalar_param(const DataType &dt); |
88 | |
89 | int insert_arr_param(const DataType &dt, |
90 | int total_dim, |
91 | std::vector<int> element_shape); |
92 | int insert_texture_param(const DataType &dt); |
93 | |
94 | int insert_ret(const DataType &dt); |
95 | |
96 | void finalize_rets(); |
97 | |
98 | [[nodiscard]] virtual std::string get_name() const = 0; |
99 | }; |
100 | |
101 | } // namespace taichi::lang |
102 | |