Skip to content

Commit f09be0a

Browse files
authored
Add HF hub mixin (#876)
* Bump version * Update gitignore * Update flake * Add hub mixin * Fix interpolation * Add from_pretrained * Update example
1 parent 3bf4d6e commit f09be0a

File tree

8 files changed

+339
-241
lines changed

8 files changed

+339
-241
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[flake8]
22
max-line-length = 119
33
exclude =.git,__pycache__,docs/conf.py,build,dist,setup.py,tests,.venv
4-
ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,W503,B006,D412
4+
ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,W503,B006,D412,F821,E501
55
inline-quotes = "

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ __pycache__/
44
*$py.class
55
.idea/
66
.venv*
7+
examples/images*
8+
examples/annotations*
79

810
# C extensions
911
*.so

examples/binary_segmentation_intro.ipynb

+183-237
Large diffs are not rendered by default.

segmentation_models_pytorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .decoders.pspnet import PSPNet
1313
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
1414
from .decoders.pan import PAN
15+
from .base.hub_mixin import from_pretrained
1516

1617
from .__version__ import __version__
1718

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
VERSION = (0, 3, 3)
1+
VERSION = (0, 3, "4dev0")
22

33
__version__ = ".".join(map(str, VERSION))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import json
2+
from pathlib import Path
3+
from typing import Optional, Union
4+
from functools import wraps
5+
from huggingface_hub import PyTorchModelHubMixin, ModelCard, ModelCardData, hf_hub_download
6+
7+
8+
MODEL_CARD = """
9+
---
10+
{{ card_data }}
11+
---
12+
# {{ model_name }} Model Card
13+
14+
Table of Contents:
15+
- [Load trained model](#load-trained-model)
16+
- [Model init parameters](#model-init-parameters)
17+
- [Model metrics](#model-metrics)
18+
- [Dataset](#dataset)
19+
20+
## Load trained model
21+
```python
22+
import segmentation_models_pytorch as smp
23+
24+
model = smp.{{ model_name }}.from_pretrained("{{ save_directory | default("<save-directory-or-repo>", true)}}")
25+
```
26+
27+
## Model init parameters
28+
```python
29+
model_init_params = {{ model_parameters }}
30+
```
31+
32+
## Model metrics
33+
{{ metrics | default("[More Information Needed]", true) }}
34+
35+
## Dataset
36+
Dataset name: {{ dataset | default("[More Information Needed]", true) }}
37+
38+
## More Information
39+
- Library: {{ repo_url | default("[More Information Needed]", true) }}
40+
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
41+
42+
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin)
43+
"""
44+
45+
46+
def _format_parameters(parameters: dict):
47+
params = {k: v for k, v in parameters.items() if not k.startswith("_")}
48+
params = [f'"{k}": {v}' if not isinstance(v, str) else f'"{k}": "{v}"' for k, v in params.items()]
49+
params = ",\n".join([f" {param}" for param in params])
50+
params = "{\n" + f"{params}" + "\n}"
51+
return params
52+
53+
54+
class SMPHubMixin(PyTorchModelHubMixin):
55+
def generate_model_card(self, *args, **kwargs) -> ModelCard:
56+
57+
model_parameters_json = _format_parameters(self._hub_mixin_config)
58+
directory = self._save_directory if hasattr(self, "_save_directory") else None
59+
repo_id = self._repo_id if hasattr(self, "_repo_id") else None
60+
repo_or_directory = repo_id if repo_id is not None else directory
61+
62+
metrics = self._metrics if hasattr(self, "_metrics") else None
63+
dataset = self._dataset if hasattr(self, "_dataset") else None
64+
65+
if metrics is not None:
66+
metrics = json.dumps(metrics, indent=4)
67+
metrics = f"```json\n{metrics}\n```"
68+
69+
model_card_data = ModelCardData(
70+
languages=["python"],
71+
library_name="segmentation-models-pytorch",
72+
license="mit",
73+
tags=["semantic-segmentation", "pytorch", "segmentation-models-pytorch"],
74+
pipeline_tag="image-segmentation",
75+
)
76+
model_card = ModelCard.from_template(
77+
card_data=model_card_data,
78+
template_str=MODEL_CARD,
79+
repo_url="https://github.com/qubvel/segmentation_models.pytorch",
80+
docs_url="https://smp.readthedocs.io/en/latest/",
81+
model_parameters=model_parameters_json,
82+
save_directory=repo_or_directory,
83+
model_name=self.__class__.__name__,
84+
metrics=metrics,
85+
dataset=dataset,
86+
)
87+
return model_card
88+
89+
def _set_attrs_from_kwargs(self, attrs, kwargs):
90+
for attr in attrs:
91+
if attr in kwargs:
92+
setattr(self, f"_{attr}", kwargs.pop(attr))
93+
94+
def _del_attrs(self, attrs):
95+
for attr in attrs:
96+
if hasattr(self, f"_{attr}"):
97+
delattr(self, f"_{attr}")
98+
99+
@wraps(PyTorchModelHubMixin.save_pretrained)
100+
def save_pretrained(self, save_directory: Union[str, Path], *args, **kwargs) -> Optional[str]:
101+
102+
# set additional attributes to be used in generate_model_card
103+
self._save_directory = save_directory
104+
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
105+
106+
# set additional attribute to be used in from_pretrained
107+
self._hub_mixin_config["_model_class"] = self.__class__.__name__
108+
109+
try:
110+
# call the original save_pretrained
111+
result = super().save_pretrained(save_directory, *args, **kwargs)
112+
finally:
113+
# delete the additional attributes
114+
self._del_attrs(["save_directory", "metrics", "dataset"])
115+
self._hub_mixin_config.pop("_model_class")
116+
117+
return result
118+
119+
@wraps(PyTorchModelHubMixin.push_to_hub)
120+
def push_to_hub(self, repo_id: str, *args, **kwargs):
121+
self._repo_id = repo_id
122+
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
123+
result = super().push_to_hub(repo_id, *args, **kwargs)
124+
self._del_attrs(["repo_id", "metrics", "dataset"])
125+
return result
126+
127+
@property
128+
def config(self):
129+
return self._hub_mixin_config
130+
131+
132+
@wraps(PyTorchModelHubMixin.from_pretrained)
133+
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
134+
config_path = hf_hub_download(
135+
pretrained_model_name_or_path, filename="config.json", revision=kwargs.get("revision", None)
136+
)
137+
with open(config_path, "r") as f:
138+
config = json.load(f)
139+
model_class_name = config.pop("_model_class")
140+
141+
import segmentation_models_pytorch as smp
142+
143+
model_class = getattr(smp, model_class_name)
144+
return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

segmentation_models_pytorch/base/model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import torch
2+
23
from . import initialization as init
4+
from .hub_mixin import SMPHubMixin
35

46

5-
class SegmentationModel(torch.nn.Module):
7+
class SegmentationModel(
8+
torch.nn.Module,
9+
SMPHubMixin,
10+
):
611
def initialize(self):
712
init.initialize_decoder(self.decoder)
813
init.initialize_head(self.segmentation_head)

segmentation_models_pytorch/datasets/oxford_pet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __getitem__(self, *args, **kwargs):
8787
sample = super().__getitem__(*args, **kwargs)
8888

8989
# resize images
90-
image = np.array(Image.fromarray(sample["image"]).resize((256, 256), Image.LINEAR))
90+
image = np.array(Image.fromarray(sample["image"]).resize((256, 256), Image.BILINEAR))
9191
mask = np.array(Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST))
9292
trimap = np.array(Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST))
9393

0 commit comments

Comments
 (0)