Spaces:
Sleeping
Sleeping
File size: 7,833 Bytes
b619545 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import streamlit as st
import pandas as pd
from utils.ui import set_page_style
from utils.auth import check_hf_token
from utils.huggingface import prepare_training_config
from utils.training import create_model_repo, upload_training_config, setup_training_script
# Set page configuration
st.set_page_config(
page_title="Model Configuration - Gemma Fine-tuning",
page_icon="π€",
layout="wide"
)
# Apply custom styling
set_page_style()
# Sidebar for authentication
with st.sidebar:
st.title("π€ Gemma Fine-tuning")
st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
use_column_width=True)
# Authentication section
st.subheader("π Authentication")
hf_token = st.text_input("Hugging Face API Token", type="password",
help="Enter your Hugging Face write token to enable model fine-tuning")
auth_status = check_hf_token(hf_token) if hf_token else False
if auth_status:
st.success("Authenticated successfully!")
elif hf_token:
st.error("Invalid token. Please check and try again.")
st.divider()
st.caption("A simple UI for fine-tuning Gemma models")
# Main content
st.title("βοΈ Model Configuration")
if not hf_token or not auth_status:
st.warning("Please authenticate with your Hugging Face token in the sidebar first")
st.stop()
# Check if dataset repository is set
if "dataset_repo" not in st.session_state:
st.warning("Please upload a dataset first")
st.page_link("pages/01_Dataset_Upload.py", label="Go to Dataset Upload", icon="π€")
# For testing purposes, allow manual entry
manual_repo = st.text_input("Or enter dataset repository name manually:")
if manual_repo:
st.session_state["dataset_repo"] = manual_repo
else:
st.stop()
# Model configuration
st.header("1. Select Gemma Model")
model_version = st.radio(
"Select Gemma model version",
options=["google/gemma-2b", "google/gemma-7b"],
horizontal=True,
help="Choose the Gemma model size. 2B is faster, 7B is more capable."
)
st.header("2. Training Configuration")
# Output repository settings
with st.expander("Output Repository Settings", expanded=True):
col1, col2 = st.columns(2)
with col1:
output_repo_name = st.text_input(
"Repository Name",
value=f"gemma-finetuned-{pd.Timestamp.now().strftime('%Y%m%d')}",
help="Name of the Hugging Face repository to store your fine-tuned model"
)
with col2:
is_private = st.toggle("Private Repository", value=True,
help="Make your model repository private (recommended)")
# Tabs for Basic and Advanced configuration
basic_tab, advanced_tab = st.tabs(["Basic Configuration", "Advanced Configuration"])
with basic_tab:
col1, col2 = st.columns(2)
with col1:
epochs = st.slider("Number of Epochs", min_value=1, max_value=10, value=3,
help="Number of complete passes through the dataset")
batch_size = st.slider("Batch Size", min_value=1, max_value=32, value=8,
help="Number of examples processed together")
with col2:
learning_rate = st.number_input("Learning Rate", min_value=1e-6, max_value=1e-3,
value=2e-5, format="%.1e",
help="Step size for gradient updates")
fp16_training = st.toggle("Use FP16 Training", value=True,
help="Use 16-bit precision for faster training")
with advanced_tab:
col1, col2 = st.columns(2)
with col1:
lora_rank = st.slider("LoRA Rank", min_value=1, max_value=64, value=8,
help="Rank of the LoRA matrices")
lora_alpha = st.slider("LoRA Alpha", min_value=1, max_value=128, value=32,
help="Scaling factor for LoRA")
lora_dropout = st.slider("LoRA Dropout", min_value=0.0, max_value=0.5, value=0.05, step=0.01,
help="Dropout probability for LoRA layers")
with col2:
weight_decay = st.number_input("Weight Decay", min_value=0.0, max_value=0.1,
value=0.01, step=0.001,
help="L2 regularization strength")
gradient_accumulation = st.slider("Gradient Accumulation Steps", min_value=1, max_value=16, value=1,
help="Number of steps to accumulate gradients")
warmup_steps = st.slider("Warmup Steps", min_value=0, max_value=500, value=0,
help="Steps of linear learning rate warmup")
st.header("3. Training Summary")
# Create hyperparameters dictionary
hyperparams = {
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": learning_rate,
"fp16": fp16_training,
"lora_rank": lora_rank,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout,
"weight_decay": weight_decay,
"gradient_accumulation": gradient_accumulation,
"warmup_steps": warmup_steps,
"max_steps": -1, # -1 means train for full epochs
"max_grad_norm": 1.0
}
# Display summary
col1, col2 = st.columns(2)
with col1:
st.subheader("Selected Model")
st.info(model_version)
st.subheader("Dataset Repository")
st.info(st.session_state["dataset_repo"])
with col2:
st.subheader("Key Hyperparameters")
st.write(f"Epochs: {epochs}")
st.write(f"Batch Size: {batch_size}")
st.write(f"Learning Rate: {learning_rate}")
st.write(f"LoRA Rank: {lora_rank}")
# Start training button
st.header("4. Start Training")
if st.button("Prepare and Launch Training", type="primary"):
with st.spinner("Setting up training job..."):
# First create the model repository
success, repo_url = create_model_repo(output_repo_name, private=is_private)
if not success:
st.error(f"Failed to create model repository: {repo_url}")
else:
st.success(f"Created model repository: {output_repo_name}")
# Prepare training configuration
config = prepare_training_config(
model_name=model_version,
hyperparams=hyperparams,
dataset_repo=st.session_state["dataset_repo"],
output_repo=output_repo_name
)
# Upload training configuration
success, message = upload_training_config(config, output_repo_name)
if not success:
st.error(f"Failed to upload training configuration: {message}")
else:
st.success("Uploaded training configuration")
# Setup training script
success, message = setup_training_script(output_repo_name, config)
if not success:
st.error(f"Failed to setup training script: {message}")
else:
st.success("Uploaded training script")
# Store in session state
st.session_state["model_repo"] = output_repo_name
st.session_state["model_version"] = model_version
st.session_state["hyperparams"] = hyperparams
# Success message and next step
st.success("Training job prepared successfully! You can now monitor the training progress.")
st.page_link("pages/03_Training_Monitor.py", label="Next: Monitor Training", icon="π") |