-
Notifications
You must be signed in to change notification settings - Fork 388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainers: add Instance Segmentation Task #2513
base: main
Are you sure you want to change the base?
Conversation
@ariannasole23 please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just need to make the code and tests match our existing trainers and run ruff: https://torchgeo.readthedocs.io/en/latest/user/contributing.html#linters
To solve the import issue, you also need to add 2 lines to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you run ruff on the code to make it more uniform? This will make it easier to review.
test_trainer_instancesegmentation.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file can be deleted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename to test_instance_segmentation.py
to match the other filename
@pytest.mark.parametrize( | ||
'name', | ||
[ | ||
'agrifieldnet', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a list of files whose configuration is present in tests/conf
. This current list is from the semantic segmentation tests. We need to create a new list (and new tests/conf/*.yaml
files) for instance segmentation datasets. We can start with something like VHR-10 and work from there.
match name: | ||
case 'chabud' | 'cabuar': | ||
pytest.importorskip('h5py', minversion='3.6') | ||
case 'ftw': | ||
pytest.importorskip('pyarrow') | ||
case 'landcoverai': | ||
sha256 = ( | ||
'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' | ||
) | ||
monkeypatch.setattr(LandCoverAI, 'sha256', sha256) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be removed if you aren't using any of these datasets.
monkeypatch.setattr(smp, 'Unet', create_model) | ||
monkeypatch.setattr(smp, 'DeepLabV3Plus', create_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We aren't using these models (or even smp) for this task, this can be removed
freeze_backbone: Freeze the backbone network to fine-tune the | ||
decoder and segmentation head. | ||
|
||
.. versionadded:: 0.7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this versionadded to the class docstring?
self.save_hyperparameters() | ||
self.model = None | ||
self.validation_outputs = [] | ||
self.test_outputs = [] | ||
self.configure_models() | ||
self.configure_metrics() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of these are needed, they are added by the base class
|
||
if model == 'mask_rcnn': | ||
# Load the Mask R-CNN model with a ResNet50 backbone | ||
self.model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weights will be chosen by the user, we can't assume ImageNet weights
|
||
- Uses Mean Average Precision (mAP) for masks (IOU-based metric). | ||
""" | ||
self.metrics = MetricCollection([MeanAveragePrecision(iou_type="segm")]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to use MetricCollection if you only have a single metric
print('\nTRAINING LOSS\n') | ||
print(loss_dict, '\n\n') | ||
print(loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove print statements, PyTorch Lightning has builtin loggers.
No description provided.