1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "./utils.h"
20
21namespace tvm {
22namespace tir {
23
24bool InstructionKindNode::IsPostproc() const {
25 static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc");
26 return this == inst_enter_postproc.get();
27}
28
29Instruction::Instruction(InstructionKind kind, Array<ObjectRef> inputs, Array<ObjectRef> attrs,
30 Array<ObjectRef> outputs) {
31 ObjectPtr<InstructionNode> n = make_object<InstructionNode>();
32 n->kind = std::move(kind);
33 n->inputs = std::move(inputs);
34 n->attrs = std::move(attrs);
35 n->outputs = std::move(outputs);
36 this->data_ = std::move(n);
37}
38
39using InstructionKindRegistry = AttrRegistry<InstructionKindRegEntry, InstructionKind>;
40
41InstructionKind InstructionKind::Get(const String& name) {
42 const InstructionKindRegEntry* reg = InstructionKindRegistry::Global()->Get(name);
43 ICHECK(reg != nullptr) << "AttributeError: Instruction kind " << name << " is not registered";
44 return reg->inst_kind_;
45}
46
47InstructionKindRegEntry::InstructionKindRegEntry(uint32_t reg_index) {
48 this->inst_kind_ = InstructionKind(make_object<InstructionKindNode>());
49}
50
51InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const String& name) {
52 return InstructionKindRegistry::Global()->RegisterOrGet(name);
53}
54
55/**************** Repr ****************/
56
57TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
58 .set_dispatch<InstructionNode>([](const ObjectRef& obj, ReprPrinter* p) {
59 const auto* self = obj.as<InstructionNode>();
60 ICHECK_NOTNULL(self);
61 Array<ObjectRef> inputs;
62 inputs.reserve(self->inputs.size());
63 for (const ObjectRef& obj : self->inputs) {
64 if (!obj.defined()) {
65 inputs.push_back(String("None"));
66 } else if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>()) {
67 inputs.push_back(String("_"));
68 } else if (const auto* str_obj = obj.as<StringObj>()) {
69 inputs.push_back(String('"' + std::string(str_obj->data) + '"'));
70 } else if (obj->IsInstance<IntImmNode>() || obj->IsInstance<FloatImmNode>()) {
71 inputs.push_back(obj);
72 } else if (const auto* expr = obj.as<PrimExprNode>()) {
73 PrimExpr new_expr =
74 Substitute(GetRef<PrimExpr>(expr), [](const Var& var) -> Optional<PrimExpr> {
75 ObjectPtr<VarNode> new_var = make_object<VarNode>(*var.get());
76 new_var->name_hint = "_";
77 return Var(new_var);
78 });
79 std::ostringstream os;
80 os << new_expr;
81 inputs.push_back(String(os.str()));
82 } else if (obj.as<IndexMapNode>()) {
83 inputs.push_back(obj);
84 } else {
85 LOG(FATAL) << "TypeError: Stringifying is not supported for type: " << obj->GetTypeKey();
86 throw;
87 }
88 }
89 p->stream << self->kind->f_as_python(
90 /*inputs=*/inputs,
91 /*attrs=*/self->attrs,
92 /*decision=*/NullOpt,
93 /*outputs=*/Array<String>(self->outputs.size(), String("_")));
94 });
95
96/**************** FFI ****************/
97
98TVM_REGISTER_NODE_TYPE(InstructionNode);
99TVM_REGISTER_NODE_TYPE(InstructionKindNode);
100
101TVM_REGISTER_GLOBAL("tir.schedule.InstructionKindGet").set_body_typed(InstructionKind::Get);
102TVM_REGISTER_GLOBAL("tir.schedule.Instruction")
103 .set_body_typed([](InstructionKind kind, Array<ObjectRef> inputs, Array<ObjectRef> attrs,
104 Array<ObjectRef> outputs) -> Instruction {
105 return Instruction(kind, inputs, attrs, outputs);
106 });
107
108} // namespace tir
109} // namespace tvm
110