1 | #pragma once |
2 | |
3 | // The InputBuffer class accumulates a list of Variables for use by a |
4 | // function. It implements logic to avoid modifying the passed |
5 | // values in-place (adding an input twice will accumulate the result). |
6 | // This behaviour is needed and used only in backward graphs. |
7 | |
8 | #include <memory> |
9 | #include <utility> |
10 | #include <vector> |
11 | |
12 | #include <c10/core/Stream.h> |
13 | #include <c10/util/Optional.h> |
14 | #include <torch/csrc/autograd/variable.h> |
15 | |
16 | namespace torch { |
17 | namespace autograd { |
18 | |
19 | struct InputBuffer { |
20 | explicit InputBuffer(size_t size) : buffer(size) {} |
21 | InputBuffer(const InputBuffer& other) = delete; |
22 | InputBuffer(InputBuffer&& other) = default; |
23 | explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)){}; |
24 | InputBuffer& operator=(InputBuffer&& other) = default; |
25 | |
26 | // Accumulates the variable at a specified index. |
27 | // The optional CUDA streams determine which stream the accumulation |
28 | // is run on and how the addition is synchronized. |
29 | void add( |
30 | size_t pos, |
31 | Variable&& var, |
32 | const c10::optional<c10::Stream>& opt_producer_stream, |
33 | const c10::optional<c10::Stream>& opt_consumer_stream); |
34 | |
35 | at::Device device() const; |
36 | |
37 | Variable operator[](size_t pos) { |
38 | return buffer[pos]; |
39 | } |
40 | |
41 | // Returns the inputs as a list of variables. Destroys given InputBuffer. |
42 | static std::vector<Variable> variables(InputBuffer&& g); |
43 | |
44 | std::vector<Variable> buffer; |
45 | }; |
46 | |
47 | } // namespace autograd |
48 | } // namespace torch |
49 | |