1#pragma once
2
3#include <ATen/core/ivalue.h>
4
5namespace torch {
6namespace jit {
7
8// Used in torch.package and TorchScript serialization to coordinate
9// sharing of storages between models. Also used to create deterministic
10// naming for storages.
11class TORCH_API SerializationStorageContext {
12 public:
13 explicit SerializationStorageContext() = default;
14 SerializationStorageContext operator=(const SerializationStorageContext&) =
15 delete;
16 SerializationStorageContext(const SerializationStorageContext&) = delete;
17
18 uint64_t getOrAddStorage(c10::Storage storage) {
19 if (!hasStorage(storage)) {
20 uint64_t size = storage_id_map_.size();
21 storage_id_map_[storage] = size;
22 }
23 return storage_id_map_[storage];
24 }
25
26 bool hasStorage(c10::Storage storage) {
27 return storage_id_map_.find(storage) != storage_id_map_.end();
28 }
29
30 ~SerializationStorageContext() = default;
31
32 private:
33 class StorageSerializationHash {
34 public:
35 size_t operator()(const c10::Storage& storage) const {
36 return std::hash<void*>()(
37 reinterpret_cast<void*>(storage.unsafeGetStorageImpl()));
38 }
39 };
40
41 class StorageSerializationEqual {
42 public:
43 bool operator()(const c10::Storage& lhs, const c10::Storage& rhs) const {
44 return lhs.unsafeGetStorageImpl() == rhs.unsafeGetStorageImpl();
45 }
46 };
47
48 std::unordered_map<
49 c10::Storage,
50 uint64_t,
51 StorageSerializationHash,
52 StorageSerializationEqual>
53 storage_id_map_;
54};
55
56// Used in torch.package and TorchScript deserialization to coordinate
57// sharing of storages between models.
58class TORCH_API DeserializationStorageContext {
59 public:
60 explicit DeserializationStorageContext() = default;
61 DeserializationStorageContext operator=(
62 const DeserializationStorageContext&) = delete;
63 DeserializationStorageContext(const DeserializationStorageContext&) = delete;
64
65 void addStorage(const std::string& name, c10::Storage storage) {
66 TORCH_INTERNAL_ASSERT(!hasStorage(name));
67 name_storage_map_.insert({name, storage});
68 }
69
70 bool hasStorage(const std::string& name) {
71 return name_storage_map_.find(name) != name_storage_map_.end();
72 }
73
74 c10::Storage getStorage(const std::string& name) {
75 TORCH_INTERNAL_ASSERT(hasStorage(name));
76 return name_storage_map_.find(name)->second;
77 }
78 ~DeserializationStorageContext() = default;
79
80 private:
81 std::unordered_map<std::string, c10::Storage> name_storage_map_;
82};
83
84} // namespace jit
85} // namespace torch
86