File size: 7,833 Bytes
b619545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import streamlit as st
import pandas as pd
from utils.ui import set_page_style
from utils.auth import check_hf_token
from utils.huggingface import prepare_training_config
from utils.training import create_model_repo, upload_training_config, setup_training_script

# Set page configuration
st.set_page_config(
    page_title="Model Configuration - Gemma Fine-tuning",
    page_icon="πŸ€–",
    layout="wide"
)

# Apply custom styling
set_page_style()

# Sidebar for authentication
with st.sidebar:
    st.title("πŸ€– Gemma Fine-tuning")
    st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png", 
             use_column_width=True)
    
    # Authentication section
    st.subheader("πŸ”‘ Authentication")
    hf_token = st.text_input("Hugging Face API Token", type="password", 
                             help="Enter your Hugging Face write token to enable model fine-tuning")
    auth_status = check_hf_token(hf_token) if hf_token else False
    
    if auth_status:
        st.success("Authenticated successfully!")
    elif hf_token:
        st.error("Invalid token. Please check and try again.")
    
    st.divider()
    st.caption("A simple UI for fine-tuning Gemma models")

# Main content
st.title("βš™οΈ Model Configuration")

if not hf_token or not auth_status:
    st.warning("Please authenticate with your Hugging Face token in the sidebar first")
    st.stop()

# Check if dataset repository is set
if "dataset_repo" not in st.session_state:
    st.warning("Please upload a dataset first")
    st.page_link("pages/01_Dataset_Upload.py", label="Go to Dataset Upload", icon="πŸ“€")
    
    # For testing purposes, allow manual entry
    manual_repo = st.text_input("Or enter dataset repository name manually:")
    if manual_repo:
        st.session_state["dataset_repo"] = manual_repo
    else:
        st.stop()

# Model configuration
st.header("1. Select Gemma Model")

model_version = st.radio(
    "Select Gemma model version",
    options=["google/gemma-2b", "google/gemma-7b"],
    horizontal=True,
    help="Choose the Gemma model size. 2B is faster, 7B is more capable."
)

st.header("2. Training Configuration")

# Output repository settings
with st.expander("Output Repository Settings", expanded=True):
    col1, col2 = st.columns(2)
    
    with col1:
        output_repo_name = st.text_input(
            "Repository Name",
            value=f"gemma-finetuned-{pd.Timestamp.now().strftime('%Y%m%d')}",
            help="Name of the Hugging Face repository to store your fine-tuned model"
        )
    
    with col2:
        is_private = st.toggle("Private Repository", value=True, 
                               help="Make your model repository private (recommended)")

# Tabs for Basic and Advanced configuration
basic_tab, advanced_tab = st.tabs(["Basic Configuration", "Advanced Configuration"])

with basic_tab:
    col1, col2 = st.columns(2)
    
    with col1:
        epochs = st.slider("Number of Epochs", min_value=1, max_value=10, value=3, 
                          help="Number of complete passes through the dataset")
        
        batch_size = st.slider("Batch Size", min_value=1, max_value=32, value=8,
                             help="Number of examples processed together")
    
    with col2:
        learning_rate = st.number_input("Learning Rate", min_value=1e-6, max_value=1e-3, 
                                       value=2e-5, format="%.1e",
                                       help="Step size for gradient updates")
        
        fp16_training = st.toggle("Use FP16 Training", value=True,
                                help="Use 16-bit precision for faster training")

with advanced_tab:
    col1, col2 = st.columns(2)
    
    with col1:
        lora_rank = st.slider("LoRA Rank", min_value=1, max_value=64, value=8, 
                            help="Rank of the LoRA matrices")
        
        lora_alpha = st.slider("LoRA Alpha", min_value=1, max_value=128, value=32,
                             help="Scaling factor for LoRA")
        
        lora_dropout = st.slider("LoRA Dropout", min_value=0.0, max_value=0.5, value=0.05, step=0.01,
                               help="Dropout probability for LoRA layers")
    
    with col2:
        weight_decay = st.number_input("Weight Decay", min_value=0.0, max_value=0.1, 
                                      value=0.01, step=0.001,
                                      help="L2 regularization strength")
        
        gradient_accumulation = st.slider("Gradient Accumulation Steps", min_value=1, max_value=16, value=1,
                                        help="Number of steps to accumulate gradients")
        
        warmup_steps = st.slider("Warmup Steps", min_value=0, max_value=500, value=0,
                               help="Steps of linear learning rate warmup")

st.header("3. Training Summary")

# Create hyperparameters dictionary
hyperparams = {
    "epochs": epochs,
    "batch_size": batch_size,
    "learning_rate": learning_rate,
    "fp16": fp16_training,
    "lora_rank": lora_rank,
    "lora_alpha": lora_alpha,
    "lora_dropout": lora_dropout,
    "weight_decay": weight_decay,
    "gradient_accumulation": gradient_accumulation,
    "warmup_steps": warmup_steps,
    "max_steps": -1,  # -1 means train for full epochs
    "max_grad_norm": 1.0
}

# Display summary
col1, col2 = st.columns(2)

with col1:
    st.subheader("Selected Model")
    st.info(model_version)
    
    st.subheader("Dataset Repository")
    st.info(st.session_state["dataset_repo"])

with col2:
    st.subheader("Key Hyperparameters")
    st.write(f"Epochs: {epochs}")
    st.write(f"Batch Size: {batch_size}")
    st.write(f"Learning Rate: {learning_rate}")
    st.write(f"LoRA Rank: {lora_rank}")

# Start training button
st.header("4. Start Training")

if st.button("Prepare and Launch Training", type="primary"):
    with st.spinner("Setting up training job..."):
        # First create the model repository
        success, repo_url = create_model_repo(output_repo_name, private=is_private)
        
        if not success:
            st.error(f"Failed to create model repository: {repo_url}")
        else:
            st.success(f"Created model repository: {output_repo_name}")
            
            # Prepare training configuration
            config = prepare_training_config(
                model_name=model_version,
                hyperparams=hyperparams,
                dataset_repo=st.session_state["dataset_repo"],
                output_repo=output_repo_name
            )
            
            # Upload training configuration
            success, message = upload_training_config(config, output_repo_name)
            
            if not success:
                st.error(f"Failed to upload training configuration: {message}")
            else:
                st.success("Uploaded training configuration")
                
                # Setup training script
                success, message = setup_training_script(output_repo_name, config)
                
                if not success:
                    st.error(f"Failed to setup training script: {message}")
                else:
                    st.success("Uploaded training script")
                    
                    # Store in session state
                    st.session_state["model_repo"] = output_repo_name
                    st.session_state["model_version"] = model_version
                    st.session_state["hyperparams"] = hyperparams
                    
                    # Success message and next step
                    st.success("Training job prepared successfully! You can now monitor the training progress.")
                    st.page_link("pages/03_Training_Monitor.py", label="Next: Monitor Training", icon="πŸ“Š")