1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * FFI registration code used for frontend testing purposes.
22 * \file ffi_testing.cc
23 */
24#include <tvm/ir/attrs.h>
25#include <tvm/ir/env_func.h>
26#include <tvm/runtime/module.h>
27#include <tvm/runtime/registry.h>
28#include <tvm/te/tensor.h>
29#include <tvm/tir/expr.h>
30
31#include <chrono>
32#include <thread>
33
34namespace tvm {
35// Attrs used to python API
36struct TestAttrs : public AttrsNode<TestAttrs> {
37 int axis;
38 String name;
39 Array<PrimExpr> padding;
40 TypedEnvFunc<int(int)> func;
41
42 TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
43 TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe(
44 "axis field");
45 TVM_ATTR_FIELD(name).describe("name");
46 TVM_ATTR_FIELD(padding).describe("padding of input").set_default(Array<PrimExpr>({0, 0}));
47 TVM_ATTR_FIELD(func)
48 .describe("some random env function")
49 .set_default(TypedEnvFunc<int(int)>(nullptr));
50 }
51};
52
53TVM_REGISTER_NODE_TYPE(TestAttrs);
54
55TVM_REGISTER_GLOBAL("testing.nop").set_body([](TVMArgs args, TVMRetValue* ret) {});
56
57TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) {
58 *ret = args[0];
59});
60
61TVM_REGISTER_GLOBAL("testing.test_wrap_callback").set_body([](TVMArgs args, TVMRetValue* ret) {
62 PackedFunc pf = args[0];
63 *ret = runtime::TypedPackedFunc<void()>([pf]() { pf(); });
64});
65
66TVM_REGISTER_GLOBAL("testing.test_raise_error_callback")
67 .set_body([](TVMArgs args, TVMRetValue* ret) {
68 std::string msg = args[0];
69 *ret = runtime::TypedPackedFunc<void()>([msg]() { LOG(FATAL) << msg; });
70 });
71
72TVM_REGISTER_GLOBAL("testing.test_check_eq_callback").set_body([](TVMArgs args, TVMRetValue* ret) {
73 std::string msg = args[0];
74 *ret =
75 runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y) { CHECK_EQ(x, y) << msg; });
76});
77
78TVM_REGISTER_GLOBAL("testing.device_test").set_body([](TVMArgs args, TVMRetValue* ret) {
79 Device dev = args[0];
80 int dtype = args[1];
81 int did = args[2];
82 CHECK_EQ(static_cast<int>(dev.device_type), dtype);
83 CHECK_EQ(static_cast<int>(dev.device_id), did);
84 *ret = dev;
85});
86
87TVM_REGISTER_GLOBAL("testing.run_check_signal").set_body_typed([](int nsec) {
88 for (int i = 0; i < nsec; ++i) {
89 tvm::runtime::EnvCheckSignals();
90 std::this_thread::sleep_for(std::chrono::seconds(1));
91 }
92 LOG(INFO) << "Function finished without catching signal";
93});
94
95TVM_REGISTER_GLOBAL("testing.identity_cpp").set_body([](TVMArgs args, TVMRetValue* ret) {
96 const auto* identity_func = tvm::runtime::Registry::Get("testing.identity_py");
97 ICHECK(identity_func != nullptr)
98 << "AttributeError: \"testing.identity_py\" is not registered. Please check "
99 "if the python module is properly loaded";
100 *ret = (*identity_func)(args[0]);
101});
102
103// in src/api_test.cc
104void ErrorTest(int x, int y) {
105 // raise ValueError
106 CHECK_EQ(x, y) << "ValueError: expect x and y to be equal.";
107 if (x == 1) {
108 // raise InternalError.
109 LOG(FATAL) << "InternalError: cannot reach here";
110 }
111}
112
113TVM_REGISTER_GLOBAL("testing.ErrorTest").set_body_typed(ErrorTest);
114
115// internal function used for debug and testing purposes
116TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRetValue* ret) {
117 runtime::ObjectRef obj = args[0];
118 // substract the current one because we always copy
119 // and get another value.
120 *ret = (obj.use_count() - 1);
121});
122
123class FrontendTestModuleNode : public runtime::ModuleNode {
124 public:
125 const char* type_key() const final { return "frontend_test"; }
126
127 static constexpr const char* kAddFunctionName = "__add_function";
128
129 virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
130
131 private:
132 std::unordered_map<std::string, PackedFunc> functions_;
133};
134
135constexpr const char* FrontendTestModuleNode::kAddFunctionName;
136
137PackedFunc FrontendTestModuleNode::GetFunction(const std::string& name,
138 const ObjectPtr<Object>& sptr_to_self) {
139 if (name == kAddFunctionName) {
140 return TypedPackedFunc<void(std::string, PackedFunc)>(
141 [this, sptr_to_self](std::string func_name, PackedFunc pf) {
142 CHECK_NE(func_name, kAddFunctionName)
143 << "func_name: cannot be special function " << kAddFunctionName;
144 functions_[func_name] = pf;
145 });
146 }
147
148 auto it = functions_.find(name);
149 if (it == functions_.end()) {
150 return PackedFunc();
151 }
152
153 return it->second;
154}
155
156runtime::Module NewFrontendTestModule() {
157 auto n = make_object<FrontendTestModuleNode>();
158 return runtime::Module(n);
159}
160
161TVM_REGISTER_GLOBAL("testing.FrontendTestModule").set_body_typed(NewFrontendTestModule);
162
163TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) {
164 std::chrono::duration<int64_t, std::nano> duration(static_cast<int64_t>(timeout * 1e9));
165 std::this_thread::sleep_for(duration);
166});
167
168} // namespace tvm
169