1 | #pragma once |
---|---|
2 | #include <torch/csrc/Export.h> |
3 | #include <torch/csrc/jit/api/module.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <iostream> |
6 | #include <vector> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | struct Method; |
12 | struct Module; |
13 | struct PythonPrintImpl; |
14 | |
15 | struct PrintDepsTable { |
16 | void add(const c10::NamedTypePtr& type); |
17 | |
18 | size_t size() const { |
19 | return table_.size(); |
20 | } |
21 | |
22 | const c10::NamedTypePtr& operator[](size_t index) const { |
23 | return table_[index]; |
24 | } |
25 | |
26 | private: |
27 | std::vector<c10::NamedTypePtr> table_; |
28 | std::unordered_set<c10::NamedTypePtr> non_unique_; |
29 | }; |
30 | |
31 | struct TORCH_API PythonPrint { |
32 | PythonPrint( |
33 | std::vector<IValue>& constant_table, |
34 | PrintDepsTable& deps_table, |
35 | c10::TypePrinter type_printer = nullptr, |
36 | bool enforce_importable = false); |
37 | |
38 | void printNamedType(const c10::NamedTypePtr& classType); |
39 | void printFunction(const Function& callee); |
40 | void printMethod(const Function& callee); |
41 | |
42 | std::string str() const; |
43 | const SourceRangeRecords& ranges() const; |
44 | uint64_t minVersion() const; |
45 | |
46 | ~PythonPrint(); |
47 | |
48 | private: |
49 | std::shared_ptr<PythonPrintImpl> pImpl; |
50 | }; |
51 | |
52 | TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym); |
53 | |
54 | TORCH_API void jitModuleToPythonCodeAndConstants( |
55 | const Module& module, |
56 | ExtraFilesMap* jit_sources, // output |
57 | std::vector<IValue>* constants // output |
58 | ); |
59 | |
60 | } // namespace jit |
61 | } // namespace torch |
62 |