1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/target/target.h> |
20 | |
21 | #include "./utils.h" |
22 | |
23 | namespace tvm { |
24 | namespace script { |
25 | namespace printer { |
26 | |
27 | TVM_REGISTER_NODE_TYPE(TIRFrameNode); |
28 | |
29 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
30 | .set_dispatch<IntImm>("" , [](IntImm imm, ObjectPath imm_p, IRDocsifier d) -> Doc { |
31 | DataType dtype = imm->dtype; |
32 | if (dtype == d->cfg->int_dtype) { |
33 | return LiteralDoc::Int(imm->value, imm_p->Attr("value" )); |
34 | } else if (dtype == DataType::Bool()) { |
35 | return LiteralDoc::Boolean(imm->value, imm_p->Attr("value" )); |
36 | } else { |
37 | return TIR(d, DType2Str(dtype))->Call({LiteralDoc::Int(imm->value, imm_p->Attr("value" ))}); |
38 | } |
39 | }); |
40 | |
41 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
42 | .set_dispatch<FloatImm>("" , [](FloatImm imm, ObjectPath imm_p, IRDocsifier d) -> Doc { |
43 | DataType dtype = imm->dtype; |
44 | if (dtype == d->cfg->float_dtype) { |
45 | return LiteralDoc::Float(imm->value, imm_p->Attr("value" )); |
46 | } else { |
47 | return TIR(d, DType2Str(dtype)) |
48 | ->Call({LiteralDoc::Float(imm->value, imm_p->Attr("value" ))}); |
49 | } |
50 | }); |
51 | |
52 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
53 | .set_dispatch<Range>("" , [](Range range, ObjectPath p, IRDocsifier d) -> Doc { |
54 | return TIR(d, "Range" ) |
55 | ->Call({ |
56 | d->AsDoc<ExprDoc>(range->min, p->Attr("min" )), |
57 | d->AsDoc<ExprDoc>(range->extent, p->Attr("extent" )), |
58 | }); |
59 | }); |
60 | |
61 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
62 | .set_dispatch<PrimType>("" , [](PrimType ty, ObjectPath p, IRDocsifier d) -> Doc { |
63 | return TIR(d, DType2Str(ty->dtype)); |
64 | }); |
65 | |
66 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
67 | .set_dispatch<PointerType>("" , [](PointerType ty, ObjectPath ty_p, IRDocsifier d) -> Doc { |
68 | ExprDoc element_type{nullptr}; |
69 | if (const auto* prim_type = ty->element_type.as<PrimTypeNode>()) { |
70 | element_type = LiteralDoc::DataType(prim_type->dtype, // |
71 | ty_p->Attr("element_type" )->Attr("dtype" )); |
72 | } else { |
73 | element_type = d->AsDoc<ExprDoc>(ty->element_type, ty_p->Attr("element_type" )); |
74 | } |
75 | if (ty->storage_scope == "" ) { |
76 | return TIR(d, "Ptr" )->Call({element_type}); |
77 | } else { |
78 | return TIR(d, "Ptr" )->Call( |
79 | {element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope" ))}); |
80 | } |
81 | }); |
82 | |
83 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
84 | .set_dispatch<TupleType>("" , [](TupleType ty, ObjectPath p, IRDocsifier d) -> Doc { |
85 | if (ty->fields.empty()) { |
86 | return LiteralDoc::None(p); |
87 | } |
88 | return TIR(d, "Tuple" )->Call(d->AsDoc<ListDoc>(ty->fields, p->Attr("fields" ))->elements); |
89 | }); |
90 | |
91 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
92 | .set_dispatch<Target>("" , [](Target target, ObjectPath p, IRDocsifier d) -> Doc { |
93 | Map<String, ObjectRef> config = target->Export(); |
94 | return TIR(d, "target" )->Call({d->AsDoc<ExprDoc>(config, p)}); |
95 | }); |
96 | |
97 | TVM_SCRIPT_REPR(IntImmNode, ReprPrintTIR); |
98 | TVM_SCRIPT_REPR(FloatImmNode, ReprPrintTIR); |
99 | TVM_SCRIPT_REPR(RangeNode, ReprPrintTIR); |
100 | TVM_SCRIPT_REPR(PrimTypeNode, ReprPrintTIR); |
101 | TVM_SCRIPT_REPR(PointerTypeNode, ReprPrintTIR); |
102 | TVM_SCRIPT_REPR(TupleTypeNode, ReprPrintTIR); |
103 | |
104 | } // namespace printer |
105 | } // namespace script |
106 | } // namespace tvm |
107 | |