1#pragma once
2
3#include <torch/csrc/lazy/backend/backend_interface.h>
4
5namespace torch {
6namespace lazy {
7
8class TORCH_API TSData : public torch::lazy::BackendData {
9 public:
10 TSData(const at::Scalar& scalar, const torch::lazy::BackendDevice& device)
11 : torch::lazy::BackendData(device, torch::lazy::Shape(scalar.type(), {})),
12 scalar(scalar) {}
13
14 TSData(
15 const at::Tensor& data,
16 const torch::lazy::Shape& shape,
17 const torch::lazy::BackendDevice& device)
18 : torch::lazy::BackendData(device, shape), data_(data) {}
19
20 TSData(
21 const torch::lazy::Shape& shape,
22 const torch::lazy::BackendDevice& device)
23 : torch::lazy::BackendData(device, shape) {}
24
25 Handle GetHandle() override {
26 return reinterpret_cast<int64_t>(this);
27 }
28
29 void Assign(const torch::lazy::BackendData& data) override {
30 data_ = static_cast<const TSData&>(data).data_;
31 }
32
33 bool HasValue() const override {
34 return data_.defined();
35 }
36
37 at::Tensor data() {
38 return data_;
39 }
40
41 c10::optional<at::Scalar> scalar;
42
43 private:
44 at::Tensor data_;
45};
46
47TORCH_API torch::lazy::BackendImplInterface* GetTSBackendImpl();
48
49TORCH_API void InitTorchScriptBackend();
50
51} // namespace lazy
52} // namespace torch
53