1#include <gtest/gtest.h>
2
3#include <c10/util/Exception.h>
4#include <torch/csrc/lazy/core/config.h>
5#include <torch/csrc/lazy/core/ir.h>
6#include <torch/csrc/lazy/core/ir_builder.h>
7#include <torch/csrc/lazy/core/ir_metadata.h>
8#include <torch/csrc/lazy/core/ir_util.h>
9
10namespace torch {
11namespace lazy {
12
13class IrUtilNode : public Node {
14 public:
15 explicit IrUtilNode() : Node(OpKind(), /* num_outputs */ 1), hash_(Hash(0)) {}
16 ~IrUtilNode() override = default;
17
18 void AddOperand(Value v) {
19 if (!v.node) {
20 return;
21 }
22 operands_as_outputs_.emplace_back(v.node.get(), v.index);
23 operands_.push_back(std::move(v.node));
24 }
25
26 hash_t hash() const override {
27 return hash_;
28 }
29 hash_t shapeHash() const override {
30 return hash_;
31 }
32
33 private:
34 hash_t hash_;
35};
36
37/* a
38 * / \
39 *b c
40 * \ /
41 * d
42 * Post-order: d c b a
43 */
44TEST(IrUtilTest, BasicTest) {
45 NodePtr a = MakeNode<IrUtilNode>();
46 NodePtr b = MakeNode<IrUtilNode>();
47 NodePtr c = MakeNode<IrUtilNode>();
48 NodePtr d = MakeNode<IrUtilNode>();
49
50 dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(b, 0));
51 dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(c, 1));
52 dynamic_cast<IrUtilNode*>(b.get())->AddOperand(Value(d, 0));
53 dynamic_cast<IrUtilNode*>(c.get())->AddOperand(Value(d, 0));
54
55 auto postorder = Util::ComputePostOrder({a.get()});
56 EXPECT_EQ(postorder.size(), 4);
57 EXPECT_EQ(postorder.at(0), d.get());
58 EXPECT_EQ(postorder.at(1), c.get());
59 EXPECT_EQ(postorder.at(2), b.get());
60 EXPECT_EQ(postorder.at(3), a.get());
61}
62
63/* a
64 * / \
65 *b---c
66 * Post-order: not valid
67 */
68TEST(IrUtilTest, TestCircle) {
69 NodePtr a = MakeNode<IrUtilNode>();
70 NodePtr b = MakeNode<IrUtilNode>();
71 NodePtr c = MakeNode<IrUtilNode>();
72
73 dynamic_cast<IrUtilNode*>(a.get())->AddOperand(Value(b, 0));
74 dynamic_cast<IrUtilNode*>(b.get())->AddOperand(Value(c, 0));
75 dynamic_cast<IrUtilNode*>(c.get())->AddOperand(Value(a, 0));
76
77 EXPECT_THROW(Util::ComputePostOrder({a.get()}), c10::Error);
78}
79
80} // namespace lazy
81} // namespace torch
82