Skip to content

Commit

Permalink
fix progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
peterneher committed Oct 31, 2023
1 parent bb8ab34 commit 939980b
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions radtract/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def run_cv_experiment(feature_files, targets, remove_map_substrings=[], n_jobs=-
features_df = normalize_features(features_df)

is_classification = True
if type(targets) is list:
targets = np.array(targets)
if not np.issubdtype(targets.dtype, np.integer):
print('Targets are not integer. Interpreting as regression problem.')
is_classification = False
Expand All @@ -113,17 +115,21 @@ def run_cv_experiment(feature_files, targets, remove_map_substrings=[], n_jobs=-
print('Starting classification experiment')

if folds > 1:
cv = StratifiedKFold(n_splits=folds)
print('Using {}-fold stratified cross-validation'.format(folds))
bar = ProgressBar(folds * 10)
else:
cv = LeaveOneOut()
print('Using leave-one-out cross-validation')
cv = LeaveOneOut()
bar = ProgressBar(len(feature_files) * 10)

predictions = []
ground_truth = []
classifiers = []
bar = ProgressBar(len(feature_files) * 10)
for seed in range(10):

if folds > 1:
cv = StratifiedKFold(n_splits=folds, shuffle=True, random_state=seed)

for train_idxs, test_idxs in cv.split(features_df, targets):

x_train = features_df.iloc[train_idxs, :]
Expand Down

0 comments on commit 939980b

Please sign in to comment.