Skip to content

Commit

Permalink
Merge branch 'main' into random_seed
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena authored Sep 20, 2024
2 parents 494fbd7 + 43d0d94 commit af0ffaa
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
4 changes: 3 additions & 1 deletion .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Double quote -> single quote
# Prettier: double quote -> single quote
6a5aaf4b93507072d40dcd78114893362c4eaf6e
# Ruff: double quote -> single quote
b09122f3e4a9cb422f6747bf33eca02993f67549
# Prettier
bd9c75798eede1a4b7d7ecd6203179d3cb5e54dd
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tutorials.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac
pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac .
pip cache purge
- name: List pip dependencies
run: pip list
Expand Down
8 changes: 4 additions & 4 deletions tests/conf/landcoverai100.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 3
num_classes: 5
num_filters: 1
Expand All @@ -13,4 +13,4 @@ data:
init_args:
batch_size: 1
dict_kwargs:
root: "tests/data/landcoverai"
root: 'tests/data/landcoverai'
2 changes: 2 additions & 0 deletions torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
batch_sampler=batch_sampler,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
persistent_workers=self.num_workers > 0,
)

def train_dataloader(self) -> DataLoader[dict[str, Tensor]]:
Expand Down Expand Up @@ -429,6 +430,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
shuffle=split == 'train',
num_workers=self.num_workers,
collate_fn=self.collate_fn,
persistent_workers=self.num_workers > 0,
)

def train_dataloader(self) -> DataLoader[dict[str, Tensor]]:
Expand Down

0 comments on commit af0ffaa

Please sign in to comment.