1 | /******************************************************************************* |
2 | Copyright (c) The Taichi Authors (2016- ). All Rights Reserved. |
3 | The use of this software is governed by the LICENSE file. |
4 | *******************************************************************************/ |
5 | |
6 | #include "taichi/python/export.h" |
7 | #include "taichi/common/dict.h" |
8 | |
9 | namespace taichi { |
10 | |
11 | template <typename T> |
12 | constexpr std::string get_type_short_name(); |
13 | |
14 | template <> |
15 | std::string get_type_short_name<float32>() { |
16 | return "f" ; |
17 | } |
18 | |
19 | template <> |
20 | std::string get_type_short_name<float64>() { |
21 | return "d" ; |
22 | } |
23 | |
24 | template <> |
25 | std::string get_type_short_name<int>() { |
26 | return "i" ; |
27 | } |
28 | |
29 | template <> |
30 | std::string get_type_short_name<int64>() { |
31 | return "I" ; |
32 | } |
33 | |
34 | template <> |
35 | std::string get_type_short_name<uint64>() { |
36 | return "U" ; |
37 | } |
38 | |
39 | template <typename T> |
40 | struct get_dim {}; |
41 | |
42 | template <int dim, typename T, InstSetExt ISE> |
43 | struct get_dim<VectorND<dim, T, ISE>> { |
44 | constexpr static int value = dim; |
45 | }; |
46 | |
47 | template <int dim, typename T> |
48 | struct VectorInitializer {}; |
49 | |
50 | template <typename T> |
51 | struct VectorInitializer<1, T> { |
52 | static auto get() { |
53 | return py::init<T>(); |
54 | } |
55 | }; |
56 | |
57 | template <typename T> |
58 | struct VectorInitializer<2, T> { |
59 | static auto get() { |
60 | return py::init<T, T>(); |
61 | } |
62 | }; |
63 | |
64 | template <typename T> |
65 | struct VectorInitializer<3, T> { |
66 | static auto get() { |
67 | return py::init<T, T, T>(); |
68 | } |
69 | }; |
70 | |
71 | template <typename T> |
72 | struct VectorInitializer<4, T> { |
73 | static auto get() { |
74 | return py::init<T, T, T, T>(); |
75 | } |
76 | }; |
77 | |
78 | template <int i, typename VEC> |
79 | struct get_vec_field {}; |
80 | |
81 | template <typename VEC> |
82 | struct get_vec_field<0, VEC> { |
83 | static auto get() { |
84 | return &VEC::x; |
85 | } |
86 | }; |
87 | |
88 | template <typename VEC> |
89 | struct get_vec_field<1, VEC> { |
90 | static auto get() { |
91 | return &VEC::y; |
92 | } |
93 | }; |
94 | |
95 | template <typename VEC> |
96 | struct get_vec_field<2, VEC> { |
97 | static auto get() { |
98 | return &VEC::z; |
99 | } |
100 | }; |
101 | |
102 | template <typename VEC> |
103 | struct get_vec_field<3, VEC> { |
104 | static auto get() { |
105 | return &VEC::w; |
106 | } |
107 | }; |
108 | |
109 | template <int i, |
110 | typename VEC, |
111 | typename Class, |
112 | std::enable_if_t<get_dim<VEC>::value<i + 1, int> = 0> void |
113 | register_vec_field(Class &cls) { |
114 | } |
115 | |
116 | template <int i, |
117 | typename VEC, |
118 | typename Class, |
119 | std::enable_if_t<get_dim<VEC>::value >= i + 1, int> = 0> |
120 | void register_vec_field(Class &cls) { |
121 | static const char *names[4] = {"x" , "y" , "z" , "w" }; |
122 | cls.def_readwrite(names[i], get_vec_field<i, VEC>::get()); |
123 | } |
124 | |
125 | template <typename T> |
126 | struct VectorRegistration {}; |
127 | |
128 | template <int dim, typename T, InstSetExt ISE> |
129 | struct VectorRegistration<VectorND<dim, T, ISE>> { |
130 | static void run(py::module &m) { |
131 | using Vector = VectorND<dim, T, ISE>; |
132 | |
133 | // e.g. Vector4f |
134 | std::string vector_name = |
135 | std::string("Vector" ) + std::to_string(dim) + get_type_short_name<T>(); |
136 | |
137 | auto cls = py::class_<Vector>(m, vector_name.c_str()); |
138 | cls.def(VectorInitializer<dim, T>::get()) |
139 | .def(py::init<T>()) |
140 | .def("__len__" , [](Vector *) { return Vector::dim; }) |
141 | .def("__getitem__" , [](Vector *vec, int i) { return (*vec)[i]; }); |
142 | |
143 | register_vec_field<0, Vector>(cls); |
144 | register_vec_field<1, Vector>(cls); |
145 | register_vec_field<2, Vector>(cls); |
146 | register_vec_field<3, Vector>(cls); |
147 | } |
148 | }; |
149 | |
150 | void export_math(py::module &m) { |
151 | VectorRegistration<Vector2f>::run(m); |
152 | VectorRegistration<Vector3f>::run(m); |
153 | VectorRegistration<Vector4f>::run(m); |
154 | |
155 | VectorRegistration<Vector2d>::run(m); |
156 | VectorRegistration<Vector3d>::run(m); |
157 | VectorRegistration<Vector4d>::run(m); |
158 | |
159 | VectorRegistration<Vector2i>::run(m); |
160 | VectorRegistration<Vector3i>::run(m); |
161 | VectorRegistration<Vector4i>::run(m); |
162 | } |
163 | |
164 | } // namespace taichi |
165 | |