1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <utility>
20
21#include "./utils.h"
22
23namespace tvm {
24namespace meta_schedule {
25
26TuneContext::TuneContext(Optional<IRModule> mod, Optional<Target> target,
27 Optional<SpaceGenerator> space_generator,
28 Optional<SearchStrategy> search_strategy, Optional<String> task_name,
29 int num_threads, TRandState rand_state, PackedFunc logger) {
30 CHECK(rand_state == -1 || rand_state >= 0) << "ValueError: Invalid random state: " << rand_state;
31 ObjectPtr<TuneContextNode> n = make_object<TuneContextNode>();
32 n->mod = mod;
33 n->target = target;
34 n->space_generator = space_generator;
35 n->search_strategy = search_strategy;
36 n->task_name = task_name;
37 n->num_threads = num_threads;
38 n->rand_state = support::LinearCongruentialEngine::NormalizeSeed(rand_state);
39 n->logger = logger;
40 data_ = std::move(n);
41}
42
43TuneContext TuneContextNode::Clone() const {
44 ObjectPtr<TuneContextNode> n = make_object<TuneContextNode>(*this);
45 if (this->space_generator.defined()) {
46 n->space_generator = this->space_generator.value()->Clone();
47 }
48 if (this->search_strategy.defined()) {
49 n->search_strategy = this->search_strategy.value()->Clone();
50 }
51 n->rand_state = ForkSeed(&n->rand_state);
52 n->Initialize();
53 return TuneContext(n);
54}
55
56void TuneContextNode::Initialize() {
57 if (this->space_generator.defined()) {
58 this->space_generator.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));
59 }
60 if (this->search_strategy.defined()) {
61 this->search_strategy.value()->InitializeWithTuneContext(GetRef<TuneContext>(this));
62 }
63}
64
65TVM_REGISTER_NODE_TYPE(TuneContextNode);
66TVM_REGISTER_GLOBAL("meta_schedule.TuneContext")
67 .set_body_typed([](Optional<IRModule> mod, Optional<Target> target,
68 Optional<SpaceGenerator> space_generator,
69 Optional<SearchStrategy> search_strategy, Optional<String> task_name,
70 int num_threads, TRandState rand_state, PackedFunc logger) -> TuneContext {
71 return TuneContext(mod, target, space_generator, search_strategy, task_name, num_threads,
72 rand_state, logger);
73 });
74TVM_REGISTER_GLOBAL("meta_schedule._SHash2Hex").set_body_typed(SHash2Hex);
75TVM_REGISTER_GLOBAL("meta_schedule.TuneContextInitialize")
76 .set_body_method<TuneContext>(&TuneContextNode::Initialize);
77TVM_REGISTER_GLOBAL("meta_schedule.TuneContextClone")
78 .set_body_method<TuneContext>(&TuneContextNode::Clone);
79
80} // namespace meta_schedule
81} // namespace tvm
82