Skip to content

Ruff: format Jupyter notebooks too #923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,212 changes: 4,211 additions & 1 deletion examples/binary_segmentation_intro.ipynb

Large diffs are not rendered by default.

14 changes: 6 additions & 8 deletions examples/camvid_segmentation_multiclass.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@
" augmentation=get_validation_augmentation(),\n",
")\n",
"\n",
"#Change to > 0 if not on Windows machine\n",
"# Change to > 0 if not on Windows machine\n",
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)\n",
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=0)\n",
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)"
Expand Down Expand Up @@ -545,12 +545,10 @@
"import pytorch_lightning as pl\n",
"import segmentation_models_pytorch as smp\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from torch.optim import lr_scheduler\n",
"\n",
"\n",
"class CamVidModel(pl.LightningModule):\n",
"\n",
" def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):\n",
" super().__init__()\n",
" self.model = smp.create_model(\n",
Expand Down Expand Up @@ -591,13 +589,14 @@
" mask = mask.long()\n",
"\n",
" # Mask shape\n",
" assert mask.ndim == 3 # [batch_size, H, W]\n",
" assert mask.ndim == 3 # [batch_size, H, W]\n",
"\n",
" # Predict mask logits\n",
" logits_mask = self.forward(image)\n",
" \n",
" assert logits_mask.shape[1] == self.number_of_classes # [batch_size, number_of_classes, H, W]\n",
" \n",
"\n",
" assert (\n",
" logits_mask.shape[1] == self.number_of_classes\n",
" ) # [batch_size, number_of_classes, H, W]\n",
"\n",
" # Ensure the logits mask is contiguous\n",
" logits_mask = logits_mask.contiguous()\n",
Expand Down Expand Up @@ -1678,7 +1677,6 @@
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Fetch a batch from the test loader\n",
Expand Down
1,271 changes: 1,270 additions & 1 deletion examples/cars segmentation (camvid).ipynb

Large diffs are not rendered by default.

41 changes: 20 additions & 21 deletions examples/save_load_model_and_share_with_hf_hub.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@
"# save the model\n",
"model.save_pretrained(\n",
" \"saved-model-dir/unet-with-metadata/\",\n",
"\n",
" # additional information to be saved with the model\n",
" # only \"dataset\" and \"metrics\" are supported\n",
" dataset=\"PASCAL VOC\", # only string name is supported\n",
" metrics={ # should be a dictionary with metric name as key and metric value as value\n",
" metrics={ # should be a dictionary with metric name as key and metric value as value\n",
" \"mIoU\": 0.95,\n",
" \"accuracy\": 0.96\n",
" }\n",
" \"accuracy\": 0.96,\n",
" },\n",
")"
]
},
Expand Down Expand Up @@ -222,13 +221,10 @@
"# save the model and share it on the HF Hub (https://huggingface.co/models)\n",
"model.save_pretrained(\n",
" \"qubvel-hf/unet-with-metadata/\",\n",
" push_to_hub=True, # <---------- push the model to the hub\n",
" private=False, # <---------- make the model private or or public\n",
" push_to_hub=True, # <---------- push the model to the hub\n",
" private=False, # <---------- make the model private or or public\n",
" dataset=\"PASCAL VOC\",\n",
" metrics={\n",
" \"mIoU\": 0.95,\n",
" \"accuracy\": 0.96\n",
" }\n",
" metrics={\"mIoU\": 0.95, \"accuracy\": 0.96},\n",
")\n",
"\n",
"# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
Expand Down Expand Up @@ -267,10 +263,7 @@
"outputs": [],
"source": [
"# define a preprocessing transform for image that would be used during inference\n",
"preprocessing_transform = A.Compose([\n",
" A.Resize(256, 256),\n",
" A.Normalize()\n",
"])\n",
"preprocessing_transform = A.Compose([A.Resize(256, 256), A.Normalize()])\n",
"\n",
"model = smp.Unet()"
]
Expand Down Expand Up @@ -367,15 +360,21 @@
"# You can also save training augmentations to the Hub too (and load it back)!\n",
"#! Just make sure to provide key=\"train\" when saving and loading the augmentations.\n",
"\n",
"train_augmentations = A.Compose([\n",
" A.HorizontalFlip(p=0.5),\n",
" A.RandomBrightnessContrast(p=0.2),\n",
" A.ShiftScaleRotate(p=0.5),\n",
"])\n",
"train_augmentations = A.Compose(\n",
" [\n",
" A.HorizontalFlip(p=0.5),\n",
" A.RandomBrightnessContrast(p=0.2),\n",
" A.ShiftScaleRotate(p=0.5),\n",
" ]\n",
")\n",
"\n",
"train_augmentations.save_pretrained(directory_or_repo_on_the_hub, key=\"train\", push_to_hub=True)\n",
"train_augmentations.save_pretrained(\n",
" directory_or_repo_on_the_hub, key=\"train\", push_to_hub=True\n",
")\n",
"\n",
"restored_train_augmentations = A.Compose.from_pretrained(directory_or_repo_on_the_hub, key=\"train\")\n",
"restored_train_augmentations = A.Compose.from_pretrained(\n",
" directory_or_repo_on_the_hub, key=\"train\"\n",
")\n",
"print(restored_train_augmentations)"
]
},
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ test = [
[project.urls]
Homepage = 'https://github.com/qubvel-org/segmentation_models.pytorch'

[tool.ruff]
extend-include = ['*.ipynb']
fix = true

[tool.setuptools.dynamic]
version = {attr = 'segmentation_models_pytorch.__version__.__version__'}

Expand Down