1 | #include <torch/csrc/jit/passes/remove_dropout.h> |
---|---|
2 | |
3 | namespace torch { |
4 | namespace jit { |
5 | |
6 | namespace { |
7 | bool isDropoutRemovable(const Node* node) { |
8 | const auto inputs = node->inputs(); |
9 | TORCH_INTERNAL_ASSERT(inputs.size() == 3); |
10 | const Value* training_input = inputs[2]; |
11 | auto optional_ivalue = toIValue(training_input); |
12 | if (!optional_ivalue) { |
13 | return false; |
14 | } |
15 | const IValue& val = optional_ivalue.value(); |
16 | TORCH_INTERNAL_ASSERT(val.isBool()); |
17 | const bool is_training = val.toBool(); |
18 | return !is_training; |
19 | } |
20 | |
21 | void removeDropoutImpl(Block* block) { |
22 | std::vector<Node*> deleted_nodes; |
23 | |
24 | for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); it++) { |
25 | Node* node = *it; |
26 | for (auto block : node->blocks()) { |
27 | removeDropoutImpl(block); |
28 | } |
29 | if ((node->kind() == c10::Symbol::fromQualString("aten::dropout") || |
30 | node->kind() == c10::Symbol::fromQualString("aten::dropout_") || |
31 | node->kind() == c10::Symbol::fromQualString("aten::feature_dropout") || |
32 | node->kind() == |
33 | c10::Symbol::fromQualString("aten::feature_dropout_")) && |
34 | isDropoutRemovable(*it)) { |
35 | // Input tensor of dropout. |
36 | Value* input_value = node->inputs()[0]; |
37 | // Output tensor. |
38 | Value* output_value = node->outputs()[0]; |
39 | output_value->replaceAllUsesWith(input_value); |
40 | deleted_nodes.push_back(node); |
41 | } |
42 | } |
43 | for (auto del_node : deleted_nodes) { |
44 | del_node->destroy(); |
45 | } |
46 | } |
47 | } // namespace |
48 | |
49 | void removeDropout(std::shared_ptr<Graph>& graph) { |
50 | removeDropoutImpl(graph->block()); |
51 | } |
52 | |
53 | void removeDropout(script::Module& module) { |
54 | TORCH_CHECK( |
55 | !module.hasattr("training") || !module.is_training(), |
56 | "Dropout removal module in training mode is not yet supported"); |
57 | auto graph = module.get_method("forward").graph(); |
58 | removeDropout(graph); |
59 | } |
60 | |
61 | } // namespace jit |
62 | } // namespace torch |
63 |