1#pragma once
2
3#include <c10/macros/Macros.h>
4#include <c10/macros/Export.h>
5#include <c10/util/TypeTraits.h>
6#include <c10/util/TypeList.h>
7#include <c10/util/intrusive_ptr.h>
8#include <c10/util/order_preserving_flat_hash_map.h>
9#include <c10/util/Optional.h>
10#include <ATen/core/TensorBody.h>
11#include <ATen/core/jit_type_base.h>
12
13namespace c10 {
14struct IValue;
15template<class Key, class Value> class Dict;
16struct Type;
17
18namespace impl {
19
20using valid_dict_key_types = guts::typelist::typelist<
21 int64_t,
22 std::string,
23 double,
24 c10::complex<double>,
25 bool,
26 at::Tensor
27>;
28}
29
30namespace detail {
31
32struct DictKeyHash {
33 size_t operator()(const IValue& ivalue) const;
34};
35
36struct DictKeyEqualTo {
37 bool operator()(const IValue& lhs, const IValue& rhs) const;
38};
39
40struct DictImpl final : public c10::intrusive_ptr_target {
41 using dict_map_type = ska_ordered::order_preserving_flat_hash_map<IValue, IValue, DictKeyHash, DictKeyEqualTo>;
42 struct DictElementTypes final {
43 TypePtr keyType;
44 TypePtr valueType;
45 };
46
47 explicit DictImpl(dict_map_type dict_, DictElementTypes elementTypes_)
48 : dict(std::move(dict_))
49 , elementTypes(std::move(elementTypes_)) {}
50 dict_map_type dict;
51
52 DictElementTypes elementTypes;
53
54 intrusive_ptr<DictImpl> copy() const;
55 friend TORCH_API bool operator==(const DictImpl& lhs, const DictImpl& rhs);
56};
57
58}
59
60namespace impl {
61template<class Key, class Value, class Iterator> class DictIterator;
62
63/**
64 * A reference to an entry in the Dict.
65 * Use the `key()` and `value()` methods to read the element.
66 */
67template<class Key, class Value, class Iterator>
68class DictEntryRef final {
69public:
70 explicit DictEntryRef(Iterator iterator)
71 : iterator_(std::move(iterator)) {}
72
73 decltype(auto) key() const {
74 return iterator_->first.template to<Key>();
75 }
76
77 decltype(auto) value() const {
78 return iterator_->second.template to<Value>();
79 }
80
81 template<class Value_>
82 void setValue(Value_&& value) const {
83 static_assert(std::is_constructible<Value, Value_>::value, "Wrong type for the value argument of setValue()");
84 iterator_->second = Value(std::forward<Value_>(value));
85 }
86
87private:
88 // allow copying and moving, but only our friends (i.e. the Dict class) can do
89 // it. Copying/moving this reference wrapper would be too ambiguous to allow it
90 // in the public API.
91 DictEntryRef(const DictEntryRef&) = default;
92 DictEntryRef& operator=(const DictEntryRef&) = default;
93 DictEntryRef(DictEntryRef&&) noexcept = default;
94 DictEntryRef& operator=(DictEntryRef&& rhs) & noexcept = default;
95
96 Iterator iterator_;
97 friend class DictIterator<Key, Value, Iterator>;
98 friend class Dict<Key, Value>;
99};
100
101// this wraps map_type::iterator to make sure user code can't rely
102// on it being the type of the underlying map.
103template<class Key, class Value, class Iterator>
104class DictIterator final {
105public:
106 // C++17 friendly std::iterator implementation
107 using iterator_category = std::forward_iterator_tag;
108 using value_type = DictEntryRef<Key, Value, Iterator>;
109 using difference_type = std::ptrdiff_t;
110 using pointer = value_type*;
111 using reference = value_type&;
112
113 explicit DictIterator() = default;
114 ~DictIterator() = default;
115
116 DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {}
117 DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {}
118 DictIterator& operator=(const DictIterator& rhs) {
119 entryRef_ = rhs.entryRef_;
120 return *this;
121 }
122 DictIterator& operator=(DictIterator&& rhs) noexcept {
123 entryRef_ = std::move(rhs.entryRef_);
124 return *this;
125 }
126
127 DictIterator& operator++() {
128 ++entryRef_.iterator_;
129 return *this;
130 }
131
132 DictIterator operator++(int) {
133 DictIterator copy(*this);
134 ++*this;
135 return copy;
136 }
137
138 const DictEntryRef<Key, Value, Iterator>& operator*() const {
139 return entryRef_;
140 }
141
142 const DictEntryRef<Key, Value, Iterator>* operator->() const {
143 return &entryRef_;
144 }
145
146 friend difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) {
147 return lhs.entryRef_.iterator_ - rhs.entryRef_.iterator_;
148 }
149
150private:
151 explicit DictIterator(Iterator iterator): entryRef_(std::move(iterator)) {}
152
153 const Iterator& get_iterator_() const {
154 return entryRef_.iterator_;
155 }
156
157 friend bool operator==(const DictIterator& lhs, const DictIterator& rhs) {
158 return lhs.get_iterator_() == rhs.get_iterator_();
159 }
160
161 friend bool operator!=(const DictIterator& lhs, const DictIterator& rhs) {
162 return lhs.get_iterator_() != rhs.get_iterator_();
163 }
164
165 friend bool operator<(const DictIterator& lhs, const DictIterator& rhs) {
166 return lhs.get_iterator_() < rhs.get_iterator_();
167 }
168
169 friend bool operator<=(const DictIterator& lhs, const DictIterator& rhs) {
170 return lhs.get_iterator_() <= rhs.get_iterator_();
171 }
172
173 friend bool operator>(const DictIterator& lhs, const DictIterator& rhs) {
174 return lhs.get_iterator_() > rhs.get_iterator_();
175 }
176
177 friend bool operator>=(const DictIterator& lhs, const DictIterator& rhs) {
178 return lhs.get_iterator_() >= rhs.get_iterator_();
179 }
180
181 DictEntryRef<Key, Value, Iterator> entryRef_;
182
183 friend class DictIterator<Key, Value, typename c10::detail::DictImpl::dict_map_type::iterator>;
184 friend class Dict<Key, Value>;
185};
186
187template<class Key, class Value> Dict<Key, Value> toTypedDict(Dict<IValue, IValue> dict);
188template<class Key, class Value> Dict<IValue, IValue> toGenericDict(Dict<Key, Value> dict);
189}
190
191/**
192 * An object of this class stores a map from Key to Value.
193 *
194 * This is a pointer type. After a copy, both Dicts
195 * will share the same storage:
196 *
197 * > Dict<int, string> a;
198 * > Dict<int, string> b = a;
199 * > b.insert(3, "three");
200 * > ASSERT("three" == a.at(3));
201 *
202 * We use this class in the PyTorch kernel API because that
203 * allows us to do optimizations and switch out the underlying
204 * map implementation without breaking backwards compatibility
205 * for the kernel API.
206 */
207template<class Key, class Value>
208class Dict final {
209private:
210 static_assert((std::is_same<IValue, Key>::value && std::is_same<IValue, Value>::value) || guts::typelist::contains<impl::valid_dict_key_types, Key>::value, "Invalid Key type for Dict. We only support int64_t, double, bool, and string.");
211
212 // impl_ stores the underlying map as a ska_ordered::order_preserving_flat_hash_map.
213 // We intentionally don't offer conversion from/to
214 // order_preserving_flat_hash_map, return references to it or something like that,
215 // because such operations would get expensive if we switch out
216 // the actual map implementation.
217 // This is an intrusive_ptr because Dict is a pointer type.
218 // Invariant: This will never be a nullptr, there will always be a valid
219 // DictImpl.
220 c10::intrusive_ptr<detail::DictImpl> impl_;
221
222 explicit Dict(c10::intrusive_ptr<detail::DictImpl>&& impl);
223 friend struct IValue;
224 template<class K, class V> friend Dict<K, V> impl::toTypedDict(Dict<IValue, IValue>);
225 template<class K, class V> friend Dict<IValue, IValue> impl::toGenericDict(Dict<K, V>);
226
227public:
228 using key_type = Key;
229 using mapped_type = Value;
230 using size_type = typename detail::DictImpl::dict_map_type::size_type;
231 using iterator = impl::DictIterator<Key, Value, typename detail::DictImpl::dict_map_type::iterator>;
232
233 /**
234 * Creates an empty dict.
235 */
236 explicit Dict();
237
238 /**
239 * Create a generic dict with runtime type information.
240 * This only works for c10::impl::GenericDict and is not part of the public API
241 * but only supposed to be used internally by PyTorch.
242 */
243 explicit Dict(TypePtr keyType, TypePtr valueType);
244
245 ~Dict() = default;
246
247 Dict(const Dict&) = default;
248 Dict& operator=(const Dict&) = default;
249
250 /**
251 * Create a new Dict pointing to a deep copy of the same data.
252 * The Dict returned is a new dict with separate storage.
253 * Changes in it are not reflected in the original dict or vice versa.
254 */
255 Dict copy() const;
256
257 /**
258 * Returns an iterator to the first element of the container.
259 * If the container is empty, the returned iterator will be equal to end().
260 */
261 iterator begin() const;
262
263 /**
264 * Returns an iterator to the element following the last element of the container.
265 * This element acts as a placeholder; attempting to access it results in undefined behavior.
266 */
267 iterator end() const;
268
269 /**
270 * Checks if the container has no elements.
271 */
272 bool empty() const;
273
274 /**
275 * Returns the number of elements in the container.
276 */
277 size_type size() const;
278
279 /**
280 * Erases all elements from the container. After this call, size() returns zero.
281 * Invalidates any references, pointers, or iterators referring to contained elements. May also invalidate past-the-end iterators.
282 */
283 void clear() const;
284
285 /**
286 * Inserts element(s) into the container, if the container doesn't already contain an element with an equivalent key.
287 * May invalidate any references, pointers, or iterators referring to contained elements.
288 *
289 * @return A pair consisting of an iterator to the inserted element (or to the element that prevented the insertion) and a bool denoting whether the insertion took place.
290 */
291 template<class Key_, class Value_>
292 std::pair<iterator, bool> insert(Key_&& key, Value_&& value) const;
293
294 /**
295 * If an element with the given key already exists, it is overwritten with the given value.
296 * Otherwise, a new element with the given key and value are inserted.
297 * May invalidate any references, pointers, or iterators referring to contained elements.
298 *
299 * @return The bool component is true if the insertion took place and false if the assignment took place. The iterator component is pointing at the element that was inserted or updated.
300 */
301 template<class Key_, class Value_>
302 std::pair<iterator, bool> insert_or_assign(Key_&& key, Value_&& value) const;
303
304 /**
305 * Removes the element pointed to by iter.
306 * May invalidate any references, pointers, or iterators referring to contained elements.
307 * The iterator iter must be valid and dereferenceable. Thus the end() iterator (which is valid, but is not dereferenceable) cannot be used as a value for iter.
308 */
309 void erase(iterator iter) const;
310
311 /**
312 * Removes the element with the given key, if it exists.
313 * May invalidate any references, pointers, or iterators referring to contained elements.
314 *
315 * @return The number of elements removed. This is either '1' if an element with the key existed, or '0' if it didn't.
316 */
317 C10_NODISCARD size_t erase(const Key& key) const;
318
319 /**
320 * Returns the mapped value of the element with key equivalent to key.
321 * If no such element exists, an exception of type std::out_of_range is thrown.
322 */
323 Value at(const Key& key) const;
324
325 /**
326 * Finds an element with key equivalent to key.
327 *
328 * @return Iterator to an element with key equivalent to key.
329 * If no such element is found, past-the-end (see end()) iterator is returned.
330 */
331 iterator find(const Key& key) const;
332
333 /**
334 * Checks if there is an element with key equivalent to key in the container.
335 *
336 * @return true if there is such an element, otherwise false.
337 */
338 bool contains(const Key& key) const;
339
340 /**
341 * Increase the capacity so that at least count elements can be stored without
342 * having to reallocate or rehash.
343 */
344 void reserve(size_type count) const;
345
346 /**
347 * Value equality comparison. This function implements Python-like semantics for
348 * equality: two dicts with the same identity (e.g. same pointer) trivially
349 * compare equal, otherwise each element is compared for equality.
350 */
351 template <class Key_, class Value_>
352 friend bool operator==(
353 const Dict<Key_, Value_>& lhs,
354 const Dict<Key_, Value_>& rhs);
355 template <class Key_, class Value_>
356 friend bool operator!=(
357 const Dict<Key_, Value_>& lhs,
358 const Dict<Key_, Value_>& rhs);
359
360 /**
361 * Identity comparison. Returns true if and only if `rhs` represents the same
362 * Dict object as `this`.
363 */
364 bool is(const Dict& rhs) const;
365
366 // private API for now because the return type will change to TypePtr
367 // instead of optional<TypePtr> once types are mandatory.
368 TypePtr keyType() const;
369 TypePtr valueType() const;
370
371 // [unsafe set type]
372 // These functions mutate the tagged type of this dictionary in place.
373 // There is no checking that the members of the dictionary are instances
374 // of the new types, nor is there a check that other IValues which
375 // hold references to this dictionary have the right static type.
376 // This functionality is used only in the unpickler, where at
377 // creation type the real type of the dictionary is unknown, but
378 // then later recovered from the static type information of the
379 // unpickled object.
380 void unsafeSetKeyType(TypePtr t);
381 void unsafeSetValueType(TypePtr t);
382};
383
384namespace impl {
385// GenericDict is how IValue stores dicts. It is, however, not part of the
386// public API. Kernels should use Dicts with concrete Key, Value types instead
387// (maybe except for some internal prim ops).
388using GenericDict = Dict<IValue, IValue>;
389
390}
391}
392
393namespace torch {
394 template<class Key, class Value> using Dict = c10::Dict<Key, Value>;
395}
396
397#include <ATen/core/Dict_inl.h> // IWYU pragma: keep
398