Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| from utils.ui import set_page_style | |
| from utils.auth import check_hf_token | |
| from utils.huggingface import prepare_training_config | |
| from utils.training import create_model_repo, upload_training_config, setup_training_script | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="Model Configuration - Gemma Fine-tuning", | |
| page_icon="π€", | |
| layout="wide" | |
| ) | |
| # Apply custom styling | |
| set_page_style() | |
| # Sidebar for authentication | |
| with st.sidebar: | |
| st.title("π€ Gemma Fine-tuning") | |
| st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png", | |
| use_column_width=True) | |
| # Authentication section | |
| st.subheader("π Authentication") | |
| hf_token = st.text_input("Hugging Face API Token", type="password", | |
| help="Enter your Hugging Face write token to enable model fine-tuning") | |
| auth_status = check_hf_token(hf_token) if hf_token else False | |
| if auth_status: | |
| st.success("Authenticated successfully!") | |
| elif hf_token: | |
| st.error("Invalid token. Please check and try again.") | |
| st.divider() | |
| st.caption("A simple UI for fine-tuning Gemma models") | |
| # Main content | |
| st.title("βοΈ Model Configuration") | |
| if not hf_token or not auth_status: | |
| st.warning("Please authenticate with your Hugging Face token in the sidebar first") | |
| st.stop() | |
| # Check if dataset repository is set | |
| if "dataset_repo" not in st.session_state: | |
| st.warning("Please upload a dataset first") | |
| st.page_link("pages/01_Dataset_Upload.py", label="Go to Dataset Upload", icon="π€") | |
| # For testing purposes, allow manual entry | |
| manual_repo = st.text_input("Or enter dataset repository name manually:") | |
| if manual_repo: | |
| st.session_state["dataset_repo"] = manual_repo | |
| else: | |
| st.stop() | |
| # Model configuration | |
| st.header("1. Select Gemma Model") | |
| model_version = st.radio( | |
| "Select Gemma model version", | |
| options=["google/gemma-2b", "google/gemma-7b"], | |
| horizontal=True, | |
| help="Choose the Gemma model size. 2B is faster, 7B is more capable." | |
| ) | |
| st.header("2. Training Configuration") | |
| # Output repository settings | |
| with st.expander("Output Repository Settings", expanded=True): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| output_repo_name = st.text_input( | |
| "Repository Name", | |
| value=f"gemma-finetuned-{pd.Timestamp.now().strftime('%Y%m%d')}", | |
| help="Name of the Hugging Face repository to store your fine-tuned model" | |
| ) | |
| with col2: | |
| is_private = st.toggle("Private Repository", value=True, | |
| help="Make your model repository private (recommended)") | |
| # Tabs for Basic and Advanced configuration | |
| basic_tab, advanced_tab = st.tabs(["Basic Configuration", "Advanced Configuration"]) | |
| with basic_tab: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| epochs = st.slider("Number of Epochs", min_value=1, max_value=10, value=3, | |
| help="Number of complete passes through the dataset") | |
| batch_size = st.slider("Batch Size", min_value=1, max_value=32, value=8, | |
| help="Number of examples processed together") | |
| with col2: | |
| learning_rate = st.number_input("Learning Rate", min_value=1e-6, max_value=1e-3, | |
| value=2e-5, format="%.1e", | |
| help="Step size for gradient updates") | |
| fp16_training = st.toggle("Use FP16 Training", value=True, | |
| help="Use 16-bit precision for faster training") | |
| with advanced_tab: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| lora_rank = st.slider("LoRA Rank", min_value=1, max_value=64, value=8, | |
| help="Rank of the LoRA matrices") | |
| lora_alpha = st.slider("LoRA Alpha", min_value=1, max_value=128, value=32, | |
| help="Scaling factor for LoRA") | |
| lora_dropout = st.slider("LoRA Dropout", min_value=0.0, max_value=0.5, value=0.05, step=0.01, | |
| help="Dropout probability for LoRA layers") | |
| with col2: | |
| weight_decay = st.number_input("Weight Decay", min_value=0.0, max_value=0.1, | |
| value=0.01, step=0.001, | |
| help="L2 regularization strength") | |
| gradient_accumulation = st.slider("Gradient Accumulation Steps", min_value=1, max_value=16, value=1, | |
| help="Number of steps to accumulate gradients") | |
| warmup_steps = st.slider("Warmup Steps", min_value=0, max_value=500, value=0, | |
| help="Steps of linear learning rate warmup") | |
| st.header("3. Training Summary") | |
| # Create hyperparameters dictionary | |
| hyperparams = { | |
| "epochs": epochs, | |
| "batch_size": batch_size, | |
| "learning_rate": learning_rate, | |
| "fp16": fp16_training, | |
| "lora_rank": lora_rank, | |
| "lora_alpha": lora_alpha, | |
| "lora_dropout": lora_dropout, | |
| "weight_decay": weight_decay, | |
| "gradient_accumulation": gradient_accumulation, | |
| "warmup_steps": warmup_steps, | |
| "max_steps": -1, # -1 means train for full epochs | |
| "max_grad_norm": 1.0 | |
| } | |
| # Display summary | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Selected Model") | |
| st.info(model_version) | |
| st.subheader("Dataset Repository") | |
| st.info(st.session_state["dataset_repo"]) | |
| with col2: | |
| st.subheader("Key Hyperparameters") | |
| st.write(f"Epochs: {epochs}") | |
| st.write(f"Batch Size: {batch_size}") | |
| st.write(f"Learning Rate: {learning_rate}") | |
| st.write(f"LoRA Rank: {lora_rank}") | |
| # Start training button | |
| st.header("4. Start Training") | |
| if st.button("Prepare and Launch Training", type="primary"): | |
| with st.spinner("Setting up training job..."): | |
| # First create the model repository | |
| success, repo_url = create_model_repo(output_repo_name, private=is_private) | |
| if not success: | |
| st.error(f"Failed to create model repository: {repo_url}") | |
| else: | |
| st.success(f"Created model repository: {output_repo_name}") | |
| # Prepare training configuration | |
| config = prepare_training_config( | |
| model_name=model_version, | |
| hyperparams=hyperparams, | |
| dataset_repo=st.session_state["dataset_repo"], | |
| output_repo=output_repo_name | |
| ) | |
| # Upload training configuration | |
| success, message = upload_training_config(config, output_repo_name) | |
| if not success: | |
| st.error(f"Failed to upload training configuration: {message}") | |
| else: | |
| st.success("Uploaded training configuration") | |
| # Setup training script | |
| success, message = setup_training_script(output_repo_name, config) | |
| if not success: | |
| st.error(f"Failed to setup training script: {message}") | |
| else: | |
| st.success("Uploaded training script") | |
| # Store in session state | |
| st.session_state["model_repo"] = output_repo_name | |
| st.session_state["model_version"] = model_version | |
| st.session_state["hyperparams"] = hyperparams | |
| # Success message and next step | |
| st.success("Training job prepared successfully! You can now monitor the training progress.") | |
| st.page_link("pages/03_Training_Monitor.py", label="Next: Monitor Training", icon="π") |