From 0203f22b4f37c7d092417c07e151930faeb137e2 Mon Sep 17 00:00:00 2001 From: Aman Gokrani Date: Fri, 25 Oct 2024 23:48:11 +0200 Subject: [PATCH 1/2] initial changes to support gitlab as git provider --- sage/config.py | 2 + sage/data_manager.py | 252 ++++++++++++++++++++++++++++++++++++- sage/index.py | 7 +- sage/retriever.py | 8 +- tests/test_data_manager.py | 42 ++++++- 5 files changed, 303 insertions(+), 8 deletions(-) diff --git a/sage/config.py b/sage/config.py index e458112..2464207 100644 --- a/sage/config.py +++ b/sage/config.py @@ -81,6 +81,8 @@ def add_repo_args(parser: ArgumentParser) -> Callable: default="repos", help="The local directory to store the repository", ) + parser.add("--git-provider", default="github", choices=["github", "gitlab"]) + parser.add("--base-url", default=None, help="The base URL for the Git provider. This is only needed for GitLab.") return validate_repo_args diff --git a/sage/data_manager.py b/sage/data_manager.py index 680f3fd..57ed4cc 100644 --- a/sage/data_manager.py +++ b/sage/data_manager.py @@ -5,7 +5,7 @@ from abc import abstractmethod from functools import cached_property from typing import Any, Dict, Generator, Tuple - +from urllib.parse import quote import requests from git import GitCommandError, Repo @@ -254,3 +254,253 @@ def from_args(args: Dict): "For private repositories, please set the GITHUB_TOKEN variable in your environment." ) return repo_manager + + +class GitLabRepoManager(DataManager): + """Class to manage a local clone of a GitLab repository.""" + + def __init__( + self, + repo_id: str, + commit_hash: str = None, + access_token: str = None, + local_dir: str = None, + inclusion_file: str = None, + exclusion_file: str = None, + base_url: str = None, + ): + """ + Args: + repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage". + commit_hash: Optional commit hash to checkout. If not specified, we pull the latest version of the repo. + access_token: A GitLab access token to use for cloning private repositories. Not needed for public repos. + local_dir: The local directory where the repository will be cloned. + inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of + the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory". + exclusion_file: A file with a lists of files/directories/extensions to exclude. Each line must be in one of + the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory". + """ + super().__init__(dataset_id=repo_id) + self.repo_id = repo_id + self.commit_hash = commit_hash + self.access_token = access_token + self.base_url = base_url or "https://gitlab.com" + + self.local_dir = local_dir or "/tmp/" + if not os.path.exists(self.local_dir): + os.makedirs(self.local_dir) + self.local_path = os.path.join(self.local_dir, repo_id) + + self.log_dir = os.path.join(self.local_dir, "logs", repo_id) + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + + if inclusion_file and exclusion_file: + raise ValueError("Only one of inclusion_file or exclusion_file should be provided.") + + self.inclusions = self._parse_filter_file(inclusion_file) if inclusion_file else None + self.exclusions = self._parse_filter_file(exclusion_file) if exclusion_file else None + + @cached_property + def is_public(self) -> bool: + """Checks whether a GitLab repository is publicly visible.""" + + repo_id = quote(self.repo_id, safe="") + response = requests.get(f"https://gitlab.com/api/v4/projects/{repo_id}", timeout=10) + # Note that the response will be 404 for both private and non-existent repos. + return response.status_code == 200 + + @cached_property + def default_branch(self) -> str: + """Fetches the default branch of the repository from GitLab.""" + headers = { + "Accept": "application/json", + } + if self.access_token: + headers["Authorization"] = f"Bearer {self.access_token}" + + repo_id = quote(self.repo_id, safe="") + response = requests.get(f"https://gitlab.com/api/v4/projects/{repo_id}", headers=headers) + if response.status_code == 200: + branch = response.json().get("default_branch", "main") + else: + # This happens sometimes when we exceed the GitLab rate limit. The best bet in this case is to assume the + # most common naming for the default branch ("main"). + logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}") + branch = "main" + return branch + + def download(self) -> bool: + """Clones the repository to the local directory, if it's not already cloned.""" + if os.path.exists(self.local_path): + # The repository is already cloned. + return True + + if not self.is_public and not self.access_token: + raise ValueError(f"Repo {self.repo_id} is private or doesn't exist.") + + if self.access_token: + clone_url = f"{self.base_url}/{self.repo_id}.git" + # Inject access token for authentication in the URL. + clone_url = clone_url.replace("https://", f"https://oauth2:{self.access_token}@") + else: + clone_url = f"{self.base_url}/{self.repo_id}.git" + + try: + if self.commit_hash: + repo = Repo.clone_from(clone_url, self.local_path) + repo.git.checkout(self.commit_hash) + else: + Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True) + except GitCommandError as e: + logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e) + return False + return True + + def _parse_filter_file(self, file_path: str) -> bool: + """Parses a file with files/directories/extensions to include/exclude. + + Lines are expected to be in the format: + # Comment that will be ignored, or + ext:.my-extension, or + file:my-file.py, or + dir:my-directory + """ + with open(file_path, "r") as f: + lines = f.readlines() + + parsed_data = {"ext": [], "file": [], "dir": []} + for line in lines: + if line.startswith("#"): + # This is a comment line. + continue + key, value = line.strip().split(":") + if key in parsed_data: + parsed_data[key].append(value) + else: + logging.error("Unrecognized key in line: %s. Skipping.", line) + + return parsed_data + + + def _should_include(self, file_path: str) -> bool: + """Checks whether the file should be indexed.""" + # Exclude symlinks. + if os.path.islink(file_path): + return False + + # Exclude hidden files and directories. + if any(part.startswith(".") for part in file_path.split(os.path.sep)): + return False + + if not self.inclusions and not self.exclusions: + return True + + # Filter based on file extensions, file names and directory names. + _, extension = os.path.splitext(file_path) + extension = extension.lower() + file_name = os.path.basename(file_path) + dirs = os.path.dirname(file_path).split("/") + + if self.inclusions: + return ( + extension in self.inclusions.get("ext", []) + or file_name in self.inclusions.get("file", []) + or any(d in dirs for d in self.inclusions.get("dir", [])) + ) + elif self.exclusions: + return ( + extension not in self.exclusions.get("ext", []) + and file_name not in self.exclusions.get("file", []) + and all(d not in dirs for d in self.exclusions.get("dir", [])) + ) + return True + + def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]: + """Walks the local repository path and yields a tuple of (content, metadata) for each file. + The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py"). + + Args: + get_content: When set to True, yields (content, metadata) tuples. When set to False, yields metadata only. + """ + # We will keep appending to these files during the iteration, so we need to clear them first. + repo_name = self.repo_id.replace("/", "_") + included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt") + excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt") + if os.path.exists(included_log_file): + os.remove(included_log_file) + logging.info("Logging included files at %s", included_log_file) + if os.path.exists(excluded_log_file): + os.remove(excluded_log_file) + logging.info("Logging excluded files at %s", excluded_log_file) + + for root, _, files in os.walk(self.local_path): + file_paths = [os.path.join(root, file) for file in files] + included_file_paths = [f for f in file_paths if self._should_include(f)] + + with open(included_log_file, "a") as f: + for path in included_file_paths: + f.write(path + "\n") + + excluded_file_paths = set(file_paths).difference(set(included_file_paths)) + with open(excluded_log_file, "a") as f: + for path in excluded_file_paths: + f.write(path + "\n") + + for file_path in included_file_paths: + relative_file_path = file_path[len(self.local_dir) + 1 :] + metadata = { + "file_path": relative_file_path, + "url": self.url_for_file(relative_file_path), + } + + if not get_content: + yield metadata + continue + + contents = self.read_file(relative_file_path) + if contents: + yield contents, metadata + + def url_for_file(self, file_path: str) -> str: + """Converts a repository file path to a GitHub link.""" + file_path = file_path[len(self.repo_id) + 1 :] + return f"{self.base_url}/{self.repo_id}/blob/-/{self.default_branch}/{file_path}" + + def read_file(self, relative_file_path: str) -> str: + """Reads the contents of a file in the repository.""" + absolute_file_path = os.path.join(self.local_dir, relative_file_path) + with open(absolute_file_path, "r") as f: + try: + contents = f.read() + return contents + except UnicodeDecodeError: + logging.warning("Unable to decode file %s.", absolute_file_path) + return None + + def from_args(args: Dict): + """Creates a GitLabRepoManager from command-line arguments and clones the underlying repository.""" + repo_manager = GitLabRepoManager( + repo_id=args.repo_id, + commit_hash=args.commit_hash, + access_token=os.getenv("GITHUB_TOKEN"), + local_dir=args.local_dir, + inclusion_file=args.include, + exclusion_file=args.exclude, + base_url=args.base_url + ) + success = repo_manager.download() + if not success: + raise ValueError( + f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. " + "For private repositories, please set the GITHUB_TOKEN variable in your environment." + ) + return repo_manager + + +def build_data_manager_from_args(args: Dict) -> DataManager: + """Creates a DataManager from command-line arguments.""" + if args.git_provider == "gitlab": + return GitLabRepoManager.from_args(args) + return GitHubRepoManager.from_args(args) + \ No newline at end of file diff --git a/sage/index.py b/sage/index.py index 9dbff11..771d240 100644 --- a/sage/index.py +++ b/sage/index.py @@ -5,10 +5,11 @@ import time import configargparse +from dotenv import load_dotenv import sage.config as sage_config from sage.chunker import UniversalFileChunker -from sage.data_manager import GitHubRepoManager +from sage.data_manager import build_data_manager_from_args from sage.embedder import build_batch_embedder_from_flags from sage.github import GitHubIssuesChunker, GitHubIssuesManager from sage.vector_store import build_vector_store_from_args @@ -17,6 +18,8 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) +load_dotenv() + def main(): parser = configargparse.ArgParser( @@ -54,7 +57,7 @@ def main(): repo_embedder = None if args.index_repo: logging.info("Cloning the repository...") - repo_manager = GitHubRepoManager.from_args(args) + repo_manager = build_data_manager_from_args(args) logging.info("Embedding the repo...") chunker = UniversalFileChunker(max_tokens=args.tokens_per_chunk) repo_embedder = build_batch_embedder_from_flags(repo_manager, chunker, args) diff --git a/sage/retriever.py b/sage/retriever.py index c287ae5..a75c083 100644 --- a/sage/retriever.py +++ b/sage/retriever.py @@ -15,7 +15,7 @@ from pydantic import Field from sage.code_symbols import get_code_symbols -from sage.data_manager import DataManager, GitHubRepoManager +from sage.data_manager import DataManager, build_data_manager_from_args from sage.llm import build_llm_via_langchain from sage.reranker import build_reranker from sage.vector_store import build_vector_store_from_args @@ -38,14 +38,14 @@ class LLMRetriever(BaseRetriever): caching to make it usable. """ - repo_manager: GitHubRepoManager = Field(...) + repo_manager: DataManager = Field(...) top_k: int = Field(...) cached_repo_metadata: List[Dict] = Field(...) cached_repo_files: List[str] = Field(...) cached_repo_hierarchy: str = Field(...) - def __init__(self, repo_manager: GitHubRepoManager, top_k: int): + def __init__(self, repo_manager: DataManager, top_k: int): super().__init__() self.repo_manager = repo_manager self.top_k = top_k @@ -322,7 +322,7 @@ def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Docum def build_retriever_from_args(args, data_manager: Optional[DataManager] = None): """Builds a retriever (with optional reranking) from command-line arguments.""" if args.llm_retriever: - retriever = LLMRetriever(GitHubRepoManager.from_args(args), top_k=args.retriever_top_k) + retriever = LLMRetriever(build_data_manager_from_args(args), top_k=args.retriever_top_k) else: if args.embedding_provider == "openai": embeddings = OpenAIEmbeddings(model=args.embedding_model) diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py index 78c6957..fcf4daa 100644 --- a/tests/test_data_manager.py +++ b/tests/test_data_manager.py @@ -17,9 +17,49 @@ import os import unittest from unittest.mock import MagicMock, patch +from urllib.parse import quote -from sage.data_manager import GitHubRepoManager +from sage.data_manager import GitHubRepoManager, GitLabRepoManager +class TestGitLabRepoManager(unittest.TestCase): + @patch("git.Repo.clone_from") + def test_download_clone_success(self, mock_clone): + """Test the download() method of GitHubRepoManager by mocking the cloning process.""" + repo_manager = GitLabRepoManager(repo_id="gitlab-org/gitlab-runner", local_dir="/tmp/test_repo") + mock_clone.return_value = MagicMock() + result = repo_manager.download() + mock_clone.assert_called_once_with("https://gitlab.com/gitlab-org/gitlab-runner.git", "/tmp/test_repo/gitlab-org/gitlab-runner", depth=1, single_branch=True) + self.assertTrue(result) + + + @patch("sage.data_manager.requests.get") + def test_is_public_repository(self, mock_get): + """Test the is_public property to check if a repository is public.""" + mock_get.return_value.status_code = 200 + repo_id="gitlab-org/gitlab-runner" + repo_manager = GitLabRepoManager(repo_id=repo_id) + self.assertTrue(repo_manager.is_public) + repo_id = quote(repo_id, safe="") + mock_get.assert_called_once_with(f"https://gitlab.com/api/v4/projects/{repo_id}", timeout=10) + + @patch("sage.data_manager.requests.get") + def test_is_private_repository(self, mock_get): + """Test the is_public property to check if a repository is private.""" + mock_get.return_value.status_code = 404 + repo_id="gitlab-org/gitlab-runner" + repo_manager = GitLabRepoManager(repo_id=repo_id) + print(repo_manager.is_public) + self.assertFalse(repo_manager.is_public) + repo_id = quote(repo_id, safe="") + mock_get.assert_called_once_with(f"https://gitlab.com/api/v4/projects/{repo_id}", timeout=10) + + @patch("builtins.open", new_callable=unittest.mock.mock_open, read_data="ext:.py\nfile:test.py\ndir:test_dir\n") + def test_parse_filter_file(self, mock_file): + """Test the _parse_filter_file method for correct parsing of inclusion/exclusion files.""" + repo_manager = GitLabRepoManager(repo_id="gitlab-org/gitlab-runner", inclusion_file="dummy_path") + expected = {"ext": [".py"], "file": ["test.py"], "dir": ["test_dir"]} + result = repo_manager._parse_filter_file("dummy_path") + self.assertEqual(result, expected) class TestGitHubRepoManager(unittest.TestCase): @patch("git.Repo.clone_from") From cd2c9898b80c12433011bbe008d5a559f6df515a Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 30 Oct 2024 17:28:56 +0100 Subject: [PATCH 2/2] new test added --- tests/test_data_manager.py | 47 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py index fcf4daa..e13e2a0 100644 --- a/tests/test_data_manager.py +++ b/tests/test_data_manager.py @@ -53,6 +53,19 @@ def test_is_private_repository(self, mock_get): repo_id = quote(repo_id, safe="") mock_get.assert_called_once_with(f"https://gitlab.com/api/v4/projects/{repo_id}", timeout=10) + @patch("sage.data_manager.requests.get") + def test_default_branch(self, mock_get): + """Test the default_branch property to fetch the default branch of the repository.""" + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = {"default_branch": "main"} + repo_id="gitlab-org/gitlab-runner" + repo_manager = GitLabRepoManager(repo_id=repo_id) + repo_id = quote(repo_id, safe="") + self.assertEqual(repo_manager.default_branch, "main") + mock_get.assert_called_once_with( + f"https://gitlab.com/api/v4/projects/{repo_id}", headers={"Accept": "application/json"} + ) + @patch("builtins.open", new_callable=unittest.mock.mock_open, read_data="ext:.py\nfile:test.py\ndir:test_dir\n") def test_parse_filter_file(self, mock_file): """Test the _parse_filter_file method for correct parsing of inclusion/exclusion files.""" @@ -61,6 +74,40 @@ def test_parse_filter_file(self, mock_file): result = repo_manager._parse_filter_file("dummy_path") self.assertEqual(result, expected) + @patch("os.path.exists") + @patch("os.remove") + @patch("builtins.open", new_callable=unittest.mock.mock_open, read_data="dummy content") + def test_walk_included_files(self, mock_open, mock_remove, mock_exists): + """Test the walk method to ensure it only includes specified files.""" + mock_exists.return_value = True + repo_manager = GitLabRepoManager(repo_id="gitlab-org/gitlab-runner", local_dir="/tmp/test_repo") + with patch( + "os.walk", + return_value=[ + ("/tmp/test_repo", ("subdir",), ("included_file.py", "excluded_file.txt")), + ], + ): + included_files = list(repo_manager.walk()) + print("Included files:", included_files) + self.assertTrue(any(file[1]["file_path"] == "included_file.py" for file in included_files)) + + def test_read_file(self): + """Test the read_file method to read the content of a file.""" + mock_file_path = "/tmp/test_repo/test_file.txt" + with patch("builtins.open", new_callable=unittest.mock.mock_open, read_data="Hello, World!"): + repo_manager = GitLabRepoManager(repo_id="gitlab-org/gitlab-runner", local_dir="/tmp/test_repo") + content = repo_manager.read_file("test_file.txt") + self.assertEqual(content, "Hello, World!") + + @patch("os.makedirs") + def test_create_log_directories(self, mock_makedirs): + """Test that log directories are created.""" + repo_manager = GitHubRepoManager(repo_id="Storia-AI/sage", local_dir="/tmp/test_repo") + + with self.assertRaises(AttributeError): + repo_manager.create_log_directories() + + class TestGitHubRepoManager(unittest.TestCase): @patch("git.Repo.clone_from") def test_download_clone_success(self, mock_clone):