Skip to content

Commit c879e2a

Browse files
authored
Enable analyzing nested input- and output-dicts (#212)
* enable analyzing nested input- and output dicts * enable analyzing nested input- and output dicts * skip tests that require torch v1.8 or above when an older version is installed * add test for highly nested dicts, fix error found by it - `elem_bytes` in `LayerInfo.calculate_size(...)` didn't work for nested dicts * `LayerInfo.calculate_size.extract_tensor` now works with `dict` properly - adapted highly_nested_dict_model.out accordingly * simplified `test_highly_nested_dict_model` * `LayerInfo.calculate_size.extract_tensor` now works properly for objects with `tensor`-attribute - Found error in new testcase that comes with this commit * Add docstring to test to explain what exactly it tests * test all edge-cases of `LayerInfo.calculate_size.extract_tensor` * use `dim=0` in `F.softmax` explicitely (implicit use depreciated) * replace custom `torchversion_at_least` with `packaging.version.parse` * modify `EdgecaseInputOutputModel` to increase test-coverage missing: - not hasattr(inputs, "__getitem__") - last return * use torch_nested-package to simplify `LayerInfo.calculate_size` - torch_nested has 99.something% test-coverage - Makes test-coverage for this package much easier - Increases readability & extensibility * Move back from using torch-nested. Fix and use `nested_list_size` instead - Fixes issue#141 - Increases test-coverage - Produces more plausible output for some cases * Fix problem with accessing of dicts Fix [issue#214](#215) * Install compressai in workflows
1 parent 5f3e78b commit c879e2a

12 files changed

Lines changed: 335 additions & 29 deletions

.github/workflows/test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ jobs:
5656
python -m pip install --upgrade pip
5757
python -m pip install mypy pytest pytest-cov
5858
pip install torch==${{ matrix.pytorch-version }} torchvision
59+
pip install transformers
60+
pip install compressai
5961
- name: mypy
6062
if: ${{ matrix.pytorch-version == '1.13' }}
6163
run: |

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ pylint
88
pytest
99
pytest-cov
1010
pre-commit
11+
transformers
12+
compressai

tests/fixtures/models.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import math
55
from collections import namedtuple
6-
from typing import Any, cast
6+
from typing import Any, Sequence, cast
77

88
import torch
99
from torch import nn
@@ -323,6 +323,64 @@ def forward(
323323
return x
324324

325325

326+
class ObjectWithTensors:
327+
"""A class with a 'tensors'-attribute."""
328+
329+
def __init__(self, tensors: torch.Tensor | Sequence[Any]) -> None:
330+
self.tensors = tensors
331+
332+
333+
class HighlyNestedDictModel(nn.Module):
334+
"""Model that returns a highly nested dict."""
335+
336+
def __init__(self) -> None:
337+
super().__init__()
338+
self.lin1 = nn.Linear(10, 10)
339+
self.lin2 = nn.Linear(10, 10)
340+
341+
def forward(
342+
self, x: torch.Tensor
343+
) -> dict[str, tuple[dict[str, list[ObjectWithTensors]]]]:
344+
x = self.lin1(x)
345+
x = self.lin2(x)
346+
x = F.softmax(x, dim=0)
347+
return {"foo": ({"bar": [ObjectWithTensors(x)]},)}
348+
349+
350+
class IntWithGetitem(int):
351+
"""An int with a __getitem__ method."""
352+
353+
def __init__(self, tensor: torch.Tensor) -> None:
354+
super().__init__()
355+
self.tensor = tensor
356+
357+
def __int__(self) -> IntWithGetitem:
358+
return self
359+
360+
def __getitem__(self, val: int) -> torch.Tensor:
361+
return self.tensor * val
362+
363+
364+
class EdgecaseInputOutputModel(nn.Module):
365+
"""
366+
For testing LayerInfo.calculate_size.extract_tensor:
367+
368+
case hasattr(inputs, "__getitem__") but not
369+
isinstance(inputs, (list, tuple, dict)).
370+
371+
case not inputs.
372+
"""
373+
374+
def __init__(self) -> None:
375+
super().__init__()
376+
self.linear = nn.Linear(3, 1)
377+
378+
def forward(self, input_list: dict[str, torch.Tensor]) -> dict[str, IntWithGetitem]:
379+
x = input_list["foo"] if input_list else torch.ones(3)
380+
x = self.linear(x)
381+
return {"foo": IntWithGetitem(x)}
382+
383+
326384
class NamedTuple(nn.Module):
327385
"""Model that takes in a NamedTuple as input."""
328386

tests/test_output/bert.out

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
====================================================================================================
2+
Layer (type:depth-idx) Output Shape Param #
3+
====================================================================================================
4+
BertModel [2, 768] --
5+
├─BertEmbeddings: 1-1 [2, 512, 768] --
6+
│ └─Embedding: 2-1 [2, 512, 768] 23,440,896
7+
│ └─Embedding: 2-2 [2, 512, 768] 1,536
8+
│ └─Embedding: 2-3 [1, 512, 768] 393,216
9+
│ └─LayerNorm: 2-4 [2, 512, 768] 1,536
10+
│ └─Dropout: 2-5 [2, 512, 768] --
11+
├─BertEncoder: 1-2 [2, 512, 768] --
12+
│ └─ModuleList: 2-6 -- --
13+
│ │ └─BertLayer: 3-1 [2, 512, 768] 7,087,872
14+
│ │ └─BertLayer: 3-2 [2, 512, 768] 7,087,872
15+
│ │ └─BertLayer: 3-3 [2, 512, 768] 7,087,872
16+
│ │ └─BertLayer: 3-4 [2, 512, 768] 7,087,872
17+
│ │ └─BertLayer: 3-5 [2, 512, 768] 7,087,872
18+
│ │ └─BertLayer: 3-6 [2, 512, 768] 7,087,872
19+
│ │ └─BertLayer: 3-7 [2, 512, 768] 7,087,872
20+
│ │ └─BertLayer: 3-8 [2, 512, 768] 7,087,872
21+
│ │ └─BertLayer: 3-9 [2, 512, 768] 7,087,872
22+
│ │ └─BertLayer: 3-10 [2, 512, 768] 7,087,872
23+
│ │ └─BertLayer: 3-11 [2, 512, 768] 7,087,872
24+
│ │ └─BertLayer: 3-12 [2, 512, 768] 7,087,872
25+
├─BertPooler: 1-3 [2, 768] --
26+
│ └─Linear: 2-7 [2, 768] 590,592
27+
│ └─Tanh: 2-8 [2, 768] --
28+
====================================================================================================
29+
Total params: 109,482,240
30+
Trainable params: 109,482,240
31+
Non-trainable params: 0
32+
Total mult-adds (M): 218.57
33+
====================================================================================================
34+
Input size (MB): 0.01
35+
Forward/backward pass size (MB): 852.50
36+
Params size (MB): 437.93
37+
Estimated Total Size (MB): 1290.45
38+
====================================================================================================

tests/test_output/compressai.out

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
===============================================================================================
2+
Layer (type:depth-idx) Output Shape Param #
3+
===============================================================================================
4+
FactorizedPrior [1, 192, 16, 16] --
5+
├─Sequential: 1-1 [1, 192, 16, 16] --
6+
│ └─Conv2d: 2-1 [1, 128, 128, 128] 9,728
7+
│ └─GDN: 2-2 [1, 128, 128, 128] 16,512
8+
│ │ └─NonNegativeParametrizer: 3-1 [128] --
9+
│ │ └─NonNegativeParametrizer: 3-2 [128, 128] --
10+
│ └─Conv2d: 2-3 [1, 128, 64, 64] 409,728
11+
│ └─GDN: 2-4 [1, 128, 64, 64] 16,512
12+
│ │ └─NonNegativeParametrizer: 3-3 [128] --
13+
│ │ └─NonNegativeParametrizer: 3-4 [128, 128] --
14+
│ └─Conv2d: 2-5 [1, 128, 32, 32] 409,728
15+
│ └─GDN: 2-6 [1, 128, 32, 32] 16,512
16+
│ │ └─NonNegativeParametrizer: 3-5 [128] --
17+
│ │ └─NonNegativeParametrizer: 3-6 [128, 128] --
18+
│ └─Conv2d: 2-7 [1, 192, 16, 16] 614,592
19+
├─EntropyBottleneck: 1-2 [1, 192, 16, 16] 11,712
20+
│ └─LowerBound: 2-8 [192, 1, 256] --
21+
├─Sequential: 1-3 [1, 3, 256, 256] --
22+
│ └─ConvTranspose2d: 2-9 [1, 128, 32, 32] 614,528
23+
│ └─GDN: 2-10 [1, 128, 32, 32] 16,512
24+
│ │ └─NonNegativeParametrizer: 3-7 [128] --
25+
│ │ └─NonNegativeParametrizer: 3-8 [128, 128] --
26+
│ └─ConvTranspose2d: 2-11 [1, 128, 64, 64] 409,728
27+
│ └─GDN: 2-12 [1, 128, 64, 64] 16,512
28+
│ │ └─NonNegativeParametrizer: 3-9 [128] --
29+
│ │ └─NonNegativeParametrizer: 3-10 [128, 128] --
30+
│ └─ConvTranspose2d: 2-13 [1, 128, 128, 128] 409,728
31+
│ └─GDN: 2-14 [1, 128, 128, 128] 16,512
32+
│ │ └─NonNegativeParametrizer: 3-11 [128] --
33+
│ │ └─NonNegativeParametrizer: 3-12 [128, 128] --
34+
│ └─ConvTranspose2d: 2-15 [1, 3, 256, 256] 9,603
35+
===============================================================================================
36+
Total params: 2,998,147
37+
Trainable params: 2,998,147
38+
Non-trainable params: 0
39+
Total mult-adds (G): 12.06
40+
===============================================================================================
41+
Input size (MB): 0.79
42+
Forward/backward pass size (MB): 46.01
43+
Params size (MB): 11.55
44+
Estimated Total Size (MB): 58.34
45+
===============================================================================================
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
==========================================================================================
2+
Layer (type:depth-idx) Output Shape Param #
3+
==========================================================================================
4+
EdgecaseInputOutputModel -- --
5+
├─Linear: 1-1 [1] 4
6+
==========================================================================================
7+
Total params: 4
8+
Trainable params: 4
9+
Non-trainable params: 0
10+
Total mult-adds (M): 0.00
11+
==========================================================================================
12+
Input size (MB): 0.00
13+
Forward/backward pass size (MB): 0.00
14+
Params size (MB): 0.00
15+
Estimated Total Size (MB): 0.00
16+
==========================================================================================
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
==============================================================================================================
2+
Layer (type:depth-idx) Output Shape Param #
3+
==============================================================================================================
4+
T5ForConditionalGeneration [2, 100, 512] --
5+
├─T5Stack: 1-1 [2, 100, 512] 35,332,800
6+
├─T5Stack: 1-2 -- (recursive)
7+
│ └─Embedding: 2-1 [2, 100, 512] 16,449,536
8+
├─T5Stack: 1-3 -- (recursive)
9+
│ └─Dropout: 2-2 [2, 100, 512] --
10+
│ └─ModuleList: 2-3 -- --
11+
│ │ └─T5Block: 3-1 [2, 100, 512] 2,360,512
12+
│ │ └─T5Block: 3-2 [2, 100, 512] 2,360,320
13+
│ │ └─T5Block: 3-3 [2, 100, 512] 2,360,320
14+
│ │ └─T5Block: 3-4 [2, 100, 512] 2,360,320
15+
│ │ └─T5Block: 3-5 [2, 100, 512] 2,360,320
16+
│ │ └─T5Block: 3-6 [2, 100, 512] 2,360,320
17+
│ │ └─T5Block: 3-7 [2, 100, 512] 2,360,320
18+
│ │ └─T5Block: 3-8 [2, 100, 512] 2,360,320
19+
│ └─T5LayerNorm: 2-4 [2, 100, 512] 512
20+
│ └─Dropout: 2-5 [2, 100, 512] --
21+
├─T5Stack: 1-4 [2, 6, 100, 64] 16,449,536
22+
│ └─Embedding: 2-6 [2, 100, 512] (recursive)
23+
│ └─Dropout: 2-7 [2, 100, 512] --
24+
│ └─ModuleList: 2-8 -- --
25+
│ │ └─T5Block: 3-9 [2, 100, 512] 3,147,456
26+
│ │ └─T5Block: 3-10 [2, 100, 512] 3,147,264
27+
│ │ └─T5Block: 3-11 [2, 100, 512] 3,147,264
28+
│ │ └─T5Block: 3-12 [2, 100, 512] 3,147,264
29+
│ │ └─T5Block: 3-13 [2, 100, 512] 3,147,264
30+
│ │ └─T5Block: 3-14 [2, 100, 512] 3,147,264
31+
│ │ └─T5Block: 3-15 [2, 100, 512] 3,147,264
32+
│ │ └─T5Block: 3-16 [2, 100, 512] 3,147,264
33+
│ └─T5LayerNorm: 2-9 [2, 100, 512] 512
34+
│ └─Dropout: 2-10 [2, 100, 512] --
35+
├─Linear: 1-5 [2, 100, 32128] 16,449,536
36+
==============================================================================================================
37+
Total params: 128,743,488
38+
Trainable params: 128,743,488
39+
Non-trainable params: 0
40+
Total mult-adds (M): 186.86
41+
==============================================================================================================
42+
Input size (MB): 0.00
43+
Forward/backward pass size (MB): 217.84
44+
Params size (MB): 307.84
45+
Estimated Total Size (MB): 525.69
46+
==============================================================================================================
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
==========================================================================================
2+
Layer (type:depth-idx) Output Shape Param #
3+
==========================================================================================
4+
HighlyNestedDictModel [10] --
5+
├─Linear: 1-1 [10] 110
6+
├─Linear: 1-2 [10] 110
7+
==========================================================================================
8+
Total params: 220
9+
Trainable params: 220
10+
Non-trainable params: 0
11+
Total mult-adds (M): 0.00
12+
==========================================================================================
13+
Input size (MB): 0.00
14+
Forward/backward pass size (MB): 0.00
15+
Params size (MB): 0.00
16+
Estimated Total Size (MB): 0.00
17+
==========================================================================================

tests/torchinfo_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
ConvLayerB,
1414
CustomParameter,
1515
DictParameter,
16+
EdgecaseInputOutputModel,
1617
EmptyModule,
1718
FakePrunedLayerModel,
19+
HighlyNestedDictModel,
1820
InsideModel,
1921
LinearModel,
2022
LSTMNet,
@@ -344,6 +346,26 @@ def test_module_dict() -> None:
344346
)
345347

346348

349+
def test_highly_nested_dict_model() -> None:
350+
"""
351+
Test the following three if-clauses
352+
from LayerInfo.calculate_size.extract_tensor: 1, 2, 4, 5
353+
(starts counting from 1)
354+
"""
355+
model = HighlyNestedDictModel()
356+
summary(model, input_data=torch.ones(10))
357+
358+
359+
def test_edgecase_input_output_model() -> None:
360+
"""
361+
Test the following two if-clauses
362+
from LayerInfo.calculate_size.extract_tensor: 3
363+
(starts counting from 1) as well as the final return.
364+
"""
365+
model = EdgecaseInputOutputModel()
366+
summary(model, input_data=[{}])
367+
368+
347369
def test_model_with_args() -> None:
348370
summary(RecursiveNet(), input_size=(1, 64, 28, 28), args1="args1", args2="args2")
349371

tests/torchinfo_xl_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import pytest
22
import torch
33
import torchvision # type: ignore[import]
4+
from compressai.zoo import image_models # type: ignore[import]
5+
from packaging import version
6+
from transformers import ( # type: ignore[import]
7+
AutoModelForSeq2SeqLM,
8+
BertConfig,
9+
BertModel,
10+
)
411

512
from tests.fixtures.genotype import GenotypeNetwork # type: ignore[attr-defined]
613
from tests.fixtures.tmva_net import TMVANet # type: ignore[attr-defined]
@@ -143,3 +150,40 @@ def test_google() -> None:
143150
# Check googlenet in training mode since InceptionAux layers are used in
144151
# forward-prop in train mode but not in eval mode.
145152
summary(google_net, (1, 3, 112, 112), depth=7, mode="train")
153+
154+
155+
@pytest.mark.skipif(
156+
version.parse(torch.__version__) < version.parse("1.8"),
157+
reason="FlanT5Small only works for PyTorch v1.8 and above",
158+
)
159+
def test_flan_t5_small() -> None:
160+
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
161+
inputs = {
162+
"input_ids": torch.zeros(2, 100).long(),
163+
"attention_mask": torch.zeros(2, 100).long(),
164+
"labels": torch.zeros(2, 100).long(),
165+
}
166+
summary(model, input_data=inputs)
167+
168+
169+
@pytest.mark.skipif(
170+
version.parse(torch.__version__) < version.parse("1.8"),
171+
reason="BertModel only works for PyTorch v1.8 and above",
172+
)
173+
def test_bert() -> None:
174+
model = BertModel(BertConfig())
175+
summary(
176+
model,
177+
input_size=[(2, 512), (2, 512), (2, 512)],
178+
dtypes=[torch.int, torch.int, torch.int],
179+
device="cpu",
180+
)
181+
182+
183+
@pytest.mark.skipif(
184+
version.parse(torch.__version__) < version.parse("1.8"),
185+
reason="compressai only works for PyTorch v1.8 and above",
186+
)
187+
def test_compressai() -> None:
188+
model = image_models["bmshj2018-factorized"](quality=4, pretrained=True)
189+
summary(model, (1, 3, 256, 256))

0 commit comments

Comments
 (0)