1#include <gtest/gtest.h>
2
3#include "test/cpp/jit/test_utils.h"
4
5#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
6
7namespace torch {
8namespace jit {
9
10TEST(CreateAutodiffSubgraphsTest, Basic) {
11 auto graph = build_lstm();
12 CreateAutodiffSubgraphs(graph, /*threshold=*/2);
13 // all of the ops are within the DifferentiableGraph
14 testing::FileCheck()
15 .check_not("aten::mm")
16 ->check_not("aten::sigmoid")
17 ->check_not("aten::tanh")
18 ->check_not("aten::mul")
19 ->check("DifferentiableGraph")
20 ->check_next("return")
21 ->run(*graph);
22}
23
24} // namespace jit
25} // namespace torch
26