1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include <tvm/runtime/module.h> |
21 | #include <tvm/runtime/registry.h> |
22 | #include <tvm/tir/var.h> |
23 | |
24 | #include <string> |
25 | |
26 | #include "text_printer.h" |
27 | |
28 | namespace tvm { |
29 | namespace relay { |
30 | |
31 | class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode { |
32 | public: |
33 | ModelLibraryFormatPrinter(bool show_meta_data, |
34 | const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate, |
35 | bool show_warning) |
36 | : text_printer_{show_meta_data, annotate, show_warning} {} |
37 | |
38 | const char* type_key() const final { return "model_library_format_printer" ; } |
39 | |
40 | std::string Print(const ObjectRef& node) { |
41 | std::ostringstream oss; |
42 | oss << node; |
43 | return oss.str(); |
44 | } |
45 | |
46 | TVMRetValue GetVarName(tir::Var var) { |
47 | TVMRetValue rv; |
48 | std::string var_name; |
49 | if (text_printer_.GetVarName(var, &var_name)) { |
50 | rv = var_name; |
51 | } |
52 | |
53 | return rv; |
54 | } |
55 | |
56 | PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) override { |
57 | if (name == "print" ) { |
58 | return TypedPackedFunc<std::string(ObjectRef)>( |
59 | [sptr_to_self, this](ObjectRef node) { return Print(node); }); |
60 | } else if (name == "get_var_name" ) { |
61 | return TypedPackedFunc<TVMRetValue(tir::Var)>( |
62 | [sptr_to_self, this](tir::Var var) { return GetVarName(var); }); |
63 | } else { |
64 | return PackedFunc(); |
65 | } |
66 | } |
67 | |
68 | private: |
69 | TextPrinter text_printer_; |
70 | }; |
71 | |
72 | TVM_REGISTER_GLOBAL("relay.ir.ModelLibraryFormatPrinter" ) |
73 | .set_body_typed([](bool show_meta_data, |
74 | const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate, |
75 | bool show_warning) { |
76 | return ObjectRef( |
77 | make_object<ModelLibraryFormatPrinter>(show_meta_data, annotate, show_warning)); |
78 | }); |
79 | |
80 | } // namespace relay |
81 | } // namespace tvm |
82 | |