1 | #pragma once |
2 | |
3 | #include <cstdint> |
4 | #include <functional> |
5 | #include <memory> |
6 | |
7 | #include <c10/util/hash.h> |
8 | |
9 | namespace torch { |
10 | namespace autograd { |
11 | |
12 | struct Node; |
13 | |
14 | /// Represents a particular input of a function. |
15 | struct Edge { |
16 | Edge() noexcept : function(nullptr), input_nr(0) {} |
17 | |
18 | Edge(std::shared_ptr<Node> function_, uint32_t input_nr_) noexcept |
19 | : function(std::move(function_)), input_nr(input_nr_) {} |
20 | |
21 | /// Convenience method to test if an edge is valid. |
22 | bool is_valid() const noexcept { |
23 | return function != nullptr; |
24 | } |
25 | |
26 | // Required for use in associative containers. |
27 | bool operator==(const Edge& other) const noexcept { |
28 | return this->function == other.function && this->input_nr == other.input_nr; |
29 | } |
30 | |
31 | bool operator!=(const Edge& other) const noexcept { |
32 | return !(*this == other); |
33 | } |
34 | |
35 | /// The function this `Edge` points to. |
36 | std::shared_ptr<Node> function; |
37 | |
38 | /// The identifier of a particular input to the function. |
39 | uint32_t input_nr; |
40 | }; |
41 | } // namespace autograd |
42 | } // namespace torch |
43 | |
44 | // The idiomatic way of enabling use of a custom type as the key of hash |
45 | // containers in C++11. This method removes the requirement of having to pass |
46 | // a custom hasher to std::unordered_{map, set}. |
47 | // See http://en.cppreference.com/w/cpp/utility/hash for more information. |
48 | namespace std { |
49 | template <> |
50 | struct hash<torch::autograd::Edge> { |
51 | // These type aliases are required by the standard. |
52 | using argument_type = torch::autograd::Edge; |
53 | using return_type = size_t; |
54 | return_type operator()(const argument_type& edge) const noexcept { |
55 | return c10::get_hash(edge.function, edge.input_nr); |
56 | } |
57 | }; |
58 | } // namespace std |
59 | |