Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch bad data prep #1644

Merged
merged 26 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
83ea875
unit test
milocress Nov 6, 2024
d989282
exempt test cases from json output file validation check
milocress Nov 6, 2024
720aa2b
update convert_delta_to_json.py
milocress Nov 6, 2024
9ab2134
make split optional
milocress Nov 6, 2024
74361e4
fix
milocress Nov 6, 2024
158cbef
support optional args?
milocress Nov 6, 2024
2678de8
Merge branch 'main' into milo/catch-storage-issues
dakinggg Nov 6, 2024
655c01b
Merge branch 'main' into milo/catch-storage-issues
milocress Nov 6, 2024
79e6411
Merge branch 'main' into milo/catch-storage-issues
milocress Nov 8, 2024
9b22b6d
Merge branch 'main' into milo/catch-storage-issues
milocress Nov 11, 2024
17fb20e
fix
milocress Nov 12, 2024
7191e0b
merged
milocress Nov 12, 2024
650dc16
assert file written
milocress Nov 12, 2024
e19b713
fix
milocress Nov 12, 2024
1b6fbaf
side effect
milocress Nov 12, 2024
2f53059
fix silliness
milocress Nov 12, 2024
39f6421
uncreated named temp file
milocress Nov 12, 2024
10b946a
Merge branch 'main' into milo/catch-storage-issues
milocress Nov 12, 2024
6299a01
Merge branch 'main' into milo/catch-storage-issues
milocress Nov 12, 2024
841c383
type
milocress Nov 12, 2024
7bddf5b
merge
milocress Nov 12, 2024
d19b7e7
fix
milocress Nov 13, 2024
38ad33e
behold, the most difficult unit test fixture ever created
milocress Nov 13, 2024
adcd2a1
pyright
milocress Nov 13, 2024
a80779e
Merge branch 'main' into milo/catch-storage-issues
milocress Nov 13, 2024
72ee750
Merge branch 'main' into milo/catch-storage-issues
milocress Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import json
import logging
import os
import re
Expand All @@ -26,6 +27,8 @@
FailedToCreateSQLConnectionError,
FaultyDataPrepCluster,
InsufficientPermissionsError,
MisconfiguredHfDatasetError,
StoragePermissionError,
UCNotEnabledError,
)

Expand Down Expand Up @@ -681,7 +684,7 @@ def fetch_DT(

log.info(f'Directory {json_output_folder} created.')

# validate_and_get_cluster_info allows cluster_id to be None if use_serverless is True
# Validate_and_get_cluster_info allows cluster_id to be None if use_serverless is True.
method, dbsql, sparkSession = validate_and_get_cluster_info(
cluster_id=cluster_id,
databricks_host=DATABRICKS_HOST,
Expand Down Expand Up @@ -732,12 +735,41 @@ def fetch_DT(
if dbsql is not None:
dbsql.close()

# combine downloaded jsonl into one big jsonl for IFT
# Combine downloaded jsonl into one big jsonl for IFT.
iterative_combine_jsons(
milocress marked this conversation as resolved.
Show resolved Hide resolved
json_output_folder,
os.path.join(json_output_folder, json_output_filename),
)

_validate_written_file(
json_output_folder,
json_output_filename,
delta_table_name,
)


def _validate_written_file(
json_output_folder: str,
json_output_filename: str,
delta_table_name: str,
):
# Validate downloaded dataset is actually downloaded.
with open(os.path.join(json_output_folder, json_output_filename)) as f:
is_empty = True
for line in f.readlines():
is_empty = False
try:
json.loads(line)
except:
raise MisconfiguredHfDatasetError(
delta_table_name,
split=json_output_folder,
milocress marked this conversation as resolved.
Show resolved Hide resolved
) from ValueError('line')
milocress marked this conversation as resolved.
Show resolved Hide resolved
if is_empty:
raise StoragePermissionError(
f'Unable to download {delta_table_name}, check network permissions.',
)


def _check_imports():
try:
Expand Down
152 changes: 95 additions & 57 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
import sys
import unittest
from argparse import Namespace
from tempfile import NamedTemporaryFile
from typing import Any
from unittest.mock import MagicMock, mock_open, patch

Expand All @@ -14,14 +16,18 @@
from llmfoundry.command_utils.data_prep.convert_delta_to_json import (
FaultyDataPrepCluster,
InsufficientPermissionsError,
_validate_written_file,
download,
fetch,
fetch_DT,
format_tablename,
iterative_combine_jsons,
run_query,
)
from llmfoundry.utils.exceptions import DeltaTableNotFoundError
from llmfoundry.utils.exceptions import (
DeltaTableNotFoundError,
StoragePermissionError,
)


class TestConvertDeltaToJsonl(unittest.TestCase):
Expand Down Expand Up @@ -102,17 +108,20 @@ def test_stream_delta_to_json(
)
mock_workspace_client.return_value.clusters.get = mock_cluster_get

fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
batch_size=batch_size,
json_output_filename=json_output_filename,
)
try:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
batch_size=batch_size,
json_output_filename=json_output_filename,
)
except FileNotFoundError:
milocress marked this conversation as resolved.
Show resolved Hide resolved
pass
mock_sql_connect.assert_called_once_with(
server_hostname='test_host',
http_path='test_path',
Expand Down Expand Up @@ -287,15 +296,18 @@ def test_dbconnect_called(
) # Mock return value for getOrCreate
mock_databricks_session.builder.remote.return_value = mock_remote

fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
)
try:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
)
except FileNotFoundError:
pass
mock_databricks_session.builder.remote.assert_called_once_with(
host=DATABRICKS_HOST,
token=DATABRICKS_TOKEN,
Expand Down Expand Up @@ -342,15 +354,19 @@ def test_sqlconnect_called_dbr13(
)
mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response

fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
)
try:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
)
except FileNotFoundError:
pass

mock_sql_connect.assert_called_once_with(
server_hostname=DATABRICKS_HOST,
http_path=http_path,
Expand Down Expand Up @@ -397,15 +413,19 @@ def test_sqlconnect_called_dbr14(
)
mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response

fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
)
try:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
)
except FileNotFoundError:
pass

mock_sql_connect.assert_called_once_with(
server_hostname=DATABRICKS_HOST,
http_path=http_path,
Expand Down Expand Up @@ -452,15 +472,18 @@ def test_sqlconnect_called_https(
)
mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response

fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
)
try:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
)
except FileNotFoundError:
pass
mock_sql_connect.assert_called_once_with(
server_hostname='test-host',
http_path=http_path,
Expand Down Expand Up @@ -504,15 +527,19 @@ def test_serverless(
mock_cluster_response = Namespace(spark_version='14.2.0-scala2.12')
mock_workspace_client.return_value.clusters.get.return_value = mock_cluster_response

fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
)
try:
fetch_DT(
delta_table_name=delta_table_name,
json_output_folder=json_output_folder,
http_path=http_path,
cluster_id=cluster_id,
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
use_serverless=use_serverless,
)
except FileNotFoundError:
pass

assert not mock_sql_connect.called
assert not mock_databricks_session.builder.remote.called

Expand Down Expand Up @@ -644,3 +671,14 @@ def test_fetch_nonexistent_table_error(

# Verify that get_total_rows was called
mock_gtr.assert_called_once()

def test_fetch_DT_catches_bad_download(self):
with NamedTemporaryFile() as tf:
file_name = tf.name
file_folder, file_name = os.path.split(file_name)
with self.assertRaises(StoragePermissionError):
_validate_written_file(
file_folder,
file_name,
'test_delta_table',
)
Loading