Skip to content

Commit

Permalink
Merge pull request #22 from Ensembl/jalvarez/add_ranged_arg
Browse files Browse the repository at this point in the history
Add numeric argument option for the ArgumentParser class
  • Loading branch information
JAlvarezJarreta authored Sep 4, 2024
2 parents e35e239 + b983678 commit 418438e
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/ensembl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Ensembl Python general-purpose utils library."""

__version__ = "0.4.2"
__version__ = "0.4.3"

__all__ = [
"StrPath",
Expand Down
61 changes: 61 additions & 0 deletions src/ensembl/utils/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,24 @@
from __future__ import annotations

__all__ = [
"ArgumentError",
"ArgumentParser",
]

import argparse
import os
from pathlib import Path
from typing import Callable

from sqlalchemy.engine import make_url, URL

from ensembl.utils import StrPath


class ArgumentError(Exception):
"""An error from creating an argument (optional or positional)."""


class ArgumentParser(argparse.ArgumentParser):
"""Extends `argparse.ArgumentParser` with additional methods and functionality.
Expand Down Expand Up @@ -95,6 +101,34 @@ def _validate_dst_path(self, dst_path: StrPath, exists_ok: bool = False) -> Path
break
return dst_path

def _validate_number(
self,
value: str,
value_type: Callable[[str], int | float],
min_value: int | float | None,
max_value: int | float | None,
) -> int | float:
"""Returns the numeric value if it is of the expected type and it is within the specified range.
Args:
value: String representation of numeric value to check.
value_type: Expected type of the numeric value.
min_value: Minimum value constrain. If `None`, no minimum value constrain.
max_value: Maximum value constrain. If `None`, no maximum value constrain.
"""
# Check if the string representation can be converted to the expected type
try:
result = value_type(value)
except (TypeError, ValueError):
self.error(f"invalid {value_type.__name__} value: {value}")
# Check if numeric value is within range
if (min_value is not None) and (result < min_value):
self.error(f"{value} is lower than minimum value ({min_value})")
if (max_value is not None) and (result > max_value):
self.error(f"{value} is greater than maximum value ({max_value})")
return result

def add_argument(self, *args, **kwargs) -> None: # type: ignore[override]
"""Extends the parent function by excluding the default value in the help text when not provided.
Expand Down Expand Up @@ -139,6 +173,33 @@ def add_argument_url(self, *args, **kwargs) -> None:
kwargs["type"] = make_url
self.add_argument(*args, **kwargs)

# pylint: disable=redefined-builtin
def add_numeric_argument(
self,
*args,
type: Callable[[str], int | float] = float,
min_value: int | float | None = None,
max_value: int | float | None = None,
**kwargs,
) -> None:
"""Adds a numeric argument with constrains on its type and its minimum or maximum value.
Note that the default value (if defined) is not checked unless the argument is an optional argument
and no value is provided in the command line.
Args:
type: Type to convert the argument value to when parsing.
min_value: Minimum value constrain. If `None`, no minimum value constrain.
max_value: Maximum value constrain. If `None`, no maximum value constrain.
"""
# If both minimum and maximum values are defined, ensure min_value <= max_value
if (min_value is not None) and (max_value is not None) and (min_value > max_value):
raise ArgumentError("minimum value is greater than maximum value")
# Add lambda function to check numeric constrains when parsing argument
kwargs["type"] = lambda x: self._validate_number(x, type, min_value, max_value)
self.add_argument(*args, **kwargs)

# pylint: disable=redefined-builtin
def add_server_arguments(self, prefix: str = "", include_database: bool = False, help: str = "") -> None:
"""Adds the usual set of arguments needed to connect to a server, i.e. `--host`, `--port`, `--user`
Expand Down
43 changes: 41 additions & 2 deletions tests/argparse/test_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from contextlib import nullcontext as does_not_raise
from pathlib import Path
import re
from typing import Any, ContextManager
from typing import Any, Callable, ContextManager

import pytest
from pytest import param, raises
from sqlalchemy.engine import make_url

from ensembl.utils.argparse import ArgumentParser
from ensembl.utils.argparse import ArgumentError, ArgumentParser


def _args_dict_to_cmd_list(args_dict: dict[str, Any]) -> list[str]:
Expand Down Expand Up @@ -181,6 +181,45 @@ def test_add_argument_url(self) -> None:
args = parser.parse_args(cmd_args)
assert args.url == make_url("https://github.com")

@pytest.mark.dependency(depends=["add_argument"])
@pytest.mark.parametrize(
"value, value_type, min_value, max_value, expectation",
[
param("3", int, None, None, does_not_raise(), id="Value has expected type"),
param("3", int, 3, 3, does_not_raise(), id="Value equal to minimum and maximum values"),
param("3.5", float, 3.4, 3.6, does_not_raise(), id="Value within range"),
param("3", int, 3, 2, raises(ArgumentError), id="Minimum value greater than maximum value"),
param("3.2", int, None, None, raises(SystemExit), id="Value has incorrect type"),
param("3", int, 4, None, raises(SystemExit), id="Value lower than minimum value"),
param("3", int, None, 2, raises(SystemExit), id="Value greater than maximum value"),
],
)
def test_add_numeric_argument(
self,
value: str,
value_type: Callable[[str], int | float],
min_value: int | float | None,
max_value: int | float | None,
expectation: ContextManager,
) -> None:
"""Tests `ArgumentParser.add_numeric_argument()` method.
Args:
value: Argument value.
value_type: Expected argument type.
min_value: Minimum value constrain. If `None`, no minimum value constrain.
max_value: Maximum value constrain. If `None`, no maximum value constrain.
expectation: Context manager for the expected exception.
"""
parser = ArgumentParser()
# Add numeric argument to parser and its command line, and check that the argument is properly parsed
with expectation:
parser.add_numeric_argument("--num", type=value_type, min_value=min_value, max_value=max_value)
cmd_args = ["--num", value]
args = parser.parse_args(cmd_args)
assert args.num == value_type(value)

@pytest.mark.dependency(name="add_server_arguments", depends=["add_argument"])
@pytest.mark.parametrize(
"prefix, include_database",
Expand Down

0 comments on commit 418438e

Please sign in to comment.