1#include <torch/csrc/jit/passes/remove_dropout.h>
2
3namespace torch {
4namespace jit {
5
6namespace {
7bool isDropoutRemovable(const Node* node) {
8 const auto inputs = node->inputs();
9 TORCH_INTERNAL_ASSERT(inputs.size() == 3);
10 const Value* training_input = inputs[2];
11 auto optional_ivalue = toIValue(training_input);
12 if (!optional_ivalue) {
13 return false;
14 }
15 const IValue& val = optional_ivalue.value();
16 TORCH_INTERNAL_ASSERT(val.isBool());
17 const bool is_training = val.toBool();
18 return !is_training;
19}
20
21void removeDropoutImpl(Block* block) {
22 std::vector<Node*> deleted_nodes;
23
24 for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); it++) {
25 Node* node = *it;
26 for (auto block : node->blocks()) {
27 removeDropoutImpl(block);
28 }
29 if ((node->kind() == c10::Symbol::fromQualString("aten::dropout") ||
30 node->kind() == c10::Symbol::fromQualString("aten::dropout_") ||
31 node->kind() == c10::Symbol::fromQualString("aten::feature_dropout") ||
32 node->kind() ==
33 c10::Symbol::fromQualString("aten::feature_dropout_")) &&
34 isDropoutRemovable(*it)) {
35 // Input tensor of dropout.
36 Value* input_value = node->inputs()[0];
37 // Output tensor.
38 Value* output_value = node->outputs()[0];
39 output_value->replaceAllUsesWith(input_value);
40 deleted_nodes.push_back(node);
41 }
42 }
43 for (auto del_node : deleted_nodes) {
44 del_node->destroy();
45 }
46}
47} // namespace
48
49void removeDropout(std::shared_ptr<Graph>& graph) {
50 removeDropoutImpl(graph->block());
51}
52
53void removeDropout(script::Module& module) {
54 TORCH_CHECK(
55 !module.hasattr("training") || !module.is_training(),
56 "Dropout removal module in training mode is not yet supported");
57 auto graph = module.get_method("forward").graph();
58 removeDropout(graph);
59}
60
61} // namespace jit
62} // namespace torch
63