1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | bool InstructionKindNode::IsPostproc() const { |
25 | static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc" ); |
26 | return this == inst_enter_postproc.get(); |
27 | } |
28 | |
29 | Instruction::Instruction(InstructionKind kind, Array<ObjectRef> inputs, Array<ObjectRef> attrs, |
30 | Array<ObjectRef> outputs) { |
31 | ObjectPtr<InstructionNode> n = make_object<InstructionNode>(); |
32 | n->kind = std::move(kind); |
33 | n->inputs = std::move(inputs); |
34 | n->attrs = std::move(attrs); |
35 | n->outputs = std::move(outputs); |
36 | this->data_ = std::move(n); |
37 | } |
38 | |
39 | using InstructionKindRegistry = AttrRegistry<InstructionKindRegEntry, InstructionKind>; |
40 | |
41 | InstructionKind InstructionKind::Get(const String& name) { |
42 | const InstructionKindRegEntry* reg = InstructionKindRegistry::Global()->Get(name); |
43 | ICHECK(reg != nullptr) << "AttributeError: Instruction kind " << name << " is not registered" ; |
44 | return reg->inst_kind_; |
45 | } |
46 | |
47 | InstructionKindRegEntry::InstructionKindRegEntry(uint32_t reg_index) { |
48 | this->inst_kind_ = InstructionKind(make_object<InstructionKindNode>()); |
49 | } |
50 | |
51 | InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const String& name) { |
52 | return InstructionKindRegistry::Global()->RegisterOrGet(name); |
53 | } |
54 | |
55 | /**************** Repr ****************/ |
56 | |
57 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
58 | .set_dispatch<InstructionNode>([](const ObjectRef& obj, ReprPrinter* p) { |
59 | const auto* self = obj.as<InstructionNode>(); |
60 | ICHECK_NOTNULL(self); |
61 | Array<ObjectRef> inputs; |
62 | inputs.reserve(self->inputs.size()); |
63 | for (const ObjectRef& obj : self->inputs) { |
64 | if (!obj.defined()) { |
65 | inputs.push_back(String("None" )); |
66 | } else if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>()) { |
67 | inputs.push_back(String("_" )); |
68 | } else if (const auto* str_obj = obj.as<StringObj>()) { |
69 | inputs.push_back(String('"' + std::string(str_obj->data) + '"')); |
70 | } else if (obj->IsInstance<IntImmNode>() || obj->IsInstance<FloatImmNode>()) { |
71 | inputs.push_back(obj); |
72 | } else if (const auto* expr = obj.as<PrimExprNode>()) { |
73 | PrimExpr new_expr = |
74 | Substitute(GetRef<PrimExpr>(expr), [](const Var& var) -> Optional<PrimExpr> { |
75 | ObjectPtr<VarNode> new_var = make_object<VarNode>(*var.get()); |
76 | new_var->name_hint = "_" ; |
77 | return Var(new_var); |
78 | }); |
79 | std::ostringstream os; |
80 | os << new_expr; |
81 | inputs.push_back(String(os.str())); |
82 | } else if (obj.as<IndexMapNode>()) { |
83 | inputs.push_back(obj); |
84 | } else { |
85 | LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << obj->GetTypeKey(); |
86 | throw; |
87 | } |
88 | } |
89 | p->stream << self->kind->f_as_python( |
90 | /*inputs=*/inputs, |
91 | /*attrs=*/self->attrs, |
92 | /*decision=*/NullOpt, |
93 | /*outputs=*/Array<String>(self->outputs.size(), String("_" ))); |
94 | }); |
95 | |
96 | /**************** FFI ****************/ |
97 | |
98 | TVM_REGISTER_NODE_TYPE(InstructionNode); |
99 | TVM_REGISTER_NODE_TYPE(InstructionKindNode); |
100 | |
101 | TVM_REGISTER_GLOBAL("tir.schedule.InstructionKindGet" ).set_body_typed(InstructionKind::Get); |
102 | TVM_REGISTER_GLOBAL("tir.schedule.Instruction" ) |
103 | .set_body_typed([](InstructionKind kind, Array<ObjectRef> inputs, Array<ObjectRef> attrs, |
104 | Array<ObjectRef> outputs) -> Instruction { |
105 | return Instruction(kind, inputs, attrs, outputs); |
106 | }); |
107 | |
108 | } // namespace tir |
109 | } // namespace tvm |
110 | |