1 | #pragma once |
2 | |
3 | #include <torch/csrc/lazy/backend/backend_interface.h> |
4 | |
5 | namespace torch { |
6 | namespace lazy { |
7 | |
8 | class TORCH_API TSData : public torch::lazy::BackendData { |
9 | public: |
10 | TSData(const at::Scalar& scalar, const torch::lazy::BackendDevice& device) |
11 | : torch::lazy::BackendData(device, torch::lazy::Shape(scalar.type(), {})), |
12 | scalar(scalar) {} |
13 | |
14 | TSData( |
15 | const at::Tensor& data, |
16 | const torch::lazy::Shape& shape, |
17 | const torch::lazy::BackendDevice& device) |
18 | : torch::lazy::BackendData(device, shape), data_(data) {} |
19 | |
20 | TSData( |
21 | const torch::lazy::Shape& shape, |
22 | const torch::lazy::BackendDevice& device) |
23 | : torch::lazy::BackendData(device, shape) {} |
24 | |
25 | Handle GetHandle() override { |
26 | return reinterpret_cast<int64_t>(this); |
27 | } |
28 | |
29 | void Assign(const torch::lazy::BackendData& data) override { |
30 | data_ = static_cast<const TSData&>(data).data_; |
31 | } |
32 | |
33 | bool HasValue() const override { |
34 | return data_.defined(); |
35 | } |
36 | |
37 | at::Tensor data() { |
38 | return data_; |
39 | } |
40 | |
41 | c10::optional<at::Scalar> scalar; |
42 | |
43 | private: |
44 | at::Tensor data_; |
45 | }; |
46 | |
47 | TORCH_API torch::lazy::BackendImplInterface* GetTSBackendImpl(); |
48 | |
49 | TORCH_API void InitTorchScriptBackend(); |
50 | |
51 | } // namespace lazy |
52 | } // namespace torch |
53 | |