Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added filtering for datasets #1109

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 103 additions & 6 deletions backend/functions/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@

from shoonya_backend.locks import Lock
from utils.constants import LANG_CHOICES

from projects.tasks import filter_data_items
from projects.models import BATCH
from dataset import models as dataset_models
from projects.registry_helper import ProjectRegistry
import logging

logger = logging.getLogger(__name__)
Expand All @@ -73,6 +76,10 @@ def sentence_text_translate_and_save_translation_pairs(
input_dataset_instance_id,
output_dataset_instance_id,
batch_size,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
api_type="indic-trans-v2",
checks_for_particular_languages=False,
automate_missing_data_items=True,
Expand All @@ -88,6 +95,10 @@ def sentence_text_translate_and_save_translation_pairs(
Allowed - [indic-trans, google, indic-trans-v2, azure, blank]
checks_for_particular_languages (bool): If True, checks for the particular languages in the translations.
automate_missing_data_items (bool): If True, consider only those data items that are missing in the target dataset instance.
filter_string (str): string to filter input data.
sampling_mode (str): can be batch or full.
sampling_parameters (json): is a json that contains, batch number and batch size

"""
task_name = "sentence_text_translate_and_save_translation_pairs"
output_sentences = list(
Expand All @@ -114,6 +125,14 @@ def sentence_text_translate_and_save_translation_pairs(
"metadata_json",
)
)
if filter_string and sampling_mode and sampling_parameters:
input_sentences = get_filtered_items(
"SentenceText",
input_dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
)

# Convert the input_sentences list into a dataframe
input_sentences_complete_df = pd.DataFrame(
Expand Down Expand Up @@ -404,7 +423,15 @@ def conversation_data_machine_translation(

@shared_task(bind=True)
def generate_ocr_prediction_json(
self, dataset_instance_id, user_id, api_type, automate_missing_data_items
self,
dataset_instance_id,
user_id,
api_type,
automate_missing_data_items,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
):
"""Function to generate OCR prediction data and to save to the same data item.
Args:
Expand Down Expand Up @@ -437,7 +464,14 @@ def generate_ocr_prediction_json(
)
except Exception as e:
ocr_data_items = []

if filter_string and sampling_mode and sampling_parameters:
ocr_data_items = get_filtered_items(
"OCRDocument",
dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
)
# converting the dataset_instance to pandas dataframe.
ocr_data_items_df = pd.DataFrame(
ocr_data_items,
Expand Down Expand Up @@ -556,7 +590,15 @@ def generate_ocr_prediction_json(

@shared_task(bind=True)
def generate_asr_prediction_json(
self, dataset_instance_id, user_id, api_type, automate_missing_data_items
self,
dataset_instance_id,
user_id,
api_type,
automate_missing_data_items,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
):
"""Function to generate ASR prediction data and to save to the same data item.
Args:
Expand Down Expand Up @@ -590,7 +632,14 @@ def generate_asr_prediction_json(
)
except Exception as e:
asr_data_items = []

if filter_string and sampling_mode and sampling_parameters:
asr_data_items = get_filtered_items(
"SpeechConversation",
dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
)
# converting the dataset_instance to pandas dataframe.
asr_data_items_df = pd.DataFrame(
asr_data_items,
Expand Down Expand Up @@ -704,7 +753,16 @@ def generate_asr_prediction_json(


@shared_task(bind=True)
def populate_draft_data_json(self, pk, user_id, fields_list):
def populate_draft_data_json(
self,
pk,
user_id,
fields_list,
filter_string,
sampling_mode,
sampling_parameters,
variable_parameters,
):
task_name = "populate_draft_data_json"
try:
dataset_instance = DatasetInstance.objects.get(pk=pk)
Expand All @@ -713,6 +771,10 @@ def populate_draft_data_json(self, pk, user_id, fields_list):
dataset_type = dataset_instance.dataset_type
dataset_model = apps.get_model("dataset", dataset_type)
dataset_items = dataset_model.objects.filter(instance_id=dataset_instance)
if filter_string and sampling_mode and sampling_parameters:
dataset_items = get_filtered_items(
dataset_type, pk, filter_string, sampling_mode, sampling_parameters
)
cnt = 0
for dataset_item in dataset_items:
new_draft_data_json = {}
Expand Down Expand Up @@ -1696,3 +1758,38 @@ def upload_all_projects_to_blob_and_get_url(csv_files_directory):
return "Error in generating url"
blob_url = f"https://{account_name}.blob.{endpoint_suffix}/{CONTAINER_NAME_FOR_DOWNLOAD_ALL_PROJECTS}/{blob_client.blob_name}?{sas_token}"
return blob_url


def get_filtered_items(
dataset_model,
dataset_instance_id,
filter_string,
sampling_mode,
sampling_parameters,
):
registry_helper = ProjectRegistry.get_instance()
project_type = registry_helper.get_project_name_from_dataset(dataset_model)
if not isinstance(dataset_instance_id, list):
dataset_instance_id = [dataset_instance_id]
filtered_items = filter_data_items(
project_type=project_type,
dataset_instance_ids=dataset_instance_id,
filter_string=filter_string,
)
# Apply sampling
if sampling_mode == BATCH:
batch_size = sampling_parameters["batch_size"]
try:
batch_number = sampling_parameters["batch_number"]
if len(batch_number) == 0:
batch_number = [1]
except KeyError:
batch_number = [1]
sampled_items = []
for batch_num in batch_number:
sampled_items += filtered_items[
batch_size * (batch_num - 1) : batch_size * batch_num
]
else:
sampled_items = filtered_items
return sampled_items
40 changes: 37 additions & 3 deletions backend/functions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def schedule_sentence_text_translate_job(request):
automate_missing_data_items = request.data.get(
"automate_missing_data_items", "true"
)
filter_string = request.data.get("filter_string", None)
sampling_mode = request.data.get("sampling_mode", None)
sampling_parameters = request.data.get("sampling_parameters_json", None)
variable_parameters = request.data.get("variable_parameters", None)

# Convert checks for languages into boolean
checks_for_particular_languages = checks_for_particular_languages.lower() == "true"
Expand Down Expand Up @@ -311,6 +315,10 @@ def schedule_sentence_text_translate_job(request):
input_dataset_instance_id=input_dataset_instance_id,
output_dataset_instance_id=output_dataset_instance_id,
batch_size=batch_size,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
api_type=api_type,
checks_for_particular_languages=checks_for_particular_languages,
automate_missing_data_items=automate_missing_data_items,
Expand Down Expand Up @@ -537,7 +545,10 @@ def schedule_ocr_prediction_json_population(request):
except KeyError:
automate_missing_data_items = True

# Calling a function asynchronously to create ocr predictions.
filter_string = request.data.get("filter_string")
sampling_mode = request.data.get("sampling_mode")
sampling_parameters = request.data.get("sampling_parameters_json")
variable_parameters = request.data.get("variable_parameters")

uid = request.user.id

Expand All @@ -546,6 +557,10 @@ def schedule_ocr_prediction_json_population(request):
user_id=uid,
api_type=api_type,
automate_missing_data_items=automate_missing_data_items,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
)

# Returning response
Expand Down Expand Up @@ -574,8 +589,20 @@ def schedule_draft_data_json_population(request):
pk = request.data["dataset_instance_id"]

uid = request.user.id
filter_string = request.data.get("filter_string")
sampling_mode = request.data.get("sampling_mode")
sampling_parameters = request.data.get("sampling_parameters_json")
variable_parameters = request.data.get("variable_parameters")

populate_draft_data_json.delay(pk=pk, user_id=uid, fields_list=fields_list)
populate_draft_data_json(
pk=pk,
user_id=uid,
fields_list=fields_list,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
)

ret_dict = {"message": "draft_data_json population started"}
ret_status = status.HTTP_200_OK
Expand Down Expand Up @@ -624,7 +651,10 @@ def schedule_asr_prediction_json_population(request):
except KeyError:
automate_missing_data_items = True

# Calling a function asynchronously to create ocr predictions.
filter_string = request.data.get("filter_string")
sampling_mode = request.data.get("sampling_mode")
sampling_parameters = request.data.get("sampling_parameters_json")
variable_parameters = request.data.get("variable_parameters")

uid = request.user.id

Expand All @@ -633,6 +663,10 @@ def schedule_asr_prediction_json_population(request):
user_id=uid,
api_type=api_type,
automate_missing_data_items=automate_missing_data_items,
filter_string=filter_string,
sampling_mode=sampling_mode,
sampling_parameters=sampling_parameters,
variable_parameters=variable_parameters,
)

ret_dict = {"message": "Generating ASR Predictions"}
Expand Down
12 changes: 12 additions & 0 deletions backend/projects/registry_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,15 @@ def validate_registry(self):
)

return True

def get_project_name_from_dataset(self, dataset_name: str):
for project_key, project_type in self.project_types.items():
input_dataset = project_type.get("input_dataset", {})
output_dataset = project_type.get("output_dataset", {})

if (
input_dataset.get("class") == dataset_name
or output_dataset.get("class") == dataset_name
):
return project_key
return None
Loading