Skip to content

Commit

Permalink
fix: normality test filters numeric columns instead of error, fixed c…
Browse files Browse the repository at this point in the history
…ase where CLI function does not compute anything if columns param is left empty
  • Loading branch information
nmaarnio committed Mar 4, 2024
1 parent 68e0062 commit 19c7229
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion eis_toolkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def normality_test_raster_cli(input_raster: Annotated[Path, INPUT_FILE_OPTION],
with rasterio.open(input_raster) as raster:
data = raster.read()
typer.echo("Progress: 25%")

print(bands)
if len(bands) == 0:
bands = None
results_dict = normality_test_array(data=data, bands=bands, nodata_value=raster.nodata)
Expand Down
10 changes: 5 additions & 5 deletions eis_toolkit/exploratory_analyses/normality_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def normality_test_dataframe(
Raises:
EmptyDataException: The input data is empty.
InvalidColumnException: All selected columns were not found in the input data.
NonNumericDataException: Selected data or columns contains non-numeric data.
NonNumericDataException: Selected columns contain non-numeric data or no numeric columns were found.
SampleSizeExceededException: Input data exceeds the maximum of 5000 samples.
"""
if check_empty_dataframe(data):
raise EmptyDataException("The input Dataframe is empty.")

if columns is not None:
if columns is not None and columns != []:
if not check_columns_valid(data, columns):
raise InvalidColumnException("All selected columns were not found in the input DataFrame.")
if not check_columns_numeric(data, columns):
Expand All @@ -51,9 +51,9 @@ def normality_test_dataframe(
data = data[columns].dropna()

else:
if not check_columns_numeric(data, data.columns):
raise NonNumericDataException("The input data contain non-numeric data.")
columns = data.columns
columns = data.select_dtypes(include=[np.number]).columns
if len(columns) == 0:
raise NonNumericDataException("No numeric columns were found.")

statistics = {}
for column in columns:
Expand Down

0 comments on commit 19c7229

Please sign in to comment.