Skip to content

Commit 7f2bed3

Browse files
committed
Move to using tensor.untyped_storage() for Pytorch 2.0
1 parent d57620c commit 7f2bed3

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

torchinfo/torchinfo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,11 @@ def get_total_memory_used(data: CORRECTED_INPUT_DATA_TYPE) -> int:
500500
"""Calculates the total memory of all tensors stored in data."""
501501
result = traverse_input_data(
502502
data,
503-
action_fn=lambda data: sys.getsizeof(data.storage()),
503+
action_fn=lambda data: sys.getsizeof(
504+
data.untyped_storage()
505+
if hasattr(data, "untyped_storage")
506+
else data.storage()
507+
),
504508
aggregate_fn=(
505509
# We don't need the dictionary keys in this case
506510
lambda data: (lambda d: sum(d.values()))

0 commit comments

Comments
 (0)