We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d57620c commit 7f2bed3Copy full SHA for 7f2bed3
1 file changed
torchinfo/torchinfo.py
@@ -500,7 +500,11 @@ def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int:
500
"""Calculates the total memory of all tensors stored in data."""
501
result = traverse_input_data(
502
data,
503
- action_fn=lambda data: sys.getsizeof(data.storage()),
+ action_fn=lambda data: sys.getsizeof(
504
+ data.untyped_storage()
505
+ if hasattr(data, "untyped_storage")
506
+ else data.storage()
507
+ ),
508
aggregate_fn=(
509
# We don't need the dictionary keys in this case
510
lambda data: (lambda d: sum(d.values()))
0 commit comments