1#pragma once
2
3#include "taichi/util/lang_util.h"
4
5namespace taichi::lang {
6
7class Program;
8class IRNode;
9class FrontendContext;
10
11class TI_DLL_EXPORT Callable {
12 public:
13 Program *program{nullptr};
14 std::unique_ptr<IRNode> ir{nullptr};
15 std::unique_ptr<FrontendContext> context{nullptr};
16
17 struct Parameter {
18 bool is_array{
19 false}; // This is true for both ndarray and external array args.
20 std::size_t total_dim{0}; // total dim of array
21
22 /* [arguments with TensorType]
23
24 Taichi used to represent TensorType with the combination of "PrimitiveType"
25 & "element_shape" and there are a bunch of interfaces designed like this (it
26 allows creating TensorType by passing in PrimitiveType + element_shape)
27
28 Here we removed the "element_shape" member in the underlying objects (class
29 Arg, class ExternalTensorExpression, ...), and forced them to use TensorType
30 in their "dtype" member.
31
32 However we kept the interfaces unchanged temporarily, so as to minimize
33 possible regressions.
34 */
35 explicit Parameter(const DataType &dt = PrimitiveType::unknown,
36 bool is_array = false,
37 std::size_t size_unused = 0,
38 int total_dim = 0,
39 std::vector<int> element_shape = {}) {
40 if (dt->is<PrimitiveType>() && element_shape.size() > 0) {
41 this->dt_ =
42 taichi::lang::TypeFactory::get_instance().create_tensor_type(
43 element_shape, dt);
44 } else {
45 this->dt_ = dt;
46 }
47
48 this->is_array = is_array;
49 this->total_dim = total_dim;
50 }
51
52 std::vector<int> get_element_shape() const {
53 return dt_.get_shape();
54 }
55
56 DataType get_element_type() const {
57 return dt_.get_element_type();
58 }
59
60 int get_element_size() const {
61 return data_type_size(dt_);
62 }
63
64 DataType get_dtype() const {
65 return dt_;
66 }
67
68 private:
69 DataType dt_;
70 };
71
72 struct Ret {
73 DataType dt;
74
75 explicit Ret(const DataType &dt = PrimitiveType::unknown) : dt(dt) {
76 }
77 };
78
79 std::vector<Parameter> parameter_list;
80 std::vector<Ret> rets;
81
82 const StructType *ret_type = nullptr;
83
84 Callable();
85 virtual ~Callable();
86
87 int insert_scalar_param(const DataType &dt);
88
89 int insert_arr_param(const DataType &dt,
90 int total_dim,
91 std::vector<int> element_shape);
92 int insert_texture_param(const DataType &dt);
93
94 int insert_ret(const DataType &dt);
95
96 void finalize_rets();
97
98 [[nodiscard]] virtual std::string get_name() const = 0;
99};
100
101} // namespace taichi::lang
102