1 | #include <ATen/Context.h> |
2 | #include <ATen/LegacyBatchedFallback.h> |
3 | #include <ATen/MatrixRef.h> |
4 | #include <ATen/LegacyVmapTransforms.h> |
5 | #include <ATen/core/dispatch/Dispatcher.h> |
6 | #include <c10/util/accumulate.h> |
7 | #include <c10/util/llvmMathExtras.h> |
8 | #include <c10/util/irange.h> |
9 | |
10 | namespace at { |
11 | |
12 | // Given a linear index, return the actual index. |
13 | // Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0] |
14 | static SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize> |
15 | computeIndex(int64_t linear_idx, IntArrayRef sizes) { |
16 | SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize> result; |
17 | result.reserve(sizes.size()); |
18 | for (auto it = sizes.rbegin(); it != sizes.rend(); it++) { |
19 | auto remainder = linear_idx % *it; |
20 | result.push_back(remainder); |
21 | linear_idx -= remainder; |
22 | linear_idx /= *it; |
23 | } |
24 | std::reverse(std::begin(result), std::end(result)); |
25 | return result; |
26 | } |
27 | |
28 | static bool areAllReturnsTensors(const FunctionSchema& schema) { |
29 | return std::all_of( |
30 | schema.returns().begin(), |
31 | schema.returns().end(), |
32 | [] (const Argument& arg) { return arg.type() == TensorType::get(); }); |
33 | } |
34 | |
35 | static bool areAnyArgumentsTensorList(const FunctionSchema& schema) { |
36 | return std::any_of( |
37 | schema.arguments().begin(), |
38 | schema.arguments().end(), |
39 | [] (const Argument& arg) { return arg.type()->isSubtypeOf(*ListType::ofTensors()); }); |
40 | } |
41 | |
42 | // Returns if an operator is in-place. An operator is inplace if: |
43 | // 1. The first argument is a Tensor and it is being written to |
44 | // 2. The first argument is being returned |
45 | // 3. No other arguments are aliased |
46 | // Here is an example of an in-place operator: |
47 | // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) |
48 | static bool isInplaceOp(const c10::FunctionSchema& schema) { |
49 | if (!schema.is_mutable() || schema.returns().size() != 1) { |
50 | return false; |
51 | } |
52 | // Check that the first argument is being written to |
53 | const AliasInfo* first_arg_alias_info = schema.arguments().begin()->alias_info(); |
54 | if (!first_arg_alias_info || !first_arg_alias_info->isWrite()) { |
55 | return false; |
56 | } |
57 | // Check that none of the other args are being aliased |
58 | for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) { |
59 | const AliasInfo* alias_info = it->alias_info(); |
60 | if (alias_info) { |
61 | return false; |
62 | } |
63 | } |
64 | // Check that the first tensor is being returned (i.e., output has a (a!)) |
65 | const AliasInfo* return_alias_info = schema.returns()[0].alias_info(); |
66 | return return_alias_info && return_alias_info->isWrite(); |
67 | } |
68 | |
69 | static void warnFallback(const c10::FunctionSchema& schema) { |
70 | if (!globalContext().areVmapFallbackWarningsEnabled()) { |
71 | return; |
72 | } |
73 | TORCH_WARN("There is a performance drop because we have not yet implemented " , |
74 | "the batching rule for " , schema.operator_name(), ". " , |
75 | "You are using the legacy vmap prototype (torch._vmap_internals.vmap). " , |
76 | "If you are using torch.autograd.functional.{jacobian, hessian} " , |
77 | "or torch._vmap_internals.vmap: please switch to using " , |
78 | "torch.func.{jacrev, jacfwd, hessian} and/or torch.vmap instead " , |
79 | "for better operator coverage and performance improvements ." ); |
80 | } |
81 | |
82 | // The general flow of the algorithm is as follows. |
83 | // - First, we figure out which arguments are BatchedTensors and save them |
84 | // to a vector. We also store a vector of which index of the arguments list |
85 | // each BatchedTensor appears in. This will be useful for bookkeeping later. |
86 | // - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors. |
87 | // This returns a vector of VmapPhysicalView that hold tensors that contain |
88 | // all of the collective batch dimensions at the front of the tensors. |
89 | // - Then, we attempt to call `op` once per slice of the inputs. To do this, |
90 | // we repeatedly we slice the input arguments (if they are BatchedTensors), |
91 | // put the sliced (or a not-sliced) version of the input onto the stack, invoke |
92 | // the operator, and then pop the results off the stack. |
93 | void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { |
94 | const auto& schema = op.schema(); |
95 | warnFallback(schema); |
96 | |
97 | const auto num_arguments = static_cast<int64_t>(schema.arguments().size()); |
98 | const auto arguments = torch::jit::last(stack, num_arguments); |
99 | const auto arguments_begin = stack->size() - num_arguments; |
100 | |
101 | // `self` is the Tensor being modified in-place |
102 | Tensor self = arguments[0].toTensor(); |
103 | const auto* self_impl = maybeGetBatchedImpl(self); |
104 | std::bitset<kVmapMaxTensorDims> self_vmap_levels; |
105 | if (self_impl) { |
106 | self_vmap_levels = createVmapLevelsBitset(self_impl->bdims()); |
107 | } |
108 | |
109 | // Figure out which arguments are BatchedTensor. Save them to a vector. |
110 | // For each BatchedTensor, also record what position of `arguments` they came from. |
111 | SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs; |
112 | VmapDimVector batched_tensor_inputs_position; |
113 | for (const auto idx : c10::irange(arguments.size())) { |
114 | const auto& ivalue = arguments[idx]; |
115 | if (!ivalue.isTensor()) { |
116 | continue; |
117 | } |
118 | const auto& tensor = ivalue.toTensor(); |
119 | if (!tensor.defined()) { |
120 | continue; |
121 | } |
122 | const auto* batched = maybeGetBatchedImpl(tensor); |
123 | if (!batched) { |
124 | continue; |
125 | } |
126 | |
127 | // NOTE: [vmap-incompatible in-place operations] |
128 | // In-place operations on `self` are not possible if there exists some vmap |
129 | // level `l` such that `self` is not being vmapped on that level but another |
130 | // argument is. For example, let B0 be a batch dim inside vmap and consider |
131 | // vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3)) |
132 | // - self is torch.ones(3) and does not participate in this vmap |
133 | // - other is BatchedTensor(torch.ones(B0, 3)) |
134 | // There's no way to do self.add_(other) because `other` has more elements |
135 | // elements than `self` due to being vmapped over. |
136 | // |
137 | // In the vmap fallback, we should error out when we detect this. |
138 | auto other_vmap_levels = createVmapLevelsBitset(batched->bdims()); |
139 | if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) { |
140 | // Find one vmap level to complain about |
141 | auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels; |
142 | auto offending_level = llvm::findLastSet(additional_bdims.to_ulong()); |
143 | // The following prints out "vmap: aten::add_(tensor, ...) is not possible", |
144 | // but it would be better to print out "tensor.add_(...) is not possible". |
145 | // Afaict there's no official way to get the add_ and there is no way to |
146 | // tell if an operator has method or function variants. |
147 | TORCH_CHECK(false, |
148 | "vmap: " , schema.name(), "(self, *extra_args) is not possible because " , |
149 | "there exists a Tensor `other` in extra_args that has more elements " , |
150 | "than `self`. This happened due to `other` being vmapped over but " , |
151 | "`self` not being vmapped over at level " , offending_level, ". " , |
152 | "Please try to use out-of-place operators instead of " , schema.name(), ". " , |
153 | "If said operator is being called inside the PyTorch framework, " , |
154 | "please file a bug report instead." ); |
155 | } |
156 | batched_tensor_inputs.push_back(tensor); |
157 | batched_tensor_inputs_position.push_back(idx); |
158 | } |
159 | TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty()); |
160 | |
161 | // MultiBatchVmapTransform the BatchedTensor arguments. This returns |
162 | // VmapPhysicalViews that contain all of the batch dimensions. |
163 | const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical( |
164 | batched_tensor_inputs); |
165 | |
166 | // Compute the total number of batches |
167 | auto num_batch_dims = input_physical_views.front().numBatchDims(); |
168 | auto first_physical_view_sizes = input_physical_views.front().tensor().sizes(); |
169 | auto batch_sizes = ArrayRef<int64_t>( |
170 | first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims); |
171 | const auto num_batches = c10::multiply_integers(batch_sizes); |
172 | // Without a shape-checking API, we're unable to compute the correct shape of |
173 | // the output so we just error out. |
174 | TORCH_CHECK(num_batches > 0, |
175 | "Batching rule not implemented for " , schema.operator_name(), ". " , |
176 | "The fallback path does not support vmap over dims of size 0." ); |
177 | |
178 | // Strategy: For each batch, we are going to push slices (where applicable) |
179 | // of the arguments onto `stack`, and call `op`. |
180 | for (const auto linear_idx : c10::irange(num_batches)) { |
181 | auto index = computeIndex(linear_idx, batch_sizes); |
182 | auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin(); |
183 | auto input_physical_views_iter = input_physical_views.begin(); |
184 | for (const auto arg_idx : c10::irange(num_arguments)) { |
185 | // We assume that torch::jit::Stack is backed by vector<IValue> for |
186 | // simplicity. When that is not the case, this code should be updated. |
187 | const auto& argument = (*stack)[arguments_begin + arg_idx]; |
188 | if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end() |
189 | || arg_idx != *batched_tensor_inputs_pos_iter) { |
190 | // argument isn't a BatchedTensor |
191 | torch::jit::push(stack, argument); |
192 | continue; |
193 | } |
194 | // argument is a BatchedTensor |
195 | TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end()); |
196 | const auto& physical_view_for_argument = *input_physical_views_iter; |
197 | torch::jit::push(stack, physical_view_for_argument.tensor().index(index)); |
198 | batched_tensor_inputs_pos_iter++; |
199 | input_physical_views_iter++; |
200 | } |
201 | |
202 | op.callBoxed(stack); |
203 | torch::jit::drop(stack, 1); |
204 | } |
205 | |
206 | // Return the tensor that was written to in-place |
207 | torch::jit::drop(stack, num_arguments); |
208 | torch::jit::push(stack, self); |
209 | } |
210 | |
211 | static Tensor safeStack(TensorList tensors) { |
212 | auto is_defined = [](const Tensor& t) { return t.defined(); }; |
213 | if (std::all_of(tensors.begin(), tensors.end(), is_defined)) { |
214 | return at::stack(tensors); |
215 | } |
216 | // NOTE [vmap through backward and undefined grad] |
217 | // While vmapping through backward functions (to compute batched grad), it |
218 | // is possible for the backward function to return an undefined grad for some |
219 | // grad_input for each example. In that case, we return an undefined grad. |
220 | // |
221 | // It is theoretically posssible for *some* of the examples to produce an |
222 | // undefined grad (a kernel could peek at the gradient values and return an |
223 | // undefined tensor if it determines the gradient is full of zeros). We |
224 | // could handle this by treating the undefined grad as a zero-filled tensor |
225 | // of the correct shape while stacking the tensors together. However I expect |
226 | // this to happen very rarely (I have not been able to find an example in our |
227 | // codebase) so we just error out in this case. |
228 | if (std::none_of(tensors.begin(), tensors.end(), is_defined)) { |
229 | return Tensor(); |
230 | } |
231 | TORCH_CHECK(false, |
232 | "vmap: slow fallback received a mix of undefined and defined tensors " , |
233 | "as the result of an operation. This is not supported, please file us " , |
234 | "an issue on github." ); |
235 | } |
236 | |
237 | // The general flow of the algorithm is as follows. |
238 | // - First, we figure out which arguments are BatchedTensors and save them |
239 | // to a vector. We also store a vector of which index of the arguments list |
240 | // each BatchedTensor appears in. This will be useful for bookkeeping later. |
241 | // - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors. |
242 | // This returns a vector of VmapPhysicalView that hold tensors that contain |
243 | // all of the collective batch dimensions at the front of the tensors. |
244 | // - Then, we attempt to call `op` once per slice of the inputs. To do this, |
245 | // we repeatedly we slice the input arguments (if they are BatchedTensors), |
246 | // put the sliced (or a not-sliced) version of the input onto the stack, invoke |
247 | // the operator, and then pop the results off the stack. |
248 | // - Each result obtained from the previous step is a slice of the total result, |
249 | // so we stack those tensors together to form the final result. |
250 | void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { |
251 | const auto& schema = op.schema(); |
252 | const auto num_returns = schema.returns().size(); |
253 | |
254 | if (isInplaceOp(schema)) { |
255 | batchedTensorInplaceForLoopFallback(op, stack); |
256 | return; |
257 | } |
258 | TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(), |
259 | "Batching rule not implemented for " , schema.operator_name(), "; " , |
260 | "the fallback path doesn't work on out= or view ops." ); |
261 | TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema), |
262 | "Batching rule not implemented for " , schema.operator_name(), ". " , |
263 | "We could not generate a fallback." ); |
264 | TORCH_CHECK(num_returns >= 1, |
265 | "Batching rule not implemented for " , schema.operator_name(), ". " , |
266 | "The fallback path does not support operations with no returns." ); |
267 | warnFallback(schema); |
268 | |
269 | const auto num_arguments = static_cast<int64_t>(schema.arguments().size()); |
270 | const auto arguments = torch::jit::last(stack, num_arguments); |
271 | const auto arguments_begin = stack->size() - num_arguments; |
272 | |
273 | // Figure out which arguments are BatchedTensor. Save them to a vector. |
274 | // For each BatchedTensor, also record what position of `arguments` they came from. |
275 | SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs; |
276 | VmapDimVector batched_tensor_inputs_position; |
277 | for (const auto idx : c10::irange(arguments.size())) { |
278 | const auto& ivalue = arguments[idx]; |
279 | if (!ivalue.isTensor()) { |
280 | continue; |
281 | } |
282 | const auto& tensor = ivalue.toTensor(); |
283 | if (!tensor.defined()) { |
284 | continue; |
285 | } |
286 | const auto* batched = maybeGetBatchedImpl(tensor); |
287 | if (!batched) { |
288 | continue; |
289 | } |
290 | batched_tensor_inputs.push_back(tensor); |
291 | batched_tensor_inputs_position.push_back(idx); |
292 | } |
293 | TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty()); |
294 | |
295 | // MultiBatchVmapTransform the BatchedTensor arguments. This returns |
296 | // VmapPhysicalViews that contain all of the batch dimensions. |
297 | const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical( |
298 | batched_tensor_inputs); |
299 | |
300 | // Compute the total number of batches |
301 | auto num_batch_dims = input_physical_views.front().numBatchDims(); |
302 | auto some_sizes = input_physical_views.front().tensor().sizes(); |
303 | auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims); |
304 | const auto num_batches = c10::multiply_integers(batch_sizes); |
305 | // Without a shape-checking API, we're unable to compute the correct shape of |
306 | // the output so we just error out. |
307 | TORCH_CHECK(num_batches > 0, |
308 | "Batching rule not implemented for " , schema.operator_name(), ". " , |
309 | "The fallback path does not support vmap over dims of size 0." ); |
310 | |
311 | // Strategy: For each batch, we are going to push slices (where applicable) |
312 | // of the arguments onto `stack`, call `op`, and store the result in |
313 | // `output_shards`. |
314 | // |
315 | // NOTE: [Output shards layout] |
316 | // Assume that the operator has three outputs: a, b, c. |
317 | // The layout of output_shards is as follows: |
318 | // [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3] |
319 | // This is so that we can call at::stack([a0...a3]), at::stack([b0...b3]) |
320 | // more easily in the next step. |
321 | std::vector<Tensor> output_shards(num_batches * num_returns); |
322 | |
323 | for (const auto linear_idx : c10::irange(num_batches)) { |
324 | auto index = computeIndex(linear_idx, batch_sizes); |
325 | auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin(); |
326 | auto input_physical_views_iter = input_physical_views.begin(); |
327 | for (const auto arg_idx : c10::irange(num_arguments)) { |
328 | // We assume that torch::jit::Stack is backed by vector<IValue> for |
329 | // simplicity. When that is not the case, this code should be updated. |
330 | const auto& argument = (*stack)[arguments_begin + arg_idx]; |
331 | if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end() |
332 | || arg_idx != *batched_tensor_inputs_pos_iter) { |
333 | // argument isn't a BatchedTensor |
334 | torch::jit::push(stack, argument); |
335 | continue; |
336 | } |
337 | // argument is a BatchedTensor |
338 | TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end()); |
339 | const auto& physical_view_for_argument = *input_physical_views_iter; |
340 | torch::jit::push(stack, physical_view_for_argument.tensor().index(index)); |
341 | batched_tensor_inputs_pos_iter++; |
342 | input_physical_views_iter++; |
343 | } |
344 | |
345 | op.callBoxed(stack); |
346 | |
347 | // Store the result into `output_shards`. See NOTE: [Output shards layout] |
348 | // to learn about the details of how we store the shards. |
349 | const auto returns = torch::jit::last(stack, num_returns); |
350 | for (const auto return_idx : c10::irange(returns.size())) { |
351 | output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor(); |
352 | } |
353 | torch::jit::drop(stack, num_returns); |
354 | } |
355 | |
356 | // For each output Tensor, stack the shards of the tensor together to form a return |
357 | torch::jit::drop(stack, num_arguments); |
358 | auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches); |
359 | for (const auto return_idx : c10::irange(num_returns)) { |
360 | auto shards = output_shards_chunks[return_idx]; |
361 | auto flat_output = safeStack(shards); |
362 | // See NOTE [vmap through backward and undefined grad] |
363 | if (!flat_output.defined()) { |
364 | torch::jit::push(stack, flat_output); |
365 | continue; |
366 | } |
367 | VmapDimVector output_sizes(batch_sizes); |
368 | output_sizes.insert( |
369 | output_sizes.end(), |
370 | flat_output.sizes().begin() + 1, |
371 | flat_output.sizes().end()); |
372 | torch::jit::push( |
373 | stack, |
374 | input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes))); |
375 | } |
376 | } |
377 | |
378 | } // namespace at |
379 | |