Skip to content

Commit

Permalink
Code quality
Browse files Browse the repository at this point in the history
- Refactored with recent version of black
- Ran isort
- Run pre-commit hooks, make sure they all run properly (except pylint/mypy for now)
- Remove requirements.txt
- Add requirements to setup.py
- Excluded generated modules in pyproject.toml for pylint
  • Loading branch information
jonasteuwen committed May 14, 2021
1 parent c780997 commit 37b41a8
Show file tree
Hide file tree
Showing 52 changed files with 222 additions and 404 deletions.
15 changes: 15 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
python:
- "**/*.py"
- "*.py"

documentation:
- "**/*.md"
- "*.md"
- "**/*.rst"
- "*.rst"

docker:
- "docker/Dockerfile"

ci:
- ".github/**"
1 change: 1 addition & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pylint
pip install -e .
- name: Analysing the code with pylint
run: |
pylint direct --errors-only
18 changes: 8 additions & 10 deletions direct/checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import datetime
import logging
import pathlib
import torch
import datetime
import warnings
import re
import torch.nn as nn


import warnings
from pickle import UnpicklingError
from typing import Union, Optional, Dict, Mapping, get_args

from direct.types import PathOrString, HasStateDict
from typing import Dict, Mapping, Optional, Union, get_args

from torch.nn.parallel import DistributedDataParallel
import torch
import torch.nn as nn
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel

from direct.types import HasStateDict, PathOrString

# TODO: Rewrite Checkpointer
# There are too many issues with typing and mypy in the checkpointer.
Expand Down
15 changes: 7 additions & 8 deletions direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
# https://github.com/facebookresearch/fastMRI/
# The code can have been adjusted to our needs.

import numpy as np
import torch
import contextlib
import logging
import pathlib

from typing import Tuple, Optional, List
from abc import abstractmethod
from typing import List, Optional, Tuple

from direct.utils import str_to_class
from direct.types import Number
import numpy as np
import torch

import logging
import contextlib
from direct.types import Number
from direct.utils import str_to_class

logger = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion direct/common/subsample_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
from dataclasses import dataclass
from typing import Tuple, Optional, Union, Any
from typing import Any, Optional, Tuple, Union

from omegaconf import MISSING

from direct.config.defaults import BaseConfig
Expand Down
5 changes: 2 additions & 3 deletions direct/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
# Copyright (c) DIRECT Contributors

from dataclasses import dataclass, field
from omegaconf import MISSING
from typing import Any, List, Optional

from omegaconf import MISSING

from direct.config import BaseConfig
from direct.data.datasets_config import DatasetConfig

from typing import Optional, List, Any


@dataclass
class TensorboardConfig(BaseConfig):
Expand Down
4 changes: 2 additions & 2 deletions direct/data/bbox.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
from typing import List, Union

import numpy as np
import torch

from typing import List, Union


def crop_to_bbox(
data: Union[np.ndarray, torch.Tensor], bbox: List[int], pad_value: int = 0
Expand Down
15 changes: 6 additions & 9 deletions direct/data/datasets.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import numpy as np
import pathlib
import bisect

from typing import Callable, Dict, Optional, Any, List, Union
import pathlib
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Union

from direct.data.h5_data import H5SliceData
from direct.utils import str_to_class, remove_keys
from direct.types import PathOrString


import numpy as np
from torch.utils.data import Dataset, IterableDataset

from direct.data.h5_data import H5SliceData
from direct.types import PathOrString
from direct.utils import remove_keys, str_to_class

try:
import ismrmrd
Expand Down
8 changes: 4 additions & 4 deletions direct/data/datasets_config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
from dataclasses import dataclass, field
from typing import Tuple, Optional, List

from direct.config.defaults import BaseConfig
from direct.common.subsample_config import MaskingConfig
from typing import List, Optional, Tuple

from omegaconf import MISSING

from direct.common.subsample_config import MaskingConfig
from direct.config.defaults import BaseConfig


@dataclass
class TransformsConfig(BaseConfig):
Expand Down
11 changes: 5 additions & 6 deletions direct/data/h5_data.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import logging
import pathlib
import numpy as np
import h5py
import re
from typing import Any, Dict, List, Optional, Tuple, Union

import h5py
import numpy as np
from torch.utils.data import Dataset
from typing import Dict, Optional, Any, Tuple, List, Union

from direct.utils import cast_as_path
from direct.types import PathOrString

import logging
from direct.utils import cast_as_path

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions direct/data/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
# - Calls to other subroutines which do not exist in DIRECT.
# - Stylistic changes.

import logging
import math
from bisect import bisect_right
from typing import List
import torch
import logging

import torch

# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes
# only on epoch boundaries. We typically use iteration based schedules instead.
Expand Down
14 changes: 6 additions & 8 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import torch
import numpy as np
import warnings
import functools
import torch.nn as nn
import logging
import warnings
from typing import Any, Callable, Dict, Iterable, Optional

from typing import Dict, Any, Callable, Optional, Iterable
import numpy as np
import torch
import torch.nn as nn

from direct.data import transforms as T
from direct.utils import DirectModule, DirectTransform


import logging

logger = logging.getLogger(__name__)


Expand Down
12 changes: 5 additions & 7 deletions direct/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
# - Docstring to match the rest of the library
# - Calls to other subroutines which do not exist in DIRECT.

import torch
import itertools
import random
import numpy as np
import logging
import math
import random
from typing import List, Optional

from typing import Optional
import numpy as np
import torch
from torch.utils.data.sampler import Sampler

from direct.utils import communication, chunks

from typing import List
from direct.utils import chunks, communication


class DistributedSampler(Sampler):
Expand Down
1 change: 1 addition & 0 deletions direct/data/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pytest
import torch

from direct.data import transforms
from direct.data.transforms import tensor_to_complex_numpy

Expand Down
9 changes: 4 additions & 5 deletions direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@
# https://github.com/facebookresearch/fastMRI/
# The code can have been adjusted to our needs.

from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
import torch
from packaging import version

from typing import Union, Optional, List, Tuple, Callable, Any

from direct.utils import is_power_of_two, ensure_list
from direct.data.bbox import crop_to_bbox
from direct.utils import ensure_list, is_power_of_two
from direct.utils.asserts import assert_complex, assert_named, assert_same_shape

from packaging import version

if version.parse(torch.__version__) >= version.parse("1.7.0"):
import torch.fft

Expand Down
53 changes: 21 additions & 32 deletions direct/engine.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,46 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
import functools
import gc
import logging
import pathlib
import signal
import sys
import warnings
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import Callable, Dict, List, Optional, TypedDict, Union

import torch
import signal
import direct
import numpy as np
import warnings
import functools
import gc

from typing import Optional, Dict, List, Union, Callable, TypedDict
from abc import abstractmethod, ABC

import torch
from torch import nn
from torch.cuda.amp import GradScaler
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, Dataset, Sampler
from torch.cuda.amp import GradScaler

from torchvision.utils import make_grid

from direct.data.mri_transforms import AddNames
import direct
from direct.checkpointer import Checkpointer
from direct.config.defaults import BaseConfig
from direct.data import transforms as T
from direct.data.bbox import crop_to_largest
from direct.data.datasets import ConcatDataset
from direct.data.mri_transforms import AddNames
from direct.data.samplers import ConcatDatasetBatchSampler
from direct.checkpointer import Checkpointer
from direct.utils.collate import named_collate
from direct.exceptions import ProcessKilledException, TrainingException
from direct.types import PathOrString
from direct.utils import (
communication,
prefix_dict_keys,
evaluate_dict,
normalize_image,
str_to_class,
prefix_dict_keys,
reduce_list_of_dicts,
str_to_class,
)
from direct.data.bbox import crop_to_largest
from direct.utils.collate import named_collate
from direct.utils.events import CommonMetricPrinter, EventStorage, JSONWriter, TensorboardWriter, get_event_storage
from direct.utils.io import write_json
from direct.utils.events import (
get_event_storage,
EventStorage,
JSONWriter,
CommonMetricPrinter,
TensorboardWriter,
)
from direct.data import transforms as T
from direct.config.defaults import BaseConfig
from direct.exceptions import ProcessKilledException, TrainingException
from direct.types import PathOrString

from torchvision.utils import make_grid


DoIterationOutput = namedtuple(
"DoIterationOutput",
Expand Down
Loading

0 comments on commit 37b41a8

Please sign in to comment.