-
Notifications
You must be signed in to change notification settings - Fork 0
/
balance_data.py
36 lines (26 loc) · 1.26 KB
/
balance_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import pandas as pd
from tqdm import tqdm
# Load the CSV file
file_path = 'Data/Merged_Train.csv'
df = pd.read_csv(file_path)
def is_augmented(image_id):
return '_rot' in image_id
# Initialize a set to keep track of unique base image IDs
base_image_ids_to_keep = set()
# Identify unique base image IDs
df['base_image_id'] = df['image_id'].apply(lambda x: x.split('_rot')[0])
# Iterate over each unique base image_id with a progress bar
for base_image_id in tqdm(df['base_image_id'].unique(), desc="Processing base image IDs"):
# Filter the dataframe for the current base_image_id
subset_df = df[df['base_image_id'] == base_image_id]
# Check if the base_image_id contains any Tin (category 2) or Other (category 1)
if subset_df['category_id'].isin([3.0, 1.0]).any():
base_image_ids_to_keep.add(base_image_id)
# Filter the dataframe to keep only the selected base_image_ids
df_filtered = df[df['base_image_id'].isin(base_image_ids_to_keep)]
# Drop the temporary base_image_id column
df_filtered = df_filtered.drop(columns=['base_image_id'])
# Save the filtered dataframe to a new CSV file
filtered_file_path = 'Data/Balanced_Train_Merge.csv'
df_filtered.to_csv(filtered_file_path, index=False)
print(f"Filtered file saved at: {filtered_file_path}")