1 | #include <ATen/core/interned_strings.h> |
2 | #include <ATen/core/jit_type.h> |
3 | #include <c10/core/Device.h> |
4 | #include <c10/util/ArrayRef.h> |
5 | #include <c10/util/Optional.h> |
6 | #include <torch/csrc/jit/ir/ir.h> |
7 | #include <torch/csrc/jit/jit_log.h> |
8 | #include <torch/csrc/jit/passes/device_type_analysis.h> |
9 | #include <torch/csrc/jit/passes/shape_analysis.h> |
10 | #include <torch/library.h> |
11 | #include <memory> |
12 | #include <utility> |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | namespace { |
18 | |
19 | using Tensor = at::Tensor; |
20 | using Device = at::Device; |
21 | |
22 | using PropRule = std::function<bool(Node*)>; |
23 | /* |
24 | A Propagation Rule takes the Node, and |
25 | applies the relevant properties to the Tensor outputs |
26 | of the Node (based on the rule itself) |
27 | |
28 | Returns: Bool indicating if anything was changed |
29 | */ |
30 | |
31 | bool setDeviceType(Value* value, c10::optional<Device> device) { |
32 | auto tensor_type = value->type()->expect<TensorType>(); |
33 | bool changed = tensor_type->device() != device; |
34 | if (changed) { |
35 | value->setType(tensor_type->withDevice(device)); |
36 | } |
37 | return changed; |
38 | } |
39 | |
40 | bool setReturnsToDevice(Node* n, c10::optional<Device> device) { |
41 | bool changed = false; |
42 | for (Value* out : n->outputs()) { |
43 | auto tensor_type = out->type()->cast<TensorType>(); |
44 | if (!tensor_type) { |
45 | continue; |
46 | } |
47 | changed |= setDeviceType(out, device); |
48 | } |
49 | return changed; |
50 | } |
51 | |
52 | PropRule setReturnstoDeviceRule(DeviceType deviceType) { |
53 | Device device = Device(deviceType); |
54 | return [=](Node* n) { return setReturnsToDevice(n, device); }; |
55 | } |
56 | |
57 | bool returnFirstArgDeviceRule(Node* n) { |
58 | // Custom Rule for when multiple args can have mismatched device types |
59 | auto tensor_type = n->inputs()[0]->type()->cast<TensorType>(); |
60 | TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type" ); |
61 | return setReturnsToDevice(n, tensor_type->device()); |
62 | } |
63 | |
64 | bool returnSecondArgDeviceRule(Node* n) { |
65 | // Custom Rule for when multiple args can have mismatched device types |
66 | auto tensor_type = n->inputs()[1]->type()->cast<TensorType>(); |
67 | TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type" ); |
68 | return setReturnsToDevice(n, tensor_type->device()); |
69 | } |
70 | |
71 | bool isZerodimCPUTensor(std::shared_ptr<TensorType> tensor_type) { |
72 | // CPU devices on zerodim tensors are the only device that can be |
73 | // overwritten by another device. Therefore, to be conservative |
74 | // assume that it is not a zerodim cpu tensor if something is not known. |
75 | bool is_zerodim = tensor_type->symbolic_sizes().rank().value_or(-1) == 0; |
76 | bool is_cpu = tensor_type->device() && tensor_type->device()->is_cpu(); |
77 | return is_zerodim && is_cpu; |
78 | } |
79 | |
80 | bool propWithNoDevice(Node* n) { |
81 | // Propagate if we can verify that all input devices match, |
82 | // except CPU zerodim, which any other type can overwrite |
83 | int input_num = 0; |
84 | |
85 | for (; input_num < n->inputs().size(); input_num++) { |
86 | if (n->inputs()[input_num]->type()->cast<TensorType>()) { |
87 | break; |
88 | } |
89 | } |
90 | if (input_num == n->inputs().size()) { |
91 | // No tensor found |
92 | return setReturnsToDevice(n, c10::nullopt); |
93 | } |
94 | |
95 | auto tensor_type = n->inputs()[input_num]->type()->expect<TensorType>(); |
96 | bool only_seen_cpu_zerodim = isZerodimCPUTensor(tensor_type); |
97 | c10::optional<Device> device = tensor_type->device(); |
98 | |
99 | // Now see if all inputs have a consistent device type |
100 | for (input_num++; input_num < n->inputs().size(); input_num++) { |
101 | auto tensor_type = n->inputs()[input_num]->type()->cast<TensorType>(); |
102 | if (!tensor_type || isZerodimCPUTensor(tensor_type)) { |
103 | continue; |
104 | } |
105 | |
106 | if (device != tensor_type->device()) { |
107 | if (only_seen_cpu_zerodim) { |
108 | device = tensor_type->device(); |
109 | only_seen_cpu_zerodim = false; |
110 | } else { |
111 | // Bail on the type not match case |
112 | return setReturnsToDevice(n, c10::nullopt); |
113 | } |
114 | } |
115 | } |
116 | return setReturnsToDevice(n, device); |
117 | } |
118 | |
119 | bool defaultDeviceProp(Node* n) { |
120 | // Detecting if the op has a device object argument |
121 | // as there is implicit string conversion to device |
122 | auto schema = n->maybeSchema(); |
123 | if (!schema) { |
124 | return false; |
125 | } |
126 | auto arguments = schema->arguments(); |
127 | for (int i = 0; i < arguments.size(); i++) { |
128 | Argument& argument = arguments[i]; |
129 | if (DeviceObjType::get()->isSubtypeOf(argument.type())) { |
130 | // Optional args are filled in by torchscript with default val |
131 | auto input_val = toIValue(n->inputs().at(i)); |
132 | if (!input_val.has_value()) { |
133 | // Can't propagate if there is a dynamic device type |
134 | return false; |
135 | } |
136 | if (input_val->isNone()) { |
137 | continue; |
138 | } |
139 | if (!input_val->isDevice()) { |
140 | // Bail on union types |
141 | return false; |
142 | } |
143 | TORCH_INTERNAL_ASSERT(input_val->isDevice()) |
144 | Device device = input_val->toDevice(); |
145 | return setReturnsToDevice(n, device); |
146 | } |
147 | } |
148 | return propWithNoDevice(n); |
149 | } |
150 | |
151 | struct DeviceTypePropagationPass : public PropertyPropBase { |
152 | explicit DeviceTypePropagationPass(std::shared_ptr<Graph> graph) |
153 | : PropertyPropBase(graph) { |
154 | buildRuleRegistry(); |
155 | } |
156 | |
157 | // returns true if at least one node has its scalar type set on a tensor node |
158 | bool run() { |
159 | propagateBlock(graph_->block(), false); |
160 | return changed_; |
161 | } |
162 | |
163 | private: |
164 | void propagateNode(Node* n, bool _ = false) override { |
165 | GRAPH_DEBUG("processNode" ); |
166 | switch (n->kind()) { |
167 | case prim::If: |
168 | return processIf(n); |
169 | case prim::Loop: |
170 | return processLoop(n); |
171 | case prim::CallMethod: |
172 | case prim::CallFunction: |
173 | return; // Not handled for now |
174 | default: |
175 | break; |
176 | } |
177 | |
178 | bool has_tensor_output = |
179 | std::any_of(n->outputs().begin(), n->outputs().end(), [](Value* v) { |
180 | return (bool)v->type()->cast<TensorType>(); |
181 | }); |
182 | |
183 | if (!has_tensor_output) { |
184 | // if output contains no tensor, nothing to propagate |
185 | return; |
186 | } |
187 | |
188 | switch (n->kind()) { |
189 | case prim::Constant: |
190 | // This is already been propagated by something else |
191 | case prim::ListConstruct: |
192 | case prim::ListUnpack: |
193 | return; // Not handled for now |
194 | default: |
195 | if (n->kind().is_aten()) { |
196 | return processAtenOps(n); |
197 | } else { |
198 | return; // Not handled for now |
199 | } |
200 | } |
201 | } |
202 | |
203 | void processAtenOps(Node* n) { |
204 | GRAPH_DEBUG("processAtenOps" ); |
205 | GRAPH_DEBUG("case = " , n->kind(), " " , *n); |
206 | // Custom Rule Matching |
207 | auto op = n->maybeOperator(); |
208 | if (!op) { |
209 | return; |
210 | } |
211 | auto prop_fn = device_prop_registry_->find(*op); |
212 | if (prop_fn) { |
213 | PropRule rule = *prop_fn; |
214 | changed_ |= rule(n); |
215 | return; |
216 | } |
217 | changed_ |= defaultDeviceProp(n); |
218 | } |
219 | |
220 | void buildRuleRegistry() { |
221 | // building a registry for all of the custom Device Type rules |
222 | if (device_prop_registry_) |
223 | return; |
224 | |
225 | static OperatorMap<PropRule> temp_registry{ |
226 | {"aten::cpu(Tensor self) -> Tensor" , |
227 | setReturnstoDeviceRule(DeviceType::CPU)}, |
228 | {"aten::cuda(Tensor self) -> Tensor" , |
229 | setReturnstoDeviceRule(DeviceType::CUDA)}, |
230 | {"aten::to_mkldnn(Tensor self, ScalarType? dtype) -> Tensor" , |
231 | setReturnstoDeviceRule(DeviceType::MKLDNN)}, |
232 | {"aten::reshape_as(Tensor self, Tensor other) -> Tensor" , |
233 | returnFirstArgDeviceRule}, |
234 | {"aten::view_as(Tensor self, Tensor other) -> Tensor" , |
235 | returnFirstArgDeviceRule}, |
236 | {"aten::expand_as(Tensor self, Tensor other) -> Tensor" , |
237 | returnFirstArgDeviceRule}, |
238 | {"aten::type_as(Tensor self, Tensor other) -> Tensor" , |
239 | returnSecondArgDeviceRule}, |
240 | }; |
241 | device_prop_registry_ = |
242 | std::make_unique<OperatorMap<PropRule>>(std::move(temp_registry)); |
243 | } |
244 | |
245 | static std::unique_ptr<OperatorMap<PropRule>> device_prop_registry_; |
246 | bool changed_ = false; |
247 | }; |
248 | |
249 | std::unique_ptr<OperatorMap<PropRule>> |
250 | DeviceTypePropagationPass::device_prop_registry_ = nullptr; |
251 | |
252 | } // anonymous namespace |
253 | |
254 | // This analysis propagates input device types (if any) throughout the |
255 | // graph. |
256 | bool DeviceTypePropagation(std::shared_ptr<Graph>& graph) { |
257 | auto tp = std::make_unique<DeviceTypePropagationPass>((graph)); |
258 | bool changed = tp->run(); |
259 | if (changed) { |
260 | GRAPH_DUMP("After TensorPropertyPropagation pass:" , graph); |
261 | } |
262 | return changed; |
263 | } |
264 | |
265 | } // namespace jit |
266 | } // namespace torch |
267 | |