Skip to content

Commit 027c8a8

Browse files
committed
Save the data locally.
1 parent 99fc16f commit 027c8a8

File tree

9 files changed

+116
-37
lines changed

9 files changed

+116
-37
lines changed

editor/core/data_types.py

+7
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,10 @@ def str_to_mlc_data_type(data_type: str) -> mlc.DataType | None:
3737
if data_type == str_data_type:
3838
return mlc_data_type
3939
return None
40+
41+
42+
def mlc_to_str_data_type(data_type: str) -> mlc.DataType | None:
43+
for str_data_type, mlc_data_type in zip(STR_DATA_TYPES, MLC_DATA_TYPES):
44+
if data_type == mlc_data_type:
45+
return str_data_type
46+
return None

editor/core/files.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import requests
99

1010
from .names import find_unique_name
11+
from .path import get_resource_path
1112
from .state import FileObject
1213
from .state import FileSet
1314

@@ -97,7 +98,9 @@ def get_dataframe(file_type: FileType, file: io.BytesIO | epath.Path) -> pd.Data
9798
raise NotImplementedError()
9899

99100

100-
def file_from_url(file_type: FileType, url: str, names: set[str]) -> FileObject:
101+
def file_from_url(
102+
file_type: FileType, url: str, names: set[str], folder: epath.Path
103+
) -> FileObject:
101104
"""Downloads locally and extracts the file information."""
102105
file_path = hash_file_path(url)
103106
if not file_path.exists():
@@ -112,30 +115,38 @@ def file_from_url(file_type: FileType, url: str, names: set[str]) -> FileObject:
112115
encoding_format=file_type.encoding_format,
113116
sha256=sha256,
114117
df=df,
118+
folder=folder,
115119
)
116120

117121

118122
def file_from_upload(
119-
file_type: FileType, file: io.BytesIO, names: set[str]
123+
file_type: FileType, file: io.BytesIO, names: set[str], folder: epath.Path
120124
) -> FileObject:
121125
"""Uploads locally and extracts the file information."""
122-
sha256 = _sha256(file.getvalue())
126+
value = file.getvalue()
127+
content_url = f"data/{file.name}"
128+
sha256 = _sha256(value)
129+
with get_resource_path(content_url).open("wb") as f:
130+
f.write(value)
123131
df = get_dataframe(file_type, file).infer_objects()
124132
return FileObject(
125133
name=find_unique_name(names, file.name),
126134
description="",
127-
content_url=f"data/{file.name}",
135+
content_url=content_url,
128136
encoding_format=file_type.encoding_format,
129137
sha256=sha256,
130138
df=df,
139+
folder=folder,
131140
)
132141

133142

134-
def file_from_form(type: str, names: set[str]) -> FileObject | FileSet:
143+
def file_from_form(
144+
type: str, names: set[str], folder: epath.Path
145+
) -> FileObject | FileSet:
135146
"""Creates a file based on manually added fields."""
136147
if type == FILE_OBJECT:
137-
return FileObject(name=find_unique_name(names, "file_object"))
148+
return FileObject(name=find_unique_name(names, "file_object"), folder=folder)
138149
elif type == FILE_SET:
139-
return FileSet(name=find_unique_name(names, "file_set"))
150+
return FileSet(name=find_unique_name(names, "file_set"), folder=folder)
140151
else:
141152
raise ValueError("type has to be one of FILE_OBJECT, FILE_SET")

editor/core/files_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ def test_check_file_csv():
1818
f.write("a,1\n")
1919
f.write("b,2\n")
2020
f.write("c,3\n")
21-
file = file_from_url(FileTypes.CSV, "https://my.url", set())
21+
file = file_from_url(FileTypes.CSV, "https://my.url", set(), epath.Path())
2222
pd.testing.assert_frame_equal(
2323
file.df, pd.DataFrame({"column1": ["a", "b", "c"], "column2": [1, 2, 3]})
2424
)
2525
# Fails with unknown encoding_format:
2626
with pytest.raises(NotImplementedError):
27-
file_from_url("unknown", "https://my.url", set())
27+
file_from_url("unknown", "https://my.url", set(), epath.Path())

editor/core/path.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from etils import epath
2+
import streamlit as st
3+
4+
from core.state import CurrentProject
5+
6+
7+
def get_resource_path(content_url: str) -> epath.Path:
8+
"""Gets the path on disk of the resource with `content_url`."""
9+
project: CurrentProject = st.session_state[CurrentProject]
10+
path = project.path / content_url
11+
if not path.parent.exists():
12+
path.parent.mkdir(parents=True, exist_ok=True)
13+
return path

editor/core/state.py

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class FileObject:
137137
sha256: str | None = None
138138
df: pd.DataFrame | None = None
139139
rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)
140+
folder: epath.PathLike | None = None
140141

141142

142143
@dataclasses.dataclass

editor/cypress/e2e/uploadCsv.cy.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ describe('Editor loads a local CSV as a resource', () => {
5252
cy.contains('base.csv_record_set (2 fields)').click()
5353
// We also see the fields with the proper types.
5454
cy.get('[data-testid="stDataFrameResizable"]').contains("column1")
55-
cy.get('[data-testid="stDataFrameResizable"]').contains("https://schema.org/Text")
55+
cy.get('[data-testid="stDataFrameResizable"]').contains("Text")
5656
cy.get('[data-testid="stDataFrameResizable"]').contains("column2")
57-
cy.get('[data-testid="stDataFrameResizable"]').contains("https://schema.org/Integer")
57+
cy.get('[data-testid="stDataFrameResizable"]').contains("Integer")
5858

5959
// I can edit the details of the fields.
6060
cy.contains('Generating the dataset...').should('not.exist')

editor/events/resources.py

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import streamlit as st
55

66
from core.files import FILE_OBJECT
7+
from core.path import get_resource_path
78
from core.state import FileObject
89
from core.state import FileSet
910
from core.state import Metadata
@@ -47,6 +48,11 @@ def handle_resource_change(event: ResourceEvent, resource: Resource, key: str):
4748
elif event == ResourceEvent.CONTENT_SIZE:
4849
resource.content_size = value
4950
elif event == ResourceEvent.CONTENT_URL:
51+
if resource.content_url and value:
52+
old_path = get_resource_path(resource.content_url)
53+
new_path = get_resource_path(value)
54+
if old_path.exists() and not new_path.exists():
55+
old_path.rename(new_path)
5056
resource.content_url = value
5157
elif event == ResourceEvent.TYPE:
5258
metadata: Metadata = st.session_state[Metadata]

editor/views/files.py

+35-19
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33

44
from components.tree import render_tree
55
from core.constants import DF_HEIGHT
6+
from core.constants import OAUTH_CLIENT_ID
67
from core.files import file_from_form
78
from core.files import file_from_upload
89
from core.files import file_from_url
910
from core.files import FILE_OBJECT
1011
from core.files import FILE_SET
1112
from core.files import FILE_TYPES
1213
from core.files import RESOURCE_TYPES
14+
from core.path import get_resource_path
1315
from core.record_sets import infer_record_sets
16+
from core.state import CurrentProject
1417
from core.state import FileObject
1518
from core.state import FileSet
1619
from core.state import Metadata
@@ -32,21 +35,7 @@
3235

3336
def render_files():
3437
"""Renders the views of the files: warnings and panels to display information."""
35-
metadata: Metadata = st.session_state[Metadata]
36-
warning = ""
37-
for resource in metadata.distribution:
38-
content_url = resource.content_url
39-
if (
40-
content_url
41-
and not content_url.startswith("http")
42-
and not epath.Path(content_url).exists()
43-
):
44-
warning += (
45-
f'⚠️ Resource "{resource.name}" is local (from `{content_url}`), but'
46-
" doesn't exist on the disk. Fix this by either downloading\n\n"
47-
)
48-
if warning:
49-
st.warning(warning.strip())
38+
_render_warnings()
5039
col1, col2, col3 = st.columns([1, 1, 1], gap="small")
5140
with col1:
5241
st.markdown("##### Upload more resources")
@@ -60,6 +49,31 @@ def render_files():
6049
_render_right_panel()
6150

6251

52+
def _render_warnings():
53+
"""Renders warnings concerning local files."""
54+
metadata: Metadata = st.session_state[Metadata]
55+
warning = ""
56+
for resource in metadata.distribution:
57+
content_url = resource.content_url
58+
if content_url and not content_url.startswith("http"):
59+
path = get_resource_path(content_url)
60+
if not path.exists():
61+
if OAUTH_CLIENT_ID:
62+
warning += (
63+
f'⚠️ Resource "{resource.name}" points to a local file, but'
64+
" doesn't exist on the disk. Fix this by changing the content"
65+
" URL.\n\n"
66+
)
67+
else:
68+
warning += (
69+
f'⚠️ Resource "{resource.name}" points to a local file, but'
70+
" doesn't exist on the disk. Fix this by either downloading"
71+
f" it to {path} or changing the content URL.\n\n"
72+
)
73+
if warning:
74+
st.warning(warning.strip())
75+
76+
6377
def _render_resources_panel(files: list[Resource]) -> Resource | None:
6478
"""Renders the left panel: the list of all resources."""
6579
filename_to_file: dict[str, list[Resource]] = {}
@@ -112,13 +126,15 @@ def handle_on_click():
112126
file_type = FILE_TYPES[file_type_name]
113127
metadata: Metadata = st.session_state[Metadata]
114128
names = metadata.names()
129+
project: CurrentProject = st.session_state[CurrentProject]
130+
folder = project.path
115131
if url:
116-
file = file_from_url(file_type, url, names)
132+
file = file_from_url(file_type, url, names, folder)
117133
elif uploaded_file:
118-
file = file_from_upload(file_type, uploaded_file, names)
134+
file = file_from_upload(file_type, uploaded_file, names, folder)
119135
else:
120136
resource_type = st.session_state[_MANUAL_RESOURCE_TYPE_KEY]
121-
file = file_from_form(resource_type, names)
137+
file = file_from_form(resource_type, names, folder)
122138

123139
st.session_state[Metadata].add_distribution(file)
124140
record_sets = infer_record_sets(file, names)
@@ -170,7 +186,7 @@ def close():
170186
col1, col2 = st.columns([1, 1])
171187
col1.button("Close", key=f"{i}_close", on_click=close, type="primary")
172188
col2.button(
173-
"Remove", key=f"{i}_remove", on_click=delete_line, type="secondary"
189+
"⚠️ Remove", key=f"{i}_remove", on_click=delete_line, type="secondary"
174190
)
175191

176192

editor/views/record_sets.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import multiprocessing
22
import textwrap
33
import time
4+
import traceback
45
from typing import TypedDict
56

67
import numpy as np
@@ -9,7 +10,9 @@
910
import streamlit as st
1011

1112
from core.data_types import MLC_DATA_TYPES
13+
from core.data_types import mlc_to_str_data_type
1214
from core.data_types import STR_DATA_TYPES
15+
from core.data_types import str_to_mlc_data_type
1316
from core.query_params import expand_record_set
1417
from core.query_params import is_record_set_expanded
1518
from core.state import Field
@@ -34,7 +37,16 @@ class _Result(TypedDict):
3437
exception: Exception | None
3538

3639

37-
@st.cache_data(show_spinner="Generating the dataset...")
40+
@st.cache_data(
41+
show_spinner="Generating the dataset...",
42+
hash_funcs={
43+
"mlcroissant.Metadata": hash,
44+
"mlcroissant.Field": hash,
45+
"mlcroissant.FileObject": hash,
46+
"mlcroissant.FileSet": hash,
47+
"mlcroissant.RecordSet": hash,
48+
},
49+
)
3850
def _generate_data_with_timeout(record_set: RecordSet) -> _Result:
3951
"""Generates the data and waits at most _TIMEOUT_SECONDS."""
4052
with multiprocessing.Manager() as manager:
@@ -59,7 +71,7 @@ def _generate_data(record_set: RecordSet, result: _Result) -> pd.DataFrame | Non
5971
"""Generates the first _NUM_RECORDS records."""
6072
try:
6173
metadata: Metadata = st.session_state[Metadata]
62-
if not metadata:
74+
if metadata is None:
6375
raise ValueError(
6476
"The dataset is still incomplete. Please, go to the overview to see"
6577
" errors."
@@ -81,8 +93,8 @@ def _generate_data(record_set: RecordSet, result: _Result) -> pd.DataFrame | Non
8193
pass
8294
df.append(record)
8395
result["df"] = pd.DataFrame(df)
84-
except Exception as exception:
85-
result["exception"] = exception
96+
except Exception:
97+
result["exception"] = traceback.format_exc()
8698

8799

88100
def _handle_close_fields():
@@ -148,6 +160,10 @@ def _handle_create_record_set():
148160
metadata.add_record_set(RecordSet(name="new-record-set", description=""))
149161

150162

163+
def _handle_remove_record_set(record_set_key: int):
164+
del st.session_state[Metadata].record_sets[record_set_key]
165+
166+
151167
def _handle_fields_change(record_set_key: int, record_set: RecordSet):
152168
expand_record_set(record_set=record_set)
153169
data_editor_key = _data_editor_key(record_set_key, record_set)
@@ -166,12 +182,13 @@ def _handle_fields_change(record_set_key: int, record_set: RecordSet):
166182
elif new_field == FieldDataFrame.DESCRIPTION:
167183
field.description = new_value
168184
elif new_field == FieldDataFrame.DATA_TYPE:
169-
field.data_types = [new_value]
185+
field.data_types = [str_to_mlc_data_type(new_value)]
170186
for added_row in result["added_rows"]:
187+
data_type = str_to_mlc_data_type(added_row.get(FieldDataFrame.DATA_TYPE))
171188
field = Field(
172189
name=added_row.get(FieldDataFrame.NAME),
173190
description=added_row.get(FieldDataFrame.DESCRIPTION),
174-
data_types=[added_row.get(FieldDataFrame.DATA_TYPE)],
191+
data_types=[data_type],
175192
source=mlc.Source(),
176193
references=mlc.Source(),
177194
)
@@ -290,7 +307,7 @@ def _render_left_panel():
290307
# TODO(https://github.com/mlcommons/croissant/issues/350): Allow to display
291308
# several data types, not only the first.
292309
data_types = [
293-
field.data_types[0] if field.data_types else None
310+
mlc_to_str_data_type(field.data_types[0]) if field.data_types else None
294311
for field in record_set.fields
295312
]
296313
fields = pd.DataFrame(
@@ -359,6 +376,14 @@ def _render_left_panel():
359376
on_click=_handle_on_click_field,
360377
args=(record_set_key, record_set),
361378
)
379+
key = f"{prefix}-delete-record-set"
380+
st.button(
381+
"⚠️ Delete RecordSet",
382+
type="primary",
383+
key=key,
384+
on_click=_handle_remove_record_set,
385+
args=(record_set_key,),
386+
)
362387
st.button(
363388
"Create a new RecordSet",
364389
key=f"create-new-record-set",

0 commit comments

Comments
 (0)