1#pragma once
2
3#include <cstdint>
4#include <functional>
5#include <memory>
6
7#include <c10/util/hash.h>
8
9namespace torch {
10namespace autograd {
11
12struct Node;
13
14/// Represents a particular input of a function.
15struct Edge {
16 Edge() noexcept : function(nullptr), input_nr(0) {}
17
18 Edge(std::shared_ptr<Node> function_, uint32_t input_nr_) noexcept
19 : function(std::move(function_)), input_nr(input_nr_) {}
20
21 /// Convenience method to test if an edge is valid.
22 bool is_valid() const noexcept {
23 return function != nullptr;
24 }
25
26 // Required for use in associative containers.
27 bool operator==(const Edge& other) const noexcept {
28 return this->function == other.function && this->input_nr == other.input_nr;
29 }
30
31 bool operator!=(const Edge& other) const noexcept {
32 return !(*this == other);
33 }
34
35 /// The function this `Edge` points to.
36 std::shared_ptr<Node> function;
37
38 /// The identifier of a particular input to the function.
39 uint32_t input_nr;
40};
41} // namespace autograd
42} // namespace torch
43
44// The idiomatic way of enabling use of a custom type as the key of hash
45// containers in C++11. This method removes the requirement of having to pass
46// a custom hasher to std::unordered_{map, set}.
47// See http://en.cppreference.com/w/cpp/utility/hash for more information.
48namespace std {
49template <>
50struct hash<torch::autograd::Edge> {
51 // These type aliases are required by the standard.
52 using argument_type = torch::autograd::Edge;
53 using return_type = size_t;
54 return_type operator()(const argument_type& edge) const noexcept {
55 return c10::get_hash(edge.function, edge.input_nr);
56 }
57};
58} // namespace std
59