1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <ir_iostream.h> |
6 | #include <iter_visitor.h> |
7 | |
8 | #include <iostream> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | //! Prints computation Fusion IR nodes |
16 | //! |
17 | //! IrMathPrinter and IrTransformPrinter allow the splitting up of fusion print |
18 | //! functions. IrMathPrinter as its name implies focuses solely on what tensor |
19 | //! computations are taking place. Resulting TensorView math will reflect the |
20 | //! series of split/merge/computeAts that have taken place, however these |
21 | //! nodes will not be displayed in what is printed. IrTransformPrinter does not |
22 | //! print any mathematical functions and only lists the series of |
23 | //! split/merge calls that were made. Both of these printing methods are |
24 | //! quite verbose on purpose as to show accurately what is represented in the IR |
25 | //! of a fusion. |
26 | // |
27 | //! \sa IrTransformPrinter |
28 | //! |
29 | class TORCH_CUDA_CU_API IrMathPrinter : public IrPrinter { |
30 | public: |
31 | IrMathPrinter(std::ostream& os) : IrPrinter(os) {} |
32 | |
33 | void handle(const Split* const) override {} |
34 | void handle(const Merge* const) override {} |
35 | void handle(const Swizzle2D* const) override {} |
36 | |
37 | void handle(Fusion* f) override { |
38 | IrPrinter::handle(f); |
39 | } |
40 | }; |
41 | |
42 | //! Prints transformation (schedule) Fusion IR nodes |
43 | //! |
44 | //! \sa IrMathPrinter |
45 | //! |
46 | class TORCH_CUDA_CU_API IrTransformPrinter : public IrPrinter { |
47 | public: |
48 | IrTransformPrinter(std::ostream& os) : IrPrinter(os) {} |
49 | |
50 | void handle(Fusion* f) override; |
51 | |
52 | private: |
53 | void printTransforms(TensorView* tv); |
54 | }; |
55 | |
56 | } // namespace cuda |
57 | } // namespace fuser |
58 | } // namespace jit |
59 | } // namespace torch |
60 | |