Skip to content

Commit

Permalink
Merge pull request #1109 from AI4Bharat/dataset_filtering
Browse files Browse the repository at this point in the history
added filtering for datasets
  • Loading branch information
ishvindersethi22 authored Aug 21, 2024
2 parents ec10755 + da3bb7d commit 4e83bcd
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 9 deletions.
109 changes: 103 additions & 6 deletions backend/functions/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@

from shoonya_backend.locks import Lock
from utils.constants import LANG_CHOICES

from projects.tasks import filter_data_items
from projects.models import BATCH
from dataset import models as dataset_models
from projects.registry_helper import ProjectRegistry
import logging

logger = logging.getLogger(__name__)
Expand All @@ -73,6 +76,10 @@ def sentence_text_translate_and_save_translation_pairs(
input_dataset_instance_id,
output_dataset_instance_id,
batch_size,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
api_type="indic-trans-v2",
checks_for_particular_languages=False,
automate_missing_data_items=True,
Expand All @@ -88,6 +95,10 @@ def sentence_text_translate_and_save_translation_pairs(
Allowed - [indic-trans, google, indic-trans-v2, azure, blank]
checks_for_particular_languages (bool): If True, checks for the particular languages in the translations.
automate_missing_data_items (bool): If True, consider only those data items that are missing in the target dataset instance.
filter_string (str): string to filter input data.
sampling_mode (str): can be batch or full.
sampling_parameters (json): is a json that contains, batch number and batch size
"""
task_name = "sentence_text_translate_and_save_translation_pairs"
output_sentences = list(
Expand All @@ -114,6 +125,14 @@ def sentence_text_translate_and_save_translation_pairs(
"metadata_json",
)
)
if filter_string and sampling_mode and sampling_parameters:
input_sentences = get_filtered_items(
"SentenceText",
input_dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
)

# Convert the input_sentences list into a dataframe
input_sentences_complete_df = pd.DataFrame(
Expand Down Expand Up @@ -404,7 +423,15 @@ def conversation_data_machine_translation(

@shared_task(bind=True)
def generate_ocr_prediction_json(
self, dataset_instance_id, user_id, api_type, automate_missing_data_items
self,
dataset_instance_id,
user_id,
api_type,
automate_missing_data_items,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
):
"""Function to generate OCR prediction data and to save to the same data item.
Args:
Expand Down Expand Up @@ -437,7 +464,14 @@ def generate_ocr_prediction_json(
)
except Exception as e:
ocr_data_items = []

if filter_string and sampling_mode and sampling_parameters:
ocr_data_items = get_filtered_items(
"OCRDocument",
dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
)
# converting the dataset_instance to pandas dataframe.
ocr_data_items_df = pd.DataFrame(
ocr_data_items,
Expand Down Expand Up @@ -556,7 +590,15 @@ def generate_ocr_prediction_json(

@shared_task(bind=True)
def generate_asr_prediction_json(
self, dataset_instance_id, user_id, api_type, automate_missing_data_items
self,
dataset_instance_id,
user_id,
api_type,
automate_missing_data_items,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
):
"""Function to generate ASR prediction data and to save to the same data item.
Args:
Expand Down Expand Up @@ -590,7 +632,14 @@ def generate_asr_prediction_json(
)
except Exception as e:
asr_data_items = []

if filter_string and sampling_mode and sampling_parameters:
asr_data_items = get_filtered_items(
"SpeechConversation",
dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
)
# converting the dataset_instance to pandas dataframe.
asr_data_items_df = pd.DataFrame(
asr_data_items,
Expand Down Expand Up @@ -704,7 +753,16 @@ def generate_asr_prediction_json(


@shared_task(bind=True)
def populate_draft_data_json(self, pk, user_id, fields_list):
def populate_draft_data_json(
self,
pk,
user_id,
fields_list,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
):
task_name = "populate_draft_data_json"
try:
dataset_instance = DatasetInstance.objects.get(pk=pk)
Expand All @@ -713,6 +771,10 @@ def populate_draft_data_json(self, pk, user_id, fields_list):
dataset_type = dataset_instance.dataset_type
dataset_model = apps.get_model("dataset", dataset_type)
dataset_items = dataset_model.objects.filter(instance_id=dataset_instance)
if filter_string and sampling_mode and sampling_parameters:
dataset_items = get_filtered_items(
dataset_type, pk, filter_string, sampling_mode, sampling_parameters
)
cnt = 0
for dataset_item in dataset_items:
new_draft_data_json = {}
Expand Down Expand Up @@ -1696,3 +1758,38 @@ def upload_all_projects_to_blob_and_get_url(csv_files_directory):
return "Error in generating url"
blob_url = f"https://{account_name}.blob.{endpoint_suffix}/{CONTAINER_NAME_FOR_DOWNLOAD_ALL_PROJECTS}/{blob_client.blob_name}?{sas_token}"
return blob_url


def get_filtered_items(
dataset_model,
dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
):
registry_helper = ProjectRegistry.get_instance()
project_type = registry_helper.get_project_name_from_dataset(dataset_model)
if not isinstance(dataset_instance_id, list):
dataset_instance_id = [dataset_instance_id]
filtered_items = filter_data_items(
project_type=project_type,
dataset_instance_ids=dataset_instance_id,
filter_string=filter_string,
)
# Apply sampling
if sampling_mode == BATCH:
batch_size = sampling_parameters["batch_size"]
try:
batch_number = sampling_parameters["batch_number"]
if len(batch_number) == 0:
batch_number = [1]
except KeyError:
batch_number = [1]
sampled_items = []
for batch_num in batch_number:
sampled_items += filtered_items[
batch_size * (batch_num - 1) : batch_size * batch_num
]
else:
sampled_items = filtered_items
return sampled_items
40 changes: 37 additions & 3 deletions backend/functions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def schedule_sentence_text_translate_job(request):
automate_missing_data_items = request.data.get(
"automate_missing_data_items", "true"
)
filter_string = request.data.get("filter_string", None)
sampling_mode = request.data.get("sampling_mode", None)
sampling_parameters = request.data.get("sampling_parameters_json", None)
variable_parameters = request.data.get("variable_parameters", None)

# Convert checks for languages into boolean
checks_for_particular_languages = checks_for_particular_languages.lower() == "true"
Expand Down Expand Up @@ -311,6 +315,10 @@ def schedule_sentence_text_translate_job(request):
input_dataset_instance_id=input_dataset_instance_id,
output_dataset_instance_id=output_dataset_instance_id,
batch_size=batch_size,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
api_type=api_type,
checks_for_particular_languages=checks_for_particular_languages,
automate_missing_data_items=automate_missing_data_items,
Expand Down Expand Up @@ -537,7 +545,10 @@ def schedule_ocr_prediction_json_population(request):
except KeyError:
automate_missing_data_items = True

# Calling a function asynchronously to create ocr predictions.
filter_string = request.data.get("filter_string")
sampling_mode = request.data.get("sampling_mode")
sampling_parameters = request.data.get("sampling_parameters_json")
variable_parameters = request.data.get("variable_parameters")

uid = request.user.id

Expand All @@ -546,6 +557,10 @@ def schedule_ocr_prediction_json_population(request):
user_id=uid,
api_type=api_type,
automate_missing_data_items=automate_missing_data_items,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
)

# Returning response
Expand Down Expand Up @@ -574,8 +589,20 @@ def schedule_draft_data_json_population(request):
pk = request.data["dataset_instance_id"]

uid = request.user.id
filter_string = request.data.get("filter_string")
sampling_mode = request.data.get("sampling_mode")
sampling_parameters = request.data.get("sampling_parameters_json")
variable_parameters = request.data.get("variable_parameters")

populate_draft_data_json.delay(pk=pk, user_id=uid, fields_list=fields_list)
populate_draft_data_json(
pk=pk,
user_id=uid,
fields_list=fields_list,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
)

ret_dict = {"message": "draft_data_json population started"}
ret_status = status.HTTP_200_OK
Expand Down Expand Up @@ -624,7 +651,10 @@ def schedule_asr_prediction_json_population(request):
except KeyError:
automate_missing_data_items = True

# Calling a function asynchronously to create ocr predictions.
filter_string = request.data.get("filter_string")
sampling_mode = request.data.get("sampling_mode")
sampling_parameters = request.data.get("sampling_parameters_json")
variable_parameters = request.data.get("variable_parameters")

uid = request.user.id

Expand All @@ -633,6 +663,10 @@ def schedule_asr_prediction_json_population(request):
user_id=uid,
api_type=api_type,
automate_missing_data_items=automate_missing_data_items,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
)

ret_dict = {"message": "Generating ASR Predictions"}
Expand Down
12 changes: 12 additions & 0 deletions backend/projects/registry_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,15 @@ def validate_registry(self):
)

return True

def get_project_name_from_dataset(self, dataset_name: str):
for project_key, project_type in self.project_types.items():
input_dataset = project_type.get("input_dataset", {})
output_dataset = project_type.get("output_dataset", {})

if (
input_dataset.get("class") == dataset_name
or output_dataset.get("class") == dataset_name
):
return project_key
return None

0 comments on commit 4e83bcd

Please sign in to comment.