1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24/*! \brief The union of design space generators. */
25class SpaceGeneratorUnionNode : public SpaceGeneratorNode {
26 public:
27 /*! \brief The array of design space generators unioned, could be recursive. */
28 Array<SpaceGenerator> space_generators;
29
30 void VisitAttrs(tvm::AttrVisitor* v) {
31 SpaceGeneratorNode::VisitAttrs(v);
32 v->Visit("space_generators", &space_generators);
33 }
34
35 void InitializeWithTuneContext(const TuneContext& context) final {
36 SpaceGeneratorNode::InitializeWithTuneContext(context);
37 for (const SpaceGenerator& space_generator : space_generators) {
38 space_generator->InitializeWithTuneContext(context);
39 }
40 }
41
42 Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
43 Array<tir::Schedule> design_spaces;
44 for (const SpaceGenerator& space_generator : space_generators) {
45 // Generate partial design spaces from each design space generator.
46 Array<tir::Schedule> partial = space_generator->GenerateDesignSpace(mod);
47 // Merge the partial design spaces.
48 design_spaces.insert(design_spaces.end(), partial.begin(), partial.end());
49 }
50 return design_spaces;
51 }
52
53 SpaceGenerator Clone() const final {
54 ObjectPtr<SpaceGeneratorUnionNode> n = make_object<SpaceGeneratorUnionNode>(*this);
55 n->space_generators = Array<SpaceGenerator>();
56 for (const SpaceGenerator& space_generator : this->space_generators) {
57 n->space_generators.push_back(space_generator->Clone());
58 }
59 CloneRules(this, n.get());
60 return SpaceGenerator(n);
61 }
62
63 static constexpr const char* _type_key = "meta_schedule.SpaceGeneratorUnion";
64 TVM_DECLARE_FINAL_OBJECT_INFO(SpaceGeneratorUnionNode, SpaceGeneratorNode);
65};
66
67/*!
68 * \brief Create a design space generator as union of given design space generators.
69 * \param space_generators Array of the design space generators to be unioned.
70 * \return The design space generator created.
71 */
72SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array<SpaceGenerator> space_generators,
73 Optional<Array<ScheduleRule>> sch_rules,
74 Optional<Array<Postproc>> postprocs,
75 Optional<Map<Mutator, FloatImm>> mutator_probs) {
76 ObjectPtr<SpaceGeneratorUnionNode> n = make_object<SpaceGeneratorUnionNode>();
77 n->sch_rules = std::move(sch_rules);
78 n->postprocs = std::move(postprocs);
79 n->mutator_probs = std::move(mutator_probs);
80 n->space_generators = std::move(space_generators);
81 return SpaceGenerator(n);
82}
83
84TVM_REGISTER_NODE_TYPE(SpaceGeneratorUnionNode);
85TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorSpaceGeneratorUnion")
86 .set_body_typed(SpaceGenerator::SpaceGeneratorUnion);
87
88} // namespace meta_schedule
89} // namespace tvm
90