Skip to content

Commit

Permalink
fix: handle marimo[extras] in --sandbox and package installation (#3425)
Browse files Browse the repository at this point in the history
This has a few fixes to be more resilient with `--sandbox`.

1. When installing `marimo[sql]` in the UI, we don't include the version
(same as `marimo`)
2. When reading the deps from the notebook, we dedupe `marimo` and any
`marimo[extras]`, and add the version correctly.
  • Loading branch information
mscolnick authored Jan 14, 2025
1 parent 42679f4 commit 19f4628
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 17 deletions.
63 changes: 48 additions & 15 deletions marimo/_cli/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def _read_pyproject(script: str) -> Dict[str, Any] | None:
return None


def _get_python_version_requirement(pyproject: Dict[str, Any]) -> str | None:
def _get_python_version_requirement(
pyproject: Dict[str, Any] | None,
) -> str | None:
"""Extract Python version requirement from pyproject metadata."""
if pyproject is None:
return None
Expand Down Expand Up @@ -203,6 +205,49 @@ def prompt_run_in_sandbox(name: str | None) -> bool:
return False


def _is_marimo_dependency(dependency: str) -> bool:
# Split on any version specifier
without_version = re.split(r"[=<>~]+", dependency)[0]
# Match marimo and marimo[extras], but not marimo-<something-else>
return without_version == "marimo" or without_version.startswith("marimo[")


def _is_versioned(dependency: str) -> bool:
return any(c in dependency for c in ("==", ">=", "<=", ">", "<", "~"))


def _normalize_sandbox_dependencies(
dependencies: List[str], marimo_version: str
) -> List[str]:
"""Normalize marimo dependencies to have only one version.
If multiple marimo dependencies exist, prefer the one with brackets.
Add version to the remaining one if not already versioned.
"""
# Find all marimo dependencies
marimo_deps = [d for d in dependencies if _is_marimo_dependency(d)]
if not marimo_deps:
# During development, you can comment this out to install an
# editable version of marimo assuming you are in the marimo directory
# DO NOT COMMIT THIS WHEN SUBMITTING PRs
# return dependencies + [f"marimo -e ."]

return dependencies + [f"marimo=={marimo_version}"]

# Prefer the one with brackets if it exists
bracketed = next((d for d in marimo_deps if "[" in d), None)
chosen = bracketed if bracketed else marimo_deps[0]

# Remove all marimo deps
filtered = [d for d in dependencies if not _is_marimo_dependency(d)]

# Add version if not already versioned
if not _is_versioned(chosen):
chosen = f"{chosen}=={marimo_version}"

return filtered + [chosen]


def run_in_sandbox(
args: List[str],
name: Optional[str] = None,
Expand All @@ -219,20 +264,8 @@ def run_in_sandbox(
get_dependencies_from_filename(name) if name is not None else []
)

# The sandbox needs to manage marimo, too, to make sure
# that the outer environment doesn't leak into the sandbox.
if "marimo" not in dependencies:
dependencies.append("marimo")

# Rename marimo to marimo=={__version__}
index_of_marimo = dependencies.index("marimo")
if index_of_marimo != -1:
dependencies[index_of_marimo] = f"marimo=={__version__}"

# During development, you can comment this out to install an
# editable version of marimo assuming you are in the marimo directory
# DO NOT COMMIT THIS WHEN SUBMITTING PRs
# dependencies[index_of_marimo] = "-e ."
# Normalize marimo dependencies
dependencies = _normalize_sandbox_dependencies(dependencies, __version__)

with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".txt"
Expand Down
13 changes: 11 additions & 2 deletions marimo/_runtime/packages/pypi_package_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ def update_notebook_script_metadata(
import_namespaces_to_add: Optional[List[str]] = None,
import_namespaces_to_remove: Optional[List[str]] = None,
) -> None:
"""Update the notebook's script metadata with the packages to add/remove.
Args:
filepath: Path to the notebook file
packages_to_add: List of packages to add to the script metadata
packages_to_remove: List of packages to remove from the script metadata
import_namespaces_to_add: List of import namespaces to add
import_namespaces_to_remove: List of import namespaces to remove
"""
packages_to_add = packages_to_add or []
packages_to_remove = packages_to_remove or []
import_namespaces_to_add = import_namespaces_to_add or []
Expand All @@ -152,8 +161,8 @@ def _is_installed(package: str) -> bool:
return without_brackets.lower() in version_map

def _maybe_add_version(package: str) -> str:
# Skip marimo
if package == "marimo":
# Skip marimo and marimo[<mod>], but not marimo-<something-else>
if package == "marimo" or package.startswith("marimo["):
return package
without_brackets = package.split("[")[0]
version = version_map.get(without_brackets.lower())
Expand Down
69 changes: 69 additions & 0 deletions tests/_cli/test_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from marimo._cli.sandbox import (
_get_dependencies,
_get_python_version_requirement,
_is_marimo_dependency,
_normalize_sandbox_dependencies,
_pyproject_toml_to_requirements_txt,
_read_pyproject,
get_dependencies_from_filename,
Expand Down Expand Up @@ -239,3 +241,70 @@ def test_get_dependencies_with_nonexistent_file():

# Test with None
assert get_dependencies_from_filename(None) == [] # type: ignore


def test_normalize_marimo_dependencies():
# Test adding marimo when not present
assert _normalize_sandbox_dependencies(["numpy"], "1.0.0") == [
"numpy",
"marimo==1.0.0",
]

# Test preferring bracketed version
assert _normalize_sandbox_dependencies(
["marimo", "marimo[extras]", "numpy"], "1.0.0"
) == ["numpy", "marimo[extras]==1.0.0"]

# Test keeping existing version with brackets
assert _normalize_sandbox_dependencies(
["marimo[extras]>=0.1.0", "numpy"], "1.0.0"
) == ["numpy", "marimo[extras]>=0.1.0"]

# Test adding version when none exists
assert _normalize_sandbox_dependencies(
["marimo[extras]", "numpy"], "1.0.0"
) == ["numpy", "marimo[extras]==1.0.0"]

# Test keeping only one marimo dependency
assert _normalize_sandbox_dependencies(
["marimo>=0.1.0", "marimo[extras]>=0.2.0", "numpy"], "1.0.0"
) == ["numpy", "marimo[extras]>=0.2.0"]
assert _normalize_sandbox_dependencies(
["marimo", "marimo[extras]>=0.2.0", "numpy"], "1.0.0"
) == ["numpy", "marimo[extras]>=0.2.0"]

# Test various version specifiers are preserved
version_specs = [
"==0.1.0",
">=0.1.0",
"<=0.1.0",
">0.1.0",
"<0.1.0",
"~=0.1.0",
]
for spec in version_specs:
assert _normalize_sandbox_dependencies(
[f"marimo{spec}", "numpy"], "1.0.0"
) == ["numpy", f"marimo{spec}"]


def test_is_marimo_dependency():
assert _is_marimo_dependency("marimo")
assert _is_marimo_dependency("marimo[extras]")
assert not _is_marimo_dependency("marimo-extras")
assert not _is_marimo_dependency("marimo-ai")

# With version specifiers
assert _is_marimo_dependency("marimo==0.1.0")
assert _is_marimo_dependency("marimo[extras]>=0.1.0")
assert _is_marimo_dependency("marimo[extras]==0.1.0")
assert _is_marimo_dependency("marimo[extras]~=0.1.0")
assert _is_marimo_dependency("marimo[extras]<=0.1.0")
assert _is_marimo_dependency("marimo[extras]>=0.1.0")
assert _is_marimo_dependency("marimo[extras]<=0.1.0")

# With other packages
assert not _is_marimo_dependency("numpy")
assert not _is_marimo_dependency("pandas")
assert not _is_marimo_dependency("marimo-ai")
assert not _is_marimo_dependency("marimo-ai==0.1.0")
108 changes: 108 additions & 0 deletions tests/_runtime/packages/test_package_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,111 @@ def _get_version_map(self) -> dict[str, str]:
"ibis-framework[duckdb]==2.0",
],
]


def test_update_script_metadata_marimo_packages() -> None:
runs_calls: list[list[str]] = []

class MockUvPackageManager(UvPackageManager):
def run(self, command: list[str]) -> bool:
runs_calls.append(command)
return True

def _get_version_map(self) -> dict[str, str]:
return {
"marimo": "0.1.0",
"marimo-ai": "0.2.0",
"pandas": "2.0.0",
}

pm = MockUvPackageManager()

# Test 1: Basic package handling
pm.update_notebook_script_metadata(
filepath="nb.py",
packages_to_add=[
"marimo-ai", # Should have version (different package)
"pandas", # Should have version
],
)
assert runs_calls == [
[
"uv",
"--quiet",
"add",
"--script",
"nb.py",
"marimo-ai==0.2.0",
"pandas==2.0.0",
]
]
runs_calls.clear()

# Test 2: Marimo package consolidation - should prefer marimo[ai] over marimo
pm.update_notebook_script_metadata(
filepath="nb.py",
packages_to_add=[
"marimo",
"marimo[sql]",
"pandas",
],
)
assert runs_calls == [
[
"uv",
"--quiet",
"add",
"--script",
"nb.py",
"marimo",
"marimo[sql]",
"pandas==2.0.0",
]
]
runs_calls.clear()

# Test 3: Multiple marimo extras - should use first one
pm.update_notebook_script_metadata(
filepath="nb.py",
packages_to_add=[
"marimo",
"marimo[sql]",
"marimo[recommended]",
"pandas",
],
)
assert runs_calls == [
[
"uv",
"--quiet",
"add",
"--script",
"nb.py",
"marimo",
"marimo[sql]",
"marimo[recommended]",
"pandas==2.0.0",
]
]
runs_calls.clear()

# Test 4: Only plain marimo
pm.update_notebook_script_metadata(
filepath="nb.py",
packages_to_add=[
"marimo",
"pandas",
],
)
assert runs_calls == [
[
"uv",
"--quiet",
"add",
"--script",
"nb.py",
"marimo",
"pandas==2.0.0",
]
]
runs_calls.clear()

0 comments on commit 19f4628

Please sign in to comment.