1 | #pragma once |
2 | |
3 | #include <cstdint> |
4 | #include <vector> |
5 | |
6 | #include "taichi/inc/constants.h" |
7 | #include "taichi/ir/type_utils.h" |
8 | #include "taichi/rhi/device.h" |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | class Program; |
13 | class NdarrayRwAccessorsBank; |
14 | |
15 | class TI_DLL_EXPORT Ndarray { |
16 | public: |
17 | /* Constructs a Ndarray managed by Program. |
18 | * Memory allocation and deallocation is handled by Program. |
19 | * TODO: Ideally Ndarray shouldn't worry about memory alloc/dealloc at all. |
20 | */ |
21 | explicit Ndarray(Program *prog, |
22 | const DataType type, |
23 | const std::vector<int> &shape, |
24 | ExternalArrayLayout layout = ExternalArrayLayout::kNull); |
25 | |
26 | /* Constructs a Ndarray from an existing DeviceAllocation. |
27 | * It doesn't handle the allocation and deallocation. |
28 | * You can see a Ndarray as a view or interpretation of DeviceAllocation |
29 | * with specified dtype & layout. |
30 | */ |
31 | explicit Ndarray(DeviceAllocation &devalloc, |
32 | const DataType type, |
33 | const std::vector<int> &shape, |
34 | ExternalArrayLayout layout = ExternalArrayLayout::kNull); |
35 | |
36 | /* Constructs a Ndarray from an existing DeviceAllocation. |
37 | * This is an overloaded constructor for constructing Ndarray with TensorType |
38 | * elements "type" is expected to be PrimitiveType |
39 | */ |
40 | explicit Ndarray(DeviceAllocation &devalloc, |
41 | const DataType type, |
42 | const std::vector<int> &shape, |
43 | const std::vector<int> &element_shape, |
44 | ExternalArrayLayout layout = ExternalArrayLayout::kNull); |
45 | |
46 | DeviceAllocation ndarray_alloc_{kDeviceNullAllocation}; |
47 | DataType dtype; |
48 | // Invariant: Since ndarray indices are flattened for vector/matrix, this is |
49 | // always true: |
50 | // num_active_indices = shape.size() |
51 | std::vector<int> shape; |
52 | ExternalArrayLayout layout{ExternalArrayLayout::kNull}; |
53 | |
54 | std::vector<int> get_element_shape() const; |
55 | DataType get_element_data_type() const; |
56 | intptr_t get_data_ptr_as_int() const; |
57 | intptr_t get_device_allocation_ptr_as_int() const; |
58 | DeviceAllocation get_device_allocation() const; |
59 | std::size_t get_element_size() const; |
60 | std::size_t get_nelement() const; |
61 | TypedConstant read(const std::vector<int> &I) const; |
62 | template <typename T> |
63 | void write(const std::vector<int> &I, T val) const; |
64 | int64 read_int(const std::vector<int> &i); |
65 | uint64 read_uint(const std::vector<int> &i); |
66 | float64 read_float(const std::vector<int> &i); |
67 | void write_int(const std::vector<int> &i, int64 val); |
68 | void write_float(const std::vector<int> &i, float64 val); |
69 | |
70 | const std::vector<int> &total_shape() const { |
71 | return total_shape_; |
72 | } |
73 | ~Ndarray(); |
74 | |
75 | private: |
76 | std::size_t nelement_{1}; |
77 | std::size_t element_size_{1}; |
78 | std::vector<int> total_shape_; |
79 | |
80 | Program *prog_{nullptr}; |
81 | }; |
82 | |
83 | } // namespace taichi::lang |
84 | |