Skip to content

fix(cli): ensure create_crew respects --provider flag #2499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 54 additions & 70 deletions src/crewai/cli/create_crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,68 +93,81 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path)
if not skip_provider:
if not provider:
provider_models = get_provider_data()
if not provider_models:
return

existing_provider = None
for provider, env_keys in ENV_VARS.items():
if any(
"key_name" in details and details["key_name"] in env_vars
for details in env_keys
):
existing_provider = provider
break

if existing_provider:
if not click.confirm(
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
):
click.secho("Keeping existing provider configuration.", fg="yellow")
return

provider_models = get_provider_data()
if not provider_models:
click.secho("Could not retrieve provider data.", fg="red")
return

while True:
selected_provider = select_provider(provider_models)
if selected_provider is None: # User typed 'q'
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_provider: # Valid selection
break
click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red"
)
selected_provider = None

if provider:
provider = provider.lower()
if provider in provider_models:
selected_provider = provider
click.secho(f"Using specified provider: {selected_provider.capitalize()}", fg="green")
else:
click.secho(f"Warning: Specified provider '{provider}' is not recognized. Please select one.", fg="yellow")

if not selected_provider:
existing_provider = None
for p, env_keys in ENV_VARS.items():
if any(
"key_name" in details and details["key_name"] in env_vars
for details in env_keys
):
existing_provider = p
break

if existing_provider:
if not click.confirm(
f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?"
):
click.secho("Keeping existing provider configuration. Exiting provider setup.", fg="yellow")
copy_template_files(folder_path, name, class_name, parent_folder)
click.secho(f"Crew '{name}' created successfully!", fg="green")
click.secho(f"To run your crew, cd into '{folder_name}' and run 'crewai run'", fg="cyan")
return
else:
pass

while True:
selected_provider = select_provider(provider_models)
if selected_provider is None:
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_provider:
break
click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red"
)

if not selected_provider:
click.secho("Provider selection failed. Exiting.", fg="red")
sys.exit(1)


# Check if the selected provider has predefined models
if selected_provider in MODELS and MODELS[selected_provider]:
while True:
selected_model = select_model(selected_provider, provider_models)
if selected_model is None: # User typed 'q'
if selected_model is None:
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_model: # Valid selection
if selected_model:
break
click.secho(
"No model selected. Please try again or press 'q' to exit.",
fg="red",
)
env_vars["MODEL"] = selected_model

# Check if the selected provider requires API keys
if selected_provider in ENV_VARS:
provider_env_vars = ENV_VARS[selected_provider]
for details in provider_env_vars:
if details.get("default", False):
# Automatically add default key-value pairs
for key, value in details.items():
if key not in ["prompt", "key_name", "default"]:
env_vars[key] = value
elif "key_name" in details:
# Prompt for non-default key-value pairs
prompt = details["prompt"]
key_name = details["key_name"]
api_key_value = click.prompt(prompt, default="", show_default=False)
Expand All @@ -167,41 +180,12 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
click.secho("API keys and model saved to .env file", fg="green")
else:
click.secho(
"No API keys provided. Skipping .env file creation.", fg="yellow"
"No API keys provided or required by provider. Skipping .env file creation.", fg="yellow"
)

click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green")

package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew"

root_template_files = (
[".gitignore", "pyproject.toml", "README.md", "knowledge/user_preference.txt"]
if not parent_folder
else []
)
tools_template_files = ["tools/custom_tool.py", "tools/__init__.py"]
config_template_files = ["config/agents.yaml", "config/tasks.yaml"]
src_template_files = (
["__init__.py", "main.py", "crew.py"] if not parent_folder else ["crew.py"]
)

for file_name in root_template_files:
src_file = templates_dir / file_name
dst_file = folder_path / file_name
copy_template(src_file, dst_file, name, class_name, folder_name)

src_folder = folder_path / "src" / folder_name if not parent_folder else folder_path

for file_name in src_template_files:
src_file = templates_dir / file_name
dst_file = src_folder / file_name
copy_template(src_file, dst_file, name, class_name, folder_name)

if not parent_folder:
for file_name in tools_template_files + config_template_files:
src_file = templates_dir / file_name
dst_file = src_folder / file_name
copy_template(src_file, dst_file, name, class_name, folder_name)
copy_template_files(folder_path, name, class_name, parent_folder)

click.secho(f"Crew {name} created successfully!", fg="green", bold=True)
click.secho(f"Crew '{name}' created successfully!", fg="green")
click.secho(f"To run your crew, cd into '{folder_name}' and run 'crewai run'", fg="cyan")
142 changes: 142 additions & 0 deletions tests/cli/test_create_crew.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import pytest
from click.testing import CliRunner
from unittest.mock import patch, MagicMock
from pathlib import Path
import sys

# Ensure the src directory is in the Python path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'src'))

from crewai.cli.cli import crewai
from crewai.cli import create_crew
from crewai.cli.constants import MODELS, ENV_VARS

# Mock provider data for testing
MOCK_PROVIDER_DATA = {
'openai': {'models': ['gpt-4', 'gpt-3.5-turbo']},
'google': {'models': ['gemini-pro']},
'anthropic': {'models': ['claude-3-opus']}
}

MOCK_VALID_PROVIDERS = list(MOCK_PROVIDER_DATA.keys())

@pytest.fixture
def runner():
return CliRunner()

@pytest.fixture(autouse=True)
def isolate_fs(monkeypatch):
# Prevent tests from interacting with the actual filesystem or real env vars
monkeypatch.setattr(Path, 'mkdir', lambda *args, **kwargs: None)
monkeypatch.setattr(Path, 'exists', lambda *args: False) # Assume folders don't exist initially
monkeypatch.setattr(create_crew, 'load_env_vars', lambda *args: {}) # Start with empty env vars
monkeypatch.setattr(create_crew, 'write_env_file', lambda *args, **kwargs: None)
monkeypatch.setattr(create_crew, 'copy_template_files', lambda *args, **kwargs: None)

@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
@patch('crewai.cli.create_crew.select_provider')
@patch('crewai.cli.create_crew.select_model')
@patch('click.prompt')
@patch('click.confirm', return_value=True) # Default to confirming prompts
def test_create_crew_with_valid_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
"""Test `crewai create crew <name> --provider <valid_provider>`"""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--provider', 'openai'])

print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
assert "Using specified provider: Openai" in result.output
mock_select_provider.assert_not_called() # Should not ask interactively
# Depending on whether openai needs models/keys, check select_model/prompt calls
assert "Crew 'testcrew' created successfully!" in result.output

@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
@patch('crewai.cli.create_crew.select_provider', return_value='google') # Simulate user selecting google
@patch('crewai.cli.create_crew.select_model', return_value='gemini-pro')
@patch('click.prompt')
@patch('click.confirm', return_value=True)
def test_create_crew_with_invalid_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
"""Test `crewai create crew <name> --provider <invalid_provider>`"""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--provider', 'invalidprovider'])

print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
assert "Warning: Specified provider 'invalidprovider' is not recognized." in result.output
mock_select_provider.assert_called_once() # Should ask interactively
# Check if subsequent steps for the selected provider (google) ran
mock_select_model.assert_called_once()
assert "Crew 'testcrew' created successfully!" in result.output

@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
@patch('crewai.cli.create_crew.select_provider', return_value='anthropic') # Simulate user selecting anthropic
@patch('crewai.cli.create_crew.select_model', return_value='claude-3-opus')
@patch('click.prompt', return_value='sk-abc') # Simulate API key entry
@patch('click.confirm', return_value=True)
def test_create_crew_no_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
"""Test `crewai create crew <name>`"""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew'])

print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
assert "Using specified provider:" not in result.output # Should not mention specified provider
mock_select_provider.assert_called_once() # Should ask interactively
mock_select_model.assert_called_once()
# Check if prompt for API key was called (assuming anthropic needs one)
if 'anthropic' in ENV_VARS and any('key_name' in d for d in ENV_VARS['anthropic']):
mock_prompt.assert_called()
assert "Crew 'testcrew' created successfully!" in result.output

@patch('crewai.cli.create_crew.get_provider_data')
@patch('crewai.cli.create_crew.select_provider')
@patch('crewai.cli.create_crew.select_model')
@patch('click.prompt')
@patch('click.confirm')
def test_create_crew_skip_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner):
"""Test `crewai create crew <name> --skip_provider`"""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--skip_provider'])

print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
mock_get_data.assert_not_called()
mock_select_provider.assert_not_called()
mock_select_model.assert_not_called()
mock_prompt.assert_not_called()
mock_confirm.assert_not_called()
assert "Crew 'testcrew' created successfully!" in result.output

@patch('crewai.cli.create_crew.load_env_vars', return_value={'OPENAI_API_KEY': 'existing_key'}) # Simulate existing env
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
@patch('crewai.cli.create_crew.select_provider', return_value='google') # Simulate selecting new provider
@patch('crewai.cli.create_crew.select_model', return_value='gemini-pro')
@patch('click.prompt')
@patch('click.confirm', return_value=True) # User confirms override
def test_create_crew_existing_override(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, mock_load_env, runner):
"""Test `crewai create crew <name>` with existing config and user overrides."""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew'])

print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
mock_confirm.assert_called_once_with(
'Found existing environment variable configuration for Openai. Do you want to override it?'
)
mock_select_provider.assert_called_once() # Should ask for new provider after confirming override
assert "Crew 'testcrew' created successfully!" in result.output

@patch('crewai.cli.create_crew.load_env_vars', return_value={'OPENAI_API_KEY': 'existing_key'}) # Simulate existing env
@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA)
@patch('crewai.cli.create_crew.select_provider')
@patch('crewai.cli.create_crew.select_model')
@patch('click.prompt')
@patch('click.confirm', return_value=False) # User denies override
def test_create_crew_existing_keep(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, mock_load_env, runner):
"""Test `crewai create crew <name>` with existing config and user keeps it."""
result = runner.invoke(crewai, ['create', 'crew', 'testcrew'])

print(f"CLI Output:\n{result.output}") # Debug output
assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}"
mock_confirm.assert_called_once_with(
'Found existing environment variable configuration for Openai. Do you want to override it?'
)
assert "Keeping existing provider configuration. Exiting provider setup." in result.output
mock_select_provider.assert_not_called() # Should NOT ask for new provider
assert "Crew 'testcrew' created successfully!" in result.output

Loading