1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * FFI registration code used for frontend testing purposes. |
22 | * \file ffi_testing.cc |
23 | */ |
24 | #include <tvm/ir/attrs.h> |
25 | #include <tvm/ir/env_func.h> |
26 | #include <tvm/runtime/module.h> |
27 | #include <tvm/runtime/registry.h> |
28 | #include <tvm/te/tensor.h> |
29 | #include <tvm/tir/expr.h> |
30 | |
31 | #include <chrono> |
32 | #include <thread> |
33 | |
34 | namespace tvm { |
35 | // Attrs used to python API |
36 | struct TestAttrs : public AttrsNode<TestAttrs> { |
37 | int axis; |
38 | String name; |
39 | Array<PrimExpr> padding; |
40 | TypedEnvFunc<int(int)> func; |
41 | |
42 | TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs" ) { |
43 | TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe( |
44 | "axis field" ); |
45 | TVM_ATTR_FIELD(name).describe("name" ); |
46 | TVM_ATTR_FIELD(padding).describe("padding of input" ).set_default(Array<PrimExpr>({0, 0})); |
47 | TVM_ATTR_FIELD(func) |
48 | .describe("some random env function" ) |
49 | .set_default(TypedEnvFunc<int(int)>(nullptr)); |
50 | } |
51 | }; |
52 | |
53 | TVM_REGISTER_NODE_TYPE(TestAttrs); |
54 | |
55 | TVM_REGISTER_GLOBAL("testing.nop" ).set_body([](TVMArgs args, TVMRetValue* ret) {}); |
56 | |
57 | TVM_REGISTER_GLOBAL("testing.echo" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
58 | *ret = args[0]; |
59 | }); |
60 | |
61 | TVM_REGISTER_GLOBAL("testing.test_wrap_callback" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
62 | PackedFunc pf = args[0]; |
63 | *ret = runtime::TypedPackedFunc<void()>([pf]() { pf(); }); |
64 | }); |
65 | |
66 | TVM_REGISTER_GLOBAL("testing.test_raise_error_callback" ) |
67 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
68 | std::string msg = args[0]; |
69 | *ret = runtime::TypedPackedFunc<void()>([msg]() { LOG(FATAL) << msg; }); |
70 | }); |
71 | |
72 | TVM_REGISTER_GLOBAL("testing.test_check_eq_callback" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
73 | std::string msg = args[0]; |
74 | *ret = |
75 | runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y) { CHECK_EQ(x, y) << msg; }); |
76 | }); |
77 | |
78 | TVM_REGISTER_GLOBAL("testing.device_test" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
79 | Device dev = args[0]; |
80 | int dtype = args[1]; |
81 | int did = args[2]; |
82 | CHECK_EQ(static_cast<int>(dev.device_type), dtype); |
83 | CHECK_EQ(static_cast<int>(dev.device_id), did); |
84 | *ret = dev; |
85 | }); |
86 | |
87 | TVM_REGISTER_GLOBAL("testing.run_check_signal" ).set_body_typed([](int nsec) { |
88 | for (int i = 0; i < nsec; ++i) { |
89 | tvm::runtime::EnvCheckSignals(); |
90 | std::this_thread::sleep_for(std::chrono::seconds(1)); |
91 | } |
92 | LOG(INFO) << "Function finished without catching signal" ; |
93 | }); |
94 | |
95 | TVM_REGISTER_GLOBAL("testing.identity_cpp" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
96 | const auto* identity_func = tvm::runtime::Registry::Get("testing.identity_py" ); |
97 | ICHECK(identity_func != nullptr) |
98 | << "AttributeError: \"testing.identity_py\" is not registered. Please check " |
99 | "if the python module is properly loaded" ; |
100 | *ret = (*identity_func)(args[0]); |
101 | }); |
102 | |
103 | // in src/api_test.cc |
104 | void ErrorTest(int x, int y) { |
105 | // raise ValueError |
106 | CHECK_EQ(x, y) << "ValueError: expect x and y to be equal." ; |
107 | if (x == 1) { |
108 | // raise InternalError. |
109 | LOG(FATAL) << "InternalError: cannot reach here" ; |
110 | } |
111 | } |
112 | |
113 | TVM_REGISTER_GLOBAL("testing.ErrorTest" ).set_body_typed(ErrorTest); |
114 | |
115 | // internal function used for debug and testing purposes |
116 | TVM_REGISTER_GLOBAL("testing.object_use_count" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
117 | runtime::ObjectRef obj = args[0]; |
118 | // substract the current one because we always copy |
119 | // and get another value. |
120 | *ret = (obj.use_count() - 1); |
121 | }); |
122 | |
123 | class FrontendTestModuleNode : public runtime::ModuleNode { |
124 | public: |
125 | const char* type_key() const final { return "frontend_test" ; } |
126 | |
127 | static constexpr const char* kAddFunctionName = "__add_function" ; |
128 | |
129 | virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self); |
130 | |
131 | private: |
132 | std::unordered_map<std::string, PackedFunc> functions_; |
133 | }; |
134 | |
135 | constexpr const char* FrontendTestModuleNode::kAddFunctionName; |
136 | |
137 | PackedFunc FrontendTestModuleNode::GetFunction(const std::string& name, |
138 | const ObjectPtr<Object>& sptr_to_self) { |
139 | if (name == kAddFunctionName) { |
140 | return TypedPackedFunc<void(std::string, PackedFunc)>( |
141 | [this, sptr_to_self](std::string func_name, PackedFunc pf) { |
142 | CHECK_NE(func_name, kAddFunctionName) |
143 | << "func_name: cannot be special function " << kAddFunctionName; |
144 | functions_[func_name] = pf; |
145 | }); |
146 | } |
147 | |
148 | auto it = functions_.find(name); |
149 | if (it == functions_.end()) { |
150 | return PackedFunc(); |
151 | } |
152 | |
153 | return it->second; |
154 | } |
155 | |
156 | runtime::Module NewFrontendTestModule() { |
157 | auto n = make_object<FrontendTestModuleNode>(); |
158 | return runtime::Module(n); |
159 | } |
160 | |
161 | TVM_REGISTER_GLOBAL("testing.FrontendTestModule" ).set_body_typed(NewFrontendTestModule); |
162 | |
163 | TVM_REGISTER_GLOBAL("testing.sleep_in_ffi" ).set_body_typed([](double timeout) { |
164 | std::chrono::duration<int64_t, std::nano> duration(static_cast<int64_t>(timeout * 1e9)); |
165 | std::this_thread::sleep_for(duration); |
166 | }); |
167 | |
168 | } // namespace tvm |
169 | |