1#pragma once
2
3#include <cstdint>
4#include <vector>
5
6#include "taichi/inc/constants.h"
7#include "taichi/ir/type_utils.h"
8#include "taichi/rhi/device.h"
9
10namespace taichi::lang {
11
12class Program;
13class NdarrayRwAccessorsBank;
14
15class TI_DLL_EXPORT Ndarray {
16 public:
17 /* Constructs a Ndarray managed by Program.
18 * Memory allocation and deallocation is handled by Program.
19 * TODO: Ideally Ndarray shouldn't worry about memory alloc/dealloc at all.
20 */
21 explicit Ndarray(Program *prog,
22 const DataType type,
23 const std::vector<int> &shape,
24 ExternalArrayLayout layout = ExternalArrayLayout::kNull);
25
26 /* Constructs a Ndarray from an existing DeviceAllocation.
27 * It doesn't handle the allocation and deallocation.
28 * You can see a Ndarray as a view or interpretation of DeviceAllocation
29 * with specified dtype & layout.
30 */
31 explicit Ndarray(DeviceAllocation &devalloc,
32 const DataType type,
33 const std::vector<int> &shape,
34 ExternalArrayLayout layout = ExternalArrayLayout::kNull);
35
36 /* Constructs a Ndarray from an existing DeviceAllocation.
37 * This is an overloaded constructor for constructing Ndarray with TensorType
38 * elements "type" is expected to be PrimitiveType
39 */
40 explicit Ndarray(DeviceAllocation &devalloc,
41 const DataType type,
42 const std::vector<int> &shape,
43 const std::vector<int> &element_shape,
44 ExternalArrayLayout layout = ExternalArrayLayout::kNull);
45
46 DeviceAllocation ndarray_alloc_{kDeviceNullAllocation};
47 DataType dtype;
48 // Invariant: Since ndarray indices are flattened for vector/matrix, this is
49 // always true:
50 // num_active_indices = shape.size()
51 std::vector<int> shape;
52 ExternalArrayLayout layout{ExternalArrayLayout::kNull};
53
54 std::vector<int> get_element_shape() const;
55 DataType get_element_data_type() const;
56 intptr_t get_data_ptr_as_int() const;
57 intptr_t get_device_allocation_ptr_as_int() const;
58 DeviceAllocation get_device_allocation() const;
59 std::size_t get_element_size() const;
60 std::size_t get_nelement() const;
61 TypedConstant read(const std::vector<int> &I) const;
62 template <typename T>
63 void write(const std::vector<int> &I, T val) const;
64 int64 read_int(const std::vector<int> &i);
65 uint64 read_uint(const std::vector<int> &i);
66 float64 read_float(const std::vector<int> &i);
67 void write_int(const std::vector<int> &i, int64 val);
68 void write_float(const std::vector<int> &i, float64 val);
69
70 const std::vector<int> &total_shape() const {
71 return total_shape_;
72 }
73 ~Ndarray();
74
75 private:
76 std::size_t nelement_{1};
77 std::size_t element_size_{1};
78 std::vector<int> total_shape_;
79
80 Program *prog_{nullptr};
81};
82
83} // namespace taichi::lang
84