Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pip install support #6

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .flow_verse import loading
from .utils import logging

VERSION = "0.1.6"
VERSION = "0.1.7"
189 changes: 150 additions & 39 deletions aiflows/flow_verse/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from huggingface_hub.hf_api import HfApi

from aiflows.utils import logging

from . import utils
import subprocess
import pkg_resources
import importlib
from aiflows.flow_verse import utils

logger = logging.get_logger(__name__)
logger.warn = logger.warning
Expand Down Expand Up @@ -389,6 +391,96 @@ def remove_dir_or_link(sync_dir: str):
raise ValueError(f"Invalid sync_dir: {sync_dir}, it is not a valid directory nor a valid link")


def get_unsatisfied_pip_requirements(requirements_file):
""" Returns a list of unsatisfied pip requirements from a requirements file.

:param requirements_file: The path to the requirements file
:type requirements_file: str
:return: A list of unsatisfied pip requirements
:rtype: List[str]
"""
#reload pkg_resources to check for newly installed packages (e.g. from previous flow modules of the same flow)
importlib.reload(pkg_resources)

# Parse the requirements file
with open(requirements_file, 'r') as f:
requirements = [line.strip() for line in f]

# Get the distributions of installed packages
installed_distributions = {dist.project_name.lower(): dist for dist in pkg_resources.working_set}

# Check if each requirement is satisfied
unsatisfied_requirements = []
for line in requirements:

req = line.split('#')[0].strip()
if req == '':
continue
req_dist = pkg_resources.Requirement.parse(req)
installed_dist = installed_distributions.get(req_dist.project_name.lower())

if not installed_dist or not installed_dist in req_dist:
unsatisfied_requirements.append(req)

return unsatisfied_requirements


def display_and_confirm_requirements(flow_name,requirements):
""" Displays the uninstalled requirements for a flow and asks the user if they want to install them.

:param flow_name: The name of the flow
:type flow_name: str
:param requirements: The list of unsatisfied pip requirements
:type requirements: List[str]
:return: True if the user wants to install the requirements, False otherwise
:rtype: bool
"""

if len(requirements) == 0:
return False

requirements_str = "\n".join([f" - {req}" for req in requirements])

question_message = \
f"""\n{flow_name} is requesting to install the following pip requirements:\n{requirements_str}\n Do you want to proceed with the installation?"""

no_message = \
f"""Installation of requirements for {flow_name} is canceled. This may impact the proper functioning of the Flow."""

yes_message = \
f"Requirements from {flow_name} will be installed."


answer = utils.yes_no_question(logger,question_message,yes_message,no_message,colorama_style=colorama.Fore.RED)

return answer


def install_requirements(synced_flow_mod_spec):
""" Installs the pip requirements (if not already installed) for a flow module.

:param synced_flow_mod_spec: The synced flow module specification
:type synced_flow_mod_spec: FlowModuleSpec
"""
repo_id = synced_flow_mod_spec.repo_id
requirements_file = os.path.join(synced_flow_mod_spec.sync_dir, "pip_requirements.txt")

# For the moment, we require that every flow module has a pip_requirements.txt file. Should we change this?
if not os.path.exists(requirements_file):
raise ValueError(f"Every flow module must have a pip_requirements.txt file, but {requirements_file} does not exist for {repo_id}")

# Get the unsatisfied pip requirements
unsatisfied_requirements = get_unsatisfied_pip_requirements(requirements_file)

#answer of the user on whether to install the requirements
user_wants_to_install_requirements = display_and_confirm_requirements(repo_id,unsatisfied_requirements)

#install the requirements
if user_wants_to_install_requirements:
subprocess.run(['pip', 'install', '-r', requirements_file])



# # TODO(Yeeef): add repo_hash and modified_flag to decrease computing


Expand Down Expand Up @@ -576,32 +668,42 @@ def sync_remote_dep(
sync_dir_modified = is_sync_dir_modified(sync_dir, previous_synced_flow_mod_spec.cache_dir)

if overwrite:
logger.warn(
f"{colorama.Fore.RED}[{caller_module_name}] {flow_mod_id} will be overwritten, are you sure? (Y/N){colorama.Style.RESET_ALL}"
)
user_input = input()
if user_input != "Y":
logger.warn(
f"{colorama.Fore.RED}[{caller_module_name}] {flow_mod_id} will not be overwritten.{colorama.Style.RESET_ALL}"
)
overwrite = False

question_message = \
f"[{caller_module_name}] {flow_mod_id} will be overwritten, are you sure?"

no_message = \
f"[{caller_module_name}] {flow_mod_id} will not be overwritten."

yes_message = \
f"[{caller_module_name}]{flow_mod_id} will be fetched from remote."

overwrite = utils.yes_no_question(logger, question_message,yes_message,no_message)

if not overwrite:
synced_flow_mod_spec = previous_synced_flow_mod_spec
else:
logger.info(f"{flow_mod_id} will be fetched from remote.{colorama.Style.RESET_ALL}")
synced_flow_mod_spec = fetch_remote(repo_id, revision, sync_dir, cache_root)

elif previous_synced_flow_mod_spec.mod_id != flow_mod_id:
# user has supplied a new flow_mod_id, we fetch the remote directly with warning
logger.warn(
f"{colorama.Fore.RED}[{caller_module_name}] {previous_synced_flow_mod_spec.mod_id} already synced, it will be overwritten by new revision {flow_mod_id}, are you sure? (Y/N){colorama.Style.RESET_ALL}"
)
user_input = input()
if user_input != "Y":
logger.warn(
f"{colorama.Fore.RED}[{caller_module_name}] {flow_mod_id} will not be overwritten.{colorama.Style.RESET_ALL}"
)

question_message = \
f"{previous_synced_flow_mod_spec.mod_id} already synced, it will be overwritten by new revision {flow_mod_id}, are you sure? "

no_message = \
f"[{caller_module_name}] {previous_synced_flow_mod_spec.mod_id} will not be overwritten."

yes_message = \
f"[{caller_module_name}] {previous_synced_flow_mod_spec.mod_id} will be fetched from remote."

fetch_from_remote = utils.yes_no_question(logger, question_message,yes_message,no_message)

if not fetch_from_remote:
synced_flow_mod_spec = previous_synced_flow_mod_spec
else:
synced_flow_mod_spec = fetch_remote(repo_id, revision, sync_dir, cache_root)

### user has supplied same flow_mod_id(repo_id:revision), we check if the remote commit has changed
elif not remote_revision_commit_hash_changed:
# trivial case, we do nothing
Expand Down Expand Up @@ -670,33 +772,41 @@ def sync_local_dep(
assert sync_dir == previous_synced_flow_mod_spec.sync_dir, (sync_dir, previous_synced_flow_mod_spec.sync_dir)

if overwrite:
logger.warn(
f"{colorama.Fore.RED}[{caller_module_name}] {flow_mod_id} will be overwritten, are you sure? (Y/N){colorama.Style.RESET_ALL}"
)
user_input = input()
if user_input != "Y":
logger.warn(
f"{colorama.Fore.RED}[{caller_module_name}] {flow_mod_id} will not be overwritten.{colorama.Style.RESET_ALL}"
)
overwrite = False

question_message = \
f"[{caller_module_name}] {flow_mod_id} will be overwritten, are you sure?"

no_message = \
f"[{caller_module_name}] {flow_mod_id} will not be overwritten."

yes_message = \
f"[{caller_module_name}] {flow_mod_id} will be fetched from local."

overwrite = utils.yes_no_question(logger, question_message,yes_message,no_message)

if not overwrite:
synced_flow_mod_spec = previous_synced_flow_mod_spec
else:
logger.info(f"{flow_mod_id} will be fetched from local")
synced_flow_mod_spec = fetch_local(repo_id, module_synced_from_dir, sync_dir)

elif previous_synced_flow_mod_spec.mod_id != flow_mod_id:
logger.warn(
f"{colorama.Fore.RED}[{caller_module_name}] {previous_synced_flow_mod_spec.mod_id} already synced, it will be overwritten by {flow_mod_id}, are you sure? (Y/N){colorama.Style.RESET_ALL}"
)
user_input = input()
if user_input != "Y":
logger.warn(
f"{colorama.Fore.RED}[{caller_module_name}] {previous_synced_flow_mod_spec.mod_id} will not be overwritten.{colorama.Style.RESET_ALL}"
)

question_message = \
f"[{caller_module_name}] {previous_synced_flow_mod_spec.mod_id} already synced, it will be overwritten by {flow_mod_id}, are you sure?"

no_message = \
f"[{caller_module_name}] {previous_synced_flow_mod_spec.mod_id} will not be overwritten."

yes_message = \
f"[{caller_module_name}] {previous_synced_flow_mod_spec.mod_id} will be fetched from local."

fetch_from_local = utils.yes_no_question(logger, question_message,yes_message,no_message)

if not fetch_from_local:
synced_flow_mod_spec = previous_synced_flow_mod_spec
else:
logger.info(f"{flow_mod_id} will be fetched from local")
synced_flow_mod_spec = fetch_local(repo_id, module_synced_from_dir, sync_dir)

else:
logger.info(f"{flow_mod_id} already synced, skip")
synced_flow_mod_spec = previous_synced_flow_mod_spec
Expand Down Expand Up @@ -817,6 +927,7 @@ def _sync_dependencies(
)
# logger.debug(f"add remote dep {synced_flow_mod_spec} to flow_mod_summary")
flow_mod_summary.add_mod(synced_flow_mod_spec)
install_requirements(synced_flow_mod_spec)

# write flow.mod
# logger.debug(f"write flow mod summary: {flow_mod_summary}")
Expand Down
42 changes: 41 additions & 1 deletion aiflows/flow_verse/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import re

import colorama

def build_hf_cache_path(repo_id: str, commit_hash: str, cache_root: str) -> str:
"""
Expand All @@ -17,6 +17,7 @@ def build_hf_cache_path(repo_id: str, commit_hash: str, cache_root: str) -> str:
:return: The path to the cache directory for the given model snapshot.
:rtype: str
"""
breakpoint()
username, modelname = repo_id.split("/")
relative_path = os.path.join(f"models--{username}--{modelname}", "snapshots", commit_hash)
return os.path.join(cache_root, relative_path)
Expand All @@ -31,3 +32,42 @@ def is_local_revision(revision: str):
:rtype: bool
"""
return os.path.exists(revision)

def yes_no_question(logger,question_message,yes_message, no_message, colorama_style=colorama.Fore.RED):
"""Asks a yes/no question and returns True if the user answers yes, False otherwise.

:param question_message: The message to display when asking the question
:type question_message: str
:param yes_message: The message to display when the user answers yes
:type yes_message: str
:param no_message: The message to display when the user answers no
:type no_message: str
:param colarama_style: The colorama style to use when displaying the question, defaults to colorama.Fore.RED
:type colarama_style: colorama.Fore, optional
:return: True if the user answers yes, False otherwise
:rtype: bool
"""
while True:

logger.warn(
f""" {colorama_style} {question_message} (Y/N){colorama.Style.RESET_ALL}"""
)
user_input = input()

if user_input == "Y":
logger.warn(
f"{colorama_style} {yes_message} {colorama.Style.RESET_ALL}"
)
break

elif user_input == "N":
logger.warn(
f"{colorama_style} {no_message} {colorama.Style.RESET_ALL}"
)
break

else:
logger.warn("Invalid input. Please enter 'Y' or 'N'.")


return user_input == "Y"
Loading