1#include <ATen/core/interned_strings.h>
2#include <ATen/core/jit_type.h>
3#include <c10/core/Device.h>
4#include <c10/util/ArrayRef.h>
5#include <c10/util/Optional.h>
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/jit_log.h>
8#include <torch/csrc/jit/passes/device_type_analysis.h>
9#include <torch/csrc/jit/passes/shape_analysis.h>
10#include <torch/library.h>
11#include <memory>
12#include <utility>
13
14namespace torch {
15namespace jit {
16
17namespace {
18
19using Tensor = at::Tensor;
20using Device = at::Device;
21
22using PropRule = std::function<bool(Node*)>;
23/*
24A Propagation Rule takes the Node, and
25applies the relevant properties to the Tensor outputs
26of the Node (based on the rule itself)
27
28Returns: Bool indicating if anything was changed
29*/
30
31bool setDeviceType(Value* value, c10::optional<Device> device) {
32 auto tensor_type = value->type()->expect<TensorType>();
33 bool changed = tensor_type->device() != device;
34 if (changed) {
35 value->setType(tensor_type->withDevice(device));
36 }
37 return changed;
38}
39
40bool setReturnsToDevice(Node* n, c10::optional<Device> device) {
41 bool changed = false;
42 for (Value* out : n->outputs()) {
43 auto tensor_type = out->type()->cast<TensorType>();
44 if (!tensor_type) {
45 continue;
46 }
47 changed |= setDeviceType(out, device);
48 }
49 return changed;
50}
51
52PropRule setReturnstoDeviceRule(DeviceType deviceType) {
53 Device device = Device(deviceType);
54 return [=](Node* n) { return setReturnsToDevice(n, device); };
55}
56
57bool returnFirstArgDeviceRule(Node* n) {
58 // Custom Rule for when multiple args can have mismatched device types
59 auto tensor_type = n->inputs()[0]->type()->cast<TensorType>();
60 TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
61 return setReturnsToDevice(n, tensor_type->device());
62}
63
64bool returnSecondArgDeviceRule(Node* n) {
65 // Custom Rule for when multiple args can have mismatched device types
66 auto tensor_type = n->inputs()[1]->type()->cast<TensorType>();
67 TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
68 return setReturnsToDevice(n, tensor_type->device());
69}
70
71bool isZerodimCPUTensor(std::shared_ptr<TensorType> tensor_type) {
72 // CPU devices on zerodim tensors are the only device that can be
73 // overwritten by another device. Therefore, to be conservative
74 // assume that it is not a zerodim cpu tensor if something is not known.
75 bool is_zerodim = tensor_type->symbolic_sizes().rank().value_or(-1) == 0;
76 bool is_cpu = tensor_type->device() && tensor_type->device()->is_cpu();
77 return is_zerodim && is_cpu;
78}
79
80bool propWithNoDevice(Node* n) {
81 // Propagate if we can verify that all input devices match,
82 // except CPU zerodim, which any other type can overwrite
83 int input_num = 0;
84
85 for (; input_num < n->inputs().size(); input_num++) {
86 if (n->inputs()[input_num]->type()->cast<TensorType>()) {
87 break;
88 }
89 }
90 if (input_num == n->inputs().size()) {
91 // No tensor found
92 return setReturnsToDevice(n, c10::nullopt);
93 }
94
95 auto tensor_type = n->inputs()[input_num]->type()->expect<TensorType>();
96 bool only_seen_cpu_zerodim = isZerodimCPUTensor(tensor_type);
97 c10::optional<Device> device = tensor_type->device();
98
99 // Now see if all inputs have a consistent device type
100 for (input_num++; input_num < n->inputs().size(); input_num++) {
101 auto tensor_type = n->inputs()[input_num]->type()->cast<TensorType>();
102 if (!tensor_type || isZerodimCPUTensor(tensor_type)) {
103 continue;
104 }
105
106 if (device != tensor_type->device()) {
107 if (only_seen_cpu_zerodim) {
108 device = tensor_type->device();
109 only_seen_cpu_zerodim = false;
110 } else {
111 // Bail on the type not match case
112 return setReturnsToDevice(n, c10::nullopt);
113 }
114 }
115 }
116 return setReturnsToDevice(n, device);
117}
118
119bool defaultDeviceProp(Node* n) {
120 // Detecting if the op has a device object argument
121 // as there is implicit string conversion to device
122 auto schema = n->maybeSchema();
123 if (!schema) {
124 return false;
125 }
126 auto arguments = schema->arguments();
127 for (int i = 0; i < arguments.size(); i++) {
128 Argument& argument = arguments[i];
129 if (DeviceObjType::get()->isSubtypeOf(argument.type())) {
130 // Optional args are filled in by torchscript with default val
131 auto input_val = toIValue(n->inputs().at(i));
132 if (!input_val.has_value()) {
133 // Can't propagate if there is a dynamic device type
134 return false;
135 }
136 if (input_val->isNone()) {
137 continue;
138 }
139 if (!input_val->isDevice()) {
140 // Bail on union types
141 return false;
142 }
143 TORCH_INTERNAL_ASSERT(input_val->isDevice())
144 Device device = input_val->toDevice();
145 return setReturnsToDevice(n, device);
146 }
147 }
148 return propWithNoDevice(n);
149}
150
151struct DeviceTypePropagationPass : public PropertyPropBase {
152 explicit DeviceTypePropagationPass(std::shared_ptr<Graph> graph)
153 : PropertyPropBase(graph) {
154 buildRuleRegistry();
155 }
156
157 // returns true if at least one node has its scalar type set on a tensor node
158 bool run() {
159 propagateBlock(graph_->block(), false);
160 return changed_;
161 }
162
163 private:
164 void propagateNode(Node* n, bool _ = false) override {
165 GRAPH_DEBUG("processNode");
166 switch (n->kind()) {
167 case prim::If:
168 return processIf(n);
169 case prim::Loop:
170 return processLoop(n);
171 case prim::CallMethod:
172 case prim::CallFunction:
173 return; // Not handled for now
174 default:
175 break;
176 }
177
178 bool has_tensor_output =
179 std::any_of(n->outputs().begin(), n->outputs().end(), [](Value* v) {
180 return (bool)v->type()->cast<TensorType>();
181 });
182
183 if (!has_tensor_output) {
184 // if output contains no tensor, nothing to propagate
185 return;
186 }
187
188 switch (n->kind()) {
189 case prim::Constant:
190 // This is already been propagated by something else
191 case prim::ListConstruct:
192 case prim::ListUnpack:
193 return; // Not handled for now
194 default:
195 if (n->kind().is_aten()) {
196 return processAtenOps(n);
197 } else {
198 return; // Not handled for now
199 }
200 }
201 }
202
203 void processAtenOps(Node* n) {
204 GRAPH_DEBUG("processAtenOps");
205 GRAPH_DEBUG("case = ", n->kind(), " ", *n);
206 // Custom Rule Matching
207 auto op = n->maybeOperator();
208 if (!op) {
209 return;
210 }
211 auto prop_fn = device_prop_registry_->find(*op);
212 if (prop_fn) {
213 PropRule rule = *prop_fn;
214 changed_ |= rule(n);
215 return;
216 }
217 changed_ |= defaultDeviceProp(n);
218 }
219
220 void buildRuleRegistry() {
221 // building a registry for all of the custom Device Type rules
222 if (device_prop_registry_)
223 return;
224
225 static OperatorMap<PropRule> temp_registry{
226 {"aten::cpu(Tensor self) -> Tensor",
227 setReturnstoDeviceRule(DeviceType::CPU)},
228 {"aten::cuda(Tensor self) -> Tensor",
229 setReturnstoDeviceRule(DeviceType::CUDA)},
230 {"aten::to_mkldnn(Tensor self, ScalarType? dtype) -> Tensor",
231 setReturnstoDeviceRule(DeviceType::MKLDNN)},
232 {"aten::reshape_as(Tensor self, Tensor other) -> Tensor",
233 returnFirstArgDeviceRule},
234 {"aten::view_as(Tensor self, Tensor other) -> Tensor",
235 returnFirstArgDeviceRule},
236 {"aten::expand_as(Tensor self, Tensor other) -> Tensor",
237 returnFirstArgDeviceRule},
238 {"aten::type_as(Tensor self, Tensor other) -> Tensor",
239 returnSecondArgDeviceRule},
240 };
241 device_prop_registry_ =
242 std::make_unique<OperatorMap<PropRule>>(std::move(temp_registry));
243 }
244
245 static std::unique_ptr<OperatorMap<PropRule>> device_prop_registry_;
246 bool changed_ = false;
247};
248
249std::unique_ptr<OperatorMap<PropRule>>
250 DeviceTypePropagationPass::device_prop_registry_ = nullptr;
251
252} // anonymous namespace
253
254// This analysis propagates input device types (if any) throughout the
255// graph.
256bool DeviceTypePropagation(std::shared_ptr<Graph>& graph) {
257 auto tp = std::make_unique<DeviceTypePropagationPass>((graph));
258 bool changed = tp->run();
259 if (changed) {
260 GRAPH_DUMP("After TensorPropertyPropagation pass:", graph);
261 }
262 return changed;
263}
264
265} // namespace jit
266} // namespace torch
267