Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: support gitlab as git provider #96

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def add_repo_args(parser: ArgumentParser) -> Callable:
default="repos",
help="The local directory to store the repository",
)
parser.add("--git-provider", default="github", choices=["github", "gitlab"])
parser.add("--base-url", default=None, help="The base URL for the Git provider. This is only needed for GitLab.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please name the flag --gitlab-base-url to avoid confusion?

return validate_repo_args


Expand Down
252 changes: 251 additions & 1 deletion sage/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import abstractmethod
from functools import cached_property
from typing import Any, Dict, Generator, Tuple

from urllib.parse import quote
import requests
from git import GitCommandError, Repo

Expand Down Expand Up @@ -254,3 +254,253 @@ def from_args(args: Dict):
"For private repositories, please set the GITHUB_TOKEN variable in your environment."
)
return repo_manager


class GitLabRepoManager(DataManager):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new class introduces a lot of code duplication. IIUC, the only difference compared to GitHubRepoManager is how URLs are constructed. Is that correct? In that case, we could have a single unified GitRepoManager that takes a base URL (which would be github.com for GitHub and gitlab.com for GitLab). And maybe we'd need occasional if/else for e.g. url_for_file.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Should be possible but I don't think getting different if else would be a good idea if you are incorporating other git providers as well. I haven't looked into other providers maybe something to check before we merge them into one class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GitHub and GitLab are the main ones, I'm not sure if there's a third one as widely used.

Alternative to if/else clauses, we could have an abstract class GitRepoManager that is implemented by GitHubRepoManager and GitLabRepoManager. For instance:

class GitRepoManager:
    @abstractclass
    def get_repo_url(self, repo_id: str):
        pass

    @cached_property
    def is_public(self) -> bool:
        """Checks whether a GitHub repository is publicly visible."""
        response = requests.get(self.get_repo_url(repo_id), timeout=10)
        # Note that the response will be 404 for both private and non-existent repos.
        return response.status_code == 200

    ... other methods

class GitHubRepoManager(GitRepoManager):
    def get_repo_url(self, repo_id: str):
        return f"https://api.github.com/repos/{self.repo_id}"
    ...

class GitLabRepoManager(GitRepoManager):
    def get_repo_url(self, repo_id: str):
        return f"https://gitlab.com/api/v4/projects/{repo_id}"  
    ... 

Does that make sense?

"""Class to manage a local clone of a GitLab repository."""

def __init__(
self,
repo_id: str,
commit_hash: str = None,
access_token: str = None,
local_dir: str = None,
inclusion_file: str = None,
exclusion_file: str = None,
base_url: str = None,
):
"""
Args:
repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage".
commit_hash: Optional commit hash to checkout. If not specified, we pull the latest version of the repo.
access_token: A GitLab access token to use for cloning private repositories. Not needed for public repos.
local_dir: The local directory where the repository will be cloned.
inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of
the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
exclusion_file: A file with a lists of files/directories/extensions to exclude. Each line must be in one of
the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
"""
super().__init__(dataset_id=repo_id)
self.repo_id = repo_id
self.commit_hash = commit_hash
self.access_token = access_token
self.base_url = base_url or "https://gitlab.com"

self.local_dir = local_dir or "/tmp/"
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
self.local_path = os.path.join(self.local_dir, repo_id)

self.log_dir = os.path.join(self.local_dir, "logs", repo_id)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)

if inclusion_file and exclusion_file:
raise ValueError("Only one of inclusion_file or exclusion_file should be provided.")

self.inclusions = self._parse_filter_file(inclusion_file) if inclusion_file else None
self.exclusions = self._parse_filter_file(exclusion_file) if exclusion_file else None

@cached_property
def is_public(self) -> bool:
"""Checks whether a GitLab repository is publicly visible."""

repo_id = quote(self.repo_id, safe="")
response = requests.get(f"https://gitlab.com/api/v4/projects/{repo_id}", timeout=10)
# Note that the response will be 404 for both private and non-existent repos.
return response.status_code == 200

@cached_property
def default_branch(self) -> str:
"""Fetches the default branch of the repository from GitLab."""
headers = {
"Accept": "application/json",
}
if self.access_token:
headers["Authorization"] = f"Bearer {self.access_token}"

repo_id = quote(self.repo_id, safe="")
response = requests.get(f"https://gitlab.com/api/v4/projects/{repo_id}", headers=headers)
if response.status_code == 200:
branch = response.json().get("default_branch", "main")
else:
# This happens sometimes when we exceed the GitLab rate limit. The best bet in this case is to assume the
# most common naming for the default branch ("main").
logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}")
branch = "main"
return branch

def download(self) -> bool:
"""Clones the repository to the local directory, if it's not already cloned."""
if os.path.exists(self.local_path):
# The repository is already cloned.
return True

if not self.is_public and not self.access_token:
raise ValueError(f"Repo {self.repo_id} is private or doesn't exist.")

if self.access_token:
clone_url = f"{self.base_url}/{self.repo_id}.git"
# Inject access token for authentication in the URL.
clone_url = clone_url.replace("https://", f"https://oauth2:{self.access_token}@")
else:
clone_url = f"{self.base_url}/{self.repo_id}.git"

try:
if self.commit_hash:
repo = Repo.clone_from(clone_url, self.local_path)
repo.git.checkout(self.commit_hash)
else:
Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
except GitCommandError as e:
logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e)
return False
return True

def _parse_filter_file(self, file_path: str) -> bool:
"""Parses a file with files/directories/extensions to include/exclude.

Lines are expected to be in the format:
# Comment that will be ignored, or
ext:.my-extension, or
file:my-file.py, or
dir:my-directory
"""
with open(file_path, "r") as f:
lines = f.readlines()

parsed_data = {"ext": [], "file": [], "dir": []}
for line in lines:
if line.startswith("#"):
# This is a comment line.
continue
key, value = line.strip().split(":")
if key in parsed_data:
parsed_data[key].append(value)
else:
logging.error("Unrecognized key in line: %s. Skipping.", line)

return parsed_data


def _should_include(self, file_path: str) -> bool:
"""Checks whether the file should be indexed."""
# Exclude symlinks.
if os.path.islink(file_path):
return False

# Exclude hidden files and directories.
if any(part.startswith(".") for part in file_path.split(os.path.sep)):
return False

if not self.inclusions and not self.exclusions:
return True

# Filter based on file extensions, file names and directory names.
_, extension = os.path.splitext(file_path)
extension = extension.lower()
file_name = os.path.basename(file_path)
dirs = os.path.dirname(file_path).split("/")

if self.inclusions:
return (
extension in self.inclusions.get("ext", [])
or file_name in self.inclusions.get("file", [])
or any(d in dirs for d in self.inclusions.get("dir", []))
)
elif self.exclusions:
return (
extension not in self.exclusions.get("ext", [])
and file_name not in self.exclusions.get("file", [])
and all(d not in dirs for d in self.exclusions.get("dir", []))
)
return True

def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]:
"""Walks the local repository path and yields a tuple of (content, metadata) for each file.
The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py").

Args:
get_content: When set to True, yields (content, metadata) tuples. When set to False, yields metadata only.
"""
# We will keep appending to these files during the iteration, so we need to clear them first.
repo_name = self.repo_id.replace("/", "_")
included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt")
excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt")
if os.path.exists(included_log_file):
os.remove(included_log_file)
logging.info("Logging included files at %s", included_log_file)
if os.path.exists(excluded_log_file):
os.remove(excluded_log_file)
logging.info("Logging excluded files at %s", excluded_log_file)

for root, _, files in os.walk(self.local_path):
file_paths = [os.path.join(root, file) for file in files]
included_file_paths = [f for f in file_paths if self._should_include(f)]

with open(included_log_file, "a") as f:
for path in included_file_paths:
f.write(path + "\n")

excluded_file_paths = set(file_paths).difference(set(included_file_paths))
with open(excluded_log_file, "a") as f:
for path in excluded_file_paths:
f.write(path + "\n")

for file_path in included_file_paths:
relative_file_path = file_path[len(self.local_dir) + 1 :]
metadata = {
"file_path": relative_file_path,
"url": self.url_for_file(relative_file_path),
}

if not get_content:
yield metadata
continue

contents = self.read_file(relative_file_path)
if contents:
yield contents, metadata

def url_for_file(self, file_path: str) -> str:
"""Converts a repository file path to a GitHub link."""
file_path = file_path[len(self.repo_id) + 1 :]
return f"{self.base_url}/{self.repo_id}/blob/-/{self.default_branch}/{file_path}"

def read_file(self, relative_file_path: str) -> str:
"""Reads the contents of a file in the repository."""
absolute_file_path = os.path.join(self.local_dir, relative_file_path)
with open(absolute_file_path, "r") as f:
try:
contents = f.read()
return contents
except UnicodeDecodeError:
logging.warning("Unable to decode file %s.", absolute_file_path)
return None

def from_args(args: Dict):
"""Creates a GitLabRepoManager from command-line arguments and clones the underlying repository."""
repo_manager = GitLabRepoManager(
repo_id=args.repo_id,
commit_hash=args.commit_hash,
access_token=os.getenv("GITHUB_TOKEN"),
local_dir=args.local_dir,
inclusion_file=args.include,
exclusion_file=args.exclude,
base_url=args.base_url
)
success = repo_manager.download()
if not success:
raise ValueError(
f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. "
"For private repositories, please set the GITHUB_TOKEN variable in your environment."
)
return repo_manager


def build_data_manager_from_args(args: Dict) -> DataManager:
"""Creates a DataManager from command-line arguments."""
if args.git_provider == "gitlab":
return GitLabRepoManager.from_args(args)
return GitHubRepoManager.from_args(args)

7 changes: 5 additions & 2 deletions sage/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import time

import configargparse
from dotenv import load_dotenv

import sage.config as sage_config
from sage.chunker import UniversalFileChunker
from sage.data_manager import GitHubRepoManager
from sage.data_manager import build_data_manager_from_args
from sage.embedder import build_batch_embedder_from_flags
from sage.github import GitHubIssuesChunker, GitHubIssuesManager
from sage.vector_store import build_vector_store_from_args
Expand All @@ -17,6 +18,8 @@
logger = logging.getLogger()
logger.setLevel(logging.INFO)

load_dotenv()


def main():
parser = configargparse.ArgParser(
Expand Down Expand Up @@ -54,7 +57,7 @@ def main():
repo_embedder = None
if args.index_repo:
logging.info("Cloning the repository...")
repo_manager = GitHubRepoManager.from_args(args)
repo_manager = build_data_manager_from_args(args)
logging.info("Embedding the repo...")
chunker = UniversalFileChunker(max_tokens=args.tokens_per_chunk)
repo_embedder = build_batch_embedder_from_flags(repo_manager, chunker, args)
Expand Down
8 changes: 4 additions & 4 deletions sage/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydantic import Field

from sage.code_symbols import get_code_symbols
from sage.data_manager import DataManager, GitHubRepoManager
from sage.data_manager import DataManager, build_data_manager_from_args
from sage.llm import build_llm_via_langchain
from sage.reranker import build_reranker
from sage.vector_store import build_vector_store_from_args
Expand All @@ -38,14 +38,14 @@ class LLMRetriever(BaseRetriever):
caching to make it usable.
"""

repo_manager: GitHubRepoManager = Field(...)
repo_manager: DataManager = Field(...)
top_k: int = Field(...)

cached_repo_metadata: List[Dict] = Field(...)
cached_repo_files: List[str] = Field(...)
cached_repo_hierarchy: str = Field(...)

def __init__(self, repo_manager: GitHubRepoManager, top_k: int):
def __init__(self, repo_manager: DataManager, top_k: int):
super().__init__()
self.repo_manager = repo_manager
self.top_k = top_k
Expand Down Expand Up @@ -322,7 +322,7 @@ def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Docum
def build_retriever_from_args(args, data_manager: Optional[DataManager] = None):
"""Builds a retriever (with optional reranking) from command-line arguments."""
if args.llm_retriever:
retriever = LLMRetriever(GitHubRepoManager.from_args(args), top_k=args.retriever_top_k)
retriever = LLMRetriever(build_data_manager_from_args(args), top_k=args.retriever_top_k)
else:
if args.embedding_provider == "openai":
embeddings = OpenAIEmbeddings(model=args.embedding_model)
Expand Down
Loading