1#pragma once
2
3#include <torch/csrc/jit/ir/ir.h>
4
5namespace torch {
6namespace jit {
7
8// propagate autograd zero information through a gradient graph and
9// remove grad_of blocks if present.
10// Note: this is a very limited pass. It only propagates autograd zeros for
11// operations generated by the symbolic autodiff code and cleans up
12// AutogradAdds when possible. Outputs of other nodes are conservatively
13// marked Unknown and not optimized.
14TORCH_API void specializeAutogradZero(std::shared_ptr<Graph> g);
15
16struct ProfilingRecord;
17
18TORCH_API void InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord* pr);
19
20} // namespace jit
21} // namespace torch
22