1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/api/module.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | /** Recursively deduplicate multiple uses of the same module by |
9 | * creating an instance clone for each use of the module, which means |
10 | * the type will be the same as before and all the attributes will be |
11 | * copied, then we'll change the use of the original module to the use |
12 | * of cloned module in the Graph. |
13 | * |
14 | * This is done to ensure that modules can survive destructive passes |
15 | * without changing model behavior. For example, here: |
16 | * |
17 | * x = self.conv1(x) |
18 | * x = self.relu(x) |
19 | * x = self.conv2(x) |
20 | * x = self.relu(x) |
21 | * |
22 | * self.relu needs to be deduplicated for potential future destructive passes |
23 | * to work properly. |
24 | */ |
25 | TORCH_API void DedupModuleUses(Module& module); |
26 | |
27 | } // namespace jit |
28 | } // namespace torch |
29 | |