1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <ir_iostream.h>
6#include <iter_visitor.h>
7
8#include <iostream>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15//! Prints computation Fusion IR nodes
16//!
17//! IrMathPrinter and IrTransformPrinter allow the splitting up of fusion print
18//! functions. IrMathPrinter as its name implies focuses solely on what tensor
19//! computations are taking place. Resulting TensorView math will reflect the
20//! series of split/merge/computeAts that have taken place, however these
21//! nodes will not be displayed in what is printed. IrTransformPrinter does not
22//! print any mathematical functions and only lists the series of
23//! split/merge calls that were made. Both of these printing methods are
24//! quite verbose on purpose as to show accurately what is represented in the IR
25//! of a fusion.
26//
27//! \sa IrTransformPrinter
28//!
29class TORCH_CUDA_CU_API IrMathPrinter : public IrPrinter {
30 public:
31 IrMathPrinter(std::ostream& os) : IrPrinter(os) {}
32
33 void handle(const Split* const) override {}
34 void handle(const Merge* const) override {}
35 void handle(const Swizzle2D* const) override {}
36
37 void handle(Fusion* f) override {
38 IrPrinter::handle(f);
39 }
40};
41
42//! Prints transformation (schedule) Fusion IR nodes
43//!
44//! \sa IrMathPrinter
45//!
46class TORCH_CUDA_CU_API IrTransformPrinter : public IrPrinter {
47 public:
48 IrTransformPrinter(std::ostream& os) : IrPrinter(os) {}
49
50 void handle(Fusion* f) override;
51
52 private:
53 void printTransforms(TensorView* tv);
54};
55
56} // namespace cuda
57} // namespace fuser
58} // namespace jit
59} // namespace torch
60