-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: support gitlab as git provider #96
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
from abc import abstractmethod | ||
from functools import cached_property | ||
from typing import Any, Dict, Generator, Tuple | ||
|
||
from urllib.parse import quote | ||
import requests | ||
from git import GitCommandError, Repo | ||
|
||
|
@@ -254,3 +254,253 @@ def from_args(args: Dict): | |
"For private repositories, please set the GITHUB_TOKEN variable in your environment." | ||
) | ||
return repo_manager | ||
|
||
|
||
class GitLabRepoManager(DataManager): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This new class introduces a lot of code duplication. IIUC, the only difference compared to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! Should be possible but I don't think getting different if else would be a good idea if you are incorporating other git providers as well. I haven't looked into other providers maybe something to check before we merge them into one class? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GitHub and GitLab are the main ones, I'm not sure if there's a third one as widely used. Alternative to if/else clauses, we could have an abstract class
Does that make sense? |
||
"""Class to manage a local clone of a GitLab repository.""" | ||
|
||
def __init__( | ||
self, | ||
repo_id: str, | ||
commit_hash: str = None, | ||
access_token: str = None, | ||
local_dir: str = None, | ||
inclusion_file: str = None, | ||
exclusion_file: str = None, | ||
base_url: str = None, | ||
): | ||
""" | ||
Args: | ||
repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage". | ||
commit_hash: Optional commit hash to checkout. If not specified, we pull the latest version of the repo. | ||
access_token: A GitLab access token to use for cloning private repositories. Not needed for public repos. | ||
local_dir: The local directory where the repository will be cloned. | ||
inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of | ||
the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory". | ||
exclusion_file: A file with a lists of files/directories/extensions to exclude. Each line must be in one of | ||
the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory". | ||
""" | ||
super().__init__(dataset_id=repo_id) | ||
self.repo_id = repo_id | ||
self.commit_hash = commit_hash | ||
self.access_token = access_token | ||
self.base_url = base_url or "https://gitlab.com" | ||
|
||
self.local_dir = local_dir or "/tmp/" | ||
if not os.path.exists(self.local_dir): | ||
os.makedirs(self.local_dir) | ||
self.local_path = os.path.join(self.local_dir, repo_id) | ||
|
||
self.log_dir = os.path.join(self.local_dir, "logs", repo_id) | ||
if not os.path.exists(self.log_dir): | ||
os.makedirs(self.log_dir) | ||
|
||
if inclusion_file and exclusion_file: | ||
raise ValueError("Only one of inclusion_file or exclusion_file should be provided.") | ||
|
||
self.inclusions = self._parse_filter_file(inclusion_file) if inclusion_file else None | ||
self.exclusions = self._parse_filter_file(exclusion_file) if exclusion_file else None | ||
|
||
@cached_property | ||
def is_public(self) -> bool: | ||
"""Checks whether a GitLab repository is publicly visible.""" | ||
|
||
repo_id = quote(self.repo_id, safe="") | ||
response = requests.get(f"https://gitlab.com/api/v4/projects/{repo_id}", timeout=10) | ||
# Note that the response will be 404 for both private and non-existent repos. | ||
return response.status_code == 200 | ||
|
||
@cached_property | ||
def default_branch(self) -> str: | ||
"""Fetches the default branch of the repository from GitLab.""" | ||
headers = { | ||
"Accept": "application/json", | ||
} | ||
if self.access_token: | ||
headers["Authorization"] = f"Bearer {self.access_token}" | ||
|
||
repo_id = quote(self.repo_id, safe="") | ||
response = requests.get(f"https://gitlab.com/api/v4/projects/{repo_id}", headers=headers) | ||
if response.status_code == 200: | ||
branch = response.json().get("default_branch", "main") | ||
else: | ||
# This happens sometimes when we exceed the GitLab rate limit. The best bet in this case is to assume the | ||
# most common naming for the default branch ("main"). | ||
logging.warn(f"Unable to fetch default branch for {self.repo_id}: {response.text}") | ||
branch = "main" | ||
return branch | ||
|
||
def download(self) -> bool: | ||
"""Clones the repository to the local directory, if it's not already cloned.""" | ||
if os.path.exists(self.local_path): | ||
# The repository is already cloned. | ||
return True | ||
|
||
if not self.is_public and not self.access_token: | ||
raise ValueError(f"Repo {self.repo_id} is private or doesn't exist.") | ||
|
||
if self.access_token: | ||
clone_url = f"{self.base_url}/{self.repo_id}.git" | ||
# Inject access token for authentication in the URL. | ||
clone_url = clone_url.replace("https://", f"https://oauth2:{self.access_token}@") | ||
else: | ||
clone_url = f"{self.base_url}/{self.repo_id}.git" | ||
|
||
try: | ||
if self.commit_hash: | ||
repo = Repo.clone_from(clone_url, self.local_path) | ||
repo.git.checkout(self.commit_hash) | ||
else: | ||
Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True) | ||
except GitCommandError as e: | ||
logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e) | ||
return False | ||
return True | ||
|
||
def _parse_filter_file(self, file_path: str) -> bool: | ||
"""Parses a file with files/directories/extensions to include/exclude. | ||
|
||
Lines are expected to be in the format: | ||
# Comment that will be ignored, or | ||
ext:.my-extension, or | ||
file:my-file.py, or | ||
dir:my-directory | ||
""" | ||
with open(file_path, "r") as f: | ||
lines = f.readlines() | ||
|
||
parsed_data = {"ext": [], "file": [], "dir": []} | ||
for line in lines: | ||
if line.startswith("#"): | ||
# This is a comment line. | ||
continue | ||
key, value = line.strip().split(":") | ||
if key in parsed_data: | ||
parsed_data[key].append(value) | ||
else: | ||
logging.error("Unrecognized key in line: %s. Skipping.", line) | ||
|
||
return parsed_data | ||
|
||
|
||
def _should_include(self, file_path: str) -> bool: | ||
"""Checks whether the file should be indexed.""" | ||
# Exclude symlinks. | ||
if os.path.islink(file_path): | ||
return False | ||
|
||
# Exclude hidden files and directories. | ||
if any(part.startswith(".") for part in file_path.split(os.path.sep)): | ||
return False | ||
|
||
if not self.inclusions and not self.exclusions: | ||
return True | ||
|
||
# Filter based on file extensions, file names and directory names. | ||
_, extension = os.path.splitext(file_path) | ||
extension = extension.lower() | ||
file_name = os.path.basename(file_path) | ||
dirs = os.path.dirname(file_path).split("/") | ||
|
||
if self.inclusions: | ||
return ( | ||
extension in self.inclusions.get("ext", []) | ||
or file_name in self.inclusions.get("file", []) | ||
or any(d in dirs for d in self.inclusions.get("dir", [])) | ||
) | ||
elif self.exclusions: | ||
return ( | ||
extension not in self.exclusions.get("ext", []) | ||
and file_name not in self.exclusions.get("file", []) | ||
and all(d not in dirs for d in self.exclusions.get("dir", [])) | ||
) | ||
return True | ||
|
||
def walk(self, get_content: bool = True) -> Generator[Tuple[Any, Dict], None, None]: | ||
"""Walks the local repository path and yields a tuple of (content, metadata) for each file. | ||
The filepath is relative to the root of the repository (e.g. "org/repo/your/file/path.py"). | ||
|
||
Args: | ||
get_content: When set to True, yields (content, metadata) tuples. When set to False, yields metadata only. | ||
""" | ||
# We will keep appending to these files during the iteration, so we need to clear them first. | ||
repo_name = self.repo_id.replace("/", "_") | ||
included_log_file = os.path.join(self.log_dir, f"included_{repo_name}.txt") | ||
excluded_log_file = os.path.join(self.log_dir, f"excluded_{repo_name}.txt") | ||
if os.path.exists(included_log_file): | ||
os.remove(included_log_file) | ||
logging.info("Logging included files at %s", included_log_file) | ||
if os.path.exists(excluded_log_file): | ||
os.remove(excluded_log_file) | ||
logging.info("Logging excluded files at %s", excluded_log_file) | ||
|
||
for root, _, files in os.walk(self.local_path): | ||
file_paths = [os.path.join(root, file) for file in files] | ||
included_file_paths = [f for f in file_paths if self._should_include(f)] | ||
|
||
with open(included_log_file, "a") as f: | ||
for path in included_file_paths: | ||
f.write(path + "\n") | ||
|
||
excluded_file_paths = set(file_paths).difference(set(included_file_paths)) | ||
with open(excluded_log_file, "a") as f: | ||
for path in excluded_file_paths: | ||
f.write(path + "\n") | ||
|
||
for file_path in included_file_paths: | ||
relative_file_path = file_path[len(self.local_dir) + 1 :] | ||
metadata = { | ||
"file_path": relative_file_path, | ||
"url": self.url_for_file(relative_file_path), | ||
} | ||
|
||
if not get_content: | ||
yield metadata | ||
continue | ||
|
||
contents = self.read_file(relative_file_path) | ||
if contents: | ||
yield contents, metadata | ||
|
||
def url_for_file(self, file_path: str) -> str: | ||
"""Converts a repository file path to a GitHub link.""" | ||
file_path = file_path[len(self.repo_id) + 1 :] | ||
return f"{self.base_url}/{self.repo_id}/blob/-/{self.default_branch}/{file_path}" | ||
|
||
def read_file(self, relative_file_path: str) -> str: | ||
"""Reads the contents of a file in the repository.""" | ||
absolute_file_path = os.path.join(self.local_dir, relative_file_path) | ||
with open(absolute_file_path, "r") as f: | ||
try: | ||
contents = f.read() | ||
return contents | ||
except UnicodeDecodeError: | ||
logging.warning("Unable to decode file %s.", absolute_file_path) | ||
return None | ||
|
||
def from_args(args: Dict): | ||
"""Creates a GitLabRepoManager from command-line arguments and clones the underlying repository.""" | ||
repo_manager = GitLabRepoManager( | ||
repo_id=args.repo_id, | ||
commit_hash=args.commit_hash, | ||
access_token=os.getenv("GITHUB_TOKEN"), | ||
local_dir=args.local_dir, | ||
inclusion_file=args.include, | ||
exclusion_file=args.exclude, | ||
base_url=args.base_url | ||
) | ||
success = repo_manager.download() | ||
if not success: | ||
raise ValueError( | ||
f"Unable to clone {args.repo_id}. Please check that it exists and you have access to it. " | ||
"For private repositories, please set the GITHUB_TOKEN variable in your environment." | ||
) | ||
return repo_manager | ||
|
||
|
||
def build_data_manager_from_args(args: Dict) -> DataManager: | ||
"""Creates a DataManager from command-line arguments.""" | ||
if args.git_provider == "gitlab": | ||
return GitLabRepoManager.from_args(args) | ||
return GitHubRepoManager.from_args(args) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please name the flag
--gitlab-base-url
to avoid confusion?