1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file env_func.cc
22 */
23#include <tvm/ir/env_func.h>
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/expr.h>
26
27namespace tvm {
28
29using runtime::PackedFunc;
30using runtime::TVMArgs;
31using runtime::TVMRetValue;
32
33TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
34 .set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
35 auto* op = static_cast<const EnvFuncNode*>(node.get());
36 p->stream << "EnvFunc(" << op->name << ")";
37 });
38
39ObjectPtr<Object> CreateEnvNode(const std::string& name) {
40 auto* f = runtime::Registry::Get(name);
41 ICHECK(f != nullptr) << "Cannot find global function \'" << name << '\'';
42 ObjectPtr<EnvFuncNode> n = make_object<EnvFuncNode>();
43 n->func = *f;
44 n->name = name;
45 return n;
46}
47
48EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); }
49
50TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get);
51
52TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body([](TVMArgs args, TVMRetValue* rv) {
53 EnvFunc env = args[0];
54 ICHECK_GE(args.size(), 1);
55 env->func.CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), rv);
56});
57
58TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc").set_body_typed([](const EnvFunc& n) {
59 return n->func;
60});
61
62TVM_REGISTER_NODE_TYPE(EnvFuncNode)
63 .set_creator(CreateEnvNode)
64 .set_repr_bytes([](const Object* n) -> std::string {
65 return static_cast<const EnvFuncNode*>(n)->name;
66 });
67
68} // namespace tvm
69