1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/target/compilation_config.h
22 * \brief A helper class to collect all the targets in canonical form necessary for compilation.
23 */
24
25#ifndef TVM_TARGET_COMPILATION_CONFIG_H_
26#define TVM_TARGET_COMPILATION_CONFIG_H_
27
28#include <tvm/target/virtual_device.h>
29
30#include <string>
31
32namespace tvm {
33
34/*!
35 * \brief Gathers the \p Targets and distinguished \p VirtualDevices in canonical form needed to
36 * compile a Relay module for execution over possibly heterogeneous devices. Centralizes the
37 * validation and canonicalization logic needed to transition from targets supplied by the Python
38 * APIs to a single internal representation. Also holds a cache of canonical \p VirtualDevices
39 * so that structural equal virtual devices have pointer equal canonical virtual devices.
40 *
41 * The construction of \p CompilationConfig is idempotent, in that given the same \p PassContext
42 * \p ctx and an arbitrary \p Array<Target> \p raw_targets:
43 *
44 * \code
45 * CompilationConfig(ctxt, raw_targets)
46 * is structurally equal to
47 * CompilationConfig(ctxt, CompilationConfig(ctxt, raw_targets)->primitive_targets)
48 * \endcode
49 *
50 * TODO(mbs): This is subject to change as we rework compilation options in general. This class
51 * is probably better called a 'CompositeTarget', and may be better made a sub-class of Target or
52 * some other common-target-root class.
53 */
54class CompilationConfigNode : public Object {
55 public:
56 /*!
57 * \brief The host target. Used for 'scalar' data and code (such as shapes and shape
58 * functions) and residual Relay expressions and data (such as conditionals and ADTs).
59 * Each \p primitive_target below will have this exact target object as its 'host'.
60 *
61 * Note that it is possible for a \p Target used for primitive operations to be structurally
62 * equal to the host \p Target (up to the \p host field.) However the \p Target objects will
63 * be distinct, and can be used as keys within a \p Map without collision.
64 */
65 Target host_target;
66
67 /*!
68 * \brief Vector of all available \p Targets for partitioning or compiling primitive tensor
69 * operators (kernels). May contain a \p Target for the same device type as for the
70 * \p host_target, however the \p host_target should be used for all host computations and data.
71 * Each \p Target will have \p host_target as its 'host'.
72 *
73 * Primitive targets must be unique by their kind name. In this way the
74 * \p FindPrimitiveTargetForKind method will find the unique target for the given kind name.
75 * This method is used when transitioning from an external codegen "Compiler" attribute value
76 * to the external codegen target representing that compiler.
77 *
78 * It is possible to have multiple primitive targets for the same device type. However given
79 * primitive targets left and right where:
80 * - left appears before right in the array
81 * - left->GetTargetDeviceType() == right->GetTargetDeviceType()
82 * then:
83 * - right.IsExternalCodegenFor(left) must be true
84 * In this way the \p FindPrimitiveTargetForDeviceOrFail method will find the 'most general'
85 * target for the requested device type. This method is used when transitioning from a device
86 * constraint to the target needed to compile for that device.
87 *
88 * In the homogeneous case primitive_targets will have just one entry, which will be pointer equal
89 * to optional_homogeneous_target.
90 *
91 * In the homogenous case where the 'host' is the same device as used for compiling kernels it
92 * is *not* the case that optional_homogenous_target == host_target. This is because all
93 * primitive always have their host field set to the host_target. Ie, it is valid to have:
94 * \code
95 * host_target=Target("llvm")
96 * optional_homogenous_target=Target("llvm", host=host_target)
97 * \endcode
98 */
99 Array<Target> primitive_targets;
100
101 /*!
102 * \brief \p VirtualDevice for primitive operators which are not otherwise constrained to a
103 * particular device. Used by the PlanDevices pass to determine a virtual device for every
104 * sub-expression.
105 */
106 VirtualDevice default_primitive_virtual_device = VirtualDevice::FullyUnconstrained();
107
108 /*! \brief VirtualDevice for the host. */
109 VirtualDevice host_virtual_device = VirtualDevice::FullyUnconstrained();
110
111 /*!
112 * \brief If defined then compile and/or run in 'homogenous execution mode'. In this mode all
113 * primitives are compiled for this target only.
114 *
115 * This is to support legacy passes which have not been adapted to heterogeneous execution and
116 * rely on an implicit global \p Target to be in scope.
117 *
118 * TODO(mbs): Remove once all passes are 'heterogeneous aware'.
119 */
120 Target optional_homogeneous_target;
121
122 void VisitAttrs(AttrVisitor* v);
123
124 /*!
125 * \brief Returns the unique \p Target to use for \p device_type. Fail if no such target exists.
126 *
127 * This will be the first primitive target with matching device type.
128 */
129 Target FindPrimitiveTargetForDeviceOrFail(DLDeviceType device_type) const;
130
131 /*!
132 * \brief Returns the unique \p Target to use for \p kind_name. Returns null if none such.
133 */
134 Optional<Target> FindPrimitiveTargetForKind(const std::string& kind_name) const;
135
136 /*!
137 * \brief Returns a \p Target structurally equal to \p target, however prefer a structually equal
138 * known host or primitive target if the configuration has one.
139 */
140 Target CanonicalTarget(const Target& target) const;
141
142 /*!
143 * \brief Returns a \p VirtualDevice which is structurally equal to \p virtual_device on all its
144 * constrained fields, however:
145 * - If \p virtual_device has a device type but not a target, fill in a target using
146 * \p FindPrimitiveTargetOrFail. This is the one place we allow targets to be defaulted
147 * from device types alone.
148 * - If \p virtual_device has a target, also canonicalize it using \p CanonicalTarget.
149 * The returned object will be unique for the adjusted virtual device w.r.t. all other
150 * \p VirtualDevices returned by this method.
151 *
152 * We call the result the 'canonical' \p VirtualDevice. Two canonical \p VirtualDevices are
153 * structurally equal if and only if they are pointer equal. In this way we can build maps
154 * from virtual devices using just pointer equality.
155 */
156 VirtualDevice CanonicalVirtualDevice(const VirtualDevice& virtual_device) const;
157
158 static constexpr const char* _type_key = "CompilationConfig";
159 TVM_DECLARE_FINAL_OBJECT_INFO(CompilationConfigNode, Object)
160
161 private:
162 /*!
163 * \brief Sets the primitive targets, the host target, the default primitive virtual device, and
164 * the host virtual device given:
165 * - the vector of 'raw' targets (in any order) supplied by one of the TVM entry points.
166 * - any "relay.fallback_device_type" attribute on \p pass_ctx.
167 * - whether the LLVM backend is available.
168 * Will look for a suitable host target in the given primitive targets, but if none found may
169 * reuse a raw target or create a default CPU target.
170 */
171 void Init(const transform::PassContext& pass_ctx, const Array<Target>& raw_targets);
172
173 /*!
174 * \brief Returns a freshly constructed CPU \p Target.
175 */
176 static Target MakeDefaultCPUTarget();
177
178 /*!
179 * \brief A cache of constructed virtual devices.
180 */
181 mutable VirtualDeviceCache virtual_device_cache_;
182
183 friend class CompilationConfig;
184};
185
186/*!
187 * \brief Managed reference class to \p CompilationConfig
188 *
189 * \sa CompilationConfig
190 */
191class CompilationConfig : public ObjectRef {
192 public:
193 /*!
194 * \brief Constructs the compilation config given the settings in \p pass_ctx and supplied
195 * \p raw_targets. See \p CompilationConfigNode::Init for details.
196 */
197 TVM_DLL CompilationConfig(const transform::PassContext& pass_ctx,
198 const Array<Target>& raw_targets);
199
200 TVM_DEFINE_OBJECT_REF_METHODS(CompilationConfig, ObjectRef, CompilationConfigNode);
201};
202
203} // namespace tvm
204
205#endif // TVM_TARGET_COMPILATION_CONFIG_H_
206