Skip to content

Commit

Permalink
Fix AttributeError: 'NoneType' object has no attribute 'strip' when e…
Browse files Browse the repository at this point in the history
…xporting form data (#2453)

* fix bug when exporting SAV

fix bug AttributeError: 'NoneType' object has no attribute 'strip'.

* enhance test

* refactor code

* update docstring
  • Loading branch information
kelvin-muchiri authored Jul 21, 2023
1 parent ea3ac4f commit 28d2990
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 42 deletions.
Binary file not shown.
50 changes: 50 additions & 0 deletions onadata/libs/tests/utils/test_export_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,56 @@ def _test_sav_file(section):
section_name = section["name"].replace("/", "_")
_test_sav_file(section_name)

def test_export_zipped_zap_missing_en_label(self):
"""Blank language label defaults to label for default language"""
survey = create_survey_from_xls(
_logger_fixture_path("childrens_survey_sw_missing_en_label.xlsx"),
default_name="childrens_survey_sw",
)
# default_language is set to swahili
self.assertEqual(survey.to_json_dict().get("default_language"), "swahili")
export_builder = ExportBuilder()
export_builder.TRUNCATE_GROUP_TITLE = True
export_builder.INCLUDE_LABELS = True
# export to be in english
export_builder.language = "english"
export_builder.set_survey(survey)

with NamedTemporaryFile(suffix=".zip") as temp_zip_file:
filename = temp_zip_file.name
export_builder.to_zipped_sav(filename, self.data)
temp_zip_file.seek(0)
temp_dir = tempfile.mkdtemp()
with zipfile.ZipFile(temp_zip_file.name, "r") as zip_file:
zip_file.extractall(temp_dir)

# check that each file exists
self.assertTrue(os.path.exists(os.path.join(temp_dir, f"{survey.name}.sav")))
checks = 0

for section in export_builder.sections:
section_name = section["name"]

if section_name == "childrens_survey_sw":
# Default swahili label is used incase english label is missing for question
result = filter(
lambda question: question["label"] == "1. Jina lako ni?",
section["elements"],
)
self.assertEqual(len(list(result)), 1)
checks += 1

if section_name == "children":
# Default swahili label is used incase english label is missing for choice
result = filter(
lambda choice: choice["label"] == "fav_colors/Nyekundu",
section["elements"],
)
self.assertEqual(len(list(result)), 1)
checks += 1

self.assertEqual(checks, 2)

# pylint: disable=invalid-name
def test_generate_field_title_truncated_titles(self):
self._create_childrens_survey()
Expand Down
113 changes: 71 additions & 42 deletions onadata/libs/utils/export_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,46 @@
from six import iteritems

from onadata.apps.logger.models.osmdata import OsmData
from onadata.apps.logger.models.xform import (QUESTION_TYPES_TO_EXCLUDE,
_encode_for_mongo)
from onadata.apps.logger.models.xform import (
QUESTION_TYPES_TO_EXCLUDE,
_encode_for_mongo,
)
from onadata.apps.viewer.models.data_dictionary import DataDictionary
from onadata.libs.utils.common_tags import (ATTACHMENTS, BAMBOO_DATASET_ID,
DELETEDAT, DURATION, GEOLOCATION,
ID, INDEX, MULTIPLE_SELECT_TYPE,
NOTES, PARENT_INDEX,
PARENT_TABLE_NAME,
REPEAT_INDEX_TAGS, REVIEW_COMMENT,
REVIEW_DATE, REVIEW_STATUS,
SAV_255_BYTES_TYPE,
SAV_NUMERIC_TYPE, SELECT_BIND_TYPE,
SELECT_ONE, STATUS,
SUBMISSION_TIME, SUBMITTED_BY,
TAGS, UUID, VERSION,
XFORM_ID_STRING)
from onadata.libs.utils.common_tools import (get_choice_label,
get_choice_label_value,
get_value_or_attachment_uri,
str_to_bool, track_task_progress)
from onadata.libs.utils.common_tags import (
ATTACHMENTS,
BAMBOO_DATASET_ID,
DELETEDAT,
DURATION,
GEOLOCATION,
ID,
INDEX,
MULTIPLE_SELECT_TYPE,
NOTES,
PARENT_INDEX,
PARENT_TABLE_NAME,
REPEAT_INDEX_TAGS,
REVIEW_COMMENT,
REVIEW_DATE,
REVIEW_STATUS,
SAV_255_BYTES_TYPE,
SAV_NUMERIC_TYPE,
SELECT_BIND_TYPE,
SELECT_ONE,
STATUS,
SUBMISSION_TIME,
SUBMITTED_BY,
TAGS,
UUID,
VERSION,
XFORM_ID_STRING,
)
from onadata.libs.utils.common_tools import (
get_choice_label,
get_choice_label_value,
get_value_or_attachment_uri,
str_to_bool,
track_task_progress,
)
from onadata.libs.utils.mongo import _decode_from_mongo, _is_invalid_for_mongo

# the bind type of select multiples that we use to compare
Expand Down Expand Up @@ -100,7 +120,7 @@ def dict_to_joined_export(data, index, indices, name, survey, row, media_xpaths=
media_xpaths = [] if media_xpaths is None else media_xpaths
# pylint: disable=too-many-nested-blocks
if isinstance(data, dict):
for (key, val) in iteritems(data):
for key, val in iteritems(data):
if isinstance(val, list) and key not in [NOTES, ATTACHMENTS, TAGS]:
output[key] = []
for child in val:
Expand All @@ -118,7 +138,7 @@ def dict_to_joined_export(data, index, indices, name, survey, row, media_xpaths=
}
# iterate over keys within new_output and append to
# main output
for (out_key, out_val) in iteritems(new_output):
for out_key, out_val in iteritems(new_output):
if isinstance(out_val, list):
if out_key not in output:
output[out_key] = []
Expand Down Expand Up @@ -202,7 +222,7 @@ def decode_mongo_encoded_section_names(data):
:param data: A dictionary to decode.
"""
results = {}
for (k, v) in iteritems(data):
for k, v in iteritems(data):
new_v = v
if isinstance(v, dict):
new_v = decode_mongo_encoded_section_names(v)
Expand Down Expand Up @@ -351,10 +371,19 @@ def format_field_title(
return title

def get_choice_label_from_dict(self, label):
"""Returns the choice label for the default language."""
"""Returns the choice label for the default language
If a label for the target language is blank then the default
language is used
"""
if isinstance(label, dict):
language = self.get_default_language(list(label))
label = label.get(self.language or language)
default_language = self.get_default_language(list(label))
default_label = label.get(default_language)

if self.language:
return label.get(self.language, default_label)

return default_label

return label

Expand Down Expand Up @@ -475,7 +504,7 @@ def build_sections(
)
elif isinstance(child, Question) and (
child.bind.get("type") not in QUESTION_TYPES_TO_EXCLUDE
and child.type not in QUESTION_TYPES_TO_EXCLUDE
and child.type not in QUESTION_TYPES_TO_EXCLUDE # noqa W503
):
# add to survey_sections
if isinstance(child, Question):
Expand All @@ -490,7 +519,7 @@ def build_sections(
data_dicionary.get_label(
child_xpath, elem=child, language=language
)
or _title
or _title # noqa W503
)
current_section["elements"].append(
{
Expand All @@ -511,7 +540,7 @@ def build_sections(
# if its a select multiple, make columns out of its choices
if (
child.bind.get("type") == SELECT_BIND_TYPE
and child.type == MULTIPLE_SELECT_TYPE
and child.type == MULTIPLE_SELECT_TYPE # noqa W503
):
choices = []
if self.SPLIT_SELECT_MULTIPLES:
Expand Down Expand Up @@ -590,7 +619,7 @@ def build_sections(
)
if (
child.bind.get("type") == SELECT_BIND_TYPE
and child.type == SELECT_ONE
and child.type == SELECT_ONE # noqa W503
):
_append_xpaths_to_section(
current_section_name,
Expand Down Expand Up @@ -679,7 +708,7 @@ def split_select_multiples(
:return: the row dict with select multiples choice as fields in the row
"""
# for each select_multiple, get the associated data and split it
for (xpath, choices) in iteritems(select_multiples):
for xpath, choices in iteritems(select_multiples):
# get the data matching this xpath
data = row.get(xpath) and str(row.get(xpath))
selections = []
Expand Down Expand Up @@ -738,7 +767,7 @@ def split_select_multiples(
def split_gps_components(cls, row, gps_fields):
"""Splits GPS components into their own fields."""
# for each gps_field, get associated data and split it
for (xpath, gps_components) in iteritems(gps_fields):
for xpath, gps_components in iteritems(gps_fields):
data = row.get(xpath)
if data:
gps_parts = data.split()
Expand All @@ -749,7 +778,7 @@ def split_gps_components(cls, row, gps_fields):
@classmethod
def decode_mongo_encoded_fields(cls, row, encoded_fields):
"""Update encoded fields with their corresponding xpath"""
for (xpath, encoded_xpath) in iteritems(encoded_fields):
for xpath, encoded_xpath in iteritems(encoded_fields):
if row.get(encoded_xpath):
val = row.pop(encoded_xpath)
row.update({xpath: val})
Expand Down Expand Up @@ -819,8 +848,8 @@ def pre_process_row(self, row, section):
value = row.get(elm["xpath"])
if (
elm["type"] in ExportBuilder.TYPES_TO_CONVERT
and value is not None
and value != ""
and value is not None # noqa W503
and value != "" # noqa W503
):
row[elm["xpath"]] = ExportBuilder.convert_type(value, elm["type"])

Expand Down Expand Up @@ -932,15 +961,15 @@ def write_row(row, csv_writer, fields):

# write zipfile
with ZipFile(path, "w", ZIP_DEFLATED, allowZip64=True) as zip_file:
for (section_name, csv_def) in iteritems(csv_defs):
for section_name, csv_def in iteritems(csv_defs):
csv_file = csv_def["csv_file"]
csv_file.seek(0)
zip_file.write(
csv_file.name, "_".join(section_name.split("/")) + ".csv"
)

# close files when we are done
for (section_name, csv_def) in iteritems(csv_defs):
for section_name, csv_def in iteritems(csv_defs):
csv_def["csv_file"].close()

@classmethod
Expand Down Expand Up @@ -1150,7 +1179,7 @@ def _get_sav_value_labels(self, xpath_var_names=None):
for question in choice_questions:
if (
xpath_var_names
and question.get_abbreviated_xpath() not in xpath_var_names
and question.get_abbreviated_xpath() not in xpath_var_names # noqa W503
):
continue
var_name = (
Expand Down Expand Up @@ -1293,7 +1322,7 @@ def _get_element_type(element_xpath):
)
for element in elements
]
+ [
+ [ # noqa W503
(
_var_types[item],
SAV_NUMERIC_TYPE
Expand All @@ -1302,7 +1331,7 @@ def _get_element_type(element_xpath):
)
for item in self.extra_columns
]
+ [
+ [ # noqa W503
(
x[1],
SAV_NUMERIC_TYPE
Expand Down Expand Up @@ -1400,20 +1429,20 @@ def write_row(row, sav_writer, fields):
index += 1
track_task_progress(i, total_records)

for (section_name, sav_def) in iteritems(sav_defs):
for section_name, sav_def in iteritems(sav_defs):
sav_def["sav_writer"].closeSavFile(sav_def["sav_writer"].fh, mode="wb")

# write zipfile
with ZipFile(path, "w", ZIP_DEFLATED, allowZip64=True) as zip_file:
for (section_name, sav_def) in iteritems(sav_defs):
for section_name, sav_def in iteritems(sav_defs):
sav_file = sav_def["sav_file"]
sav_file.seek(0)
zip_file.write(
sav_file.name, "_".join(section_name.split("/")) + ".sav"
)

# close files when we are done
for (section_name, sav_def) in iteritems(sav_defs):
for section_name, sav_def in iteritems(sav_defs):
sav_def["sav_file"].close()

def get_fields(self, dataview, section, key):
Expand Down

0 comments on commit 28d2990

Please sign in to comment.