Skip to content

Commit

Permalink
Add inference to CI
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy committed Feb 15, 2024
1 parent 9752a6c commit 88d78a4
Show file tree
Hide file tree
Showing 8 changed files with 430 additions and 13 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ jobs:
- name: Check training results
run: |
python ci/check_loss.py
- name: Check inference results
run: |
python ci/check_inference.py
21 changes: 21 additions & 0 deletions ci/check_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import subprocess


if __name__ == "__main__":

with open(".last_run_name", "r") as f:
run_name = f.read().strip()

prompt = """[INST] Using the schema context below, generate a SQL query that answers the question.
CREATE TABLE head (age INTEGER)
How many heads of the departments are older than 56 ? [/INST] """

p = subprocess.Popen(["modal", "run", "src.inference", "--run-folder", f"/runs/{run_name}", "--prompt", prompt], stdout=subprocess.PIPE)
output = ""

for line in iter(p.stdout.readline, b''):
output += line.decode()
print(line.decode())

print("Asserting that the output contains the expected SQL query")
assert "[SQL] SELECT" in output and "[/SQL]" in output
5 changes: 4 additions & 1 deletion ci/check_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,8 @@
train_loss = float(results["TrainingLoss"].iloc[-1])
val_loss = float(results["ValidationLoss"].iloc[-1])

# Arbitrary threshold
max_loss = 10 if b"Mixtral" in contents else 0.25

print(f"Loss: {train_loss:.2f} (training), {val_loss:.2f} (validation)")
sys.exit(val_loss > 0.25) # Arbitrary threshold
sys.exit(val_loss > max_loss)
9 changes: 7 additions & 2 deletions ci/prep_for_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@
@click.option("--data")
def main(config: str, data: str):
"""Set the config to train for only one epoch and truncate the dataset."""
train_set_size = 1000
val_set_size = 64
with open(config) as f:
cfg = yaml.safe_load(f.read())

if cfg["sample_packing"]:
train_set_size = 2048
else:
train_set_size = 1024
val_set_size = 64

cfg["val_set_size"] = val_set_size
cfg["num_epochs"] = 1
cfg.pop("eval_steps", None) # Evaluate once at the end of the epoch
Expand Down
4 changes: 2 additions & 2 deletions config/mixtral.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
base_model: mistralai/Mixtral-8x7B-v0.1
base_model: mistralai/Mixtral-8x7B-Instruct-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
Expand Down Expand Up @@ -69,7 +69,7 @@ wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 16
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
Expand Down
104 changes: 104 additions & 0 deletions config/mixtral_out_of_box.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
# This will be the path used for the data when it is saved to the Volume in the cloud.
- path: data.jsonl
ds_type: json
type:
# JSONL file contains question, context, answer fields per line.
# This gets mapped to instruction, input, output axolotl tags.
field_instruction: question
field_input: context
field_output: answer
# Format is used by axolotl to generate the prompt.
format: |-
[INST] Using the schema context below, generate a SQL query that answers the question.
{input}
{instruction} [/INST]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./qlora-out

## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
# - lm_head.*
# - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts.*

model_config:
output_router_logits: true

adapter: qlora
lora_model_dir:

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
#lora_target_modules:
# - gate
# - q_proj
# - k_proj
# - v_proj
# - o_proj
# - w1
# - w2
# - w3

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: /root/axolotl/deepspeed_configs/zero2.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
Loading

0 comments on commit 88d78a4

Please sign in to comment.