1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "proposal.h" |
20 | |
21 | #include <tvm/runtime/container/array.h> |
22 | #include <tvm/runtime/container/map.h> |
23 | #include <tvm/runtime/object.h> |
24 | #include <tvm/runtime/registry.h> |
25 | |
26 | #include <algorithm> |
27 | #include <utility> |
28 | #include <vector> |
29 | |
30 | #include "plan.h" |
31 | |
32 | namespace tvm { |
33 | namespace contrib { |
34 | namespace ethosu { |
35 | namespace cascader { |
36 | |
37 | void ProposalNode::VisitAttrs(AttrVisitor* v) { |
38 | v->Visit("_graph" , &graph_); |
39 | Array<Part> tmp_parts(part_group_.begin(), part_group_.end()); |
40 | v->Visit("_part_group" , &tmp_parts); |
41 | Array<Plan> tmp_plans(plans_.begin(), plans_.end()); |
42 | v->Visit("_plans" , &tmp_plans); |
43 | Map<Tensor, TensorConfig> tmp_tmap(input_tensor_configs_.begin(), input_tensor_configs_.end()); |
44 | v->Visit("_input_tensor_configs" , &tmp_tmap); |
45 | v->Visit("_cascade_region" , &cascade_region_); |
46 | v->Visit("_memory_usage" , &memory_usage_); |
47 | v->Visit("_cycles" , &cycles_); |
48 | } |
49 | |
50 | Proposal::Proposal(const CascaderGraph& graph, const std::vector<Part>& part_group, |
51 | const std::vector<Plan>& plans, const TensorConfigMap& input_tensor_configs, |
52 | const MemoryRegion& cascade_region, int memory_usage, int cycles) { |
53 | auto n = make_object<ProposalNode>(); |
54 | n->graph_ = std::move(graph); |
55 | n->part_group_ = std::move(part_group); |
56 | std::sort(n->part_group_.begin(), n->part_group_.end()); |
57 | n->plans_ = std::move(plans); |
58 | n->input_tensor_configs_ = std::move(input_tensor_configs); |
59 | n->cascade_region_ = std::move(cascade_region); |
60 | n->memory_usage_ = std::move(memory_usage); |
61 | n->cycles_ = cycles; |
62 | data_ = std::move(n); |
63 | } |
64 | |
65 | TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.Proposal" ) |
66 | .set_body_typed([](CascaderGraph graph, Array<Part> part_group, Array<Plan> plans, |
67 | Map<Tensor, TensorConfig> input_tensor_configs, MemoryRegion cascade_region, |
68 | int memory_usage, int cycles) { |
69 | std::vector<Part> spart_group(part_group.begin(), part_group.end()); |
70 | std::vector<Plan> vplans(plans.begin(), plans.end()); |
71 | TensorConfigMap minput_tensor_configs(input_tensor_configs.begin(), |
72 | input_tensor_configs.end()); |
73 | return Proposal(graph, spart_group, vplans, minput_tensor_configs, cascade_region, |
74 | memory_usage, cycles); |
75 | }); |
76 | |
77 | TVM_REGISTER_NODE_TYPE(ProposalNode); |
78 | |
79 | } // namespace cascader |
80 | } // namespace ethosu |
81 | } // namespace contrib |
82 | } // namespace tvm |
83 | |