Skip to content

Commit

Permalink
Working mixtral support
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy committed Feb 14, 2024
1 parent 62cfb65 commit aeb0765
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
fail-fast: false
matrix:
config: ["codellama", "llama-2", "mistral"]
config: ["codellama", "llama-2", "mistral", "mixtral"]
env:
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
Expand Down
105 changes: 105 additions & 0 deletions config/mixtral.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
# This will be the path used for the data when it is saved to the Volume in the cloud.
- path: data.jsonl
ds_type: json
type:
# JSONL file contains question, context, answer fields per line.
# This gets mapped to instruction, input, output axolotl tags.
field_instruction: question
field_input: context
field_output: answer
# Format is used by axolotl to generate the prompt.
format: |-
[INST] Using the schema context below, generate a SQL query that answers the question.
{input}
{instruction} [/INST]
dataset_prepared_path:
val_set_size: 32
output_dir: ./lora-out

## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
# - lm_head.*
# - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts.*

model_config:
output_router_logits: true

adapter: qlora
lora_model_dir:

sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true

lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
#lora_target_modules:
# - gate
# - q_proj
# - k_proj
# - v_proj
# - o_proj
# - w1
# - w2
# - w3

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: /root/axolotl/deepspeed_configs/zero2.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
4 changes: 2 additions & 2 deletions src/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

APP_NAME = "example-axolotl"

# Axolotl image hash corresponding to 0.4.0 release
# Axolotl image hash corresponding to 0.4.0 release (2024-02-14)
AXOLOTL_REGISTRY_SHA = (
"af4d878e9fbc90c7ba30fa78ce4d6d95b1ccba398ab944efbd322d7c0d6313c8"
"d5b941ba2293534c01c23202c8fc459fd2a169871fa5e6c45cb00f363d474b6a"
)

axolotl_image = (
Expand Down
6 changes: 3 additions & 3 deletions src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ def get_model_choices():

@stub.local_entrypoint()
def main():
dir = os.path.dirname(__file__)
with open(f"{dir}/config.yml", "r") as cfg, open(
f"{dir}/my_data.jsonl", "r"
parent = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
with open(f"{parent}/config/mixtral.yml", "r") as cfg, open(
f"{parent}/data/sqlqa.jsonl", "r"
) as data:
handle = gui.spawn(cfg.read(), data.read())
url = stub.q.get()
Expand Down
3 changes: 1 addition & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
)

N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_MEM = int(os.environ.get("GPU_MEM", 40))
GPU_CONFIG = modal.gpu.A100(count=N_GPUS, memory=GPU_MEM)
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)


def print_common_training_issues(config):
Expand Down

0 comments on commit aeb0765

Please sign in to comment.