1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/target/compilation_config.h |
22 | * \brief A helper class to collect all the targets in canonical form necessary for compilation. |
23 | */ |
24 | |
25 | #ifndef TVM_TARGET_COMPILATION_CONFIG_H_ |
26 | #define TVM_TARGET_COMPILATION_CONFIG_H_ |
27 | |
28 | #include <tvm/target/virtual_device.h> |
29 | |
30 | #include <string> |
31 | |
32 | namespace tvm { |
33 | |
34 | /*! |
35 | * \brief Gathers the \p Targets and distinguished \p VirtualDevices in canonical form needed to |
36 | * compile a Relay module for execution over possibly heterogeneous devices. Centralizes the |
37 | * validation and canonicalization logic needed to transition from targets supplied by the Python |
38 | * APIs to a single internal representation. Also holds a cache of canonical \p VirtualDevices |
39 | * so that structural equal virtual devices have pointer equal canonical virtual devices. |
40 | * |
41 | * The construction of \p CompilationConfig is idempotent, in that given the same \p PassContext |
42 | * \p ctx and an arbitrary \p Array<Target> \p raw_targets: |
43 | * |
44 | * \code |
45 | * CompilationConfig(ctxt, raw_targets) |
46 | * is structurally equal to |
47 | * CompilationConfig(ctxt, CompilationConfig(ctxt, raw_targets)->primitive_targets) |
48 | * \endcode |
49 | * |
50 | * TODO(mbs): This is subject to change as we rework compilation options in general. This class |
51 | * is probably better called a 'CompositeTarget', and may be better made a sub-class of Target or |
52 | * some other common-target-root class. |
53 | */ |
54 | class CompilationConfigNode : public Object { |
55 | public: |
56 | /*! |
57 | * \brief The host target. Used for 'scalar' data and code (such as shapes and shape |
58 | * functions) and residual Relay expressions and data (such as conditionals and ADTs). |
59 | * Each \p primitive_target below will have this exact target object as its 'host'. |
60 | * |
61 | * Note that it is possible for a \p Target used for primitive operations to be structurally |
62 | * equal to the host \p Target (up to the \p host field.) However the \p Target objects will |
63 | * be distinct, and can be used as keys within a \p Map without collision. |
64 | */ |
65 | Target host_target; |
66 | |
67 | /*! |
68 | * \brief Vector of all available \p Targets for partitioning or compiling primitive tensor |
69 | * operators (kernels). May contain a \p Target for the same device type as for the |
70 | * \p host_target, however the \p host_target should be used for all host computations and data. |
71 | * Each \p Target will have \p host_target as its 'host'. |
72 | * |
73 | * Primitive targets must be unique by their kind name. In this way the |
74 | * \p FindPrimitiveTargetForKind method will find the unique target for the given kind name. |
75 | * This method is used when transitioning from an external codegen "Compiler" attribute value |
76 | * to the external codegen target representing that compiler. |
77 | * |
78 | * It is possible to have multiple primitive targets for the same device type. However given |
79 | * primitive targets left and right where: |
80 | * - left appears before right in the array |
81 | * - left->GetTargetDeviceType() == right->GetTargetDeviceType() |
82 | * then: |
83 | * - right.IsExternalCodegenFor(left) must be true |
84 | * In this way the \p FindPrimitiveTargetForDeviceOrFail method will find the 'most general' |
85 | * target for the requested device type. This method is used when transitioning from a device |
86 | * constraint to the target needed to compile for that device. |
87 | * |
88 | * In the homogeneous case primitive_targets will have just one entry, which will be pointer equal |
89 | * to optional_homogeneous_target. |
90 | * |
91 | * In the homogenous case where the 'host' is the same device as used for compiling kernels it |
92 | * is *not* the case that optional_homogenous_target == host_target. This is because all |
93 | * primitive always have their host field set to the host_target. Ie, it is valid to have: |
94 | * \code |
95 | * host_target=Target("llvm") |
96 | * optional_homogenous_target=Target("llvm", host=host_target) |
97 | * \endcode |
98 | */ |
99 | Array<Target> primitive_targets; |
100 | |
101 | /*! |
102 | * \brief \p VirtualDevice for primitive operators which are not otherwise constrained to a |
103 | * particular device. Used by the PlanDevices pass to determine a virtual device for every |
104 | * sub-expression. |
105 | */ |
106 | VirtualDevice default_primitive_virtual_device = VirtualDevice::FullyUnconstrained(); |
107 | |
108 | /*! \brief VirtualDevice for the host. */ |
109 | VirtualDevice host_virtual_device = VirtualDevice::FullyUnconstrained(); |
110 | |
111 | /*! |
112 | * \brief If defined then compile and/or run in 'homogenous execution mode'. In this mode all |
113 | * primitives are compiled for this target only. |
114 | * |
115 | * This is to support legacy passes which have not been adapted to heterogeneous execution and |
116 | * rely on an implicit global \p Target to be in scope. |
117 | * |
118 | * TODO(mbs): Remove once all passes are 'heterogeneous aware'. |
119 | */ |
120 | Target optional_homogeneous_target; |
121 | |
122 | void VisitAttrs(AttrVisitor* v); |
123 | |
124 | /*! |
125 | * \brief Returns the unique \p Target to use for \p device_type. Fail if no such target exists. |
126 | * |
127 | * This will be the first primitive target with matching device type. |
128 | */ |
129 | Target FindPrimitiveTargetForDeviceOrFail(DLDeviceType device_type) const; |
130 | |
131 | /*! |
132 | * \brief Returns the unique \p Target to use for \p kind_name. Returns null if none such. |
133 | */ |
134 | Optional<Target> FindPrimitiveTargetForKind(const std::string& kind_name) const; |
135 | |
136 | /*! |
137 | * \brief Returns a \p Target structurally equal to \p target, however prefer a structually equal |
138 | * known host or primitive target if the configuration has one. |
139 | */ |
140 | Target CanonicalTarget(const Target& target) const; |
141 | |
142 | /*! |
143 | * \brief Returns a \p VirtualDevice which is structurally equal to \p virtual_device on all its |
144 | * constrained fields, however: |
145 | * - If \p virtual_device has a device type but not a target, fill in a target using |
146 | * \p FindPrimitiveTargetOrFail. This is the one place we allow targets to be defaulted |
147 | * from device types alone. |
148 | * - If \p virtual_device has a target, also canonicalize it using \p CanonicalTarget. |
149 | * The returned object will be unique for the adjusted virtual device w.r.t. all other |
150 | * \p VirtualDevices returned by this method. |
151 | * |
152 | * We call the result the 'canonical' \p VirtualDevice. Two canonical \p VirtualDevices are |
153 | * structurally equal if and only if they are pointer equal. In this way we can build maps |
154 | * from virtual devices using just pointer equality. |
155 | */ |
156 | VirtualDevice CanonicalVirtualDevice(const VirtualDevice& virtual_device) const; |
157 | |
158 | static constexpr const char* _type_key = "CompilationConfig" ; |
159 | TVM_DECLARE_FINAL_OBJECT_INFO(CompilationConfigNode, Object) |
160 | |
161 | private: |
162 | /*! |
163 | * \brief Sets the primitive targets, the host target, the default primitive virtual device, and |
164 | * the host virtual device given: |
165 | * - the vector of 'raw' targets (in any order) supplied by one of the TVM entry points. |
166 | * - any "relay.fallback_device_type" attribute on \p pass_ctx. |
167 | * - whether the LLVM backend is available. |
168 | * Will look for a suitable host target in the given primitive targets, but if none found may |
169 | * reuse a raw target or create a default CPU target. |
170 | */ |
171 | void Init(const transform::PassContext& pass_ctx, const Array<Target>& raw_targets); |
172 | |
173 | /*! |
174 | * \brief Returns a freshly constructed CPU \p Target. |
175 | */ |
176 | static Target MakeDefaultCPUTarget(); |
177 | |
178 | /*! |
179 | * \brief A cache of constructed virtual devices. |
180 | */ |
181 | mutable VirtualDeviceCache virtual_device_cache_; |
182 | |
183 | friend class CompilationConfig; |
184 | }; |
185 | |
186 | /*! |
187 | * \brief Managed reference class to \p CompilationConfig |
188 | * |
189 | * \sa CompilationConfig |
190 | */ |
191 | class CompilationConfig : public ObjectRef { |
192 | public: |
193 | /*! |
194 | * \brief Constructs the compilation config given the settings in \p pass_ctx and supplied |
195 | * \p raw_targets. See \p CompilationConfigNode::Init for details. |
196 | */ |
197 | TVM_DLL CompilationConfig(const transform::PassContext& pass_ctx, |
198 | const Array<Target>& raw_targets); |
199 | |
200 | TVM_DEFINE_OBJECT_REF_METHODS(CompilationConfig, ObjectRef, CompilationConfigNode); |
201 | }; |
202 | |
203 | } // namespace tvm |
204 | |
205 | #endif // TVM_TARGET_COMPILATION_CONFIG_H_ |
206 | |