1 | #pragma once |
---|---|
2 | |
3 | #include <torch/csrc/lazy/backend/backend_device.h> |
4 | #include <torch/csrc/lazy/core/shape.h> |
5 | #include <cstring> |
6 | |
7 | namespace torch { |
8 | namespace lazy { |
9 | |
10 | class TORCH_API BackendData { |
11 | public: |
12 | struct Info { |
13 | /** |
14 | * Used by Lazy Graph Executor to tag info on BackendData objs |
15 | * */ |
16 | virtual ~Info() = default; |
17 | }; |
18 | /** |
19 | * Represents (Tensor) data stored on a backend device |
20 | * in its native format. |
21 | * */ |
22 | using Handle = int64_t; |
23 | |
24 | BackendData(BackendDevice device, Shape shape) |
25 | : device_(std::move(device)), shape_(std::move(shape)) {} |
26 | |
27 | virtual ~BackendData() = default; |
28 | |
29 | const BackendDevice& device() const { |
30 | return device_; |
31 | } |
32 | |
33 | const Shape& shape() const { |
34 | return shape_; |
35 | } |
36 | |
37 | Info* info() const { |
38 | return info_.get(); |
39 | } |
40 | |
41 | std::shared_ptr<Info> SetInfo(std::shared_ptr<Info> info) { |
42 | std::swap(info, info_); |
43 | return info; |
44 | } |
45 | |
46 | virtual Handle GetHandle() = 0; |
47 | |
48 | virtual void Assign(const BackendData& data) = 0; |
49 | |
50 | virtual bool HasValue() const = 0; |
51 | |
52 | private: |
53 | BackendDevice device_; |
54 | Shape shape_; |
55 | std::shared_ptr<Info> info_; |
56 | }; |
57 | |
58 | using BackendDataPtr = std::shared_ptr<BackendData>; |
59 | |
60 | } // namespace lazy |
61 | } // namespace torch |
62 |