1 | #include <torch/csrc/autograd/function.h> |
2 | #include <torch/csrc/profiler/kineto_shim.h> |
3 | #include <torch/csrc/profiler/util.h> |
4 | |
5 | #include <c10/util/ArrayRef.h> |
6 | #include <c10/util/irange.h> |
7 | #include <fmt/format.h> |
8 | |
9 | #ifdef USE_KINETO |
10 | #include <libkineto.h> |
11 | #endif |
12 | |
13 | namespace torch { |
14 | namespace profiler { |
15 | namespace impl { |
16 | |
17 | ApproximateClockToUnixTimeConverter::ApproximateClockToUnixTimeConverter() |
18 | : start_times_(measurePairs()) {} |
19 | |
20 | ApproximateClockToUnixTimeConverter::UnixAndApproximateTimePair |
21 | ApproximateClockToUnixTimeConverter::measurePair() { |
22 | // Take a measurement on either side to avoid an ordering bias. |
23 | auto fast_0 = getApproximateTime(); |
24 | auto wall = std::chrono::system_clock::now(); |
25 | auto fast_1 = getApproximateTime(); |
26 | |
27 | TORCH_INTERNAL_ASSERT(fast_1 >= fast_0, "getCount is non-monotonic." ); |
28 | auto t = std::chrono::duration_cast<std::chrono::nanoseconds>( |
29 | wall.time_since_epoch()); |
30 | |
31 | // `x + (y - x) / 2` is a more numerically stable average than `(x + y) / 2`. |
32 | return {t.count(), fast_0 + (fast_1 - fast_0) / 2}; |
33 | } |
34 | |
35 | ApproximateClockToUnixTimeConverter::time_pairs |
36 | ApproximateClockToUnixTimeConverter::measurePairs() { |
37 | static constexpr auto n_warmup = 5; |
38 | for (C10_UNUSED const auto _ : c10::irange(n_warmup)) { |
39 | getApproximateTime(); |
40 | steady_clock_t::now(); |
41 | } |
42 | |
43 | time_pairs out; |
44 | for (const auto i : c10::irange(out.size())) { |
45 | out[i] = measurePair(); |
46 | } |
47 | return out; |
48 | } |
49 | |
50 | std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter:: |
51 | makeConverter() { |
52 | auto end_times = measurePairs(); |
53 | |
54 | // Compute the real time that passes for each tick of the approximate clock. |
55 | std::array<long double, replicates> scale_factors{}; |
56 | for (const auto i : c10::irange(replicates)) { |
57 | auto delta_ns = end_times[i].t_ - start_times_[i].t_; |
58 | auto delta_approx = end_times[i].approx_t_ - start_times_[i].approx_t_; |
59 | scale_factors[i] = (double)delta_ns / (double)delta_approx; |
60 | } |
61 | std::sort(scale_factors.begin(), scale_factors.end()); |
62 | long double scale_factor = scale_factors[replicates / 2 + 1]; |
63 | |
64 | // We shift all times by `t0` for better numerics. Double precision only has |
65 | // 16 decimal digits of accuracy, so if we blindly multiply times by |
66 | // `scale_factor` we may suffer from precision loss. The choice of `t0` is |
67 | // mostly arbitrary; we just need a factor that is the correct order of |
68 | // magnitude to bring the intermediate values closer to zero. We are not, |
69 | // however, guaranteed that `t0_approx` is *exactly* the getApproximateTime |
70 | // equivilent of `t0`; it is only an estimate that we have to fine tune. |
71 | auto t0 = start_times_[0].t_; |
72 | auto t0_approx = start_times_[0].approx_t_; |
73 | std::array<double, replicates> t0_correction{}; |
74 | for (const auto i : c10::irange(replicates)) { |
75 | auto dt = start_times_[i].t_ - t0; |
76 | auto dt_approx = |
77 | (double)(start_times_[i].approx_t_ - t0_approx) * scale_factor; |
78 | t0_correction[i] = dt - (time_t)dt_approx; |
79 | } |
80 | t0 += t0_correction[t0_correction.size() / 2 + 1]; |
81 | |
82 | return [=](approx_time_t t_approx) { |
83 | // See above for why this is more stable than `A * t_approx + B`. |
84 | return (time_t)((double)(t_approx - t0_approx) * scale_factor) + t0; |
85 | }; |
86 | } |
87 | |
88 | namespace { |
89 | c10::optional<bool> soft_assert_raises_; |
90 | } // namespace |
91 | |
92 | void setSoftAssertRaises(c10::optional<bool> value) { |
93 | soft_assert_raises_ = value; |
94 | } |
95 | |
96 | bool softAssertRaises() { |
97 | return soft_assert_raises_.value_or(false); |
98 | } |
99 | |
100 | void logSoftAssert( |
101 | const char* func, |
102 | const char* file, |
103 | uint32_t line, |
104 | const char* cond, |
105 | const char* args) { |
106 | #ifdef USE_KINETO |
107 | std::string error; |
108 | error = fmt::format( |
109 | "{} SOFT ASSERT FAILED at {}:{}, func: {}, args: {}" , |
110 | cond, |
111 | file, |
112 | line, |
113 | func, |
114 | args); |
115 | // TODO: Implement profile_id and group_profile_id as 3rd/4th arguments. |
116 | kineto::logInvariantViolation(cond, error, "" , "" ); |
117 | #endif |
118 | } |
119 | |
120 | void logSoftAssert( |
121 | const char* func, |
122 | const char* file, |
123 | uint32_t line, |
124 | const char* cond, |
125 | const std::string& args) { |
126 | #ifdef USE_KINETO |
127 | std::string error; |
128 | error = fmt::format( |
129 | "{} SOFT ASSERT FAILED at {}:{}, func: {}, args: {}" , |
130 | cond, |
131 | file, |
132 | line, |
133 | func, |
134 | args); |
135 | // TODO: Implement profile_id and group_profile_id as 3rd/4th arguments. |
136 | kineto::logInvariantViolation(cond, error, "" , "" ); |
137 | #endif |
138 | } |
139 | |
140 | // ---------------------------------------------------------------------------- |
141 | // -- NVTX -------------------------------------------------------------------- |
142 | // ---------------------------------------------------------------------------- |
143 | std::string getNvtxStr( |
144 | const char* name, |
145 | int64_t sequence_nr, |
146 | const std::vector<std::vector<int64_t>>& shapes, |
147 | at::RecordFunctionHandle op_id, |
148 | const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids) { |
149 | if (sequence_nr >= -1 || !shapes.empty()) { |
150 | std::string str; |
151 | if (sequence_nr >= 0) { |
152 | str = fmt::format("{}, seq = {}" , name, sequence_nr); |
153 | } else if (sequence_nr == -1) { |
154 | str = name; |
155 | } else { |
156 | #if defined(USE_ROCM) |
157 | // Only ROCM supports < -1 sequence_nr |
158 | str = name; |
159 | #endif |
160 | } |
161 | if (op_id > 0) { |
162 | str = fmt::format("{}, op_id = {}" , str, op_id); |
163 | } |
164 | if (!shapes.empty()) { |
165 | str = fmt::format("{}, sizes = {}" , str, shapesToStr(shapes)); |
166 | } |
167 | // Include the op ids of the input edges so |
168 | // you can build the network graph |
169 | if (!input_op_ids.empty()) { |
170 | str = fmt::format( |
171 | "{}, input_op_ids = {}" , str, inputOpIdsToStr(input_op_ids)); |
172 | } |
173 | return str; |
174 | } else { |
175 | return name; |
176 | } |
177 | } |
178 | |
179 | // ---------------------------------------------------------------------------- |
180 | // -- Op context (shapes, call stack) ----------------------------------------- |
181 | // ---------------------------------------------------------------------------- |
182 | std::vector<FileLineFunc> prepareCallstack( |
183 | const std::vector<jit::StackEntry>& cs) { |
184 | std::vector<FileLineFunc> entries; |
185 | entries.reserve(cs.size()); |
186 | for (const auto& entry : cs) { |
187 | auto& range = entry.range; |
188 | if (range.source()) { |
189 | auto& src = range.source(); |
190 | if (src && src->filename()) { |
191 | auto line = |
192 | src->starting_line_no() + src->lineno_for_offset(range.start()); |
193 | entries.emplace_back( |
194 | FileLineFunc{*(src->filename()), line, entry.filename}); |
195 | } |
196 | } |
197 | } |
198 | return entries; |
199 | } |
200 | |
201 | std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs) { |
202 | std::vector<std::string> cs_str; |
203 | cs_str.reserve(cs.size()); |
204 | for (const auto& entry : cs) { |
205 | std::stringstream loc; |
206 | loc << entry.filename << "(" << entry.line << "): " << entry.funcname; |
207 | cs_str.push_back(loc.str()); |
208 | } |
209 | return cs_str; |
210 | } |
211 | |
212 | std::string stacksToStr( |
213 | const std::vector<std::string>& stacks, |
214 | const char* delim) { |
215 | std::ostringstream oss; |
216 | std::transform( |
217 | stacks.begin(), |
218 | stacks.end(), |
219 | std::ostream_iterator<std::string>(oss, delim), |
220 | [](std::string s) -> std::string { |
221 | #ifdef _WIN32 |
222 | // replace the windows backslash with forward slash |
223 | std::replace(s.begin(), s.end(), '\\', '/'); |
224 | #endif |
225 | return s; |
226 | }); |
227 | auto rc = oss.str(); |
228 | return "\"" + rc + "\"" ; |
229 | } |
230 | |
231 | std::vector<std::vector<int64_t>> flattenList( |
232 | c10::List<c10::IValue> list, |
233 | std::string fn_name) { |
234 | std::vector<std::vector<int64_t>> tensor_dims; |
235 | for (const c10::IValue& input : list) { |
236 | if (input.isTensor()) { |
237 | const at::Tensor& tensor = input.toTensor(); |
238 | if (tensor.defined()) { |
239 | tensor_dims.push_back(input.toTensor().sizes().vec()); |
240 | } |
241 | } |
242 | } |
243 | return tensor_dims; |
244 | } |
245 | |
246 | std::vector<std::vector<int64_t>> inputSizes( |
247 | const at::RecordFunction& fn, |
248 | bool flatten_list_enabled) { |
249 | std::vector<std::vector<int64_t>> sizes; |
250 | sizes.reserve(fn.inputs().size()); |
251 | for (const c10::IValue& input : fn.inputs()) { |
252 | if (input.isTensor()) { |
253 | const at::Tensor& tensor = input.toTensor(); |
254 | if (tensor.defined()) { |
255 | sizes.push_back(input.toTensor().sizes().vec()); |
256 | } else { |
257 | sizes.emplace_back(); |
258 | } |
259 | } else if (input.isList()) { |
260 | std::vector<std::vector<int64_t>> tmp_sizes; |
261 | if (flatten_list_enabled) { |
262 | tmp_sizes = flattenList(input.toList(), std::string(fn.name())); |
263 | } |
264 | // Extend the current sizes array by the array returned from input sizes |
265 | if (!tmp_sizes.empty()) { |
266 | sizes.insert(sizes.end(), tmp_sizes.begin(), tmp_sizes.end()); |
267 | } else { |
268 | sizes.emplace_back(); |
269 | } |
270 | } else { |
271 | sizes.emplace_back(); |
272 | } |
273 | } |
274 | return sizes; |
275 | } |
276 | |
277 | std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes) { |
278 | std::string str("[" ); |
279 | for (const auto t_idx : c10::irange(shapes.size())) { |
280 | if (t_idx > 0) { |
281 | str = fmt::format("{}, " , str); |
282 | } |
283 | str = fmt::format("{}[" , str); |
284 | for (const auto s_idx : c10::irange(shapes[t_idx].size())) { |
285 | if (s_idx > 0) { |
286 | str = fmt::format("{}, " , str); |
287 | } |
288 | str = fmt::format("{}{}" , str, shapes[t_idx][s_idx]); |
289 | } |
290 | str = fmt::format("{}]" , str); |
291 | } |
292 | str = fmt::format("{}]" , str); |
293 | return str; |
294 | } |
295 | |
296 | std::string inputOpIdsToStr( |
297 | const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids) { |
298 | std::string str("[" ); |
299 | int idx = 0; |
300 | |
301 | for (const auto& op_id_info_pair : input_op_ids) { |
302 | if (idx++ > 0) { |
303 | str = fmt::format("{}, " , str); |
304 | } |
305 | // (OpId,OutputNr) |
306 | str = fmt::format( |
307 | "{}({},{})" , str, op_id_info_pair.first, op_id_info_pair.second); |
308 | } |
309 | str = fmt::format("{}]" , str); |
310 | return str; |
311 | } |
312 | |
313 | std::string dtypesToStr(const std::vector<std::string>& types) { |
314 | if (types.empty()) { |
315 | return "[]" ; |
316 | } else { |
317 | std::ostringstream oss; |
318 | std::transform( |
319 | types.begin(), |
320 | types.end(), |
321 | std::ostream_iterator<std::string>(oss, ", " ), |
322 | [](std::string s) -> std::string { return "\"" + s + "\"" ; }); |
323 | auto rc = oss.str(); |
324 | rc.erase(rc.length() - 2); // remove last ", " |
325 | return "[" + rc + "]" ; |
326 | } |
327 | } |
328 | |
329 | std::vector<std::string> inputTypes(const at::RecordFunction& fn) { |
330 | std::vector<std::string> types; |
331 | types.reserve(fn.inputs().size()); |
332 | for (const c10::IValue& input : fn.inputs()) { |
333 | if (input.isTensor()) { |
334 | const at::Tensor& tensor = input.toTensor(); |
335 | if (tensor.defined()) { |
336 | types.push_back( |
337 | static_cast<std::string>(input.toTensor().dtype().name())); |
338 | } else { |
339 | types.emplace_back(); |
340 | } |
341 | } else if (input.isScalar() || input.isList()) { |
342 | types.push_back(input.tagKind()); |
343 | } else { |
344 | types.emplace_back(); |
345 | } |
346 | } |
347 | return types; |
348 | } |
349 | |
350 | // ---------------------------------------------------------------------------- |
351 | // -- FLOPS ------------------------------------------------------------------- |
352 | // ---------------------------------------------------------------------------- |
353 | static constexpr auto kConv2dStride = 3; |
354 | static constexpr auto kConv2dPadding = 4; |
355 | static constexpr auto kConv2dDilation = 5; |
356 | static constexpr auto kConv2dGroups = 6; |
357 | |
358 | // List of supported operators |
359 | static constexpr auto kConv2dOp = "aten::conv2d" ; |
360 | static constexpr auto kMMOp = "aten::mm" ; |
361 | static constexpr auto kAddMMOp = "aten::addmm" ; |
362 | static constexpr auto kMulOp = "aten::mul" ; |
363 | static constexpr auto kAddOp = "aten::add" ; |
364 | static constexpr auto kBMMOp = "aten::bmm" ; |
365 | static constexpr auto kBAddBMMOp = "aten::baddbmm" ; |
366 | |
367 | static constexpr auto kInputSize = "input_size" ; |
368 | static constexpr auto kWeightSize = "weight_size" ; |
369 | static constexpr auto kGroups = "groups" ; |
370 | static constexpr auto kPadding = "padding" ; |
371 | static constexpr auto kStride = "stride" ; |
372 | static constexpr auto kDilation = "dilation" ; |
373 | static constexpr auto kMatSize = "mat_size" ; |
374 | static constexpr auto kMat1Size = "mat1_size" ; |
375 | static constexpr auto kMat2Size = "mat2_size" ; |
376 | |
377 | static bool validateInput( |
378 | const std::string& op_name, |
379 | size_t min_size, |
380 | c10::ArrayRef<const c10::IValue> inputs, |
381 | const c10::ArrayRef<int>& should_be_tensor) { |
382 | std::stringstream ss; |
383 | if (inputs.size() < min_size) { |
384 | ss << "Failed to save extra arguments for flops computation of op " |
385 | << op_name << ", min size: " << min_size |
386 | << ", actual size: " << inputs.size(); |
387 | TORCH_WARN(ss.str()); |
388 | return false; |
389 | } |
390 | for (auto index : should_be_tensor) { |
391 | if (!inputs[index].isTensor()) { |
392 | ss << "Failed to save extra arguments for flops computation of op " |
393 | << op_name << ", input[" << index << "] must be a tensor." ; |
394 | TORCH_WARN(ss.str()); |
395 | return false; |
396 | } |
397 | } |
398 | return true; |
399 | } |
400 | |
401 | std::unordered_map<std::string, c10::IValue> ( |
402 | const at::RecordFunction& fn) { |
403 | // for specific types of fn, return the saved extra args for computing flops |
404 | std::unordered_map<std::string, c10::IValue> map; |
405 | auto inputs = fn.inputs(); |
406 | std::string fname(fn.name()); |
407 | |
408 | if (inputs.empty()) { |
409 | // Input shape is unavailable, return empty map |
410 | return map; |
411 | } |
412 | |
413 | if (fname == kConv2dOp) { |
414 | bool check = validateInput(fname, kConv2dGroups + 1, inputs, {0, 1}); |
415 | if (!check) { |
416 | return map; |
417 | } |
418 | |
419 | at::Tensor input = inputs[0].toTensor(); |
420 | at::Tensor weight = inputs[1].toTensor(); |
421 | if (weight.sizes().size() != 4) { |
422 | TORCH_WARN( |
423 | "Failed to compute flops for op aten::conv2d because it requires a 4D kernel tensor." ); |
424 | return map; |
425 | } |
426 | map[kInputSize] = at::IValue(input.sizes()); |
427 | map[kWeightSize] = at::IValue(weight.sizes()); |
428 | map[kStride] = inputs[kConv2dStride]; |
429 | map[kPadding] = inputs[kConv2dPadding]; |
430 | map[kDilation] = inputs[kConv2dDilation]; |
431 | map[kGroups] = inputs[kConv2dGroups]; |
432 | } else if (fname == kMMOp) { |
433 | bool check = validateInput(fname, 2, inputs, {0, 1}); |
434 | if (!check) { |
435 | return map; |
436 | } |
437 | |
438 | at::Tensor left = inputs[0].toTensor(); |
439 | at::Tensor right = inputs[1].toTensor(); |
440 | map[kMat1Size] = at::IValue(left.sizes()); |
441 | map[kMat2Size] = at::IValue(right.sizes()); |
442 | } else if (fname == kAddMMOp) { |
443 | bool check = validateInput(fname, 3, inputs, {0, 1, 2}); |
444 | if (!check) { |
445 | return map; |
446 | } |
447 | |
448 | // Exact FLOP count depends on scaling factors alpha and beta but |
449 | // just assume these are +=1. |
450 | // (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf, |
451 | // "Operations Count for the BLAS and LAPACK", Table 3, SGEMM) |
452 | at::Tensor left = inputs[1].toTensor(); |
453 | at::Tensor right = inputs[2].toTensor(); |
454 | map[kMat1Size] = at::IValue(left.sizes()); |
455 | map[kMat2Size] = at::IValue(right.sizes()); |
456 | } else if (fname == kMulOp) { |
457 | bool check = validateInput(fname, 1, inputs, {0}); |
458 | if (!check) { |
459 | return map; |
460 | } |
461 | |
462 | at::Tensor mat = inputs[0].toTensor(); |
463 | map[kMatSize] = at::IValue(mat.sizes()); |
464 | } else if (fname == kAddOp) { |
465 | bool check = validateInput(fname, 1, inputs, {0}); |
466 | if (!check) { |
467 | return map; |
468 | } |
469 | |
470 | at::Tensor mat = inputs[0].toTensor(); |
471 | map[kMatSize] = at::IValue(mat.sizes()); |
472 | } else if (fname == kBMMOp) { |
473 | bool check = validateInput(fname, 2, inputs, {0, 1}); |
474 | if (!check) { |
475 | return map; |
476 | } |
477 | |
478 | at::Tensor left = inputs[0].toTensor(); |
479 | at::Tensor right = inputs[1].toTensor(); |
480 | map[kMat1Size] = at::IValue(left.sizes()); |
481 | map[kMat2Size] = at::IValue(right.sizes()); |
482 | } else if (fname == kBAddBMMOp) { |
483 | bool check = validateInput(fname, 3, inputs, {0, 1, 2}); |
484 | if (!check) { |
485 | return map; |
486 | } |
487 | |
488 | // Exact FLOP count depends on scaling factors alpha and beta but |
489 | // just assume these are +=1. |
490 | // (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf, |
491 | // "Operations Count for the BLAS and LAPACK", Table 3, SGEMM) |
492 | at::Tensor left = inputs[1].toTensor(); |
493 | at::Tensor right = inputs[2].toTensor(); |
494 | map[kMat1Size] = at::IValue(left.sizes()); |
495 | map[kMat2Size] = at::IValue(right.sizes()); |
496 | } |
497 | |
498 | return map; |
499 | } |
500 | |
501 | uint64_t computeFlops( |
502 | const std::string& op_name, |
503 | const std::unordered_map<std::string, c10::IValue>& ) { |
504 | if (op_name == kConv2dOp) { |
505 | if (extra_args.find(kInputSize) == extra_args.end() || |
506 | extra_args.find(kWeightSize) == extra_args.end() || |
507 | extra_args.find(kGroups) == extra_args.end() || |
508 | extra_args.find(kPadding) == extra_args.end() || |
509 | extra_args.find(kStride) == extra_args.end() || |
510 | extra_args.find(kDilation) == extra_args.end()) { |
511 | TORCH_WARN( |
512 | "Calculating flops for aten::conv2d requires groups, padding, stride, dilation, input_size, and weight_size in saved arguments." ); |
513 | return 0; |
514 | } |
515 | auto input_sizes_ref = extra_args.at(kInputSize); |
516 | auto kernel_sizes_ref = extra_args.at(kWeightSize); |
517 | auto groups_ref = extra_args.at(kGroups); |
518 | auto padding_ref = extra_args.at(kPadding); |
519 | auto stride_ref = extra_args.at(kStride); |
520 | auto dilation_ref = extra_args.at(kDilation); |
521 | if (!input_sizes_ref.isIntList() || !kernel_sizes_ref.isIntList()) { |
522 | TORCH_WARN( |
523 | "Failed to compute flops for op aten::conv2d because it requires input and weight tensor sizes." ); |
524 | return 0; |
525 | } |
526 | if (!padding_ref.isIntList() || !stride_ref.isIntList() || |
527 | !dilation_ref.isIntList()) { |
528 | TORCH_WARN( |
529 | "Failed to compute flops for op aten::conv2d because it requires padding, stride, and dilation values." ); |
530 | return 0; |
531 | } |
532 | |
533 | const auto input_sizes = input_sizes_ref.toDimVector(); |
534 | const auto kernel_sizes = kernel_sizes_ref.toDimVector(); |
535 | const uint64_t groups = groups_ref.toInt(); |
536 | const std::vector<int64_t> padding = padding_ref.toIntVector(); |
537 | const std::vector<int64_t> stride = stride_ref.toIntVector(); |
538 | const std::vector<int64_t> dilation = dilation_ref.toIntVector(); |
539 | if (input_sizes.size() != 4 || kernel_sizes.size() != 4) { |
540 | TORCH_WARN( |
541 | "Failed to compute flops for op aten::conv2d because both input and weight must be size 4." ); |
542 | return 0; |
543 | } |
544 | if (!groups) { |
545 | TORCH_WARN( |
546 | "Failed to compute flops for op aten::conv2d because group size must not be 0." ); |
547 | return 0; |
548 | } |
549 | if (padding.size() != 2 || dilation.size() != 2) { |
550 | TORCH_WARN( |
551 | "Failed to compute flops for op aten::conv2d because both padding and dilation must be size 2." ); |
552 | return 0; |
553 | } |
554 | if (stride.size() != 2 || (stride[0] * stride[1] == 0)) { |
555 | TORCH_WARN( |
556 | "Failed to compute flops for op aten::conv2d because stride must be size 2 and cannot be 0." ); |
557 | return 0; |
558 | } |
559 | // format of the input is defined in |
560 | // torch.ao.nn.quantized.functional.conv2d() |
561 | uint64_t minibatch = 0, in_channels = 0, input_h = 0, input_w = 0; |
562 | uint64_t out_channels = 0, kernel_h = 0, kernel_w = 0; |
563 | const uint64_t conv2d_multiply_factor = 2; |
564 | std::tie(minibatch, in_channels, input_h, input_w) = std::make_tuple( |
565 | input_sizes[0], input_sizes[1], input_sizes[2], input_sizes[3]); |
566 | std::tie(out_channels, std::ignore, kernel_h, kernel_w) = std::make_tuple( |
567 | kernel_sizes[0], kernel_sizes[1], kernel_sizes[2], kernel_sizes[3]); |
568 | uint64_t output_h = |
569 | (input_h + 2 * padding[0] - dilation[0] * (kernel_h - 1) - 1) / |
570 | stride[0] + |
571 | 1; |
572 | uint64_t output_w = |
573 | (input_w + 2 * padding[1] - dilation[1] * (kernel_w - 1) - 1) / |
574 | stride[1] + |
575 | 1; |
576 | |
577 | return conv2d_multiply_factor * minibatch * output_h * output_w * kernel_h * |
578 | kernel_w * in_channels * out_channels / groups; |
579 | } else if (op_name == kMMOp || op_name == kAddMMOp) { |
580 | if (extra_args.find(kMat1Size) == extra_args.end() || |
581 | extra_args.find(kMat2Size) == extra_args.end()) { |
582 | TORCH_WARN( |
583 | "Calculating flops for " , |
584 | op_name, |
585 | " requires mat1_size and mat2_size in saved arguments." ); |
586 | return 0; |
587 | } |
588 | auto mat1_sizes_ref = extra_args.at(kMat1Size); |
589 | auto mat2_sizes_ref = extra_args.at(kMat2Size); |
590 | if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) { |
591 | TORCH_WARN( |
592 | "Failed to compute flops for op " , |
593 | op_name, |
594 | " because it requires mat1_size and mat2_size to be IntList." ); |
595 | return 0; |
596 | } |
597 | |
598 | const auto mat1_size = mat1_sizes_ref.toDimVector(); |
599 | const auto mat2_size = mat2_sizes_ref.toDimVector(); |
600 | if (mat1_size.empty()) { |
601 | return 0; |
602 | } |
603 | |
604 | int64_t overlap_dim = mat1_size.back(); |
605 | if (overlap_dim == 0) { |
606 | return 0; |
607 | } |
608 | |
609 | const uint64_t gemm_multiply_factor = 2; |
610 | uint64_t flops = 1; |
611 | for (int64_t dim : mat1_size) { |
612 | flops *= dim; |
613 | } |
614 | flops /= overlap_dim; |
615 | for (int64_t dim : mat2_size) { |
616 | flops *= dim; |
617 | } |
618 | flops *= gemm_multiply_factor; |
619 | return flops; |
620 | } else if (op_name == kBMMOp || op_name == kBAddBMMOp) { |
621 | if (extra_args.find(kMat1Size) == extra_args.end() || |
622 | extra_args.find(kMat2Size) == extra_args.end()) { |
623 | TORCH_WARN( |
624 | "Calculating flops for " , |
625 | op_name, |
626 | " requires mat1_size and mat2_size in saved arguments." ); |
627 | return 0; |
628 | } |
629 | auto mat1_sizes_ref = extra_args.at(kMat1Size); |
630 | auto mat2_sizes_ref = extra_args.at(kMat2Size); |
631 | if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) { |
632 | TORCH_WARN( |
633 | "Failed to compute flops for op " , |
634 | op_name, |
635 | " because it requires mat1_size and mat2_size to be IntList." ); |
636 | return 0; |
637 | } |
638 | |
639 | const auto mat1_size = mat1_sizes_ref.toDimVector(); |
640 | const auto mat2_size = mat2_sizes_ref.toDimVector(); |
641 | if (mat1_size.empty()) { |
642 | return 0; |
643 | } |
644 | |
645 | int64_t batch_size = mat1_size.front(); |
646 | if (batch_size == 0) { |
647 | return 0; |
648 | } |
649 | |
650 | int64_t overlap_dim = mat1_size.back(); |
651 | if (overlap_dim == 0) { |
652 | return 0; |
653 | } |
654 | |
655 | const uint64_t gemm_multiply_factor = 2; |
656 | uint64_t flops = 1; |
657 | for (int64_t dim : mat1_size) { |
658 | flops *= dim; |
659 | } |
660 | flops /= overlap_dim; |
661 | flops /= batch_size; |
662 | for (int64_t dim : mat2_size) { |
663 | flops *= dim; |
664 | } |
665 | flops *= gemm_multiply_factor; |
666 | return flops; |
667 | } else if (op_name == kMulOp) { |
668 | if (extra_args.find(kMatSize) == extra_args.end()) { |
669 | TORCH_WARN( |
670 | "Calculating flops for aten::mul.Tensor requires mat_size in saved arguments." ); |
671 | return 0; |
672 | } |
673 | auto mat_sizes = extra_args.at(kMatSize); |
674 | if (!mat_sizes.isIntList()) { |
675 | TORCH_WARN( |
676 | "Failed to compute flops for op aten::mul because it requires mat_size to be IntList." ); |
677 | return 0; |
678 | } |
679 | |
680 | const auto mat_size = mat_sizes.toDimVector(); |
681 | uint64_t flops = 1; |
682 | for (int64_t dim : mat_size) { |
683 | flops *= dim; |
684 | } |
685 | return flops; |
686 | } else if (op_name == kAddOp) { |
687 | if (extra_args.find(kMatSize) == extra_args.end()) { |
688 | TORCH_WARN( |
689 | "Calculating flops for aten::add.Tensor requires mat_size in saved arguments." ); |
690 | return 0; |
691 | } |
692 | auto mat_sizes = extra_args.at(kMatSize); |
693 | if (!mat_sizes.isIntList()) { |
694 | TORCH_WARN( |
695 | "Failed to compute flops for op aten::add because it requires mat_size to be IntList." ); |
696 | return 0; |
697 | } |
698 | |
699 | const auto mat_size = mat_sizes.toDimVector(); |
700 | uint64_t flops = 1; |
701 | for (int64_t dim : mat_size) { |
702 | flops *= dim; |
703 | } |
704 | return flops; |
705 | } |
706 | return 0; |
707 | } |
708 | |
709 | } // namespace impl |
710 | } // namespace profiler |
711 | } // namespace torch |
712 | |