Skip to content

Commit

Permalink
feat: run lint check on every PR (#43)
Browse files Browse the repository at this point in the history
This commit includes a lint check on
every PR.
  • Loading branch information
ccamacho authored Jun 4, 2024
1 parent 9caa617 commit 833246f
Show file tree
Hide file tree
Showing 14 changed files with 230 additions and 96 deletions.
12 changes: 9 additions & 3 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
[flake8]
per-file-ignores =
# line too long
*py: E501,
extend-ignore =
H101,
E501
exclude =
.git,
.github,
__pycache__,
generation_pb2.py,
generation_pb2_grpc.py
41 changes: 41 additions & 0 deletions .github/workflows/linters.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
---
name: linters
on:
pull_request:
push:
# Run the functional tests every 8 hours.
# This will help to identify faster if
# there is a CI failure related to a
# change in any dependency.
schedule:
- cron: '0 */8 * * *'

jobs:
build:
runs-on: ubuntu-latest
strategy:
max-parallel: 4
matrix:
python-version: [3.9]
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Run a commitlint
uses: wagoid/commitlint-github-action@v2
- name: Install dependencies
run: |
sudo apt-get update -y
sudo apt-get install jq libyaml-dev git build-essential findutils -y
sudo python3 -m pip install -r test-requirements.txt
sudo python3 -m pip install --upgrade --ignore-installed PyYAML
sudo python3 -m pip install --upgrade pip
sudo python3 -m pip install --upgrade virtualenv
sudo python3 -m pip install --upgrade setuptools
- name: Run Linters
run: |
tox -e linters
63 changes: 31 additions & 32 deletions ci/azure/azure-pipelines.yml
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
trigger:
- main
- main

variables:
- name: UBUNTU_DEP_LIST
value: "podman pylint python3-boto3 python3-deprecation python3-flake8 python3-numpy python3-pip yajl-tools yamllint"
- name: MACOS_DEP_LIST
value: "ghz yajl python"
- name: GHZ_URL
value: "https://github.com/dagrayvid/ghz"
- name: UBUNTU_DEP_LIST
value: "podman pylint python3-boto3 python3-deprecation python3-flake8 python3-numpy python3-pip yajl-tools yamllint"
- name: MACOS_DEP_LIST
value: "ghz yajl python"
- name: GHZ_URL
value: "https://github.com/dagrayvid/ghz"

stages:
- stage: lint
jobs:
- job: macOS_lint
pool:
vmImage: "macOS-latest"
steps:
- script: |
brew install ${MACOS_DEP_LIST}
displayName: "install deps"
- script: |
pip install -r requirements.txt
- script: |
sh validate.sh || true
displayName: "macOS linter - pip"
- job: Linux_lint
pool:
vmImage: ubuntu-latest
steps:
- script: |
sudo apt-get install -y ${UBUNTU_DEP_LIST}
displayName: "install deps"
- script: |
sh validate.sh || true
displayName: "Linux linter - Ubuntu packages"
- stage: lint
jobs:
- job: macOS_lint
pool:
vmImage: "macOS-latest"
steps:
- script: |
brew install ${MACOS_DEP_LIST}
displayName: "install deps"
- script: |
pip install -r requirements.txt
- script: |
sh validate.sh || true
displayName: "macOS linter - pip"
- job: Linux_lint
pool:
vmImage: ubuntu-latest
steps:
- script: |
sudo apt-get install -y ${UBUNTU_DEP_LIST}
displayName: "install deps"
- script: |
sh validate.sh || true
displayName: "Linux linter - Ubuntu packages"
61 changes: 33 additions & 28 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Dataset class."""
import json
import logging
import random
Expand All @@ -6,6 +7,8 @@


class Dataset:
"""Dataset class."""

def __init__(self,
file,
model_name="",
Expand All @@ -16,24 +19,25 @@ def __init__(self,
max_output_tokens=4096,
max_sequence_tokens=32000
):
"""Init method."""
logging.info("Initializing dataset with %s", locals())
self.dataset_list = [input for input in
initialize_dataset(
file,
model_name=model_name,
max_queries=max_queries,
min_input_tokens=min_input_tokens,
max_input_tokens=max_input_tokens,
min_output_tokens=min_output_tokens,
max_output_tokens=max_output_tokens,
max_sequence_tokens=max_sequence_tokens,
)
]
initialize_dataset(file,
model_name=model_name,
max_queries=max_queries,
min_input_tokens=min_input_tokens,
max_input_tokens=max_input_tokens,
min_output_tokens=min_output_tokens,
max_output_tokens=max_output_tokens,
max_sequence_tokens=max_sequence_tokens,
)
]
if len(self.dataset_list) < 4:
logging.warning("Total dataset is %s elements, check filters!", len(self.dataset_list))
self.index = 0

def get_next_n_queries(self, n):
"""Get the N next queries."""
max_index = len(self.dataset_list)
next_n_indices = [i % max_index for i in range(self.index, self.index + n)]
self.index = (self.index + n) % max_index
Expand All @@ -50,7 +54,7 @@ def initialize_dataset(
max_output_tokens=4096,
max_sequence_tokens=32000
):

"""Initialize the dataset."""
prompt_format = get_format_string(model_name)
with open(filename, "r", encoding="utf-8") as file:
total_queries = 0
Expand Down Expand Up @@ -79,16 +83,16 @@ def initialize_dataset(
continue
# TODO exit or just skip here?
token_lengths_ok = filter_token_lengths(input_tokens,
output_tokens,
min_input_tokens,
max_input_tokens,
min_output_tokens,
max_output_tokens,
max_sequence_tokens)
output_tokens,
min_input_tokens,
max_input_tokens,
min_output_tokens,
max_output_tokens,
max_sequence_tokens)
if (token_lengths_ok):
input_data = {
"text": prompt_format.format(prompt=prompt,
system_prompt=system_prompt),
"text": prompt_format.format(prompt=prompt,
system_prompt=system_prompt),
"input_id": input_id,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
Expand All @@ -98,24 +102,25 @@ def initialize_dataset(
if total_queries >= max_queries:
break


def filter_token_lengths(input_tokens,
output_tokens,
min_input_tokens,
max_input_tokens,
min_output_tokens,
max_output_tokens,
min_output_tokens,
max_output_tokens,
max_sequence_tokens):

"""Filter the tokens by length."""
sequence_tokens = input_tokens + output_tokens
return (output_tokens > min_output_tokens
and output_tokens < max_output_tokens
and input_tokens < max_input_tokens
and input_tokens > min_input_tokens
and sequence_tokens < max_sequence_tokens)

and output_tokens < max_output_tokens
and input_tokens < max_input_tokens
and input_tokens > min_input_tokens
and sequence_tokens < max_sequence_tokens)


def get_format_string(model_name):
"""Get the format string."""
known_system_prompts = {
"llama": "<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{prompt} [/INST]",
"flan": "Question: {prompt}\n\nAnswer:",
Expand Down
2 changes: 1 addition & 1 deletion generation_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 0 additions & 6 deletions linter.sh

This file was deleted.

20 changes: 14 additions & 6 deletions load_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""Main llm-load-test CLI entrypoint."""

import logging
import logging.handlers
import multiprocessing as mp
import sys
import time
from user import User

from dataset import Dataset

import logging_utils

import utils
from dataset import Dataset
from user import User


def run_main_process(concurrency, duration, dataset, dataset_q, stop_q):
"""Run the main process."""
logging.info("Test from main process")

# Initialize the dataset_queue with 4*concurrency requests
Expand Down Expand Up @@ -49,7 +54,7 @@ def run_warmup(
warmup_reqs=10,
warmup_timeout=60,
):

"""Run the warmup tasks."""
# Put requests in warmup queue
for query in dataset.get_next_n_queries(warmup_reqs):
dataset_q.put(query)
Expand Down Expand Up @@ -94,6 +99,7 @@ def run_warmup(


def gather_results(results_pipes):
"""Get the results."""
# Receive all results from each processes results_pipe
logging.debug("Receiving results from user processes")
results_list = []
Expand All @@ -104,6 +110,7 @@ def gather_results(results_pipes):


def exit_gracefully(procs, warmup_q, dataset_q, stop_q, logger_q, log_reader_thread, code):
"""Exit gracefully."""
# Signal users to stop sending requests
if warmup_q is not None and warmup_q.empty():
warmup_q.put(None)
Expand All @@ -129,20 +136,21 @@ def exit_gracefully(procs, warmup_q, dataset_q, stop_q, logger_q, log_reader_thr


def main(args):
"""Load test CLI entrypoint."""
args = utils.parse_args(args)

mp_ctx = mp.get_context("spawn")
logger_q = mp_ctx.Queue()
log_reader_thread = logging_utils.init_logging(args.log_level, logger_q)

## Create processes and their Users
# Create processes and their Users
stop_q = mp_ctx.Queue(1)
dataset_q = mp_ctx.Queue()
warmup_q = mp_ctx.Queue(1)
procs = []
results_pipes = []

#### Parse config
# Parse config
logging.debug("Parsing YAML config file %s", args.config)
concurrency, duration, plugin = 0, 0, None
try:
Expand Down Expand Up @@ -205,7 +213,7 @@ def main(args):
results_list = gather_results(results_pipes)

utils.write_output(config, results_list)

except Exception:
logging.exception("Unexpected exception in main process")
exit_gracefully(procs, warmup_q, dataset_q, stop_q, logger_q, log_reader_thread, 1)
Expand Down
4 changes: 4 additions & 0 deletions logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Main logging class."""

import logging
import threading


def logger_thread(q):
"""Get the logger thread."""
while True:
record = q.get()
if record is None:
Expand All @@ -12,6 +15,7 @@ def logger_thread(q):


def init_logging(log_level, logger_q):
"""Initialize the logger."""
logging_format = (
"%(asctime)s %(levelname)-8s %(name)s %(processName)-10s %(message)s"
)
Expand Down
10 changes: 9 additions & 1 deletion result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Main result class."""


class RequestResult:
"""Request result class."""

def __init__(self, user_id, input_id, input_tokens):
"""Init method."""
self.user_id = user_id
self.input_id = input_id
self.input_tokens = input_tokens
Expand All @@ -20,12 +26,14 @@ def __init__(self, user_id, input_id, input_tokens):
self.error_text = None

def asdict(self):
"""Return a dictionary."""
# Maybe later we will want to only include some fields in the results,
# but for now, this just puts all object fields in a dict.
return vars(self)

# Fill in calculated fields like response_time, tt_ack, ttft, tpot.
def calculate_results(self):
"""Calculate the results."""
# Only calculate results if response is error-free.
if self.error_code is None and self.error_text is None:
# response_time in seconds
Expand All @@ -41,7 +49,7 @@ def calculate_results(self):
self.itl = (1000 * (self.end_time - self.first_token_time)) / (
self.output_tokens - 1
) # Inter-token latency in ms. Distinct from TPOT as it excludes the first token time.

self.tpot = (
self.response_time / self.output_tokens
) # Time per output token in ms
Loading

0 comments on commit 833246f

Please sign in to comment.