1#pragma once
2#include <torch/csrc/Export.h>
3#include <torch/csrc/jit/api/module.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <iostream>
6#include <vector>
7
8namespace torch {
9namespace jit {
10
11struct Method;
12struct Module;
13struct PythonPrintImpl;
14
15struct PrintDepsTable {
16 void add(const c10::NamedTypePtr& type);
17
18 size_t size() const {
19 return table_.size();
20 }
21
22 const c10::NamedTypePtr& operator[](size_t index) const {
23 return table_[index];
24 }
25
26 private:
27 std::vector<c10::NamedTypePtr> table_;
28 std::unordered_set<c10::NamedTypePtr> non_unique_;
29};
30
31struct TORCH_API PythonPrint {
32 PythonPrint(
33 std::vector<IValue>& constant_table,
34 PrintDepsTable& deps_table,
35 c10::TypePrinter type_printer = nullptr,
36 bool enforce_importable = false);
37
38 void printNamedType(const c10::NamedTypePtr& classType);
39 void printFunction(const Function& callee);
40 void printMethod(const Function& callee);
41
42 std::string str() const;
43 const SourceRangeRecords& ranges() const;
44 uint64_t minVersion() const;
45
46 ~PythonPrint();
47
48 private:
49 std::shared_ptr<PythonPrintImpl> pImpl;
50};
51
52TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
53
54TORCH_API void jitModuleToPythonCodeAndConstants(
55 const Module& module,
56 ExtraFilesMap* jit_sources, // output
57 std::vector<IValue>* constants // output
58);
59
60} // namespace jit
61} // namespace torch
62