Spaces:
Sleeping
Sleeping
Upload 15 files
Browse files- Dockerfile +20 -0
- README.md +103 -0
- app.py +149 -0
- app.sh +4 -0
- data/sample_dataset.jsonl +10 -0
- packages.txt +3 -0
- pages/01_Dataset_Upload.py +218 -0
- pages/02_Model_Configuration.py +208 -0
- pages/03_Training_Monitor.py +218 -0
- pages/04_Evaluation.py +246 -0
- requirements.txt +15 -0
- utils/auth.py +50 -0
- utils/huggingface.py +152 -0
- utils/training.py +347 -0
- utils/ui.py +163 -0
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Copy requirements first for better caching
|
| 6 |
+
COPY requirements.txt .
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
+
|
| 9 |
+
# Copy the rest of the application
|
| 10 |
+
COPY . .
|
| 11 |
+
|
| 12 |
+
# Set environment variables
|
| 13 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 14 |
+
PYTHONDONTWRITEBYTECODE=1
|
| 15 |
+
|
| 16 |
+
# Expose port for Streamlit
|
| 17 |
+
EXPOSE 8501
|
| 18 |
+
|
| 19 |
+
# Command to run the application
|
| 20 |
+
CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
README.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Gemma Fine-tuning UI
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.30.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Gemma Fine-tuning UI
|
| 13 |
+
|
| 14 |
+
A web-based user interface for fine-tuning Google's Gemma models using Hugging Face infrastructure.
|
| 15 |
+
|
| 16 |
+

|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- **Dataset Upload**: Upload and preprocess your custom training data in CSV, JSON, or JSONL format
|
| 21 |
+
- **Model Configuration**: Configure Gemma model version and hyperparameters
|
| 22 |
+
- **Training Management**: Start, monitor, and manage fine-tuning jobs
|
| 23 |
+
- **Evaluation**: Test your fine-tuned model with interactive generation
|
| 24 |
+
- **Export Options**: Use your model directly from Hugging Face or export in various formats
|
| 25 |
+
|
| 26 |
+
## Installation
|
| 27 |
+
|
| 28 |
+
1. Clone this repository:
|
| 29 |
+
```bash
|
| 30 |
+
git clone https://github.com/yourusername/gemma-finetuning-ui.git
|
| 31 |
+
cd gemma-finetuning-ui
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
2. Create a virtual environment and install dependencies:
|
| 35 |
+
```bash
|
| 36 |
+
python -m venv venv
|
| 37 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 38 |
+
pip install -r requirements.txt
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
3. Run the application:
|
| 42 |
+
```bash
|
| 43 |
+
streamlit run app.py
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Usage
|
| 47 |
+
|
| 48 |
+
1. **Authentication**: Provide your Hugging Face API token with write permissions
|
| 49 |
+
2. **Dataset Preparation**: Upload your dataset and configure column mappings
|
| 50 |
+
3. **Model Selection**: Choose between Gemma 2B or 7B and customize training parameters
|
| 51 |
+
4. **Training**: Start the fine-tuning process and monitor progress
|
| 52 |
+
5. **Evaluation**: Test your fine-tuned model with custom prompts
|
| 53 |
+
6. **Deployment**: Export or directly use your model from Hugging Face
|
| 54 |
+
|
| 55 |
+
## Hugging Face Spaces Deployment
|
| 56 |
+
|
| 57 |
+
This application is designed to be deployed easily on Hugging Face Spaces:
|
| 58 |
+
|
| 59 |
+
1. Create a new Space on [Hugging Face Spaces](https://huggingface.co/spaces)
|
| 60 |
+
2. Select Streamlit as the SDK
|
| 61 |
+
3. Connect your GitHub repository or upload the files directly
|
| 62 |
+
4. The Space will automatically detect and install the requirements
|
| 63 |
+
|
| 64 |
+
## Requirements
|
| 65 |
+
|
| 66 |
+
- Python 3.8+
|
| 67 |
+
- Streamlit 1.30.0+
|
| 68 |
+
- Hugging Face Account with API token
|
| 69 |
+
- For training: GPU access (recommended)
|
| 70 |
+
|
| 71 |
+
## Project Structure
|
| 72 |
+
|
| 73 |
+
```
|
| 74 |
+
.
|
| 75 |
+
├── app.py # Main Streamlit application
|
| 76 |
+
├── pages/ # Multi-page app components
|
| 77 |
+
│ ├── 01_Dataset_Upload.py
|
| 78 |
+
│ ├── 02_Model_Configuration.py
|
| 79 |
+
│ ├── 03_Training_Monitor.py
|
| 80 |
+
│ └── 04_Evaluation.py
|
| 81 |
+
├── utils/ # Utility functions
|
| 82 |
+
│ ├── auth.py # Authentication utilities
|
| 83 |
+
│ ├── huggingface.py # Hugging Face API integration
|
| 84 |
+
│ ├── training.py # Training utilities
|
| 85 |
+
│ └── ui.py # UI components and styling
|
| 86 |
+
├── data/ # Sample data and uploads
|
| 87 |
+
└── requirements.txt # Project dependencies
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## License
|
| 91 |
+
|
| 92 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
| 93 |
+
|
| 94 |
+
## Acknowledgements
|
| 95 |
+
|
| 96 |
+
- [Google Gemma Models](https://huggingface.co/google/gemma-7b)
|
| 97 |
+
- [Hugging Face Transformers](https://huggingface.co/docs/transformers/index)
|
| 98 |
+
- [Streamlit](https://streamlit.io/)
|
| 99 |
+
- [PEFT - Parameter-Efficient Fine-Tuning](https://github.com/huggingface/peft)
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
Developed as a simplified interface for fine-tuning Gemma models with Hugging Face.
|
app.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from utils.auth import check_hf_token
|
| 3 |
+
from utils.ui import set_page_style
|
| 4 |
+
|
| 5 |
+
# Set page configuration
|
| 6 |
+
st.set_page_config(
|
| 7 |
+
page_title="Gemma Fine-tuning UI",
|
| 8 |
+
page_icon="🤖",
|
| 9 |
+
layout="wide",
|
| 10 |
+
initial_sidebar_state="expanded"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
# Apply custom styling
|
| 14 |
+
set_page_style()
|
| 15 |
+
|
| 16 |
+
# Sidebar for navigation and authentication
|
| 17 |
+
with st.sidebar:
|
| 18 |
+
st.title("🤖 Gemma Fine-tuning")
|
| 19 |
+
st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
|
| 20 |
+
use_column_width=True)
|
| 21 |
+
|
| 22 |
+
# Authentication section
|
| 23 |
+
st.subheader("🔑 Authentication")
|
| 24 |
+
hf_token = st.text_input("Hugging Face API Token", type="password",
|
| 25 |
+
help="Enter your Hugging Face write token to enable model fine-tuning")
|
| 26 |
+
auth_status = check_hf_token(hf_token) if hf_token else False
|
| 27 |
+
|
| 28 |
+
if auth_status:
|
| 29 |
+
st.success("Authenticated successfully!")
|
| 30 |
+
elif hf_token:
|
| 31 |
+
st.error("Invalid token. Please check and try again.")
|
| 32 |
+
|
| 33 |
+
st.divider()
|
| 34 |
+
st.markdown("**Navigation:**")
|
| 35 |
+
if auth_status:
|
| 36 |
+
st.page_link("pages/01_Dataset_Upload.py", label="1️⃣ Dataset Upload", icon="📤")
|
| 37 |
+
st.page_link("pages/02_Model_Configuration.py", label="2️⃣ Model Configuration", icon="⚙️")
|
| 38 |
+
st.page_link("pages/03_Training_Monitor.py", label="3️⃣ Training Monitor", icon="📊")
|
| 39 |
+
st.page_link("pages/04_Evaluation.py", label="4️⃣ Model Evaluation", icon="🔍")
|
| 40 |
+
|
| 41 |
+
st.caption("A simple UI for fine-tuning Gemma models")
|
| 42 |
+
|
| 43 |
+
# Main content
|
| 44 |
+
st.title("Welcome to Gemma Fine-tuning UI")
|
| 45 |
+
|
| 46 |
+
if not hf_token:
|
| 47 |
+
st.info("👈 Please enter your Hugging Face token in the sidebar to get started")
|
| 48 |
+
|
| 49 |
+
with st.expander("ℹ️ How to get a Hugging Face token", expanded=True):
|
| 50 |
+
st.markdown("""
|
| 51 |
+
1. Go to [Hugging Face](https://huggingface.co/settings/tokens)
|
| 52 |
+
2. Sign in or create an account
|
| 53 |
+
3. Create a new token with write access
|
| 54 |
+
4. Copy and paste the token in the sidebar
|
| 55 |
+
""")
|
| 56 |
+
|
| 57 |
+
st.divider()
|
| 58 |
+
|
| 59 |
+
st.subheader("What you can do with this app:")
|
| 60 |
+
|
| 61 |
+
st.markdown("""
|
| 62 |
+
### 📝 Simple Fine-tuning Process
|
| 63 |
+
|
| 64 |
+
This app provides a straightforward interface for fine-tuning Gemma models with your own data:
|
| 65 |
+
""")
|
| 66 |
+
|
| 67 |
+
cols = st.columns(2)
|
| 68 |
+
|
| 69 |
+
with cols[0]:
|
| 70 |
+
st.markdown("""
|
| 71 |
+
✅ **Upload your dataset**
|
| 72 |
+
- Support for CSV and JSON/JSONL formats
|
| 73 |
+
- Manual input option for small datasets
|
| 74 |
+
- Automatic preprocessing for Gemma format
|
| 75 |
+
|
| 76 |
+
✅ **Configure Gemma model parameters**
|
| 77 |
+
- Choose between Gemma 2B and 7B models
|
| 78 |
+
- Adjust learning rate, batch size, and epochs
|
| 79 |
+
- LoRA parameter configuration
|
| 80 |
+
""")
|
| 81 |
+
|
| 82 |
+
with cols[1]:
|
| 83 |
+
st.markdown("""
|
| 84 |
+
✅ **Monitor training progress**
|
| 85 |
+
- Visual training progress tracking
|
| 86 |
+
- Loss curve visualization
|
| 87 |
+
- Training logs and status updates
|
| 88 |
+
|
| 89 |
+
✅ **Evaluate and use your model**
|
| 90 |
+
- Interactive testing interface
|
| 91 |
+
- Export options for deployment
|
| 92 |
+
- Usage examples with code snippets
|
| 93 |
+
""")
|
| 94 |
+
|
| 95 |
+
else:
|
| 96 |
+
st.success("You're all set up! Follow these steps to fine-tune your Gemma model:")
|
| 97 |
+
|
| 98 |
+
# Simple step-by-step guide
|
| 99 |
+
col1, col2 = st.columns(2)
|
| 100 |
+
|
| 101 |
+
with col1:
|
| 102 |
+
st.markdown("### Step 1: Prepare Your Dataset")
|
| 103 |
+
st.markdown("""
|
| 104 |
+
- Upload your dataset in CSV or JSONL format
|
| 105 |
+
- Ensure your data has prompt/instruction and response columns
|
| 106 |
+
- The app will preprocess the data into the right format for Gemma
|
| 107 |
+
""")
|
| 108 |
+
st.page_link("pages/01_Dataset_Upload.py", label="Go to Dataset Upload", icon="📤")
|
| 109 |
+
|
| 110 |
+
with col2:
|
| 111 |
+
st.markdown("### Step 2: Configure Your Model")
|
| 112 |
+
st.markdown("""
|
| 113 |
+
- Select either Gemma 2B (faster) or 7B (more powerful)
|
| 114 |
+
- Adjust hyperparameters based on your needs
|
| 115 |
+
- Basic configurations work well for most use cases
|
| 116 |
+
""")
|
| 117 |
+
st.page_link("pages/02_Model_Configuration.py", label="Go to Model Configuration", icon="⚙️")
|
| 118 |
+
|
| 119 |
+
st.divider()
|
| 120 |
+
|
| 121 |
+
col3, col4 = st.columns(2)
|
| 122 |
+
|
| 123 |
+
with col3:
|
| 124 |
+
st.markdown("### Step 3: Train Your Model")
|
| 125 |
+
st.markdown("""
|
| 126 |
+
- Start the training process
|
| 127 |
+
- Monitor progress with real-time updates
|
| 128 |
+
- Training a model may take time depending on your dataset size
|
| 129 |
+
""")
|
| 130 |
+
st.page_link("pages/03_Training_Monitor.py", label="Go to Training Monitor", icon="📊")
|
| 131 |
+
|
| 132 |
+
with col4:
|
| 133 |
+
st.markdown("### Step 4: Evaluate & Use Your Model")
|
| 134 |
+
st.markdown("""
|
| 135 |
+
- Test your model with custom prompts
|
| 136 |
+
- Compare results with the base model
|
| 137 |
+
- Export your model for use in applications
|
| 138 |
+
""")
|
| 139 |
+
st.page_link("pages/04_Evaluation.py", label="Go to Model Evaluation", icon="🔍")
|
| 140 |
+
|
| 141 |
+
# Notes about CPU limitations
|
| 142 |
+
st.info("""
|
| 143 |
+
**Note on Training Limitations**: This app is running on CPU resources (2vCPU, 16GB RAM), which means:
|
| 144 |
+
|
| 145 |
+
- For actual training, we use Parameter-Efficient Fine-Tuning (PEFT) with LoRA to reduce memory requirements
|
| 146 |
+
- Training will be slower than on GPU hardware
|
| 147 |
+
- For very large datasets, consider using this interface to prepare your data and configuration,
|
| 148 |
+
then download the config to run training on more powerful hardware
|
| 149 |
+
""")
|
app.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Run the Streamlit app
|
| 4 |
+
streamlit run app.py
|
data/sample_dataset.jsonl
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"prompt": "Explain the concept of machine learning in simple terms.", "response": "Machine learning is like teaching a computer to learn from examples rather than programming it with specific rules. Imagine showing a child many pictures of cats and dogs until they can recognize the difference. Similarly, we show computers lots of data examples, and they learn patterns to make predictions or decisions without being explicitly programmed for each scenario."}
|
| 2 |
+
{"prompt": "What is fine-tuning in the context of language models?", "response": "Fine-tuning is the process of taking a pre-trained language model and further training it on a specific dataset for a particular task. It's like taking a general-purpose tool and customizing it for a specialized job. The model keeps its general knowledge but becomes better at the specific task you want it to perform, like answering medical questions or writing in a particular style."}
|
| 3 |
+
{"prompt": "How does a neural network work?", "response": "A neural network is a computing system inspired by the human brain. It consists of interconnected nodes (neurons) organized in layers. Each connection has a weight that adjusts as the network learns. The network processes input data through these layers, with each neuron applying a mathematical function to its inputs and passing the result to the next layer. Through a process called backpropagation, the network adjusts its weights based on errors in its output, gradually improving its performance on a task."}
|
| 4 |
+
{"prompt": "Write a short poem about artificial intelligence.", "response": "Silicon dreams in digital streams,\nThinking machines with invisible seams.\nLearning patterns humans can't see,\nArtificial minds, what will they be?\n\nVast knowledge stored in neural webs,\nAnswering questions as curiosity ebbs.\nPartner or servant or something more—\nAI's potential we've yet to explore."}
|
| 5 |
+
{"prompt": "Summarize the concept of transfer learning.", "response": "Transfer learning is a machine learning technique where a model developed for one task is reused as the starting point for a model on a second task. It's like learning to ride a bicycle makes it easier to learn to ride a motorcycle—the skills transfer. In practice, this means taking a pre-trained model (often trained on a large dataset) and fine-tuning it for your specific task with a smaller dataset, saving computational resources and time while achieving good performance."}
|
| 6 |
+
{"prompt": "How would you explain deep learning to a 10-year-old?", "response": "Deep learning is like teaching a computer to learn by itself, similar to how you learn to recognize your friends' faces. Imagine you have a robot that you want to teach to identify dogs. Instead of telling it exactly what a dog looks like (pointy ears, tail, four legs), you show it thousands of pictures of dogs. The robot has layers of simple learning parts, like building blocks stacked together. Each layer learns something different - one might learn edges, another might learn shapes, and so on. After seeing enough pictures, the robot can recognize dogs on its own, even dogs it's never seen before!"}
|
| 7 |
+
{"prompt": "What are the ethical concerns with AI development?", "response": "Ethical concerns in AI development include: 1) Bias and fairness issues when AI systems reflect or amplify human biases in their training data, 2) Privacy concerns as AI systems collect and process vast amounts of personal data, 3) Accountability questions about who is responsible when AI makes harmful decisions, 4) Transparency problems when AI systems function as 'black boxes' with unexplainable decisions, 5) Job displacement as automation changes employment landscapes, and 6) Safety and control concerns about ensuring AI systems remain aligned with human values as they become more capable."}
|
| 8 |
+
{"prompt": "Compare supervised and unsupervised learning.", "response": "Supervised learning is like learning with a teacher. The algorithm is trained on labeled data, where each example has input features and the correct output. It learns to map inputs to outputs based on these examples, like predicting house prices from features like size and location. Unsupervised learning is like learning without a teacher. The algorithm finds patterns in unlabeled data without specific guidance. It might group similar items together (clustering) or find ways to represent data more efficiently. Supervised learning is used for classification and regression tasks, while unsupervised learning is used for clustering, association, and dimensionality reduction."}
|
| 9 |
+
{"prompt": "Explain how language models like Gemma work.", "response": "Language models like Gemma work by predicting the next word in a sequence based on the context of previous words. They're trained on massive text datasets to understand patterns in language. These models use transformer architecture with attention mechanisms that help them focus on relevant parts of the input when making predictions. The model processes text as tokens (pieces of words or characters) and maps these to high-dimensional vector representations that capture semantic meaning. During training, the model adjusts its billions of parameters to minimize prediction errors. When you prompt the model, it uses its learned patterns to generate text that continues your input in a coherent and contextually appropriate way."}
|
| 10 |
+
{"prompt": "What is the difference between AI, machine learning, and deep learning?", "response": "These terms represent nested concepts. Artificial Intelligence (AI) is the broadest, referring to any technique that enables computers to mimic human intelligence. Machine Learning is a subset of AI where systems learn from data without being explicitly programmed, improving automatically through experience. Deep Learning is a specialized subset of machine learning that uses neural networks with many layers (hence 'deep'). Think of it like this: AI is like the entire field of transportation, machine learning is like motorized vehicles specifically, and deep learning is like self-driving cars – a specialized, advanced implementation."}
|
packages.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libgl1-mesa-glx
|
| 2 |
+
gcc
|
| 3 |
+
g++
|
pages/01_Dataset_Upload.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from utils.ui import set_page_style
|
| 7 |
+
from utils.auth import check_hf_token
|
| 8 |
+
from utils.huggingface import create_dataset_repo, upload_dataset_to_hub, preprocess_dataset
|
| 9 |
+
|
| 10 |
+
# Set page configuration
|
| 11 |
+
st.set_page_config(
|
| 12 |
+
page_title="Dataset Upload - Gemma Fine-tuning",
|
| 13 |
+
page_icon="🤖",
|
| 14 |
+
layout="wide"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Apply custom styling
|
| 18 |
+
set_page_style()
|
| 19 |
+
|
| 20 |
+
# Sidebar for authentication
|
| 21 |
+
with st.sidebar:
|
| 22 |
+
st.title("🤖 Gemma Fine-tuning")
|
| 23 |
+
st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
|
| 24 |
+
use_column_width=True)
|
| 25 |
+
|
| 26 |
+
# Authentication section
|
| 27 |
+
st.subheader("🔑 Authentication")
|
| 28 |
+
hf_token = st.text_input("Hugging Face API Token", type="password",
|
| 29 |
+
help="Enter your Hugging Face write token to enable model fine-tuning")
|
| 30 |
+
auth_status = check_hf_token(hf_token) if hf_token else False
|
| 31 |
+
|
| 32 |
+
if auth_status:
|
| 33 |
+
st.success("Authenticated successfully!")
|
| 34 |
+
elif hf_token:
|
| 35 |
+
st.error("Invalid token. Please check and try again.")
|
| 36 |
+
|
| 37 |
+
st.divider()
|
| 38 |
+
st.caption("A simple UI for fine-tuning Gemma models")
|
| 39 |
+
|
| 40 |
+
# Main content
|
| 41 |
+
st.title("📤 Dataset Upload")
|
| 42 |
+
|
| 43 |
+
if not hf_token or not auth_status:
|
| 44 |
+
st.warning("Please authenticate with your Hugging Face token in the sidebar first")
|
| 45 |
+
st.stop()
|
| 46 |
+
|
| 47 |
+
# Dataset repository settings
|
| 48 |
+
with st.expander("Dataset Repository Settings", expanded=True):
|
| 49 |
+
col1, col2 = st.columns(2)
|
| 50 |
+
|
| 51 |
+
with col1:
|
| 52 |
+
dataset_repo_name = st.text_input(
|
| 53 |
+
"Repository Name",
|
| 54 |
+
value=f"gemma-finetune-dataset-{pd.Timestamp.now().strftime('%Y%m%d')}",
|
| 55 |
+
help="Name of the Hugging Face repository to store your dataset"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
with col2:
|
| 59 |
+
is_private = st.toggle("Private Repository", value=True,
|
| 60 |
+
help="Make your dataset repository private (recommended)")
|
| 61 |
+
|
| 62 |
+
# Dataset upload section
|
| 63 |
+
st.header("Upload Your Dataset")
|
| 64 |
+
|
| 65 |
+
dataset_format = st.radio(
|
| 66 |
+
"Select dataset format",
|
| 67 |
+
options=["CSV", "JSON/JSONL", "Manual Input"],
|
| 68 |
+
horizontal=True
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Session state to store dataset
|
| 72 |
+
if "dataset" not in st.session_state:
|
| 73 |
+
st.session_state.dataset = None
|
| 74 |
+
if "dataset_preview" not in st.session_state:
|
| 75 |
+
st.session_state.dataset_preview = None
|
| 76 |
+
|
| 77 |
+
# Upload handlers
|
| 78 |
+
if dataset_format == "CSV":
|
| 79 |
+
uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
|
| 80 |
+
|
| 81 |
+
if uploaded_file:
|
| 82 |
+
try:
|
| 83 |
+
# Read the file
|
| 84 |
+
df = pd.read_csv(uploaded_file)
|
| 85 |
+
st.session_state.dataset = df
|
| 86 |
+
st.session_state.dataset_preview = df.head(5)
|
| 87 |
+
except Exception as e:
|
| 88 |
+
st.error(f"Error reading CSV file: {str(e)}")
|
| 89 |
+
|
| 90 |
+
elif dataset_format == "JSON/JSONL":
|
| 91 |
+
uploaded_file = st.file_uploader("Upload a JSON/JSONL file", type=["json", "jsonl"])
|
| 92 |
+
|
| 93 |
+
if uploaded_file:
|
| 94 |
+
try:
|
| 95 |
+
# Read file contents
|
| 96 |
+
content = uploaded_file.read()
|
| 97 |
+
|
| 98 |
+
# Try to parse as JSON array first
|
| 99 |
+
try:
|
| 100 |
+
data = json.loads(content)
|
| 101 |
+
if isinstance(data, list):
|
| 102 |
+
df = pd.DataFrame(data)
|
| 103 |
+
else:
|
| 104 |
+
df = pd.DataFrame([data])
|
| 105 |
+
except:
|
| 106 |
+
# Try to parse as JSONL
|
| 107 |
+
try:
|
| 108 |
+
content = io.StringIO(content.decode("utf-8"))
|
| 109 |
+
df = pd.read_json(content, lines=True)
|
| 110 |
+
except:
|
| 111 |
+
raise ValueError("Invalid JSON/JSONL format")
|
| 112 |
+
|
| 113 |
+
st.session_state.dataset = df
|
| 114 |
+
st.session_state.dataset_preview = df.head(5)
|
| 115 |
+
except Exception as e:
|
| 116 |
+
st.error(f"Error reading JSON/JSONL file: {str(e)}")
|
| 117 |
+
|
| 118 |
+
elif dataset_format == "Manual Input":
|
| 119 |
+
st.info("Enter your training examples manually:")
|
| 120 |
+
|
| 121 |
+
num_examples = st.number_input("Number of examples", min_value=1, max_value=20, value=3)
|
| 122 |
+
examples = []
|
| 123 |
+
|
| 124 |
+
for i in range(num_examples):
|
| 125 |
+
st.subheader(f"Example {i+1}")
|
| 126 |
+
col1, col2 = st.columns(2)
|
| 127 |
+
|
| 128 |
+
with col1:
|
| 129 |
+
prompt = st.text_area(f"Prompt/Instruction {i+1}", height=100)
|
| 130 |
+
|
| 131 |
+
with col2:
|
| 132 |
+
response = st.text_area(f"Response {i+1}", height=100)
|
| 133 |
+
|
| 134 |
+
if prompt and response:
|
| 135 |
+
examples.append({"prompt": prompt, "response": response})
|
| 136 |
+
|
| 137 |
+
if examples:
|
| 138 |
+
df = pd.DataFrame(examples)
|
| 139 |
+
st.session_state.dataset = df
|
| 140 |
+
st.session_state.dataset_preview = df
|
| 141 |
+
|
| 142 |
+
# Display dataset preview if available
|
| 143 |
+
if st.session_state.dataset_preview is not None:
|
| 144 |
+
st.subheader("Dataset Preview")
|
| 145 |
+
st.dataframe(st.session_state.dataset_preview, use_container_width=True)
|
| 146 |
+
|
| 147 |
+
# Column mapping
|
| 148 |
+
st.subheader("Column Mapping")
|
| 149 |
+
st.info("Select which columns contain prompts/instructions and responses")
|
| 150 |
+
|
| 151 |
+
columns = st.session_state.dataset.columns.tolist()
|
| 152 |
+
|
| 153 |
+
col1, col2 = st.columns(2)
|
| 154 |
+
|
| 155 |
+
with col1:
|
| 156 |
+
prompt_column = st.selectbox("Prompt/Instruction Column", options=columns,
|
| 157 |
+
index=columns.index("prompt") if "prompt" in columns else 0)
|
| 158 |
+
|
| 159 |
+
with col2:
|
| 160 |
+
response_column = st.selectbox("Response Column", options=columns,
|
| 161 |
+
index=columns.index("response") if "response" in columns else min(1, len(columns)-1))
|
| 162 |
+
|
| 163 |
+
# Dataset statistics
|
| 164 |
+
st.subheader("Dataset Statistics")
|
| 165 |
+
|
| 166 |
+
stats_col1, stats_col2, stats_col3 = st.columns(3)
|
| 167 |
+
|
| 168 |
+
with stats_col1:
|
| 169 |
+
st.metric("Total Examples", len(st.session_state.dataset))
|
| 170 |
+
|
| 171 |
+
with stats_col2:
|
| 172 |
+
avg_prompt_len = st.session_state.dataset[prompt_column].str.len().mean().round(1)
|
| 173 |
+
st.metric("Avg. Prompt Length", f"{avg_prompt_len} chars")
|
| 174 |
+
|
| 175 |
+
with stats_col3:
|
| 176 |
+
avg_response_len = st.session_state.dataset[response_column].str.len().mean().round(1)
|
| 177 |
+
st.metric("Avg. Response Length", f"{avg_response_len} chars")
|
| 178 |
+
|
| 179 |
+
# Upload button
|
| 180 |
+
st.subheader("Upload to Hugging Face")
|
| 181 |
+
|
| 182 |
+
if st.button("Process and Upload Dataset", type="primary"):
|
| 183 |
+
with st.spinner("Processing dataset..."):
|
| 184 |
+
# Preprocess the dataset into the right format for Gemma
|
| 185 |
+
try:
|
| 186 |
+
processed_df = preprocess_dataset(
|
| 187 |
+
st.session_state.dataset,
|
| 188 |
+
prompt_column=prompt_column,
|
| 189 |
+
response_column=response_column
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Create repository
|
| 193 |
+
success, result = create_dataset_repo(dataset_repo_name, private=is_private)
|
| 194 |
+
|
| 195 |
+
if not success:
|
| 196 |
+
st.error(f"Failed to create repository: {result}")
|
| 197 |
+
else:
|
| 198 |
+
st.success(f"Created repository: {dataset_repo_name}")
|
| 199 |
+
|
| 200 |
+
# Upload dataset
|
| 201 |
+
success, result = upload_dataset_to_hub(
|
| 202 |
+
processed_df,
|
| 203 |
+
"train.jsonl",
|
| 204 |
+
dataset_repo_name
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if success:
|
| 208 |
+
st.session_state["dataset_repo"] = dataset_repo_name
|
| 209 |
+
st.success(f"Dataset uploaded successfully! You can now proceed to model configuration.")
|
| 210 |
+
|
| 211 |
+
# Next steps button
|
| 212 |
+
st.page_link("pages/02_Model_Configuration.py", label="Next: Configure Model", icon="⚙️")
|
| 213 |
+
else:
|
| 214 |
+
st.error(f"Failed to upload dataset: {result}")
|
| 215 |
+
except Exception as e:
|
| 216 |
+
st.error(f"Error processing dataset: {str(e)}")
|
| 217 |
+
else:
|
| 218 |
+
st.info("Please upload or input your dataset above")
|
pages/02_Model_Configuration.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from utils.ui import set_page_style
|
| 4 |
+
from utils.auth import check_hf_token
|
| 5 |
+
from utils.huggingface import prepare_training_config
|
| 6 |
+
from utils.training import create_model_repo, upload_training_config, setup_training_script
|
| 7 |
+
|
| 8 |
+
# Set page configuration
|
| 9 |
+
st.set_page_config(
|
| 10 |
+
page_title="Model Configuration - Gemma Fine-tuning",
|
| 11 |
+
page_icon="🤖",
|
| 12 |
+
layout="wide"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# Apply custom styling
|
| 16 |
+
set_page_style()
|
| 17 |
+
|
| 18 |
+
# Sidebar for authentication
|
| 19 |
+
with st.sidebar:
|
| 20 |
+
st.title("🤖 Gemma Fine-tuning")
|
| 21 |
+
st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
|
| 22 |
+
use_column_width=True)
|
| 23 |
+
|
| 24 |
+
# Authentication section
|
| 25 |
+
st.subheader("🔑 Authentication")
|
| 26 |
+
hf_token = st.text_input("Hugging Face API Token", type="password",
|
| 27 |
+
help="Enter your Hugging Face write token to enable model fine-tuning")
|
| 28 |
+
auth_status = check_hf_token(hf_token) if hf_token else False
|
| 29 |
+
|
| 30 |
+
if auth_status:
|
| 31 |
+
st.success("Authenticated successfully!")
|
| 32 |
+
elif hf_token:
|
| 33 |
+
st.error("Invalid token. Please check and try again.")
|
| 34 |
+
|
| 35 |
+
st.divider()
|
| 36 |
+
st.caption("A simple UI for fine-tuning Gemma models")
|
| 37 |
+
|
| 38 |
+
# Main content
|
| 39 |
+
st.title("⚙️ Model Configuration")
|
| 40 |
+
|
| 41 |
+
if not hf_token or not auth_status:
|
| 42 |
+
st.warning("Please authenticate with your Hugging Face token in the sidebar first")
|
| 43 |
+
st.stop()
|
| 44 |
+
|
| 45 |
+
# Check if dataset repository is set
|
| 46 |
+
if "dataset_repo" not in st.session_state:
|
| 47 |
+
st.warning("Please upload a dataset first")
|
| 48 |
+
st.page_link("pages/01_Dataset_Upload.py", label="Go to Dataset Upload", icon="📤")
|
| 49 |
+
|
| 50 |
+
# For testing purposes, allow manual entry
|
| 51 |
+
manual_repo = st.text_input("Or enter dataset repository name manually:")
|
| 52 |
+
if manual_repo:
|
| 53 |
+
st.session_state["dataset_repo"] = manual_repo
|
| 54 |
+
else:
|
| 55 |
+
st.stop()
|
| 56 |
+
|
| 57 |
+
# Model configuration
|
| 58 |
+
st.header("1. Select Gemma Model")
|
| 59 |
+
|
| 60 |
+
model_version = st.radio(
|
| 61 |
+
"Select Gemma model version",
|
| 62 |
+
options=["google/gemma-2b", "google/gemma-7b"],
|
| 63 |
+
horizontal=True,
|
| 64 |
+
help="Choose the Gemma model size. 2B is faster, 7B is more capable."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
st.header("2. Training Configuration")
|
| 68 |
+
|
| 69 |
+
# Output repository settings
|
| 70 |
+
with st.expander("Output Repository Settings", expanded=True):
|
| 71 |
+
col1, col2 = st.columns(2)
|
| 72 |
+
|
| 73 |
+
with col1:
|
| 74 |
+
output_repo_name = st.text_input(
|
| 75 |
+
"Repository Name",
|
| 76 |
+
value=f"gemma-finetuned-{pd.Timestamp.now().strftime('%Y%m%d')}",
|
| 77 |
+
help="Name of the Hugging Face repository to store your fine-tuned model"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
with col2:
|
| 81 |
+
is_private = st.toggle("Private Repository", value=True,
|
| 82 |
+
help="Make your model repository private (recommended)")
|
| 83 |
+
|
| 84 |
+
# Tabs for Basic and Advanced configuration
|
| 85 |
+
basic_tab, advanced_tab = st.tabs(["Basic Configuration", "Advanced Configuration"])
|
| 86 |
+
|
| 87 |
+
with basic_tab:
|
| 88 |
+
col1, col2 = st.columns(2)
|
| 89 |
+
|
| 90 |
+
with col1:
|
| 91 |
+
epochs = st.slider("Number of Epochs", min_value=1, max_value=10, value=3,
|
| 92 |
+
help="Number of complete passes through the dataset")
|
| 93 |
+
|
| 94 |
+
batch_size = st.slider("Batch Size", min_value=1, max_value=32, value=8,
|
| 95 |
+
help="Number of examples processed together")
|
| 96 |
+
|
| 97 |
+
with col2:
|
| 98 |
+
learning_rate = st.number_input("Learning Rate", min_value=1e-6, max_value=1e-3,
|
| 99 |
+
value=2e-5, format="%.1e",
|
| 100 |
+
help="Step size for gradient updates")
|
| 101 |
+
|
| 102 |
+
fp16_training = st.toggle("Use FP16 Training", value=True,
|
| 103 |
+
help="Use 16-bit precision for faster training")
|
| 104 |
+
|
| 105 |
+
with advanced_tab:
|
| 106 |
+
col1, col2 = st.columns(2)
|
| 107 |
+
|
| 108 |
+
with col1:
|
| 109 |
+
lora_rank = st.slider("LoRA Rank", min_value=1, max_value=64, value=8,
|
| 110 |
+
help="Rank of the LoRA matrices")
|
| 111 |
+
|
| 112 |
+
lora_alpha = st.slider("LoRA Alpha", min_value=1, max_value=128, value=32,
|
| 113 |
+
help="Scaling factor for LoRA")
|
| 114 |
+
|
| 115 |
+
lora_dropout = st.slider("LoRA Dropout", min_value=0.0, max_value=0.5, value=0.05, step=0.01,
|
| 116 |
+
help="Dropout probability for LoRA layers")
|
| 117 |
+
|
| 118 |
+
with col2:
|
| 119 |
+
weight_decay = st.number_input("Weight Decay", min_value=0.0, max_value=0.1,
|
| 120 |
+
value=0.01, step=0.001,
|
| 121 |
+
help="L2 regularization strength")
|
| 122 |
+
|
| 123 |
+
gradient_accumulation = st.slider("Gradient Accumulation Steps", min_value=1, max_value=16, value=1,
|
| 124 |
+
help="Number of steps to accumulate gradients")
|
| 125 |
+
|
| 126 |
+
warmup_steps = st.slider("Warmup Steps", min_value=0, max_value=500, value=0,
|
| 127 |
+
help="Steps of linear learning rate warmup")
|
| 128 |
+
|
| 129 |
+
st.header("3. Training Summary")
|
| 130 |
+
|
| 131 |
+
# Create hyperparameters dictionary
|
| 132 |
+
hyperparams = {
|
| 133 |
+
"epochs": epochs,
|
| 134 |
+
"batch_size": batch_size,
|
| 135 |
+
"learning_rate": learning_rate,
|
| 136 |
+
"fp16": fp16_training,
|
| 137 |
+
"lora_rank": lora_rank,
|
| 138 |
+
"lora_alpha": lora_alpha,
|
| 139 |
+
"lora_dropout": lora_dropout,
|
| 140 |
+
"weight_decay": weight_decay,
|
| 141 |
+
"gradient_accumulation": gradient_accumulation,
|
| 142 |
+
"warmup_steps": warmup_steps,
|
| 143 |
+
"max_steps": -1, # -1 means train for full epochs
|
| 144 |
+
"max_grad_norm": 1.0
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
# Display summary
|
| 148 |
+
col1, col2 = st.columns(2)
|
| 149 |
+
|
| 150 |
+
with col1:
|
| 151 |
+
st.subheader("Selected Model")
|
| 152 |
+
st.info(model_version)
|
| 153 |
+
|
| 154 |
+
st.subheader("Dataset Repository")
|
| 155 |
+
st.info(st.session_state["dataset_repo"])
|
| 156 |
+
|
| 157 |
+
with col2:
|
| 158 |
+
st.subheader("Key Hyperparameters")
|
| 159 |
+
st.write(f"Epochs: {epochs}")
|
| 160 |
+
st.write(f"Batch Size: {batch_size}")
|
| 161 |
+
st.write(f"Learning Rate: {learning_rate}")
|
| 162 |
+
st.write(f"LoRA Rank: {lora_rank}")
|
| 163 |
+
|
| 164 |
+
# Start training button
|
| 165 |
+
st.header("4. Start Training")
|
| 166 |
+
|
| 167 |
+
if st.button("Prepare and Launch Training", type="primary"):
|
| 168 |
+
with st.spinner("Setting up training job..."):
|
| 169 |
+
# First create the model repository
|
| 170 |
+
success, repo_url = create_model_repo(output_repo_name, private=is_private)
|
| 171 |
+
|
| 172 |
+
if not success:
|
| 173 |
+
st.error(f"Failed to create model repository: {repo_url}")
|
| 174 |
+
else:
|
| 175 |
+
st.success(f"Created model repository: {output_repo_name}")
|
| 176 |
+
|
| 177 |
+
# Prepare training configuration
|
| 178 |
+
config = prepare_training_config(
|
| 179 |
+
model_name=model_version,
|
| 180 |
+
hyperparams=hyperparams,
|
| 181 |
+
dataset_repo=st.session_state["dataset_repo"],
|
| 182 |
+
output_repo=output_repo_name
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Upload training configuration
|
| 186 |
+
success, message = upload_training_config(config, output_repo_name)
|
| 187 |
+
|
| 188 |
+
if not success:
|
| 189 |
+
st.error(f"Failed to upload training configuration: {message}")
|
| 190 |
+
else:
|
| 191 |
+
st.success("Uploaded training configuration")
|
| 192 |
+
|
| 193 |
+
# Setup training script
|
| 194 |
+
success, message = setup_training_script(output_repo_name, config)
|
| 195 |
+
|
| 196 |
+
if not success:
|
| 197 |
+
st.error(f"Failed to setup training script: {message}")
|
| 198 |
+
else:
|
| 199 |
+
st.success("Uploaded training script")
|
| 200 |
+
|
| 201 |
+
# Store in session state
|
| 202 |
+
st.session_state["model_repo"] = output_repo_name
|
| 203 |
+
st.session_state["model_version"] = model_version
|
| 204 |
+
st.session_state["hyperparams"] = hyperparams
|
| 205 |
+
|
| 206 |
+
# Success message and next step
|
| 207 |
+
st.success("Training job prepared successfully! You can now monitor the training progress.")
|
| 208 |
+
st.page_link("pages/03_Training_Monitor.py", label="Next: Monitor Training", icon="📊")
|
pages/03_Training_Monitor.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
from utils.ui import set_page_style, display_metric
|
| 8 |
+
from utils.auth import check_hf_token
|
| 9 |
+
from utils.training import simulate_training_progress
|
| 10 |
+
|
| 11 |
+
# Set page configuration
|
| 12 |
+
st.set_page_config(
|
| 13 |
+
page_title="Training Monitor - Gemma Fine-tuning",
|
| 14 |
+
page_icon="🤖",
|
| 15 |
+
layout="wide"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Apply custom styling
|
| 19 |
+
set_page_style()
|
| 20 |
+
|
| 21 |
+
# Sidebar for authentication
|
| 22 |
+
with st.sidebar:
|
| 23 |
+
st.title("🤖 Gemma Fine-tuning")
|
| 24 |
+
st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
|
| 25 |
+
use_column_width=True)
|
| 26 |
+
|
| 27 |
+
# Authentication section
|
| 28 |
+
st.subheader("🔑 Authentication")
|
| 29 |
+
hf_token = st.text_input("Hugging Face API Token", type="password",
|
| 30 |
+
help="Enter your Hugging Face write token to enable model fine-tuning")
|
| 31 |
+
auth_status = check_hf_token(hf_token) if hf_token else False
|
| 32 |
+
|
| 33 |
+
if auth_status:
|
| 34 |
+
st.success("Authenticated successfully!")
|
| 35 |
+
elif hf_token:
|
| 36 |
+
st.error("Invalid token. Please check and try again.")
|
| 37 |
+
|
| 38 |
+
st.divider()
|
| 39 |
+
st.caption("A simple UI for fine-tuning Gemma models")
|
| 40 |
+
|
| 41 |
+
# Main content
|
| 42 |
+
st.title("📊 Training Monitor")
|
| 43 |
+
|
| 44 |
+
if not hf_token or not auth_status:
|
| 45 |
+
st.warning("Please authenticate with your Hugging Face token in the sidebar first")
|
| 46 |
+
st.stop()
|
| 47 |
+
|
| 48 |
+
# Check if training was started
|
| 49 |
+
if "model_repo" not in st.session_state:
|
| 50 |
+
st.warning("No active training jobs found")
|
| 51 |
+
st.page_link("pages/02_Model_Configuration.py", label="Go to Model Configuration", icon="⚙️")
|
| 52 |
+
|
| 53 |
+
# For testing purposes, allow manual entry
|
| 54 |
+
manual_repo = st.text_input("Or enter model repository name manually:")
|
| 55 |
+
if manual_repo:
|
| 56 |
+
st.session_state["model_repo"] = manual_repo
|
| 57 |
+
st.session_state["model_version"] = "google/gemma-2b" # Default
|
| 58 |
+
else:
|
| 59 |
+
st.stop()
|
| 60 |
+
|
| 61 |
+
# Training information
|
| 62 |
+
st.header("Training Information")
|
| 63 |
+
|
| 64 |
+
col1, col2 = st.columns(2)
|
| 65 |
+
|
| 66 |
+
with col1:
|
| 67 |
+
st.subheader("Model Repository")
|
| 68 |
+
st.info(st.session_state["model_repo"])
|
| 69 |
+
|
| 70 |
+
st.subheader("Base Model")
|
| 71 |
+
st.info(st.session_state.get("model_version", "google/gemma-2b"))
|
| 72 |
+
|
| 73 |
+
with col2:
|
| 74 |
+
# For demo purposes, create a button to start simulated training
|
| 75 |
+
if "training_started" not in st.session_state:
|
| 76 |
+
st.subheader("Start Training")
|
| 77 |
+
if st.button("Launch Training Job", type="primary"):
|
| 78 |
+
st.session_state["training_started"] = True
|
| 79 |
+
st.experimental_rerun()
|
| 80 |
+
else:
|
| 81 |
+
st.subheader("Status")
|
| 82 |
+
st.success("Training in Progress")
|
| 83 |
+
|
| 84 |
+
# Simulate a cancel button
|
| 85 |
+
if st.button("Cancel Training Job", type="secondary"):
|
| 86 |
+
st.warning("This is a simulation - in a real environment, this would cancel the training job")
|
| 87 |
+
|
| 88 |
+
# If training has started, show the progress
|
| 89 |
+
if st.session_state.get("training_started", False):
|
| 90 |
+
st.header("Training Progress")
|
| 91 |
+
|
| 92 |
+
# Create a placeholder for the progress bar
|
| 93 |
+
progress_bar = st.progress(0)
|
| 94 |
+
|
| 95 |
+
# Create placeholder metrics
|
| 96 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 97 |
+
|
| 98 |
+
# Get training progress (simulated for demo)
|
| 99 |
+
progress_data = simulate_training_progress()
|
| 100 |
+
|
| 101 |
+
# Update progress bar
|
| 102 |
+
progress_bar.progress(progress_data["progress"])
|
| 103 |
+
|
| 104 |
+
# Update metrics
|
| 105 |
+
with col1:
|
| 106 |
+
display_metric("Epoch", f"{progress_data['current_epoch'] + 1}/{progress_data['total_epochs']}")
|
| 107 |
+
|
| 108 |
+
with col2:
|
| 109 |
+
display_metric("Loss", f"{progress_data['loss']:.4f}")
|
| 110 |
+
|
| 111 |
+
with col3:
|
| 112 |
+
display_metric("Learning Rate", f"{progress_data['learning_rate']:.1e}")
|
| 113 |
+
|
| 114 |
+
with col4:
|
| 115 |
+
status_text = "Complete" if progress_data["status"] == "completed" else "Running"
|
| 116 |
+
display_metric("Status", status_text)
|
| 117 |
+
|
| 118 |
+
# Create training history visualization
|
| 119 |
+
st.subheader("Training Metrics")
|
| 120 |
+
|
| 121 |
+
# Simulate training history data
|
| 122 |
+
if "training_history" not in st.session_state:
|
| 123 |
+
st.session_state.training_history = []
|
| 124 |
+
|
| 125 |
+
# Add current data point to history if not completed
|
| 126 |
+
if progress_data["status"] != "completed" or len(st.session_state.training_history) == 0:
|
| 127 |
+
st.session_state.training_history.append({
|
| 128 |
+
"epoch": progress_data["current_epoch"],
|
| 129 |
+
"loss": progress_data["loss"],
|
| 130 |
+
"learning_rate": progress_data["learning_rate"],
|
| 131 |
+
"timestamp": time.time()
|
| 132 |
+
})
|
| 133 |
+
|
| 134 |
+
# Convert history to DataFrame
|
| 135 |
+
history_df = pd.DataFrame(st.session_state.training_history)
|
| 136 |
+
|
| 137 |
+
if not history_df.empty and len(history_df) > 1:
|
| 138 |
+
# Create tabs for different visualizations
|
| 139 |
+
loss_tab, lr_tab = st.tabs(["Loss Curve", "Learning Rate"])
|
| 140 |
+
|
| 141 |
+
with loss_tab:
|
| 142 |
+
# Create a Plotly figure for the loss curve
|
| 143 |
+
fig = go.Figure()
|
| 144 |
+
fig.add_trace(go.Scatter(
|
| 145 |
+
x=history_df["epoch"],
|
| 146 |
+
y=history_df["loss"],
|
| 147 |
+
mode='lines+markers',
|
| 148 |
+
name='Training Loss',
|
| 149 |
+
line=dict(color='#FF4B4B', width=3)
|
| 150 |
+
))
|
| 151 |
+
|
| 152 |
+
fig.update_layout(
|
| 153 |
+
title="Training Loss Over Time",
|
| 154 |
+
xaxis_title="Epoch",
|
| 155 |
+
yaxis_title="Loss",
|
| 156 |
+
template="plotly_white",
|
| 157 |
+
height=400
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 161 |
+
|
| 162 |
+
with lr_tab:
|
| 163 |
+
# Create a Plotly figure for the learning rate
|
| 164 |
+
fig = go.Figure()
|
| 165 |
+
fig.add_trace(go.Scatter(
|
| 166 |
+
x=history_df["epoch"],
|
| 167 |
+
y=history_df["learning_rate"],
|
| 168 |
+
mode='lines+markers',
|
| 169 |
+
name='Learning Rate',
|
| 170 |
+
line=dict(color='#0068C9', width=3)
|
| 171 |
+
))
|
| 172 |
+
|
| 173 |
+
fig.update_layout(
|
| 174 |
+
title="Learning Rate Schedule",
|
| 175 |
+
xaxis_title="Epoch",
|
| 176 |
+
yaxis_title="Learning Rate",
|
| 177 |
+
template="plotly_white",
|
| 178 |
+
height=400
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 182 |
+
else:
|
| 183 |
+
st.info("Training metrics will appear here once training progresses")
|
| 184 |
+
|
| 185 |
+
# Training logs
|
| 186 |
+
st.subheader("Training Logs")
|
| 187 |
+
|
| 188 |
+
# Simulate logs
|
| 189 |
+
log_lines = [
|
| 190 |
+
f"[{pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}] Initialized training job",
|
| 191 |
+
f"[{pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}] Loading model: {st.session_state.get('model_version', 'google/gemma-2b')}",
|
| 192 |
+
f"[{pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}] Preparing LoRA configuration"
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
# Add epoch logs based on progress
|
| 196 |
+
current_epoch = progress_data["current_epoch"]
|
| 197 |
+
for epoch in range(min(current_epoch + 1, progress_data["total_epochs"])):
|
| 198 |
+
timestamp = pd.Timestamp.now() - pd.Timedelta(seconds=(current_epoch - epoch) * 60)
|
| 199 |
+
log_lines.append(f"[{timestamp.strftime('%Y-%m-%d %H:%M:%S')}] Epoch {epoch+1}/{progress_data['total_epochs']} started")
|
| 200 |
+
|
| 201 |
+
if epoch < current_epoch:
|
| 202 |
+
# For completed epochs, add completion log
|
| 203 |
+
timestamp = pd.Timestamp.now() - pd.Timedelta(seconds=(current_epoch - epoch - 0.5) * 60)
|
| 204 |
+
sim_loss = max(2.5 - (epoch * 0.5), 0.5)
|
| 205 |
+
log_lines.append(f"[{timestamp.strftime('%Y-%m-%d %H:%M:%S')}] Epoch {epoch+1} completed: loss={sim_loss:.4f}")
|
| 206 |
+
|
| 207 |
+
# Display logs in a scrollable area
|
| 208 |
+
st.code("\n".join(log_lines))
|
| 209 |
+
|
| 210 |
+
# Next steps (only show when training is complete)
|
| 211 |
+
if progress_data["status"] == "completed":
|
| 212 |
+
st.success("Training completed successfully!")
|
| 213 |
+
st.page_link("pages/04_Evaluation.py", label="Next: Evaluate Model", icon="🔍")
|
| 214 |
+
else:
|
| 215 |
+
# Auto-refresh for monitoring
|
| 216 |
+
st.empty()
|
| 217 |
+
time.sleep(2) # Wait for 2 seconds
|
| 218 |
+
st.experimental_rerun()
|
pages/04_Evaluation.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
from utils.ui import set_page_style, create_card
|
| 6 |
+
from utils.auth import check_hf_token
|
| 7 |
+
|
| 8 |
+
# Set page configuration
|
| 9 |
+
st.set_page_config(
|
| 10 |
+
page_title="Model Evaluation - Gemma Fine-tuning",
|
| 11 |
+
page_icon="🤖",
|
| 12 |
+
layout="wide"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# Apply custom styling
|
| 16 |
+
set_page_style()
|
| 17 |
+
|
| 18 |
+
# Sidebar for authentication
|
| 19 |
+
with st.sidebar:
|
| 20 |
+
st.title("🤖 Gemma Fine-tuning")
|
| 21 |
+
st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
|
| 22 |
+
use_column_width=True)
|
| 23 |
+
|
| 24 |
+
# Authentication section
|
| 25 |
+
st.subheader("🔑 Authentication")
|
| 26 |
+
hf_token = st.text_input("Hugging Face API Token", type="password",
|
| 27 |
+
help="Enter your Hugging Face write token to enable model fine-tuning")
|
| 28 |
+
auth_status = check_hf_token(hf_token) if hf_token else False
|
| 29 |
+
|
| 30 |
+
if auth_status:
|
| 31 |
+
st.success("Authenticated successfully!")
|
| 32 |
+
elif hf_token:
|
| 33 |
+
st.error("Invalid token. Please check and try again.")
|
| 34 |
+
|
| 35 |
+
st.divider()
|
| 36 |
+
st.caption("A simple UI for fine-tuning Gemma models")
|
| 37 |
+
|
| 38 |
+
# Main content
|
| 39 |
+
st.title("🔍 Model Evaluation")
|
| 40 |
+
|
| 41 |
+
if not hf_token or not auth_status:
|
| 42 |
+
st.warning("Please authenticate with your Hugging Face token in the sidebar first")
|
| 43 |
+
st.stop()
|
| 44 |
+
|
| 45 |
+
# Check if model is trained
|
| 46 |
+
if "model_repo" not in st.session_state:
|
| 47 |
+
st.warning("No trained model found")
|
| 48 |
+
st.page_link("pages/03_Training_Monitor.py", label="Go to Training Monitor", icon="📊")
|
| 49 |
+
|
| 50 |
+
# For testing purposes, allow manual entry
|
| 51 |
+
manual_repo = st.text_input("Or enter model repository name manually:")
|
| 52 |
+
if manual_repo:
|
| 53 |
+
st.session_state["model_repo"] = manual_repo
|
| 54 |
+
st.session_state["model_version"] = "google/gemma-2b" # Default
|
| 55 |
+
else:
|
| 56 |
+
st.stop()
|
| 57 |
+
|
| 58 |
+
# Model information
|
| 59 |
+
st.header("Model Information")
|
| 60 |
+
|
| 61 |
+
col1, col2 = st.columns(2)
|
| 62 |
+
|
| 63 |
+
with col1:
|
| 64 |
+
st.subheader("Fine-tuned Model")
|
| 65 |
+
username = st.session_state.get("hf_username", "user")
|
| 66 |
+
model_id = f"{username}/{st.session_state['model_repo']}"
|
| 67 |
+
st.info(model_id)
|
| 68 |
+
|
| 69 |
+
st.markdown(f"[View on Hugging Face 🔗](https://huggingface.co/{model_id})")
|
| 70 |
+
|
| 71 |
+
with col2:
|
| 72 |
+
st.subheader("Base Model")
|
| 73 |
+
st.info(st.session_state.get("model_version", "google/gemma-2b"))
|
| 74 |
+
|
| 75 |
+
# Interactive testing
|
| 76 |
+
st.header("Interactive Testing")
|
| 77 |
+
|
| 78 |
+
def generate_response(prompt, max_length=100, temperature=0.7):
|
| 79 |
+
"""Simulate generating a response from the fine-tuned model"""
|
| 80 |
+
# In a real implementation, this would call the model API
|
| 81 |
+
|
| 82 |
+
# For demo purposes, simulate a response with a delay
|
| 83 |
+
with st.spinner("Generating response..."):
|
| 84 |
+
# Simulate thinking time
|
| 85 |
+
time.sleep(2)
|
| 86 |
+
|
| 87 |
+
# Generate a simple response based on the prompt for demonstration
|
| 88 |
+
if "hello" in prompt.lower() or "hi" in prompt.lower():
|
| 89 |
+
return "Hello! How can I assist you today?"
|
| 90 |
+
elif "name" in prompt.lower():
|
| 91 |
+
return "I'm a fine-tuned version of Gemma, created to assist you!"
|
| 92 |
+
elif "weather" in prompt.lower():
|
| 93 |
+
return "I don't have real-time data access, but I can tell you about weather patterns and climate information if you'd like."
|
| 94 |
+
elif "help" in prompt.lower():
|
| 95 |
+
return "I'm here to help answer questions, provide information, or assist with tasks. Just let me know what you need!"
|
| 96 |
+
elif len(prompt) < 10:
|
| 97 |
+
return "Could you please provide more details so I can give you a better response?"
|
| 98 |
+
else:
|
| 99 |
+
return f"Based on your request about '{prompt[:20]}...', I've analyzed the content and formulated this response. This is a demonstration of how the fine-tuned model would respond to your input, tailored to the data it was trained on."
|
| 100 |
+
|
| 101 |
+
# Create the interface
|
| 102 |
+
with st.expander("Generation Settings", expanded=False):
|
| 103 |
+
col1, col2 = st.columns(2)
|
| 104 |
+
|
| 105 |
+
with col1:
|
| 106 |
+
max_length = st.slider("Maximum Length", min_value=10, max_value=500, value=150,
|
| 107 |
+
help="Maximum number of tokens to generate")
|
| 108 |
+
|
| 109 |
+
with col2:
|
| 110 |
+
temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1,
|
| 111 |
+
help="Controls randomness: 1.0 is creative, 0.1 is more deterministic")
|
| 112 |
+
|
| 113 |
+
prompt = st.text_area("Enter your prompt:", height=150,
|
| 114 |
+
help="Enter a prompt or instruction for the model to respond to")
|
| 115 |
+
|
| 116 |
+
if st.button("Generate Response", type="primary"):
|
| 117 |
+
if prompt:
|
| 118 |
+
response = generate_response(prompt, max_length, temperature)
|
| 119 |
+
|
| 120 |
+
# Store in history
|
| 121 |
+
if "conversation_history" not in st.session_state:
|
| 122 |
+
st.session_state.conversation_history = []
|
| 123 |
+
|
| 124 |
+
st.session_state.conversation_history.append({
|
| 125 |
+
"prompt": prompt,
|
| 126 |
+
"response": response,
|
| 127 |
+
"timestamp": pd.Timestamp.now().isoformat()
|
| 128 |
+
})
|
| 129 |
+
|
| 130 |
+
# Display response
|
| 131 |
+
st.subheader("Model Response")
|
| 132 |
+
st.markdown(f"<div style='background-color: #f8f9fa; padding: 15px; border-radius: 8px;'>{response}</div>", unsafe_allow_html=True)
|
| 133 |
+
else:
|
| 134 |
+
st.warning("Please enter a prompt first")
|
| 135 |
+
|
| 136 |
+
# Display conversation history
|
| 137 |
+
if "conversation_history" in st.session_state and st.session_state.conversation_history:
|
| 138 |
+
st.header("Conversation History")
|
| 139 |
+
|
| 140 |
+
for i, item in enumerate(reversed(st.session_state.conversation_history)):
|
| 141 |
+
with st.container():
|
| 142 |
+
st.markdown(f"### Conversation {len(st.session_state.conversation_history) - i}")
|
| 143 |
+
|
| 144 |
+
st.markdown("**Prompt:**")
|
| 145 |
+
st.markdown(f"<div style='background-color: #e6f7ff; padding: 10px; border-radius: 8px;'>{item['prompt']}</div>", unsafe_allow_html=True)
|
| 146 |
+
|
| 147 |
+
st.markdown("**Response:**")
|
| 148 |
+
st.markdown(f"<div style='background-color: #f8f9fa; padding: 10px; border-radius: 8px;'>{item['response']}</div>", unsafe_allow_html=True)
|
| 149 |
+
|
| 150 |
+
st.caption(f"Generated at: {pd.Timestamp(item['timestamp']).strftime('%Y-%m-%d %H:%M:%S')}")
|
| 151 |
+
st.divider()
|
| 152 |
+
|
| 153 |
+
# Export section
|
| 154 |
+
st.header("Export Model")
|
| 155 |
+
|
| 156 |
+
export_tab1, export_tab2 = st.tabs(["Export Options", "Usage Guide"])
|
| 157 |
+
|
| 158 |
+
with export_tab1:
|
| 159 |
+
st.subheader("Export Configuration")
|
| 160 |
+
|
| 161 |
+
export_format = st.radio("Export Format", options=["Hugging Face Hub", "ONNX", "TensorFlow Lite", "PyTorch"],
|
| 162 |
+
horizontal=True)
|
| 163 |
+
|
| 164 |
+
if export_format == "Hugging Face Hub":
|
| 165 |
+
st.success("Your model is already available on the Hugging Face Hub!")
|
| 166 |
+
username = st.session_state.get("hf_username", "user")
|
| 167 |
+
model_id = f"{username}/{st.session_state['model_repo']}"
|
| 168 |
+
st.code(f"from transformers import AutoModelForCausalLM, AutoTokenizer\n\nmodel = AutoModelForCausalLM.from_pretrained('{model_id}')\ntokenizer = AutoTokenizer.from_pretrained('{model_id}')")
|
| 169 |
+
else:
|
| 170 |
+
st.info("This export format is under development and will be available soon.")
|
| 171 |
+
|
| 172 |
+
# Show a sample of how to export
|
| 173 |
+
if export_format == "ONNX":
|
| 174 |
+
st.code("""
|
| 175 |
+
# Example code for ONNX export (not functional in this demo)
|
| 176 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 177 |
+
from transformers.onnx import export
|
| 178 |
+
|
| 179 |
+
model_id = "your_username/your_model_repo"
|
| 180 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 181 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 182 |
+
|
| 183 |
+
# Export to ONNX
|
| 184 |
+
export(
|
| 185 |
+
tokenizer=tokenizer,
|
| 186 |
+
model=model,
|
| 187 |
+
output=Path("model.onnx"),
|
| 188 |
+
opset=13
|
| 189 |
+
)
|
| 190 |
+
""")
|
| 191 |
+
elif export_format == "PyTorch":
|
| 192 |
+
st.code("""
|
| 193 |
+
# Example code for PyTorch export (not functional in this demo)
|
| 194 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 195 |
+
|
| 196 |
+
model_id = "your_username/your_model_repo"
|
| 197 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 198 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 199 |
+
|
| 200 |
+
# Save model locally
|
| 201 |
+
model.save_pretrained("./my_exported_model")
|
| 202 |
+
tokenizer.save_pretrained("./my_exported_model")
|
| 203 |
+
""")
|
| 204 |
+
|
| 205 |
+
with export_tab2:
|
| 206 |
+
st.subheader("How to Use Your Model")
|
| 207 |
+
|
| 208 |
+
st.markdown("""
|
| 209 |
+
### Using with Hugging Face Transformers
|
| 210 |
+
|
| 211 |
+
```python
|
| 212 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 213 |
+
|
| 214 |
+
# Replace with your model ID
|
| 215 |
+
model_id = "your_username/your_model_repo"
|
| 216 |
+
|
| 217 |
+
# Load model and tokenizer
|
| 218 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 219 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 220 |
+
|
| 221 |
+
# Generate text
|
| 222 |
+
inputs = tokenizer("What is machine learning?", return_tensors="pt")
|
| 223 |
+
outputs = model.generate(**inputs, max_length=100)
|
| 224 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 225 |
+
print(response)
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
### Deployment Options
|
| 229 |
+
|
| 230 |
+
1. **Hugging Face Inference API** - Easiest option for quick deployment
|
| 231 |
+
2. **Gradio or Streamlit** - For creating interactive demos
|
| 232 |
+
3. **FastAPI or Flask** - For creating backend API services
|
| 233 |
+
4. **Mobile Deployment** - Use TensorFlow Lite or ONNX formats
|
| 234 |
+
""")
|
| 235 |
+
|
| 236 |
+
# Next steps
|
| 237 |
+
st.divider()
|
| 238 |
+
st.subheader("Continue Your Journey")
|
| 239 |
+
|
| 240 |
+
col1, col2 = st.columns(2)
|
| 241 |
+
|
| 242 |
+
with col1:
|
| 243 |
+
st.page_link("pages/01_Dataset_Upload.py", label="Start New Project", icon="🔄")
|
| 244 |
+
|
| 245 |
+
with col2:
|
| 246 |
+
st.write("Finished? Export your model or try it out in the playground above.")
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.30.0
|
| 2 |
+
huggingface_hub>=0.20.0
|
| 3 |
+
pandas>=2.0.0
|
| 4 |
+
matplotlib>=3.7.0
|
| 5 |
+
transformers>=4.35.0
|
| 6 |
+
peft>=0.7.0
|
| 7 |
+
torch>=2.0.0
|
| 8 |
+
bitsandbytes>=0.40.0
|
| 9 |
+
accelerate>=0.20.0
|
| 10 |
+
datasets>=2.10.0
|
| 11 |
+
tensorboard>=2.13.0
|
| 12 |
+
plotly>=5.15.0
|
| 13 |
+
pillow>=10.0.0
|
| 14 |
+
watchdog>=3.0.0
|
| 15 |
+
scipy>=1.10.0
|
utils/auth.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
from huggingface_hub import HfApi, login
|
| 4 |
+
|
| 5 |
+
def check_hf_token(token):
|
| 6 |
+
"""
|
| 7 |
+
Validate Hugging Face token and login if valid
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
token (str): Hugging Face API token
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
bool: True if token is valid, False otherwise
|
| 14 |
+
"""
|
| 15 |
+
try:
|
| 16 |
+
# Set token in environment and session state
|
| 17 |
+
os.environ["HF_TOKEN"] = token
|
| 18 |
+
st.session_state["hf_token"] = token
|
| 19 |
+
|
| 20 |
+
# Try to log in
|
| 21 |
+
login(token=token, add_to_git_credential=False)
|
| 22 |
+
|
| 23 |
+
# Test API access
|
| 24 |
+
api = HfApi(token=token)
|
| 25 |
+
user_info = api.whoami()
|
| 26 |
+
|
| 27 |
+
# Store username in session state
|
| 28 |
+
st.session_state["hf_username"] = user_info["name"] if "name" in user_info else None
|
| 29 |
+
|
| 30 |
+
return True
|
| 31 |
+
except Exception as e:
|
| 32 |
+
st.session_state["hf_token"] = None
|
| 33 |
+
st.session_state["hf_username"] = None
|
| 34 |
+
print(f"Authentication error: {str(e)}")
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
def get_current_user():
|
| 38 |
+
"""
|
| 39 |
+
Get the currently authenticated user's information
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
dict: User information or None if not authenticated
|
| 43 |
+
"""
|
| 44 |
+
if "hf_token" in st.session_state and st.session_state["hf_token"]:
|
| 45 |
+
try:
|
| 46 |
+
api = HfApi(token=st.session_state["hf_token"])
|
| 47 |
+
return api.whoami()
|
| 48 |
+
except:
|
| 49 |
+
return None
|
| 50 |
+
return None
|
utils/huggingface.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import tempfile
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import streamlit as st
|
| 6 |
+
from huggingface_hub import HfApi, upload_file, create_repo
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
|
| 9 |
+
def create_dataset_repo(repo_name, private=True):
|
| 10 |
+
"""
|
| 11 |
+
Create a new dataset repository on Hugging Face Hub
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
repo_name (str): Name of the repository
|
| 15 |
+
private (bool): Whether the repository should be private
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
str: URL of the created repository
|
| 19 |
+
"""
|
| 20 |
+
try:
|
| 21 |
+
token = st.session_state.get("hf_token")
|
| 22 |
+
if not token:
|
| 23 |
+
return False, "No Hugging Face token found"
|
| 24 |
+
|
| 25 |
+
username = st.session_state.get("hf_username", "user")
|
| 26 |
+
full_repo_name = f"{username}/{repo_name}"
|
| 27 |
+
|
| 28 |
+
api = HfApi(token=token)
|
| 29 |
+
repo_url = api.create_repo(
|
| 30 |
+
repo_id=full_repo_name,
|
| 31 |
+
repo_type="dataset",
|
| 32 |
+
private=private,
|
| 33 |
+
exist_ok=True
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return True, repo_url
|
| 37 |
+
except Exception as e:
|
| 38 |
+
return False, str(e)
|
| 39 |
+
|
| 40 |
+
def upload_dataset_to_hub(file_data, file_name, repo_name):
|
| 41 |
+
"""
|
| 42 |
+
Upload a dataset file to Hugging Face Hub
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
file_data (bytes/DataFrame): File content as bytes or a DataFrame
|
| 46 |
+
file_name (str): Name to save the file as
|
| 47 |
+
repo_name (str): Repository to upload to
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
tuple: (success (bool), message (str))
|
| 51 |
+
"""
|
| 52 |
+
try:
|
| 53 |
+
token = st.session_state.get("hf_token")
|
| 54 |
+
if not token:
|
| 55 |
+
return False, "No Hugging Face token found"
|
| 56 |
+
|
| 57 |
+
username = st.session_state.get("hf_username", "user")
|
| 58 |
+
repo_id = f"{username}/{repo_name}"
|
| 59 |
+
|
| 60 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file_name)[1]) as tmp:
|
| 61 |
+
# If it's a DataFrame, save as JSONL
|
| 62 |
+
if isinstance(file_data, pd.DataFrame):
|
| 63 |
+
file_data.to_json(tmp.name, orient="records", lines=True)
|
| 64 |
+
else:
|
| 65 |
+
# Otherwise, assume it's bytes
|
| 66 |
+
tmp.write(file_data)
|
| 67 |
+
|
| 68 |
+
# Upload file to repository
|
| 69 |
+
upload_file(
|
| 70 |
+
path_or_fileobj=tmp.name,
|
| 71 |
+
path_in_repo=file_name,
|
| 72 |
+
repo_id=repo_id,
|
| 73 |
+
token=token,
|
| 74 |
+
repo_type="dataset"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Clean up temporary file
|
| 78 |
+
tmp_name = tmp.name
|
| 79 |
+
|
| 80 |
+
os.unlink(tmp_name)
|
| 81 |
+
return True, f"File uploaded to {repo_id}"
|
| 82 |
+
except Exception as e:
|
| 83 |
+
return False, str(e)
|
| 84 |
+
|
| 85 |
+
def prepare_training_config(model_name, hyperparams, dataset_repo, output_repo):
|
| 86 |
+
"""
|
| 87 |
+
Prepare a training configuration for Gemma fine-tuning
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
model_name (str): Model identifier
|
| 91 |
+
hyperparams (dict): Training hyperparameters
|
| 92 |
+
dataset_repo (str): Dataset repository name
|
| 93 |
+
output_repo (str): Output repository name
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
dict: Training configuration
|
| 97 |
+
"""
|
| 98 |
+
username = st.session_state.get("hf_username", "user")
|
| 99 |
+
|
| 100 |
+
config = {
|
| 101 |
+
"model_name_or_path": model_name,
|
| 102 |
+
"dataset_name": f"{username}/{dataset_repo}",
|
| 103 |
+
"output_dir": f"{username}/{output_repo}",
|
| 104 |
+
"num_train_epochs": hyperparams.get("epochs", 3),
|
| 105 |
+
"per_device_train_batch_size": hyperparams.get("batch_size", 8),
|
| 106 |
+
"learning_rate": hyperparams.get("learning_rate", 2e-5),
|
| 107 |
+
"weight_decay": hyperparams.get("weight_decay", 0.01),
|
| 108 |
+
"save_strategy": "epoch",
|
| 109 |
+
"evaluation_strategy": "epoch",
|
| 110 |
+
"fp16": hyperparams.get("fp16", False),
|
| 111 |
+
"peft_config": {
|
| 112 |
+
"r": hyperparams.get("lora_rank", 8),
|
| 113 |
+
"lora_alpha": hyperparams.get("lora_alpha", 32),
|
| 114 |
+
"lora_dropout": hyperparams.get("lora_dropout", 0.05),
|
| 115 |
+
"bias": "none",
|
| 116 |
+
"task_type": "CAUSAL_LM"
|
| 117 |
+
},
|
| 118 |
+
"optim": "adamw_torch",
|
| 119 |
+
"logging_steps": 50,
|
| 120 |
+
"gradient_accumulation_steps": hyperparams.get("gradient_accumulation", 1),
|
| 121 |
+
"max_steps": hyperparams.get("max_steps", -1),
|
| 122 |
+
"warmup_steps": hyperparams.get("warmup_steps", 0),
|
| 123 |
+
"max_grad_norm": hyperparams.get("max_grad_norm", 1.0),
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
return config
|
| 127 |
+
|
| 128 |
+
def preprocess_dataset(df, prompt_column, response_column, model_name="google/gemma-2b"):
|
| 129 |
+
"""
|
| 130 |
+
Preprocess a dataset for Gemma fine-tuning
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
df (DataFrame): Dataset
|
| 134 |
+
prompt_column (str): Column containing prompts/instructions
|
| 135 |
+
response_column (str): Column containing responses
|
| 136 |
+
model_name (str): Model identifier for tokenizer
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
DataFrame: Processed dataset
|
| 140 |
+
"""
|
| 141 |
+
# Check if columns exist
|
| 142 |
+
if prompt_column not in df.columns or response_column not in df.columns:
|
| 143 |
+
raise ValueError(f"Columns {prompt_column} and/or {response_column} not found in dataset")
|
| 144 |
+
|
| 145 |
+
# Simple format for instruction tuning
|
| 146 |
+
df["text"] = df.apply(
|
| 147 |
+
lambda row: f"<start_of_turn>user\n{row[prompt_column]}<end_of_turn>\n<start_of_turn>model\n{row[response_column]}<end_of_turn>",
|
| 148 |
+
axis=1
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Return the processed dataset
|
| 152 |
+
return df[["text"]]
|
utils/training.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import streamlit as st
|
| 6 |
+
from huggingface_hub import HfApi, create_repo, upload_file
|
| 7 |
+
import tempfile
|
| 8 |
+
|
| 9 |
+
def create_model_repo(repo_name, private=True):
|
| 10 |
+
"""
|
| 11 |
+
Create a new model repository on Hugging Face Hub
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
repo_name (str): Name of the repository
|
| 15 |
+
private (bool): Whether the repository should be private
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
tuple: (success (bool), message (str))
|
| 19 |
+
"""
|
| 20 |
+
try:
|
| 21 |
+
token = st.session_state.get("hf_token")
|
| 22 |
+
if not token:
|
| 23 |
+
return False, "No Hugging Face token found"
|
| 24 |
+
|
| 25 |
+
username = st.session_state.get("hf_username", "user")
|
| 26 |
+
full_repo_name = f"{username}/{repo_name}"
|
| 27 |
+
|
| 28 |
+
api = HfApi(token=token)
|
| 29 |
+
repo_url = api.create_repo(
|
| 30 |
+
repo_id=full_repo_name,
|
| 31 |
+
private=private,
|
| 32 |
+
exist_ok=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
return True, repo_url
|
| 36 |
+
except Exception as e:
|
| 37 |
+
return False, str(e)
|
| 38 |
+
|
| 39 |
+
def upload_training_config(config, repo_name):
|
| 40 |
+
"""
|
| 41 |
+
Upload a training configuration file to Hugging Face Hub
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
config (dict): Training configuration
|
| 45 |
+
repo_name (str): Repository to upload to
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
tuple: (success (bool), message (str))
|
| 49 |
+
"""
|
| 50 |
+
try:
|
| 51 |
+
token = st.session_state.get("hf_token")
|
| 52 |
+
if not token:
|
| 53 |
+
return False, "No Hugging Face token found"
|
| 54 |
+
|
| 55 |
+
username = st.session_state.get("hf_username", "user")
|
| 56 |
+
repo_id = f"{username}/{repo_name}"
|
| 57 |
+
|
| 58 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.json') as tmp:
|
| 59 |
+
with open(tmp.name, 'w') as f:
|
| 60 |
+
json.dump(config, f, indent=2)
|
| 61 |
+
|
| 62 |
+
# Upload file to repository
|
| 63 |
+
upload_file(
|
| 64 |
+
path_or_fileobj=tmp.name,
|
| 65 |
+
path_in_repo="training_config.json",
|
| 66 |
+
repo_id=repo_id,
|
| 67 |
+
token=token
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Clean up temporary file
|
| 71 |
+
tmp_name = tmp.name
|
| 72 |
+
|
| 73 |
+
os.unlink(tmp_name)
|
| 74 |
+
return True, f"Training config uploaded to {repo_id}"
|
| 75 |
+
except Exception as e:
|
| 76 |
+
return False, str(e)
|
| 77 |
+
|
| 78 |
+
def setup_training_script(repo_name, config):
|
| 79 |
+
"""
|
| 80 |
+
Generate and upload a training script to the repository
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
repo_name (str): Repository name
|
| 84 |
+
config (dict): Training configuration
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
tuple: (success (bool), message (str))
|
| 88 |
+
"""
|
| 89 |
+
# Create a training script using transformers Trainer with 4-bit quantization for CPU
|
| 90 |
+
script = """
|
| 91 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
|
| 92 |
+
from datasets import load_dataset
|
| 93 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
| 94 |
+
import json
|
| 95 |
+
import os
|
| 96 |
+
import torch
|
| 97 |
+
from huggingface_hub import login
|
| 98 |
+
import bitsandbytes as bnb
|
| 99 |
+
|
| 100 |
+
# Load configuration
|
| 101 |
+
with open("training_config.json", "r") as f:
|
| 102 |
+
config = json.load(f)
|
| 103 |
+
|
| 104 |
+
# Login to Hugging Face
|
| 105 |
+
login(token=os.environ.get("HF_TOKEN"))
|
| 106 |
+
|
| 107 |
+
# Load dataset
|
| 108 |
+
print("Loading dataset:", config["dataset_name"])
|
| 109 |
+
dataset = load_dataset(config["dataset_name"])
|
| 110 |
+
|
| 111 |
+
# Prepare train/validation split if not already split
|
| 112 |
+
if "train" in dataset and "validation" not in dataset:
|
| 113 |
+
dataset = dataset["train"].train_test_split(test_size=0.1)
|
| 114 |
+
elif "train" not in dataset:
|
| 115 |
+
# If dataset has no train split but has text column, use that
|
| 116 |
+
if "text" in dataset:
|
| 117 |
+
dataset = dataset.train_test_split(test_size=0.1)
|
| 118 |
+
else:
|
| 119 |
+
# Try to find what splits are available
|
| 120 |
+
print("Available splits:", list(dataset.keys()))
|
| 121 |
+
# Default to using the first split and splitting it
|
| 122 |
+
first_split = list(dataset.keys())[0]
|
| 123 |
+
dataset = dataset[first_split].train_test_split(test_size=0.1)
|
| 124 |
+
|
| 125 |
+
print("Dataset splits:", list(dataset.keys()))
|
| 126 |
+
|
| 127 |
+
# Print dataset sample
|
| 128 |
+
print("Dataset sample:", dataset["train"][0])
|
| 129 |
+
|
| 130 |
+
# Load tokenizer
|
| 131 |
+
print("Loading tokenizer for model:", config["model_name_or_path"])
|
| 132 |
+
tokenizer = AutoTokenizer.from_pretrained(config["model_name_or_path"])
|
| 133 |
+
if tokenizer.pad_token is None:
|
| 134 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 135 |
+
|
| 136 |
+
# Load model with 4-bit quantization for CPU efficiency
|
| 137 |
+
print("Loading model with quantization...")
|
| 138 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 139 |
+
config["model_name_or_path"],
|
| 140 |
+
load_in_4bit=True, # Enable 4-bit quantization
|
| 141 |
+
device_map="auto",
|
| 142 |
+
quantization_config=bnb.nn.modules.Linear4bit.compute_quant_config(),
|
| 143 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 144 |
+
use_cache=False, # Required for gradient checkpointing
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Enable gradient checkpointing for memory efficiency
|
| 148 |
+
model.gradient_checkpointing_enable()
|
| 149 |
+
|
| 150 |
+
# Print memory usage before PEFT
|
| 151 |
+
print(f"Model loaded. Memory usage: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
|
| 152 |
+
|
| 153 |
+
# Prepare model for training with LoRA
|
| 154 |
+
print("Setting up LoRA with rank:", config["peft_config"]["r"])
|
| 155 |
+
peft_config = LoraConfig(
|
| 156 |
+
r=config["peft_config"]["r"],
|
| 157 |
+
lora_alpha=config["peft_config"]["lora_alpha"],
|
| 158 |
+
lora_dropout=config["peft_config"]["lora_dropout"],
|
| 159 |
+
bias=config["peft_config"]["bias"],
|
| 160 |
+
task_type=config["peft_config"]["task_type"],
|
| 161 |
+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Prepare model - use 8-bit Adam for memory efficiency
|
| 165 |
+
print("Preparing model for training...")
|
| 166 |
+
model = prepare_model_for_kbit_training(model)
|
| 167 |
+
model = get_peft_model(model, peft_config)
|
| 168 |
+
model.print_trainable_parameters()
|
| 169 |
+
|
| 170 |
+
# Setup training arguments with CPU optimizations
|
| 171 |
+
print("Setting up training arguments...")
|
| 172 |
+
training_args = TrainingArguments(
|
| 173 |
+
output_dir=config["output_dir"],
|
| 174 |
+
num_train_epochs=config["num_train_epochs"],
|
| 175 |
+
per_device_train_batch_size=config["per_device_train_batch_size"],
|
| 176 |
+
per_device_eval_batch_size=max(1, config["per_device_train_batch_size"] // 2),
|
| 177 |
+
learning_rate=config["learning_rate"],
|
| 178 |
+
weight_decay=config["weight_decay"],
|
| 179 |
+
save_strategy=config["save_strategy"],
|
| 180 |
+
evaluation_strategy=config["evaluation_strategy"],
|
| 181 |
+
fp16=config["fp16"] and torch.cuda.is_available(),
|
| 182 |
+
optim=config["optim"],
|
| 183 |
+
logging_steps=config["logging_steps"],
|
| 184 |
+
gradient_accumulation_steps=config["gradient_accumulation_steps"],
|
| 185 |
+
max_steps=config["max_steps"] if config["max_steps"] > 0 else None,
|
| 186 |
+
warmup_steps=config["warmup_steps"],
|
| 187 |
+
max_grad_norm=config["max_grad_norm"],
|
| 188 |
+
push_to_hub=True,
|
| 189 |
+
hub_token=os.environ.get("HF_TOKEN"),
|
| 190 |
+
dataloader_num_workers=0, # Lower CPU usage for smaller machines
|
| 191 |
+
use_cpu=not torch.cuda.is_available(), # Force CPU if no GPU
|
| 192 |
+
lr_scheduler_type="cosine", # Better LR scheduling for small datasets
|
| 193 |
+
report_to=["tensorboard"], # Enable tensorboard logging
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Define data collator
|
| 197 |
+
print("Setting up data collator...")
|
| 198 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 199 |
+
|
| 200 |
+
# Tokenize dataset
|
| 201 |
+
print("Tokenizing dataset...")
|
| 202 |
+
def tokenize_function(examples):
|
| 203 |
+
# Use a smaller max length on CPU to save memory
|
| 204 |
+
max_length = 256 if not torch.cuda.is_available() else 512
|
| 205 |
+
return tokenizer(
|
| 206 |
+
examples["text"],
|
| 207 |
+
padding="max_length",
|
| 208 |
+
truncation=True,
|
| 209 |
+
max_length=max_length
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Show progress while tokenizing
|
| 213 |
+
print("Mapping tokenization function...")
|
| 214 |
+
tokenized_dataset = dataset.map(
|
| 215 |
+
tokenize_function,
|
| 216 |
+
batched=True,
|
| 217 |
+
batch_size=8, # Smaller batch size for CPU
|
| 218 |
+
remove_columns=dataset["train"].column_names, # Remove original columns after tokenizing
|
| 219 |
+
desc="Tokenizing dataset",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Initialize trainer
|
| 223 |
+
print("Initializing trainer...")
|
| 224 |
+
trainer = Trainer(
|
| 225 |
+
model=model,
|
| 226 |
+
args=training_args,
|
| 227 |
+
train_dataset=tokenized_dataset["train"],
|
| 228 |
+
eval_dataset=tokenized_dataset["validation"],
|
| 229 |
+
data_collator=data_collator,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Start training
|
| 233 |
+
print("Starting training...")
|
| 234 |
+
try:
|
| 235 |
+
trainer.train()
|
| 236 |
+
# Save model and tokenizer
|
| 237 |
+
print("Saving model...")
|
| 238 |
+
trainer.save_model()
|
| 239 |
+
print("Training completed successfully!")
|
| 240 |
+
except Exception as e:
|
| 241 |
+
print(f"Error during training: {str(e)}")
|
| 242 |
+
# Save checkpoint even if error occurred
|
| 243 |
+
try:
|
| 244 |
+
trainer.save_model("./checkpoint-error")
|
| 245 |
+
print("Saved checkpoint before error")
|
| 246 |
+
except:
|
| 247 |
+
print("Could not save checkpoint")
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
token = st.session_state.get("hf_token")
|
| 252 |
+
if not token:
|
| 253 |
+
return False, "No Hugging Face token found"
|
| 254 |
+
|
| 255 |
+
username = st.session_state.get("hf_username", "user")
|
| 256 |
+
repo_id = f"{username}/{repo_name}"
|
| 257 |
+
|
| 258 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.py') as tmp:
|
| 259 |
+
with open(tmp.name, 'w') as f:
|
| 260 |
+
f.write(script)
|
| 261 |
+
|
| 262 |
+
# Upload file to repository
|
| 263 |
+
upload_file(
|
| 264 |
+
path_or_fileobj=tmp.name,
|
| 265 |
+
path_in_repo="train.py",
|
| 266 |
+
repo_id=repo_id,
|
| 267 |
+
token=token
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Clean up temporary file
|
| 271 |
+
tmp_name = tmp.name
|
| 272 |
+
|
| 273 |
+
# Also create and upload a CPU-optimized requirements file
|
| 274 |
+
requirements = """
|
| 275 |
+
transformers>=4.35.0
|
| 276 |
+
peft>=0.7.0
|
| 277 |
+
bitsandbytes>=0.40.0
|
| 278 |
+
datasets>=2.10.0
|
| 279 |
+
torch>=2.0.0
|
| 280 |
+
tensorboard>=2.13.0
|
| 281 |
+
accelerate>=0.20.0
|
| 282 |
+
huggingface_hub>=0.15.0
|
| 283 |
+
scipy>=1.10.0
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.txt') as tmp:
|
| 287 |
+
with open(tmp.name, 'w') as f:
|
| 288 |
+
f.write(requirements)
|
| 289 |
+
|
| 290 |
+
# Upload file to repository
|
| 291 |
+
upload_file(
|
| 292 |
+
path_or_fileobj=tmp.name,
|
| 293 |
+
path_in_repo="requirements.txt",
|
| 294 |
+
repo_id=repo_id,
|
| 295 |
+
token=token
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Clean up temporary file
|
| 299 |
+
tmp_name = tmp.name
|
| 300 |
+
|
| 301 |
+
os.unlink(tmp_name)
|
| 302 |
+
|
| 303 |
+
return True, f"Training script and requirements uploaded to {repo_id}"
|
| 304 |
+
except Exception as e:
|
| 305 |
+
return False, str(e)
|
| 306 |
+
|
| 307 |
+
def simulate_training_progress():
|
| 308 |
+
"""
|
| 309 |
+
Simulate training progress for demonstration purposes
|
| 310 |
+
"""
|
| 311 |
+
if "training_progress" not in st.session_state:
|
| 312 |
+
st.session_state.training_progress = {
|
| 313 |
+
"started": time.time(),
|
| 314 |
+
"current_epoch": 0,
|
| 315 |
+
"total_epochs": 3,
|
| 316 |
+
"loss": 2.5,
|
| 317 |
+
"learning_rate": 2e-5,
|
| 318 |
+
"progress": 0.0,
|
| 319 |
+
"status": "running"
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
# Update progress based on elapsed time (simulated)
|
| 323 |
+
elapsed = time.time() - st.session_state.training_progress["started"]
|
| 324 |
+
epoch_duration = 60 # Simulate each epoch taking 60 seconds
|
| 325 |
+
|
| 326 |
+
# Calculate current progress
|
| 327 |
+
total_duration = epoch_duration * st.session_state.training_progress["total_epochs"]
|
| 328 |
+
progress = min(elapsed / total_duration, 1.0)
|
| 329 |
+
|
| 330 |
+
# Calculate current epoch
|
| 331 |
+
current_epoch = min(
|
| 332 |
+
int(progress * st.session_state.training_progress["total_epochs"]),
|
| 333 |
+
st.session_state.training_progress["total_epochs"]
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Simulate decreasing loss
|
| 337 |
+
loss = max(2.5 - (progress * 2.0), 0.5)
|
| 338 |
+
|
| 339 |
+
# Update session state
|
| 340 |
+
st.session_state.training_progress.update({
|
| 341 |
+
"progress": progress,
|
| 342 |
+
"current_epoch": current_epoch,
|
| 343 |
+
"loss": loss,
|
| 344 |
+
"status": "completed" if progress >= 1.0 else "running"
|
| 345 |
+
})
|
| 346 |
+
|
| 347 |
+
return st.session_state.training_progress
|
utils/ui.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
def set_page_style():
|
| 4 |
+
"""
|
| 5 |
+
Set custom CSS styling for the Streamlit app with light mode
|
| 6 |
+
"""
|
| 7 |
+
st.markdown("""
|
| 8 |
+
<style>
|
| 9 |
+
/* Main theme colors - Light Mode */
|
| 10 |
+
:root {
|
| 11 |
+
--primary-color: #3366FF;
|
| 12 |
+
--background-color: #FFFFFF;
|
| 13 |
+
--secondary-background-color: #F5F7F9;
|
| 14 |
+
--text-color: #333333;
|
| 15 |
+
--font: 'Source Sans Pro', sans-serif;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
/* Overall page background */
|
| 19 |
+
.stApp {
|
| 20 |
+
background-color: var(--background-color);
|
| 21 |
+
color: var(--text-color);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
/* Headers */
|
| 25 |
+
h1, h2, h3, h4, h5, h6 {
|
| 26 |
+
color: #111111;
|
| 27 |
+
font-weight: 700;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
/* Card-like elements */
|
| 31 |
+
.stCard {
|
| 32 |
+
border-radius: 8px;
|
| 33 |
+
padding: 20px;
|
| 34 |
+
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
|
| 35 |
+
background-color: white;
|
| 36 |
+
border: 1px solid #EAEAEA;
|
| 37 |
+
margin-bottom: 15px;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
/* Sidebar */
|
| 41 |
+
section[data-testid="stSidebar"] {
|
| 42 |
+
background-color: #F5F7F9;
|
| 43 |
+
border-right: 1px solid #EAEAEA;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
/* Buttons */
|
| 47 |
+
.stButton button {
|
| 48 |
+
border-radius: 4px;
|
| 49 |
+
font-weight: 600;
|
| 50 |
+
background-color: var(--primary-color);
|
| 51 |
+
color: white;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
.stButton button:hover {
|
| 55 |
+
background-color: #2a56d9;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
/* Banner image */
|
| 59 |
+
.banner-img {
|
| 60 |
+
border-radius: 8px;
|
| 61 |
+
margin-bottom: 15px;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
/* Metric display */
|
| 65 |
+
.metric-container {
|
| 66 |
+
background-color: #FFFFFF;
|
| 67 |
+
border-radius: 6px;
|
| 68 |
+
padding: 12px;
|
| 69 |
+
text-align: center;
|
| 70 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
|
| 71 |
+
border: 1px solid #EAEAEA;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
.metric-value {
|
| 75 |
+
font-size: 20px;
|
| 76 |
+
font-weight: 700;
|
| 77 |
+
color: #3366FF;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
.metric-label {
|
| 81 |
+
font-size: 14px;
|
| 82 |
+
color: #555555;
|
| 83 |
+
margin-top: 5px;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
/* Make input fields a bit nicer */
|
| 87 |
+
input, textarea, select {
|
| 88 |
+
border-radius: 4px !important;
|
| 89 |
+
border: 1px solid #CCCCCC !important;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
/* Improve readability of text elements */
|
| 93 |
+
p, li, label, div {
|
| 94 |
+
color: #333333;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/* Info, success, warning boxes */
|
| 98 |
+
.stAlert {
|
| 99 |
+
border-radius: 4px;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
/* Tabs styling */
|
| 103 |
+
.stTabs [data-baseweb="tab-list"] {
|
| 104 |
+
gap: 8px;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
.stTabs [data-baseweb="tab"] {
|
| 108 |
+
border-radius: 4px 4px 0 0;
|
| 109 |
+
padding: 8px 16px;
|
| 110 |
+
background-color: #F0F2F6;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
.stTabs [aria-selected="true"] {
|
| 114 |
+
background-color: white;
|
| 115 |
+
border-top: 2px solid var(--primary-color);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
/* Code blocks */
|
| 119 |
+
code {
|
| 120 |
+
background-color: #F0F2F6;
|
| 121 |
+
color: #333333;
|
| 122 |
+
padding: 2px 4px;
|
| 123 |
+
border-radius: 3px;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
/* Progress bar */
|
| 127 |
+
.stProgress > div > div {
|
| 128 |
+
background-color: var(--primary-color);
|
| 129 |
+
}
|
| 130 |
+
</style>
|
| 131 |
+
""", unsafe_allow_html=True)
|
| 132 |
+
|
| 133 |
+
def display_metric(title, value, container=None):
|
| 134 |
+
"""
|
| 135 |
+
Display a metric in a nice formatted container
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
title (str): Title of the metric
|
| 139 |
+
value (str/int/float): Value to display
|
| 140 |
+
container (streamlit container, optional): Container to display the metric in
|
| 141 |
+
"""
|
| 142 |
+
target = container if container else st
|
| 143 |
+
target.markdown(f"""
|
| 144 |
+
<div class="metric-container">
|
| 145 |
+
<div class="metric-value">{value}</div>
|
| 146 |
+
<div class="metric-label">{title}</div>
|
| 147 |
+
</div>
|
| 148 |
+
""", unsafe_allow_html=True)
|
| 149 |
+
|
| 150 |
+
def create_card(content, container=None):
|
| 151 |
+
"""
|
| 152 |
+
Create a card-like container for content
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
content (callable): Function to call to render content inside the card
|
| 156 |
+
container (streamlit container, optional): Container to place the card in
|
| 157 |
+
"""
|
| 158 |
+
target = container if container else st
|
| 159 |
+
|
| 160 |
+
with target.container():
|
| 161 |
+
target.markdown('<div class="stCard">', unsafe_allow_html=True)
|
| 162 |
+
content()
|
| 163 |
+
target.markdown('</div>', unsafe_allow_html=True)
|