1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | RunnerInput::RunnerInput(String artifact_path, String device_type, Array<ArgInfo> args_info) { |
25 | ObjectPtr<RunnerInputNode> n = make_object<RunnerInputNode>(); |
26 | n->artifact_path = artifact_path; |
27 | n->device_type = device_type; |
28 | n->args_info = args_info; |
29 | this->data_ = n; |
30 | } |
31 | |
32 | RunnerResult::RunnerResult(Optional<Array<FloatImm>> run_secs, Optional<String> error_msg) { |
33 | ObjectPtr<RunnerResultNode> n = make_object<RunnerResultNode>(); |
34 | n->run_secs = run_secs; |
35 | n->error_msg = error_msg; |
36 | this->data_ = n; |
37 | } |
38 | |
39 | RunnerFuture::RunnerFuture(RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) { |
40 | ObjectPtr<RunnerFutureNode> n = make_object<RunnerFutureNode>(); |
41 | n->f_done = f_done; |
42 | n->f_result = f_result; |
43 | this->data_ = n; |
44 | } |
45 | |
46 | Runner Runner::PyRunner(Runner::FRun f_run) { |
47 | ObjectPtr<PyRunnerNode> n = make_object<PyRunnerNode>(); |
48 | n->f_run = f_run; |
49 | return Runner(n); |
50 | } |
51 | |
52 | /******** FFI ********/ |
53 | |
54 | TVM_REGISTER_NODE_TYPE(RunnerInputNode); |
55 | TVM_REGISTER_NODE_TYPE(RunnerResultNode); |
56 | TVM_REGISTER_NODE_TYPE(RunnerFutureNode); |
57 | TVM_REGISTER_OBJECT_TYPE(RunnerNode); |
58 | TVM_REGISTER_NODE_TYPE(PyRunnerNode); |
59 | TVM_REGISTER_GLOBAL("meta_schedule.RunnerInput" ) |
60 | .set_body_typed([](String artifact_path, String device_type, |
61 | Array<ArgInfo> args_info) -> RunnerInput { |
62 | return RunnerInput(artifact_path, device_type, args_info); |
63 | }); |
64 | TVM_REGISTER_GLOBAL("meta_schedule.RunnerResult" ) |
65 | .set_body_typed([](Array<FloatImm> run_secs, Optional<String> error_msg) -> RunnerResult { |
66 | return RunnerResult(run_secs, error_msg); |
67 | }); |
68 | TVM_REGISTER_GLOBAL("meta_schedule.RunnerFuture" ) |
69 | .set_body_typed([](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture { |
70 | return RunnerFuture(f_done, f_result); |
71 | }); |
72 | TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureDone" ) |
73 | .set_body_method<RunnerFuture>(&RunnerFutureNode::Done); |
74 | TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureResult" ) |
75 | .set_body_method<RunnerFuture>(&RunnerFutureNode::Result); |
76 | TVM_REGISTER_GLOBAL("meta_schedule.RunnerRun" ).set_body_method<Runner>(&RunnerNode::Run); |
77 | TVM_REGISTER_GLOBAL("meta_schedule.RunnerPyRunner" ).set_body_typed(Runner::PyRunner); |
78 | |
79 | } // namespace meta_schedule |
80 | } // namespace tvm |
81 | |