1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <torch/csrc/lazy/core/config.h> |
5 | #include <torch/csrc/lazy/core/ir.h> |
6 | #include <torch/csrc/lazy/core/ir_builder.h> |
7 | #include <torch/csrc/lazy/core/ir_metadata.h> |
8 | #include <torch/csrc/lazy/core/ir_util.h> |
9 | |
10 | namespace torch { |
11 | namespace lazy { |
12 | |
13 | class IrUtilNode : public Node { |
14 | public: |
15 | explicit IrUtilNode() : Node(OpKind(), /* num_outputs */ 1), hash_(Hash(0)) {} |
16 | ~IrUtilNode() override = default; |
17 | |
18 | void AddOperand(Value v) { |
19 | if (!v.node) { |
20 | return; |
21 | } |
22 | operands_as_outputs_.emplace_back(v.node.get(), v.index); |
23 | operands_.push_back(std::move(v.node)); |
24 | } |
25 | |
26 | hash_t hash() const override { |
27 | return hash_; |
28 | } |
29 | hash_t shapeHash() const override { |
30 | return hash_; |
31 | } |
32 | |
33 | private: |
34 | hash_t hash_; |
35 | }; |
36 | |
37 | /* a |
38 | * / \ |
39 | *b c |
40 | * \ / |
41 | * d |
42 | * Post-order: d c b a |
43 | */ |
44 | TEST(IrUtilTest, BasicTest) { |
45 | NodePtr a = MakeNode<IrUtilNode>(); |
46 | NodePtr b = MakeNode<IrUtilNode>(); |
47 | NodePtr c = MakeNode<IrUtilNode>(); |
48 | NodePtr d = MakeNode<IrUtilNode>(); |
49 | |
50 | dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(b, 0)); |
51 | dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(c, 1)); |
52 | dynamic_cast<IrUtilNode*>(b.get())->AddOperand(Value(d, 0)); |
53 | dynamic_cast<IrUtilNode*>(c.get())->AddOperand(Value(d, 0)); |
54 | |
55 | auto postorder = Util::ComputePostOrder({a.get()}); |
56 | EXPECT_EQ(postorder.size(), 4); |
57 | EXPECT_EQ(postorder.at(0), d.get()); |
58 | EXPECT_EQ(postorder.at(1), c.get()); |
59 | EXPECT_EQ(postorder.at(2), b.get()); |
60 | EXPECT_EQ(postorder.at(3), a.get()); |
61 | } |
62 | |
63 | /* a |
64 | * / \ |
65 | *b---c |
66 | * Post-order: not valid |
67 | */ |
68 | TEST(IrUtilTest, TestCircle) { |
69 | NodePtr a = MakeNode<IrUtilNode>(); |
70 | NodePtr b = MakeNode<IrUtilNode>(); |
71 | NodePtr c = MakeNode<IrUtilNode>(); |
72 | |
73 | dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(b, 0)); |
74 | dynamic_cast<IrUtilNode*>(b.get())->AddOperand(Value(c, 0)); |
75 | dynamic_cast<IrUtilNode*>(c.get())->AddOperand(Value(a, 0)); |
76 | |
77 | EXPECT_THROW(Util::ComputePostOrder({a.get()}), c10::Error); |
78 | } |
79 | |
80 | } // namespace lazy |
81 | } // namespace torch |
82 | |