astroknotsheep commited on
Commit
b619545
·
verified ·
1 Parent(s): f94c2e7

Upload 15 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Copy requirements first for better caching
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Copy the rest of the application
10
+ COPY . .
11
+
12
+ # Set environment variables
13
+ ENV PYTHONUNBUFFERED=1 \
14
+ PYTHONDONTWRITEBYTECODE=1
15
+
16
+ # Expose port for Streamlit
17
+ EXPOSE 8501
18
+
19
+ # Command to run the application
20
+ CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Gemma Fine-tuning UI
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: 1.30.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Gemma Fine-tuning UI
13
+
14
+ A web-based user interface for fine-tuning Google's Gemma models using Hugging Face infrastructure.
15
+
16
+ ![Gemma Banner](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png)
17
+
18
+ ## Features
19
+
20
+ - **Dataset Upload**: Upload and preprocess your custom training data in CSV, JSON, or JSONL format
21
+ - **Model Configuration**: Configure Gemma model version and hyperparameters
22
+ - **Training Management**: Start, monitor, and manage fine-tuning jobs
23
+ - **Evaluation**: Test your fine-tuned model with interactive generation
24
+ - **Export Options**: Use your model directly from Hugging Face or export in various formats
25
+
26
+ ## Installation
27
+
28
+ 1. Clone this repository:
29
+ ```bash
30
+ git clone https://github.com/yourusername/gemma-finetuning-ui.git
31
+ cd gemma-finetuning-ui
32
+ ```
33
+
34
+ 2. Create a virtual environment and install dependencies:
35
+ ```bash
36
+ python -m venv venv
37
+ source venv/bin/activate # On Windows: venv\Scripts\activate
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+ 3. Run the application:
42
+ ```bash
43
+ streamlit run app.py
44
+ ```
45
+
46
+ ## Usage
47
+
48
+ 1. **Authentication**: Provide your Hugging Face API token with write permissions
49
+ 2. **Dataset Preparation**: Upload your dataset and configure column mappings
50
+ 3. **Model Selection**: Choose between Gemma 2B or 7B and customize training parameters
51
+ 4. **Training**: Start the fine-tuning process and monitor progress
52
+ 5. **Evaluation**: Test your fine-tuned model with custom prompts
53
+ 6. **Deployment**: Export or directly use your model from Hugging Face
54
+
55
+ ## Hugging Face Spaces Deployment
56
+
57
+ This application is designed to be deployed easily on Hugging Face Spaces:
58
+
59
+ 1. Create a new Space on [Hugging Face Spaces](https://huggingface.co/spaces)
60
+ 2. Select Streamlit as the SDK
61
+ 3. Connect your GitHub repository or upload the files directly
62
+ 4. The Space will automatically detect and install the requirements
63
+
64
+ ## Requirements
65
+
66
+ - Python 3.8+
67
+ - Streamlit 1.30.0+
68
+ - Hugging Face Account with API token
69
+ - For training: GPU access (recommended)
70
+
71
+ ## Project Structure
72
+
73
+ ```
74
+ .
75
+ ├── app.py # Main Streamlit application
76
+ ├── pages/ # Multi-page app components
77
+ │ ├── 01_Dataset_Upload.py
78
+ │ ├── 02_Model_Configuration.py
79
+ │ ├── 03_Training_Monitor.py
80
+ │ └── 04_Evaluation.py
81
+ ├── utils/ # Utility functions
82
+ │ ├── auth.py # Authentication utilities
83
+ │ ├── huggingface.py # Hugging Face API integration
84
+ │ ├── training.py # Training utilities
85
+ │ └── ui.py # UI components and styling
86
+ ├── data/ # Sample data and uploads
87
+ └── requirements.txt # Project dependencies
88
+ ```
89
+
90
+ ## License
91
+
92
+ This project is licensed under the MIT License - see the LICENSE file for details.
93
+
94
+ ## Acknowledgements
95
+
96
+ - [Google Gemma Models](https://huggingface.co/google/gemma-7b)
97
+ - [Hugging Face Transformers](https://huggingface.co/docs/transformers/index)
98
+ - [Streamlit](https://streamlit.io/)
99
+ - [PEFT - Parameter-Efficient Fine-Tuning](https://github.com/huggingface/peft)
100
+
101
+ ---
102
+
103
+ Developed as a simplified interface for fine-tuning Gemma models with Hugging Face.
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils.auth import check_hf_token
3
+ from utils.ui import set_page_style
4
+
5
+ # Set page configuration
6
+ st.set_page_config(
7
+ page_title="Gemma Fine-tuning UI",
8
+ page_icon="🤖",
9
+ layout="wide",
10
+ initial_sidebar_state="expanded"
11
+ )
12
+
13
+ # Apply custom styling
14
+ set_page_style()
15
+
16
+ # Sidebar for navigation and authentication
17
+ with st.sidebar:
18
+ st.title("🤖 Gemma Fine-tuning")
19
+ st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
20
+ use_column_width=True)
21
+
22
+ # Authentication section
23
+ st.subheader("🔑 Authentication")
24
+ hf_token = st.text_input("Hugging Face API Token", type="password",
25
+ help="Enter your Hugging Face write token to enable model fine-tuning")
26
+ auth_status = check_hf_token(hf_token) if hf_token else False
27
+
28
+ if auth_status:
29
+ st.success("Authenticated successfully!")
30
+ elif hf_token:
31
+ st.error("Invalid token. Please check and try again.")
32
+
33
+ st.divider()
34
+ st.markdown("**Navigation:**")
35
+ if auth_status:
36
+ st.page_link("pages/01_Dataset_Upload.py", label="1️⃣ Dataset Upload", icon="📤")
37
+ st.page_link("pages/02_Model_Configuration.py", label="2️⃣ Model Configuration", icon="⚙️")
38
+ st.page_link("pages/03_Training_Monitor.py", label="3️⃣ Training Monitor", icon="📊")
39
+ st.page_link("pages/04_Evaluation.py", label="4️⃣ Model Evaluation", icon="🔍")
40
+
41
+ st.caption("A simple UI for fine-tuning Gemma models")
42
+
43
+ # Main content
44
+ st.title("Welcome to Gemma Fine-tuning UI")
45
+
46
+ if not hf_token:
47
+ st.info("👈 Please enter your Hugging Face token in the sidebar to get started")
48
+
49
+ with st.expander("ℹ️ How to get a Hugging Face token", expanded=True):
50
+ st.markdown("""
51
+ 1. Go to [Hugging Face](https://huggingface.co/settings/tokens)
52
+ 2. Sign in or create an account
53
+ 3. Create a new token with write access
54
+ 4. Copy and paste the token in the sidebar
55
+ """)
56
+
57
+ st.divider()
58
+
59
+ st.subheader("What you can do with this app:")
60
+
61
+ st.markdown("""
62
+ ### 📝 Simple Fine-tuning Process
63
+
64
+ This app provides a straightforward interface for fine-tuning Gemma models with your own data:
65
+ """)
66
+
67
+ cols = st.columns(2)
68
+
69
+ with cols[0]:
70
+ st.markdown("""
71
+ ✅ **Upload your dataset**
72
+ - Support for CSV and JSON/JSONL formats
73
+ - Manual input option for small datasets
74
+ - Automatic preprocessing for Gemma format
75
+
76
+ ✅ **Configure Gemma model parameters**
77
+ - Choose between Gemma 2B and 7B models
78
+ - Adjust learning rate, batch size, and epochs
79
+ - LoRA parameter configuration
80
+ """)
81
+
82
+ with cols[1]:
83
+ st.markdown("""
84
+ ✅ **Monitor training progress**
85
+ - Visual training progress tracking
86
+ - Loss curve visualization
87
+ - Training logs and status updates
88
+
89
+ ✅ **Evaluate and use your model**
90
+ - Interactive testing interface
91
+ - Export options for deployment
92
+ - Usage examples with code snippets
93
+ """)
94
+
95
+ else:
96
+ st.success("You're all set up! Follow these steps to fine-tune your Gemma model:")
97
+
98
+ # Simple step-by-step guide
99
+ col1, col2 = st.columns(2)
100
+
101
+ with col1:
102
+ st.markdown("### Step 1: Prepare Your Dataset")
103
+ st.markdown("""
104
+ - Upload your dataset in CSV or JSONL format
105
+ - Ensure your data has prompt/instruction and response columns
106
+ - The app will preprocess the data into the right format for Gemma
107
+ """)
108
+ st.page_link("pages/01_Dataset_Upload.py", label="Go to Dataset Upload", icon="📤")
109
+
110
+ with col2:
111
+ st.markdown("### Step 2: Configure Your Model")
112
+ st.markdown("""
113
+ - Select either Gemma 2B (faster) or 7B (more powerful)
114
+ - Adjust hyperparameters based on your needs
115
+ - Basic configurations work well for most use cases
116
+ """)
117
+ st.page_link("pages/02_Model_Configuration.py", label="Go to Model Configuration", icon="⚙️")
118
+
119
+ st.divider()
120
+
121
+ col3, col4 = st.columns(2)
122
+
123
+ with col3:
124
+ st.markdown("### Step 3: Train Your Model")
125
+ st.markdown("""
126
+ - Start the training process
127
+ - Monitor progress with real-time updates
128
+ - Training a model may take time depending on your dataset size
129
+ """)
130
+ st.page_link("pages/03_Training_Monitor.py", label="Go to Training Monitor", icon="📊")
131
+
132
+ with col4:
133
+ st.markdown("### Step 4: Evaluate & Use Your Model")
134
+ st.markdown("""
135
+ - Test your model with custom prompts
136
+ - Compare results with the base model
137
+ - Export your model for use in applications
138
+ """)
139
+ st.page_link("pages/04_Evaluation.py", label="Go to Model Evaluation", icon="🔍")
140
+
141
+ # Notes about CPU limitations
142
+ st.info("""
143
+ **Note on Training Limitations**: This app is running on CPU resources (2vCPU, 16GB RAM), which means:
144
+
145
+ - For actual training, we use Parameter-Efficient Fine-Tuning (PEFT) with LoRA to reduce memory requirements
146
+ - Training will be slower than on GPU hardware
147
+ - For very large datasets, consider using this interface to prepare your data and configuration,
148
+ then download the config to run training on more powerful hardware
149
+ """)
app.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Run the Streamlit app
4
+ streamlit run app.py
data/sample_dataset.jsonl ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {"prompt": "Explain the concept of machine learning in simple terms.", "response": "Machine learning is like teaching a computer to learn from examples rather than programming it with specific rules. Imagine showing a child many pictures of cats and dogs until they can recognize the difference. Similarly, we show computers lots of data examples, and they learn patterns to make predictions or decisions without being explicitly programmed for each scenario."}
2
+ {"prompt": "What is fine-tuning in the context of language models?", "response": "Fine-tuning is the process of taking a pre-trained language model and further training it on a specific dataset for a particular task. It's like taking a general-purpose tool and customizing it for a specialized job. The model keeps its general knowledge but becomes better at the specific task you want it to perform, like answering medical questions or writing in a particular style."}
3
+ {"prompt": "How does a neural network work?", "response": "A neural network is a computing system inspired by the human brain. It consists of interconnected nodes (neurons) organized in layers. Each connection has a weight that adjusts as the network learns. The network processes input data through these layers, with each neuron applying a mathematical function to its inputs and passing the result to the next layer. Through a process called backpropagation, the network adjusts its weights based on errors in its output, gradually improving its performance on a task."}
4
+ {"prompt": "Write a short poem about artificial intelligence.", "response": "Silicon dreams in digital streams,\nThinking machines with invisible seams.\nLearning patterns humans can't see,\nArtificial minds, what will they be?\n\nVast knowledge stored in neural webs,\nAnswering questions as curiosity ebbs.\nPartner or servant or something more—\nAI's potential we've yet to explore."}
5
+ {"prompt": "Summarize the concept of transfer learning.", "response": "Transfer learning is a machine learning technique where a model developed for one task is reused as the starting point for a model on a second task. It's like learning to ride a bicycle makes it easier to learn to ride a motorcycle—the skills transfer. In practice, this means taking a pre-trained model (often trained on a large dataset) and fine-tuning it for your specific task with a smaller dataset, saving computational resources and time while achieving good performance."}
6
+ {"prompt": "How would you explain deep learning to a 10-year-old?", "response": "Deep learning is like teaching a computer to learn by itself, similar to how you learn to recognize your friends' faces. Imagine you have a robot that you want to teach to identify dogs. Instead of telling it exactly what a dog looks like (pointy ears, tail, four legs), you show it thousands of pictures of dogs. The robot has layers of simple learning parts, like building blocks stacked together. Each layer learns something different - one might learn edges, another might learn shapes, and so on. After seeing enough pictures, the robot can recognize dogs on its own, even dogs it's never seen before!"}
7
+ {"prompt": "What are the ethical concerns with AI development?", "response": "Ethical concerns in AI development include: 1) Bias and fairness issues when AI systems reflect or amplify human biases in their training data, 2) Privacy concerns as AI systems collect and process vast amounts of personal data, 3) Accountability questions about who is responsible when AI makes harmful decisions, 4) Transparency problems when AI systems function as 'black boxes' with unexplainable decisions, 5) Job displacement as automation changes employment landscapes, and 6) Safety and control concerns about ensuring AI systems remain aligned with human values as they become more capable."}
8
+ {"prompt": "Compare supervised and unsupervised learning.", "response": "Supervised learning is like learning with a teacher. The algorithm is trained on labeled data, where each example has input features and the correct output. It learns to map inputs to outputs based on these examples, like predicting house prices from features like size and location. Unsupervised learning is like learning without a teacher. The algorithm finds patterns in unlabeled data without specific guidance. It might group similar items together (clustering) or find ways to represent data more efficiently. Supervised learning is used for classification and regression tasks, while unsupervised learning is used for clustering, association, and dimensionality reduction."}
9
+ {"prompt": "Explain how language models like Gemma work.", "response": "Language models like Gemma work by predicting the next word in a sequence based on the context of previous words. They're trained on massive text datasets to understand patterns in language. These models use transformer architecture with attention mechanisms that help them focus on relevant parts of the input when making predictions. The model processes text as tokens (pieces of words or characters) and maps these to high-dimensional vector representations that capture semantic meaning. During training, the model adjusts its billions of parameters to minimize prediction errors. When you prompt the model, it uses its learned patterns to generate text that continues your input in a coherent and contextually appropriate way."}
10
+ {"prompt": "What is the difference between AI, machine learning, and deep learning?", "response": "These terms represent nested concepts. Artificial Intelligence (AI) is the broadest, referring to any technique that enables computers to mimic human intelligence. Machine Learning is a subset of AI where systems learn from data without being explicitly programmed, improving automatically through experience. Deep Learning is a specialized subset of machine learning that uses neural networks with many layers (hence 'deep'). Think of it like this: AI is like the entire field of transportation, machine learning is like motorized vehicles specifically, and deep learning is like self-driving cars – a specialized, advanced implementation."}
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ libgl1-mesa-glx
2
+ gcc
3
+ g++
pages/01_Dataset_Upload.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import io
4
+ import json
5
+ import os
6
+ from utils.ui import set_page_style
7
+ from utils.auth import check_hf_token
8
+ from utils.huggingface import create_dataset_repo, upload_dataset_to_hub, preprocess_dataset
9
+
10
+ # Set page configuration
11
+ st.set_page_config(
12
+ page_title="Dataset Upload - Gemma Fine-tuning",
13
+ page_icon="🤖",
14
+ layout="wide"
15
+ )
16
+
17
+ # Apply custom styling
18
+ set_page_style()
19
+
20
+ # Sidebar for authentication
21
+ with st.sidebar:
22
+ st.title("🤖 Gemma Fine-tuning")
23
+ st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
24
+ use_column_width=True)
25
+
26
+ # Authentication section
27
+ st.subheader("🔑 Authentication")
28
+ hf_token = st.text_input("Hugging Face API Token", type="password",
29
+ help="Enter your Hugging Face write token to enable model fine-tuning")
30
+ auth_status = check_hf_token(hf_token) if hf_token else False
31
+
32
+ if auth_status:
33
+ st.success("Authenticated successfully!")
34
+ elif hf_token:
35
+ st.error("Invalid token. Please check and try again.")
36
+
37
+ st.divider()
38
+ st.caption("A simple UI for fine-tuning Gemma models")
39
+
40
+ # Main content
41
+ st.title("📤 Dataset Upload")
42
+
43
+ if not hf_token or not auth_status:
44
+ st.warning("Please authenticate with your Hugging Face token in the sidebar first")
45
+ st.stop()
46
+
47
+ # Dataset repository settings
48
+ with st.expander("Dataset Repository Settings", expanded=True):
49
+ col1, col2 = st.columns(2)
50
+
51
+ with col1:
52
+ dataset_repo_name = st.text_input(
53
+ "Repository Name",
54
+ value=f"gemma-finetune-dataset-{pd.Timestamp.now().strftime('%Y%m%d')}",
55
+ help="Name of the Hugging Face repository to store your dataset"
56
+ )
57
+
58
+ with col2:
59
+ is_private = st.toggle("Private Repository", value=True,
60
+ help="Make your dataset repository private (recommended)")
61
+
62
+ # Dataset upload section
63
+ st.header("Upload Your Dataset")
64
+
65
+ dataset_format = st.radio(
66
+ "Select dataset format",
67
+ options=["CSV", "JSON/JSONL", "Manual Input"],
68
+ horizontal=True
69
+ )
70
+
71
+ # Session state to store dataset
72
+ if "dataset" not in st.session_state:
73
+ st.session_state.dataset = None
74
+ if "dataset_preview" not in st.session_state:
75
+ st.session_state.dataset_preview = None
76
+
77
+ # Upload handlers
78
+ if dataset_format == "CSV":
79
+ uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
80
+
81
+ if uploaded_file:
82
+ try:
83
+ # Read the file
84
+ df = pd.read_csv(uploaded_file)
85
+ st.session_state.dataset = df
86
+ st.session_state.dataset_preview = df.head(5)
87
+ except Exception as e:
88
+ st.error(f"Error reading CSV file: {str(e)}")
89
+
90
+ elif dataset_format == "JSON/JSONL":
91
+ uploaded_file = st.file_uploader("Upload a JSON/JSONL file", type=["json", "jsonl"])
92
+
93
+ if uploaded_file:
94
+ try:
95
+ # Read file contents
96
+ content = uploaded_file.read()
97
+
98
+ # Try to parse as JSON array first
99
+ try:
100
+ data = json.loads(content)
101
+ if isinstance(data, list):
102
+ df = pd.DataFrame(data)
103
+ else:
104
+ df = pd.DataFrame([data])
105
+ except:
106
+ # Try to parse as JSONL
107
+ try:
108
+ content = io.StringIO(content.decode("utf-8"))
109
+ df = pd.read_json(content, lines=True)
110
+ except:
111
+ raise ValueError("Invalid JSON/JSONL format")
112
+
113
+ st.session_state.dataset = df
114
+ st.session_state.dataset_preview = df.head(5)
115
+ except Exception as e:
116
+ st.error(f"Error reading JSON/JSONL file: {str(e)}")
117
+
118
+ elif dataset_format == "Manual Input":
119
+ st.info("Enter your training examples manually:")
120
+
121
+ num_examples = st.number_input("Number of examples", min_value=1, max_value=20, value=3)
122
+ examples = []
123
+
124
+ for i in range(num_examples):
125
+ st.subheader(f"Example {i+1}")
126
+ col1, col2 = st.columns(2)
127
+
128
+ with col1:
129
+ prompt = st.text_area(f"Prompt/Instruction {i+1}", height=100)
130
+
131
+ with col2:
132
+ response = st.text_area(f"Response {i+1}", height=100)
133
+
134
+ if prompt and response:
135
+ examples.append({"prompt": prompt, "response": response})
136
+
137
+ if examples:
138
+ df = pd.DataFrame(examples)
139
+ st.session_state.dataset = df
140
+ st.session_state.dataset_preview = df
141
+
142
+ # Display dataset preview if available
143
+ if st.session_state.dataset_preview is not None:
144
+ st.subheader("Dataset Preview")
145
+ st.dataframe(st.session_state.dataset_preview, use_container_width=True)
146
+
147
+ # Column mapping
148
+ st.subheader("Column Mapping")
149
+ st.info("Select which columns contain prompts/instructions and responses")
150
+
151
+ columns = st.session_state.dataset.columns.tolist()
152
+
153
+ col1, col2 = st.columns(2)
154
+
155
+ with col1:
156
+ prompt_column = st.selectbox("Prompt/Instruction Column", options=columns,
157
+ index=columns.index("prompt") if "prompt" in columns else 0)
158
+
159
+ with col2:
160
+ response_column = st.selectbox("Response Column", options=columns,
161
+ index=columns.index("response") if "response" in columns else min(1, len(columns)-1))
162
+
163
+ # Dataset statistics
164
+ st.subheader("Dataset Statistics")
165
+
166
+ stats_col1, stats_col2, stats_col3 = st.columns(3)
167
+
168
+ with stats_col1:
169
+ st.metric("Total Examples", len(st.session_state.dataset))
170
+
171
+ with stats_col2:
172
+ avg_prompt_len = st.session_state.dataset[prompt_column].str.len().mean().round(1)
173
+ st.metric("Avg. Prompt Length", f"{avg_prompt_len} chars")
174
+
175
+ with stats_col3:
176
+ avg_response_len = st.session_state.dataset[response_column].str.len().mean().round(1)
177
+ st.metric("Avg. Response Length", f"{avg_response_len} chars")
178
+
179
+ # Upload button
180
+ st.subheader("Upload to Hugging Face")
181
+
182
+ if st.button("Process and Upload Dataset", type="primary"):
183
+ with st.spinner("Processing dataset..."):
184
+ # Preprocess the dataset into the right format for Gemma
185
+ try:
186
+ processed_df = preprocess_dataset(
187
+ st.session_state.dataset,
188
+ prompt_column=prompt_column,
189
+ response_column=response_column
190
+ )
191
+
192
+ # Create repository
193
+ success, result = create_dataset_repo(dataset_repo_name, private=is_private)
194
+
195
+ if not success:
196
+ st.error(f"Failed to create repository: {result}")
197
+ else:
198
+ st.success(f"Created repository: {dataset_repo_name}")
199
+
200
+ # Upload dataset
201
+ success, result = upload_dataset_to_hub(
202
+ processed_df,
203
+ "train.jsonl",
204
+ dataset_repo_name
205
+ )
206
+
207
+ if success:
208
+ st.session_state["dataset_repo"] = dataset_repo_name
209
+ st.success(f"Dataset uploaded successfully! You can now proceed to model configuration.")
210
+
211
+ # Next steps button
212
+ st.page_link("pages/02_Model_Configuration.py", label="Next: Configure Model", icon="⚙️")
213
+ else:
214
+ st.error(f"Failed to upload dataset: {result}")
215
+ except Exception as e:
216
+ st.error(f"Error processing dataset: {str(e)}")
217
+ else:
218
+ st.info("Please upload or input your dataset above")
pages/02_Model_Configuration.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from utils.ui import set_page_style
4
+ from utils.auth import check_hf_token
5
+ from utils.huggingface import prepare_training_config
6
+ from utils.training import create_model_repo, upload_training_config, setup_training_script
7
+
8
+ # Set page configuration
9
+ st.set_page_config(
10
+ page_title="Model Configuration - Gemma Fine-tuning",
11
+ page_icon="🤖",
12
+ layout="wide"
13
+ )
14
+
15
+ # Apply custom styling
16
+ set_page_style()
17
+
18
+ # Sidebar for authentication
19
+ with st.sidebar:
20
+ st.title("🤖 Gemma Fine-tuning")
21
+ st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
22
+ use_column_width=True)
23
+
24
+ # Authentication section
25
+ st.subheader("🔑 Authentication")
26
+ hf_token = st.text_input("Hugging Face API Token", type="password",
27
+ help="Enter your Hugging Face write token to enable model fine-tuning")
28
+ auth_status = check_hf_token(hf_token) if hf_token else False
29
+
30
+ if auth_status:
31
+ st.success("Authenticated successfully!")
32
+ elif hf_token:
33
+ st.error("Invalid token. Please check and try again.")
34
+
35
+ st.divider()
36
+ st.caption("A simple UI for fine-tuning Gemma models")
37
+
38
+ # Main content
39
+ st.title("⚙️ Model Configuration")
40
+
41
+ if not hf_token or not auth_status:
42
+ st.warning("Please authenticate with your Hugging Face token in the sidebar first")
43
+ st.stop()
44
+
45
+ # Check if dataset repository is set
46
+ if "dataset_repo" not in st.session_state:
47
+ st.warning("Please upload a dataset first")
48
+ st.page_link("pages/01_Dataset_Upload.py", label="Go to Dataset Upload", icon="📤")
49
+
50
+ # For testing purposes, allow manual entry
51
+ manual_repo = st.text_input("Or enter dataset repository name manually:")
52
+ if manual_repo:
53
+ st.session_state["dataset_repo"] = manual_repo
54
+ else:
55
+ st.stop()
56
+
57
+ # Model configuration
58
+ st.header("1. Select Gemma Model")
59
+
60
+ model_version = st.radio(
61
+ "Select Gemma model version",
62
+ options=["google/gemma-2b", "google/gemma-7b"],
63
+ horizontal=True,
64
+ help="Choose the Gemma model size. 2B is faster, 7B is more capable."
65
+ )
66
+
67
+ st.header("2. Training Configuration")
68
+
69
+ # Output repository settings
70
+ with st.expander("Output Repository Settings", expanded=True):
71
+ col1, col2 = st.columns(2)
72
+
73
+ with col1:
74
+ output_repo_name = st.text_input(
75
+ "Repository Name",
76
+ value=f"gemma-finetuned-{pd.Timestamp.now().strftime('%Y%m%d')}",
77
+ help="Name of the Hugging Face repository to store your fine-tuned model"
78
+ )
79
+
80
+ with col2:
81
+ is_private = st.toggle("Private Repository", value=True,
82
+ help="Make your model repository private (recommended)")
83
+
84
+ # Tabs for Basic and Advanced configuration
85
+ basic_tab, advanced_tab = st.tabs(["Basic Configuration", "Advanced Configuration"])
86
+
87
+ with basic_tab:
88
+ col1, col2 = st.columns(2)
89
+
90
+ with col1:
91
+ epochs = st.slider("Number of Epochs", min_value=1, max_value=10, value=3,
92
+ help="Number of complete passes through the dataset")
93
+
94
+ batch_size = st.slider("Batch Size", min_value=1, max_value=32, value=8,
95
+ help="Number of examples processed together")
96
+
97
+ with col2:
98
+ learning_rate = st.number_input("Learning Rate", min_value=1e-6, max_value=1e-3,
99
+ value=2e-5, format="%.1e",
100
+ help="Step size for gradient updates")
101
+
102
+ fp16_training = st.toggle("Use FP16 Training", value=True,
103
+ help="Use 16-bit precision for faster training")
104
+
105
+ with advanced_tab:
106
+ col1, col2 = st.columns(2)
107
+
108
+ with col1:
109
+ lora_rank = st.slider("LoRA Rank", min_value=1, max_value=64, value=8,
110
+ help="Rank of the LoRA matrices")
111
+
112
+ lora_alpha = st.slider("LoRA Alpha", min_value=1, max_value=128, value=32,
113
+ help="Scaling factor for LoRA")
114
+
115
+ lora_dropout = st.slider("LoRA Dropout", min_value=0.0, max_value=0.5, value=0.05, step=0.01,
116
+ help="Dropout probability for LoRA layers")
117
+
118
+ with col2:
119
+ weight_decay = st.number_input("Weight Decay", min_value=0.0, max_value=0.1,
120
+ value=0.01, step=0.001,
121
+ help="L2 regularization strength")
122
+
123
+ gradient_accumulation = st.slider("Gradient Accumulation Steps", min_value=1, max_value=16, value=1,
124
+ help="Number of steps to accumulate gradients")
125
+
126
+ warmup_steps = st.slider("Warmup Steps", min_value=0, max_value=500, value=0,
127
+ help="Steps of linear learning rate warmup")
128
+
129
+ st.header("3. Training Summary")
130
+
131
+ # Create hyperparameters dictionary
132
+ hyperparams = {
133
+ "epochs": epochs,
134
+ "batch_size": batch_size,
135
+ "learning_rate": learning_rate,
136
+ "fp16": fp16_training,
137
+ "lora_rank": lora_rank,
138
+ "lora_alpha": lora_alpha,
139
+ "lora_dropout": lora_dropout,
140
+ "weight_decay": weight_decay,
141
+ "gradient_accumulation": gradient_accumulation,
142
+ "warmup_steps": warmup_steps,
143
+ "max_steps": -1, # -1 means train for full epochs
144
+ "max_grad_norm": 1.0
145
+ }
146
+
147
+ # Display summary
148
+ col1, col2 = st.columns(2)
149
+
150
+ with col1:
151
+ st.subheader("Selected Model")
152
+ st.info(model_version)
153
+
154
+ st.subheader("Dataset Repository")
155
+ st.info(st.session_state["dataset_repo"])
156
+
157
+ with col2:
158
+ st.subheader("Key Hyperparameters")
159
+ st.write(f"Epochs: {epochs}")
160
+ st.write(f"Batch Size: {batch_size}")
161
+ st.write(f"Learning Rate: {learning_rate}")
162
+ st.write(f"LoRA Rank: {lora_rank}")
163
+
164
+ # Start training button
165
+ st.header("4. Start Training")
166
+
167
+ if st.button("Prepare and Launch Training", type="primary"):
168
+ with st.spinner("Setting up training job..."):
169
+ # First create the model repository
170
+ success, repo_url = create_model_repo(output_repo_name, private=is_private)
171
+
172
+ if not success:
173
+ st.error(f"Failed to create model repository: {repo_url}")
174
+ else:
175
+ st.success(f"Created model repository: {output_repo_name}")
176
+
177
+ # Prepare training configuration
178
+ config = prepare_training_config(
179
+ model_name=model_version,
180
+ hyperparams=hyperparams,
181
+ dataset_repo=st.session_state["dataset_repo"],
182
+ output_repo=output_repo_name
183
+ )
184
+
185
+ # Upload training configuration
186
+ success, message = upload_training_config(config, output_repo_name)
187
+
188
+ if not success:
189
+ st.error(f"Failed to upload training configuration: {message}")
190
+ else:
191
+ st.success("Uploaded training configuration")
192
+
193
+ # Setup training script
194
+ success, message = setup_training_script(output_repo_name, config)
195
+
196
+ if not success:
197
+ st.error(f"Failed to setup training script: {message}")
198
+ else:
199
+ st.success("Uploaded training script")
200
+
201
+ # Store in session state
202
+ st.session_state["model_repo"] = output_repo_name
203
+ st.session_state["model_version"] = model_version
204
+ st.session_state["hyperparams"] = hyperparams
205
+
206
+ # Success message and next step
207
+ st.success("Training job prepared successfully! You can now monitor the training progress.")
208
+ st.page_link("pages/03_Training_Monitor.py", label="Next: Monitor Training", icon="📊")
pages/03_Training_Monitor.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ import json
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ import plotly.graph_objects as go
7
+ from utils.ui import set_page_style, display_metric
8
+ from utils.auth import check_hf_token
9
+ from utils.training import simulate_training_progress
10
+
11
+ # Set page configuration
12
+ st.set_page_config(
13
+ page_title="Training Monitor - Gemma Fine-tuning",
14
+ page_icon="🤖",
15
+ layout="wide"
16
+ )
17
+
18
+ # Apply custom styling
19
+ set_page_style()
20
+
21
+ # Sidebar for authentication
22
+ with st.sidebar:
23
+ st.title("🤖 Gemma Fine-tuning")
24
+ st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
25
+ use_column_width=True)
26
+
27
+ # Authentication section
28
+ st.subheader("🔑 Authentication")
29
+ hf_token = st.text_input("Hugging Face API Token", type="password",
30
+ help="Enter your Hugging Face write token to enable model fine-tuning")
31
+ auth_status = check_hf_token(hf_token) if hf_token else False
32
+
33
+ if auth_status:
34
+ st.success("Authenticated successfully!")
35
+ elif hf_token:
36
+ st.error("Invalid token. Please check and try again.")
37
+
38
+ st.divider()
39
+ st.caption("A simple UI for fine-tuning Gemma models")
40
+
41
+ # Main content
42
+ st.title("📊 Training Monitor")
43
+
44
+ if not hf_token or not auth_status:
45
+ st.warning("Please authenticate with your Hugging Face token in the sidebar first")
46
+ st.stop()
47
+
48
+ # Check if training was started
49
+ if "model_repo" not in st.session_state:
50
+ st.warning("No active training jobs found")
51
+ st.page_link("pages/02_Model_Configuration.py", label="Go to Model Configuration", icon="⚙️")
52
+
53
+ # For testing purposes, allow manual entry
54
+ manual_repo = st.text_input("Or enter model repository name manually:")
55
+ if manual_repo:
56
+ st.session_state["model_repo"] = manual_repo
57
+ st.session_state["model_version"] = "google/gemma-2b" # Default
58
+ else:
59
+ st.stop()
60
+
61
+ # Training information
62
+ st.header("Training Information")
63
+
64
+ col1, col2 = st.columns(2)
65
+
66
+ with col1:
67
+ st.subheader("Model Repository")
68
+ st.info(st.session_state["model_repo"])
69
+
70
+ st.subheader("Base Model")
71
+ st.info(st.session_state.get("model_version", "google/gemma-2b"))
72
+
73
+ with col2:
74
+ # For demo purposes, create a button to start simulated training
75
+ if "training_started" not in st.session_state:
76
+ st.subheader("Start Training")
77
+ if st.button("Launch Training Job", type="primary"):
78
+ st.session_state["training_started"] = True
79
+ st.experimental_rerun()
80
+ else:
81
+ st.subheader("Status")
82
+ st.success("Training in Progress")
83
+
84
+ # Simulate a cancel button
85
+ if st.button("Cancel Training Job", type="secondary"):
86
+ st.warning("This is a simulation - in a real environment, this would cancel the training job")
87
+
88
+ # If training has started, show the progress
89
+ if st.session_state.get("training_started", False):
90
+ st.header("Training Progress")
91
+
92
+ # Create a placeholder for the progress bar
93
+ progress_bar = st.progress(0)
94
+
95
+ # Create placeholder metrics
96
+ col1, col2, col3, col4 = st.columns(4)
97
+
98
+ # Get training progress (simulated for demo)
99
+ progress_data = simulate_training_progress()
100
+
101
+ # Update progress bar
102
+ progress_bar.progress(progress_data["progress"])
103
+
104
+ # Update metrics
105
+ with col1:
106
+ display_metric("Epoch", f"{progress_data['current_epoch'] + 1}/{progress_data['total_epochs']}")
107
+
108
+ with col2:
109
+ display_metric("Loss", f"{progress_data['loss']:.4f}")
110
+
111
+ with col3:
112
+ display_metric("Learning Rate", f"{progress_data['learning_rate']:.1e}")
113
+
114
+ with col4:
115
+ status_text = "Complete" if progress_data["status"] == "completed" else "Running"
116
+ display_metric("Status", status_text)
117
+
118
+ # Create training history visualization
119
+ st.subheader("Training Metrics")
120
+
121
+ # Simulate training history data
122
+ if "training_history" not in st.session_state:
123
+ st.session_state.training_history = []
124
+
125
+ # Add current data point to history if not completed
126
+ if progress_data["status"] != "completed" or len(st.session_state.training_history) == 0:
127
+ st.session_state.training_history.append({
128
+ "epoch": progress_data["current_epoch"],
129
+ "loss": progress_data["loss"],
130
+ "learning_rate": progress_data["learning_rate"],
131
+ "timestamp": time.time()
132
+ })
133
+
134
+ # Convert history to DataFrame
135
+ history_df = pd.DataFrame(st.session_state.training_history)
136
+
137
+ if not history_df.empty and len(history_df) > 1:
138
+ # Create tabs for different visualizations
139
+ loss_tab, lr_tab = st.tabs(["Loss Curve", "Learning Rate"])
140
+
141
+ with loss_tab:
142
+ # Create a Plotly figure for the loss curve
143
+ fig = go.Figure()
144
+ fig.add_trace(go.Scatter(
145
+ x=history_df["epoch"],
146
+ y=history_df["loss"],
147
+ mode='lines+markers',
148
+ name='Training Loss',
149
+ line=dict(color='#FF4B4B', width=3)
150
+ ))
151
+
152
+ fig.update_layout(
153
+ title="Training Loss Over Time",
154
+ xaxis_title="Epoch",
155
+ yaxis_title="Loss",
156
+ template="plotly_white",
157
+ height=400
158
+ )
159
+
160
+ st.plotly_chart(fig, use_container_width=True)
161
+
162
+ with lr_tab:
163
+ # Create a Plotly figure for the learning rate
164
+ fig = go.Figure()
165
+ fig.add_trace(go.Scatter(
166
+ x=history_df["epoch"],
167
+ y=history_df["learning_rate"],
168
+ mode='lines+markers',
169
+ name='Learning Rate',
170
+ line=dict(color='#0068C9', width=3)
171
+ ))
172
+
173
+ fig.update_layout(
174
+ title="Learning Rate Schedule",
175
+ xaxis_title="Epoch",
176
+ yaxis_title="Learning Rate",
177
+ template="plotly_white",
178
+ height=400
179
+ )
180
+
181
+ st.plotly_chart(fig, use_container_width=True)
182
+ else:
183
+ st.info("Training metrics will appear here once training progresses")
184
+
185
+ # Training logs
186
+ st.subheader("Training Logs")
187
+
188
+ # Simulate logs
189
+ log_lines = [
190
+ f"[{pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}] Initialized training job",
191
+ f"[{pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}] Loading model: {st.session_state.get('model_version', 'google/gemma-2b')}",
192
+ f"[{pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}] Preparing LoRA configuration"
193
+ ]
194
+
195
+ # Add epoch logs based on progress
196
+ current_epoch = progress_data["current_epoch"]
197
+ for epoch in range(min(current_epoch + 1, progress_data["total_epochs"])):
198
+ timestamp = pd.Timestamp.now() - pd.Timedelta(seconds=(current_epoch - epoch) * 60)
199
+ log_lines.append(f"[{timestamp.strftime('%Y-%m-%d %H:%M:%S')}] Epoch {epoch+1}/{progress_data['total_epochs']} started")
200
+
201
+ if epoch < current_epoch:
202
+ # For completed epochs, add completion log
203
+ timestamp = pd.Timestamp.now() - pd.Timedelta(seconds=(current_epoch - epoch - 0.5) * 60)
204
+ sim_loss = max(2.5 - (epoch * 0.5), 0.5)
205
+ log_lines.append(f"[{timestamp.strftime('%Y-%m-%d %H:%M:%S')}] Epoch {epoch+1} completed: loss={sim_loss:.4f}")
206
+
207
+ # Display logs in a scrollable area
208
+ st.code("\n".join(log_lines))
209
+
210
+ # Next steps (only show when training is complete)
211
+ if progress_data["status"] == "completed":
212
+ st.success("Training completed successfully!")
213
+ st.page_link("pages/04_Evaluation.py", label="Next: Evaluate Model", icon="🔍")
214
+ else:
215
+ # Auto-refresh for monitoring
216
+ st.empty()
217
+ time.sleep(2) # Wait for 2 seconds
218
+ st.experimental_rerun()
pages/04_Evaluation.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import time
4
+ import json
5
+ from utils.ui import set_page_style, create_card
6
+ from utils.auth import check_hf_token
7
+
8
+ # Set page configuration
9
+ st.set_page_config(
10
+ page_title="Model Evaluation - Gemma Fine-tuning",
11
+ page_icon="🤖",
12
+ layout="wide"
13
+ )
14
+
15
+ # Apply custom styling
16
+ set_page_style()
17
+
18
+ # Sidebar for authentication
19
+ with st.sidebar:
20
+ st.title("🤖 Gemma Fine-tuning")
21
+ st.image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gemma-banner.png",
22
+ use_column_width=True)
23
+
24
+ # Authentication section
25
+ st.subheader("🔑 Authentication")
26
+ hf_token = st.text_input("Hugging Face API Token", type="password",
27
+ help="Enter your Hugging Face write token to enable model fine-tuning")
28
+ auth_status = check_hf_token(hf_token) if hf_token else False
29
+
30
+ if auth_status:
31
+ st.success("Authenticated successfully!")
32
+ elif hf_token:
33
+ st.error("Invalid token. Please check and try again.")
34
+
35
+ st.divider()
36
+ st.caption("A simple UI for fine-tuning Gemma models")
37
+
38
+ # Main content
39
+ st.title("🔍 Model Evaluation")
40
+
41
+ if not hf_token or not auth_status:
42
+ st.warning("Please authenticate with your Hugging Face token in the sidebar first")
43
+ st.stop()
44
+
45
+ # Check if model is trained
46
+ if "model_repo" not in st.session_state:
47
+ st.warning("No trained model found")
48
+ st.page_link("pages/03_Training_Monitor.py", label="Go to Training Monitor", icon="📊")
49
+
50
+ # For testing purposes, allow manual entry
51
+ manual_repo = st.text_input("Or enter model repository name manually:")
52
+ if manual_repo:
53
+ st.session_state["model_repo"] = manual_repo
54
+ st.session_state["model_version"] = "google/gemma-2b" # Default
55
+ else:
56
+ st.stop()
57
+
58
+ # Model information
59
+ st.header("Model Information")
60
+
61
+ col1, col2 = st.columns(2)
62
+
63
+ with col1:
64
+ st.subheader("Fine-tuned Model")
65
+ username = st.session_state.get("hf_username", "user")
66
+ model_id = f"{username}/{st.session_state['model_repo']}"
67
+ st.info(model_id)
68
+
69
+ st.markdown(f"[View on Hugging Face 🔗](https://huggingface.co/{model_id})")
70
+
71
+ with col2:
72
+ st.subheader("Base Model")
73
+ st.info(st.session_state.get("model_version", "google/gemma-2b"))
74
+
75
+ # Interactive testing
76
+ st.header("Interactive Testing")
77
+
78
+ def generate_response(prompt, max_length=100, temperature=0.7):
79
+ """Simulate generating a response from the fine-tuned model"""
80
+ # In a real implementation, this would call the model API
81
+
82
+ # For demo purposes, simulate a response with a delay
83
+ with st.spinner("Generating response..."):
84
+ # Simulate thinking time
85
+ time.sleep(2)
86
+
87
+ # Generate a simple response based on the prompt for demonstration
88
+ if "hello" in prompt.lower() or "hi" in prompt.lower():
89
+ return "Hello! How can I assist you today?"
90
+ elif "name" in prompt.lower():
91
+ return "I'm a fine-tuned version of Gemma, created to assist you!"
92
+ elif "weather" in prompt.lower():
93
+ return "I don't have real-time data access, but I can tell you about weather patterns and climate information if you'd like."
94
+ elif "help" in prompt.lower():
95
+ return "I'm here to help answer questions, provide information, or assist with tasks. Just let me know what you need!"
96
+ elif len(prompt) < 10:
97
+ return "Could you please provide more details so I can give you a better response?"
98
+ else:
99
+ return f"Based on your request about '{prompt[:20]}...', I've analyzed the content and formulated this response. This is a demonstration of how the fine-tuned model would respond to your input, tailored to the data it was trained on."
100
+
101
+ # Create the interface
102
+ with st.expander("Generation Settings", expanded=False):
103
+ col1, col2 = st.columns(2)
104
+
105
+ with col1:
106
+ max_length = st.slider("Maximum Length", min_value=10, max_value=500, value=150,
107
+ help="Maximum number of tokens to generate")
108
+
109
+ with col2:
110
+ temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1,
111
+ help="Controls randomness: 1.0 is creative, 0.1 is more deterministic")
112
+
113
+ prompt = st.text_area("Enter your prompt:", height=150,
114
+ help="Enter a prompt or instruction for the model to respond to")
115
+
116
+ if st.button("Generate Response", type="primary"):
117
+ if prompt:
118
+ response = generate_response(prompt, max_length, temperature)
119
+
120
+ # Store in history
121
+ if "conversation_history" not in st.session_state:
122
+ st.session_state.conversation_history = []
123
+
124
+ st.session_state.conversation_history.append({
125
+ "prompt": prompt,
126
+ "response": response,
127
+ "timestamp": pd.Timestamp.now().isoformat()
128
+ })
129
+
130
+ # Display response
131
+ st.subheader("Model Response")
132
+ st.markdown(f"<div style='background-color: #f8f9fa; padding: 15px; border-radius: 8px;'>{response}</div>", unsafe_allow_html=True)
133
+ else:
134
+ st.warning("Please enter a prompt first")
135
+
136
+ # Display conversation history
137
+ if "conversation_history" in st.session_state and st.session_state.conversation_history:
138
+ st.header("Conversation History")
139
+
140
+ for i, item in enumerate(reversed(st.session_state.conversation_history)):
141
+ with st.container():
142
+ st.markdown(f"### Conversation {len(st.session_state.conversation_history) - i}")
143
+
144
+ st.markdown("**Prompt:**")
145
+ st.markdown(f"<div style='background-color: #e6f7ff; padding: 10px; border-radius: 8px;'>{item['prompt']}</div>", unsafe_allow_html=True)
146
+
147
+ st.markdown("**Response:**")
148
+ st.markdown(f"<div style='background-color: #f8f9fa; padding: 10px; border-radius: 8px;'>{item['response']}</div>", unsafe_allow_html=True)
149
+
150
+ st.caption(f"Generated at: {pd.Timestamp(item['timestamp']).strftime('%Y-%m-%d %H:%M:%S')}")
151
+ st.divider()
152
+
153
+ # Export section
154
+ st.header("Export Model")
155
+
156
+ export_tab1, export_tab2 = st.tabs(["Export Options", "Usage Guide"])
157
+
158
+ with export_tab1:
159
+ st.subheader("Export Configuration")
160
+
161
+ export_format = st.radio("Export Format", options=["Hugging Face Hub", "ONNX", "TensorFlow Lite", "PyTorch"],
162
+ horizontal=True)
163
+
164
+ if export_format == "Hugging Face Hub":
165
+ st.success("Your model is already available on the Hugging Face Hub!")
166
+ username = st.session_state.get("hf_username", "user")
167
+ model_id = f"{username}/{st.session_state['model_repo']}"
168
+ st.code(f"from transformers import AutoModelForCausalLM, AutoTokenizer\n\nmodel = AutoModelForCausalLM.from_pretrained('{model_id}')\ntokenizer = AutoTokenizer.from_pretrained('{model_id}')")
169
+ else:
170
+ st.info("This export format is under development and will be available soon.")
171
+
172
+ # Show a sample of how to export
173
+ if export_format == "ONNX":
174
+ st.code("""
175
+ # Example code for ONNX export (not functional in this demo)
176
+ from transformers import AutoModelForCausalLM, AutoTokenizer
177
+ from transformers.onnx import export
178
+
179
+ model_id = "your_username/your_model_repo"
180
+ model = AutoModelForCausalLM.from_pretrained(model_id)
181
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
182
+
183
+ # Export to ONNX
184
+ export(
185
+ tokenizer=tokenizer,
186
+ model=model,
187
+ output=Path("model.onnx"),
188
+ opset=13
189
+ )
190
+ """)
191
+ elif export_format == "PyTorch":
192
+ st.code("""
193
+ # Example code for PyTorch export (not functional in this demo)
194
+ from transformers import AutoModelForCausalLM, AutoTokenizer
195
+
196
+ model_id = "your_username/your_model_repo"
197
+ model = AutoModelForCausalLM.from_pretrained(model_id)
198
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
199
+
200
+ # Save model locally
201
+ model.save_pretrained("./my_exported_model")
202
+ tokenizer.save_pretrained("./my_exported_model")
203
+ """)
204
+
205
+ with export_tab2:
206
+ st.subheader("How to Use Your Model")
207
+
208
+ st.markdown("""
209
+ ### Using with Hugging Face Transformers
210
+
211
+ ```python
212
+ from transformers import AutoModelForCausalLM, AutoTokenizer
213
+
214
+ # Replace with your model ID
215
+ model_id = "your_username/your_model_repo"
216
+
217
+ # Load model and tokenizer
218
+ model = AutoModelForCausalLM.from_pretrained(model_id)
219
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
220
+
221
+ # Generate text
222
+ inputs = tokenizer("What is machine learning?", return_tensors="pt")
223
+ outputs = model.generate(**inputs, max_length=100)
224
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
225
+ print(response)
226
+ ```
227
+
228
+ ### Deployment Options
229
+
230
+ 1. **Hugging Face Inference API** - Easiest option for quick deployment
231
+ 2. **Gradio or Streamlit** - For creating interactive demos
232
+ 3. **FastAPI or Flask** - For creating backend API services
233
+ 4. **Mobile Deployment** - Use TensorFlow Lite or ONNX formats
234
+ """)
235
+
236
+ # Next steps
237
+ st.divider()
238
+ st.subheader("Continue Your Journey")
239
+
240
+ col1, col2 = st.columns(2)
241
+
242
+ with col1:
243
+ st.page_link("pages/01_Dataset_Upload.py", label="Start New Project", icon="🔄")
244
+
245
+ with col2:
246
+ st.write("Finished? Export your model or try it out in the playground above.")
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.30.0
2
+ huggingface_hub>=0.20.0
3
+ pandas>=2.0.0
4
+ matplotlib>=3.7.0
5
+ transformers>=4.35.0
6
+ peft>=0.7.0
7
+ torch>=2.0.0
8
+ bitsandbytes>=0.40.0
9
+ accelerate>=0.20.0
10
+ datasets>=2.10.0
11
+ tensorboard>=2.13.0
12
+ plotly>=5.15.0
13
+ pillow>=10.0.0
14
+ watchdog>=3.0.0
15
+ scipy>=1.10.0
utils/auth.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from huggingface_hub import HfApi, login
4
+
5
+ def check_hf_token(token):
6
+ """
7
+ Validate Hugging Face token and login if valid
8
+
9
+ Args:
10
+ token (str): Hugging Face API token
11
+
12
+ Returns:
13
+ bool: True if token is valid, False otherwise
14
+ """
15
+ try:
16
+ # Set token in environment and session state
17
+ os.environ["HF_TOKEN"] = token
18
+ st.session_state["hf_token"] = token
19
+
20
+ # Try to log in
21
+ login(token=token, add_to_git_credential=False)
22
+
23
+ # Test API access
24
+ api = HfApi(token=token)
25
+ user_info = api.whoami()
26
+
27
+ # Store username in session state
28
+ st.session_state["hf_username"] = user_info["name"] if "name" in user_info else None
29
+
30
+ return True
31
+ except Exception as e:
32
+ st.session_state["hf_token"] = None
33
+ st.session_state["hf_username"] = None
34
+ print(f"Authentication error: {str(e)}")
35
+ return False
36
+
37
+ def get_current_user():
38
+ """
39
+ Get the currently authenticated user's information
40
+
41
+ Returns:
42
+ dict: User information or None if not authenticated
43
+ """
44
+ if "hf_token" in st.session_state and st.session_state["hf_token"]:
45
+ try:
46
+ api = HfApi(token=st.session_state["hf_token"])
47
+ return api.whoami()
48
+ except:
49
+ return None
50
+ return None
utils/huggingface.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tempfile
4
+ import pandas as pd
5
+ import streamlit as st
6
+ from huggingface_hub import HfApi, upload_file, create_repo
7
+ from transformers import AutoTokenizer
8
+
9
+ def create_dataset_repo(repo_name, private=True):
10
+ """
11
+ Create a new dataset repository on Hugging Face Hub
12
+
13
+ Args:
14
+ repo_name (str): Name of the repository
15
+ private (bool): Whether the repository should be private
16
+
17
+ Returns:
18
+ str: URL of the created repository
19
+ """
20
+ try:
21
+ token = st.session_state.get("hf_token")
22
+ if not token:
23
+ return False, "No Hugging Face token found"
24
+
25
+ username = st.session_state.get("hf_username", "user")
26
+ full_repo_name = f"{username}/{repo_name}"
27
+
28
+ api = HfApi(token=token)
29
+ repo_url = api.create_repo(
30
+ repo_id=full_repo_name,
31
+ repo_type="dataset",
32
+ private=private,
33
+ exist_ok=True
34
+ )
35
+
36
+ return True, repo_url
37
+ except Exception as e:
38
+ return False, str(e)
39
+
40
+ def upload_dataset_to_hub(file_data, file_name, repo_name):
41
+ """
42
+ Upload a dataset file to Hugging Face Hub
43
+
44
+ Args:
45
+ file_data (bytes/DataFrame): File content as bytes or a DataFrame
46
+ file_name (str): Name to save the file as
47
+ repo_name (str): Repository to upload to
48
+
49
+ Returns:
50
+ tuple: (success (bool), message (str))
51
+ """
52
+ try:
53
+ token = st.session_state.get("hf_token")
54
+ if not token:
55
+ return False, "No Hugging Face token found"
56
+
57
+ username = st.session_state.get("hf_username", "user")
58
+ repo_id = f"{username}/{repo_name}"
59
+
60
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file_name)[1]) as tmp:
61
+ # If it's a DataFrame, save as JSONL
62
+ if isinstance(file_data, pd.DataFrame):
63
+ file_data.to_json(tmp.name, orient="records", lines=True)
64
+ else:
65
+ # Otherwise, assume it's bytes
66
+ tmp.write(file_data)
67
+
68
+ # Upload file to repository
69
+ upload_file(
70
+ path_or_fileobj=tmp.name,
71
+ path_in_repo=file_name,
72
+ repo_id=repo_id,
73
+ token=token,
74
+ repo_type="dataset"
75
+ )
76
+
77
+ # Clean up temporary file
78
+ tmp_name = tmp.name
79
+
80
+ os.unlink(tmp_name)
81
+ return True, f"File uploaded to {repo_id}"
82
+ except Exception as e:
83
+ return False, str(e)
84
+
85
+ def prepare_training_config(model_name, hyperparams, dataset_repo, output_repo):
86
+ """
87
+ Prepare a training configuration for Gemma fine-tuning
88
+
89
+ Args:
90
+ model_name (str): Model identifier
91
+ hyperparams (dict): Training hyperparameters
92
+ dataset_repo (str): Dataset repository name
93
+ output_repo (str): Output repository name
94
+
95
+ Returns:
96
+ dict: Training configuration
97
+ """
98
+ username = st.session_state.get("hf_username", "user")
99
+
100
+ config = {
101
+ "model_name_or_path": model_name,
102
+ "dataset_name": f"{username}/{dataset_repo}",
103
+ "output_dir": f"{username}/{output_repo}",
104
+ "num_train_epochs": hyperparams.get("epochs", 3),
105
+ "per_device_train_batch_size": hyperparams.get("batch_size", 8),
106
+ "learning_rate": hyperparams.get("learning_rate", 2e-5),
107
+ "weight_decay": hyperparams.get("weight_decay", 0.01),
108
+ "save_strategy": "epoch",
109
+ "evaluation_strategy": "epoch",
110
+ "fp16": hyperparams.get("fp16", False),
111
+ "peft_config": {
112
+ "r": hyperparams.get("lora_rank", 8),
113
+ "lora_alpha": hyperparams.get("lora_alpha", 32),
114
+ "lora_dropout": hyperparams.get("lora_dropout", 0.05),
115
+ "bias": "none",
116
+ "task_type": "CAUSAL_LM"
117
+ },
118
+ "optim": "adamw_torch",
119
+ "logging_steps": 50,
120
+ "gradient_accumulation_steps": hyperparams.get("gradient_accumulation", 1),
121
+ "max_steps": hyperparams.get("max_steps", -1),
122
+ "warmup_steps": hyperparams.get("warmup_steps", 0),
123
+ "max_grad_norm": hyperparams.get("max_grad_norm", 1.0),
124
+ }
125
+
126
+ return config
127
+
128
+ def preprocess_dataset(df, prompt_column, response_column, model_name="google/gemma-2b"):
129
+ """
130
+ Preprocess a dataset for Gemma fine-tuning
131
+
132
+ Args:
133
+ df (DataFrame): Dataset
134
+ prompt_column (str): Column containing prompts/instructions
135
+ response_column (str): Column containing responses
136
+ model_name (str): Model identifier for tokenizer
137
+
138
+ Returns:
139
+ DataFrame: Processed dataset
140
+ """
141
+ # Check if columns exist
142
+ if prompt_column not in df.columns or response_column not in df.columns:
143
+ raise ValueError(f"Columns {prompt_column} and/or {response_column} not found in dataset")
144
+
145
+ # Simple format for instruction tuning
146
+ df["text"] = df.apply(
147
+ lambda row: f"<start_of_turn>user\n{row[prompt_column]}<end_of_turn>\n<start_of_turn>model\n{row[response_column]}<end_of_turn>",
148
+ axis=1
149
+ )
150
+
151
+ # Return the processed dataset
152
+ return df[["text"]]
utils/training.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import pandas as pd
5
+ import streamlit as st
6
+ from huggingface_hub import HfApi, create_repo, upload_file
7
+ import tempfile
8
+
9
+ def create_model_repo(repo_name, private=True):
10
+ """
11
+ Create a new model repository on Hugging Face Hub
12
+
13
+ Args:
14
+ repo_name (str): Name of the repository
15
+ private (bool): Whether the repository should be private
16
+
17
+ Returns:
18
+ tuple: (success (bool), message (str))
19
+ """
20
+ try:
21
+ token = st.session_state.get("hf_token")
22
+ if not token:
23
+ return False, "No Hugging Face token found"
24
+
25
+ username = st.session_state.get("hf_username", "user")
26
+ full_repo_name = f"{username}/{repo_name}"
27
+
28
+ api = HfApi(token=token)
29
+ repo_url = api.create_repo(
30
+ repo_id=full_repo_name,
31
+ private=private,
32
+ exist_ok=True
33
+ )
34
+
35
+ return True, repo_url
36
+ except Exception as e:
37
+ return False, str(e)
38
+
39
+ def upload_training_config(config, repo_name):
40
+ """
41
+ Upload a training configuration file to Hugging Face Hub
42
+
43
+ Args:
44
+ config (dict): Training configuration
45
+ repo_name (str): Repository to upload to
46
+
47
+ Returns:
48
+ tuple: (success (bool), message (str))
49
+ """
50
+ try:
51
+ token = st.session_state.get("hf_token")
52
+ if not token:
53
+ return False, "No Hugging Face token found"
54
+
55
+ username = st.session_state.get("hf_username", "user")
56
+ repo_id = f"{username}/{repo_name}"
57
+
58
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.json') as tmp:
59
+ with open(tmp.name, 'w') as f:
60
+ json.dump(config, f, indent=2)
61
+
62
+ # Upload file to repository
63
+ upload_file(
64
+ path_or_fileobj=tmp.name,
65
+ path_in_repo="training_config.json",
66
+ repo_id=repo_id,
67
+ token=token
68
+ )
69
+
70
+ # Clean up temporary file
71
+ tmp_name = tmp.name
72
+
73
+ os.unlink(tmp_name)
74
+ return True, f"Training config uploaded to {repo_id}"
75
+ except Exception as e:
76
+ return False, str(e)
77
+
78
+ def setup_training_script(repo_name, config):
79
+ """
80
+ Generate and upload a training script to the repository
81
+
82
+ Args:
83
+ repo_name (str): Repository name
84
+ config (dict): Training configuration
85
+
86
+ Returns:
87
+ tuple: (success (bool), message (str))
88
+ """
89
+ # Create a training script using transformers Trainer with 4-bit quantization for CPU
90
+ script = """
91
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
92
+ from datasets import load_dataset
93
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
94
+ import json
95
+ import os
96
+ import torch
97
+ from huggingface_hub import login
98
+ import bitsandbytes as bnb
99
+
100
+ # Load configuration
101
+ with open("training_config.json", "r") as f:
102
+ config = json.load(f)
103
+
104
+ # Login to Hugging Face
105
+ login(token=os.environ.get("HF_TOKEN"))
106
+
107
+ # Load dataset
108
+ print("Loading dataset:", config["dataset_name"])
109
+ dataset = load_dataset(config["dataset_name"])
110
+
111
+ # Prepare train/validation split if not already split
112
+ if "train" in dataset and "validation" not in dataset:
113
+ dataset = dataset["train"].train_test_split(test_size=0.1)
114
+ elif "train" not in dataset:
115
+ # If dataset has no train split but has text column, use that
116
+ if "text" in dataset:
117
+ dataset = dataset.train_test_split(test_size=0.1)
118
+ else:
119
+ # Try to find what splits are available
120
+ print("Available splits:", list(dataset.keys()))
121
+ # Default to using the first split and splitting it
122
+ first_split = list(dataset.keys())[0]
123
+ dataset = dataset[first_split].train_test_split(test_size=0.1)
124
+
125
+ print("Dataset splits:", list(dataset.keys()))
126
+
127
+ # Print dataset sample
128
+ print("Dataset sample:", dataset["train"][0])
129
+
130
+ # Load tokenizer
131
+ print("Loading tokenizer for model:", config["model_name_or_path"])
132
+ tokenizer = AutoTokenizer.from_pretrained(config["model_name_or_path"])
133
+ if tokenizer.pad_token is None:
134
+ tokenizer.pad_token = tokenizer.eos_token
135
+
136
+ # Load model with 4-bit quantization for CPU efficiency
137
+ print("Loading model with quantization...")
138
+ model = AutoModelForCausalLM.from_pretrained(
139
+ config["model_name_or_path"],
140
+ load_in_4bit=True, # Enable 4-bit quantization
141
+ device_map="auto",
142
+ quantization_config=bnb.nn.modules.Linear4bit.compute_quant_config(),
143
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
144
+ use_cache=False, # Required for gradient checkpointing
145
+ )
146
+
147
+ # Enable gradient checkpointing for memory efficiency
148
+ model.gradient_checkpointing_enable()
149
+
150
+ # Print memory usage before PEFT
151
+ print(f"Model loaded. Memory usage: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
152
+
153
+ # Prepare model for training with LoRA
154
+ print("Setting up LoRA with rank:", config["peft_config"]["r"])
155
+ peft_config = LoraConfig(
156
+ r=config["peft_config"]["r"],
157
+ lora_alpha=config["peft_config"]["lora_alpha"],
158
+ lora_dropout=config["peft_config"]["lora_dropout"],
159
+ bias=config["peft_config"]["bias"],
160
+ task_type=config["peft_config"]["task_type"],
161
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
162
+ )
163
+
164
+ # Prepare model - use 8-bit Adam for memory efficiency
165
+ print("Preparing model for training...")
166
+ model = prepare_model_for_kbit_training(model)
167
+ model = get_peft_model(model, peft_config)
168
+ model.print_trainable_parameters()
169
+
170
+ # Setup training arguments with CPU optimizations
171
+ print("Setting up training arguments...")
172
+ training_args = TrainingArguments(
173
+ output_dir=config["output_dir"],
174
+ num_train_epochs=config["num_train_epochs"],
175
+ per_device_train_batch_size=config["per_device_train_batch_size"],
176
+ per_device_eval_batch_size=max(1, config["per_device_train_batch_size"] // 2),
177
+ learning_rate=config["learning_rate"],
178
+ weight_decay=config["weight_decay"],
179
+ save_strategy=config["save_strategy"],
180
+ evaluation_strategy=config["evaluation_strategy"],
181
+ fp16=config["fp16"] and torch.cuda.is_available(),
182
+ optim=config["optim"],
183
+ logging_steps=config["logging_steps"],
184
+ gradient_accumulation_steps=config["gradient_accumulation_steps"],
185
+ max_steps=config["max_steps"] if config["max_steps"] > 0 else None,
186
+ warmup_steps=config["warmup_steps"],
187
+ max_grad_norm=config["max_grad_norm"],
188
+ push_to_hub=True,
189
+ hub_token=os.environ.get("HF_TOKEN"),
190
+ dataloader_num_workers=0, # Lower CPU usage for smaller machines
191
+ use_cpu=not torch.cuda.is_available(), # Force CPU if no GPU
192
+ lr_scheduler_type="cosine", # Better LR scheduling for small datasets
193
+ report_to=["tensorboard"], # Enable tensorboard logging
194
+ )
195
+
196
+ # Define data collator
197
+ print("Setting up data collator...")
198
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
199
+
200
+ # Tokenize dataset
201
+ print("Tokenizing dataset...")
202
+ def tokenize_function(examples):
203
+ # Use a smaller max length on CPU to save memory
204
+ max_length = 256 if not torch.cuda.is_available() else 512
205
+ return tokenizer(
206
+ examples["text"],
207
+ padding="max_length",
208
+ truncation=True,
209
+ max_length=max_length
210
+ )
211
+
212
+ # Show progress while tokenizing
213
+ print("Mapping tokenization function...")
214
+ tokenized_dataset = dataset.map(
215
+ tokenize_function,
216
+ batched=True,
217
+ batch_size=8, # Smaller batch size for CPU
218
+ remove_columns=dataset["train"].column_names, # Remove original columns after tokenizing
219
+ desc="Tokenizing dataset",
220
+ )
221
+
222
+ # Initialize trainer
223
+ print("Initializing trainer...")
224
+ trainer = Trainer(
225
+ model=model,
226
+ args=training_args,
227
+ train_dataset=tokenized_dataset["train"],
228
+ eval_dataset=tokenized_dataset["validation"],
229
+ data_collator=data_collator,
230
+ )
231
+
232
+ # Start training
233
+ print("Starting training...")
234
+ try:
235
+ trainer.train()
236
+ # Save model and tokenizer
237
+ print("Saving model...")
238
+ trainer.save_model()
239
+ print("Training completed successfully!")
240
+ except Exception as e:
241
+ print(f"Error during training: {str(e)}")
242
+ # Save checkpoint even if error occurred
243
+ try:
244
+ trainer.save_model("./checkpoint-error")
245
+ print("Saved checkpoint before error")
246
+ except:
247
+ print("Could not save checkpoint")
248
+ """
249
+
250
+ try:
251
+ token = st.session_state.get("hf_token")
252
+ if not token:
253
+ return False, "No Hugging Face token found"
254
+
255
+ username = st.session_state.get("hf_username", "user")
256
+ repo_id = f"{username}/{repo_name}"
257
+
258
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.py') as tmp:
259
+ with open(tmp.name, 'w') as f:
260
+ f.write(script)
261
+
262
+ # Upload file to repository
263
+ upload_file(
264
+ path_or_fileobj=tmp.name,
265
+ path_in_repo="train.py",
266
+ repo_id=repo_id,
267
+ token=token
268
+ )
269
+
270
+ # Clean up temporary file
271
+ tmp_name = tmp.name
272
+
273
+ # Also create and upload a CPU-optimized requirements file
274
+ requirements = """
275
+ transformers>=4.35.0
276
+ peft>=0.7.0
277
+ bitsandbytes>=0.40.0
278
+ datasets>=2.10.0
279
+ torch>=2.0.0
280
+ tensorboard>=2.13.0
281
+ accelerate>=0.20.0
282
+ huggingface_hub>=0.15.0
283
+ scipy>=1.10.0
284
+ """
285
+
286
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.txt') as tmp:
287
+ with open(tmp.name, 'w') as f:
288
+ f.write(requirements)
289
+
290
+ # Upload file to repository
291
+ upload_file(
292
+ path_or_fileobj=tmp.name,
293
+ path_in_repo="requirements.txt",
294
+ repo_id=repo_id,
295
+ token=token
296
+ )
297
+
298
+ # Clean up temporary file
299
+ tmp_name = tmp.name
300
+
301
+ os.unlink(tmp_name)
302
+
303
+ return True, f"Training script and requirements uploaded to {repo_id}"
304
+ except Exception as e:
305
+ return False, str(e)
306
+
307
+ def simulate_training_progress():
308
+ """
309
+ Simulate training progress for demonstration purposes
310
+ """
311
+ if "training_progress" not in st.session_state:
312
+ st.session_state.training_progress = {
313
+ "started": time.time(),
314
+ "current_epoch": 0,
315
+ "total_epochs": 3,
316
+ "loss": 2.5,
317
+ "learning_rate": 2e-5,
318
+ "progress": 0.0,
319
+ "status": "running"
320
+ }
321
+
322
+ # Update progress based on elapsed time (simulated)
323
+ elapsed = time.time() - st.session_state.training_progress["started"]
324
+ epoch_duration = 60 # Simulate each epoch taking 60 seconds
325
+
326
+ # Calculate current progress
327
+ total_duration = epoch_duration * st.session_state.training_progress["total_epochs"]
328
+ progress = min(elapsed / total_duration, 1.0)
329
+
330
+ # Calculate current epoch
331
+ current_epoch = min(
332
+ int(progress * st.session_state.training_progress["total_epochs"]),
333
+ st.session_state.training_progress["total_epochs"]
334
+ )
335
+
336
+ # Simulate decreasing loss
337
+ loss = max(2.5 - (progress * 2.0), 0.5)
338
+
339
+ # Update session state
340
+ st.session_state.training_progress.update({
341
+ "progress": progress,
342
+ "current_epoch": current_epoch,
343
+ "loss": loss,
344
+ "status": "completed" if progress >= 1.0 else "running"
345
+ })
346
+
347
+ return st.session_state.training_progress
utils/ui.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def set_page_style():
4
+ """
5
+ Set custom CSS styling for the Streamlit app with light mode
6
+ """
7
+ st.markdown("""
8
+ <style>
9
+ /* Main theme colors - Light Mode */
10
+ :root {
11
+ --primary-color: #3366FF;
12
+ --background-color: #FFFFFF;
13
+ --secondary-background-color: #F5F7F9;
14
+ --text-color: #333333;
15
+ --font: 'Source Sans Pro', sans-serif;
16
+ }
17
+
18
+ /* Overall page background */
19
+ .stApp {
20
+ background-color: var(--background-color);
21
+ color: var(--text-color);
22
+ }
23
+
24
+ /* Headers */
25
+ h1, h2, h3, h4, h5, h6 {
26
+ color: #111111;
27
+ font-weight: 700;
28
+ }
29
+
30
+ /* Card-like elements */
31
+ .stCard {
32
+ border-radius: 8px;
33
+ padding: 20px;
34
+ box-shadow: 0 2px 5px rgba(0,0,0,0.1);
35
+ background-color: white;
36
+ border: 1px solid #EAEAEA;
37
+ margin-bottom: 15px;
38
+ }
39
+
40
+ /* Sidebar */
41
+ section[data-testid="stSidebar"] {
42
+ background-color: #F5F7F9;
43
+ border-right: 1px solid #EAEAEA;
44
+ }
45
+
46
+ /* Buttons */
47
+ .stButton button {
48
+ border-radius: 4px;
49
+ font-weight: 600;
50
+ background-color: var(--primary-color);
51
+ color: white;
52
+ }
53
+
54
+ .stButton button:hover {
55
+ background-color: #2a56d9;
56
+ }
57
+
58
+ /* Banner image */
59
+ .banner-img {
60
+ border-radius: 8px;
61
+ margin-bottom: 15px;
62
+ }
63
+
64
+ /* Metric display */
65
+ .metric-container {
66
+ background-color: #FFFFFF;
67
+ border-radius: 6px;
68
+ padding: 12px;
69
+ text-align: center;
70
+ box-shadow: 0 1px 3px rgba(0,0,0,0.05);
71
+ border: 1px solid #EAEAEA;
72
+ }
73
+
74
+ .metric-value {
75
+ font-size: 20px;
76
+ font-weight: 700;
77
+ color: #3366FF;
78
+ }
79
+
80
+ .metric-label {
81
+ font-size: 14px;
82
+ color: #555555;
83
+ margin-top: 5px;
84
+ }
85
+
86
+ /* Make input fields a bit nicer */
87
+ input, textarea, select {
88
+ border-radius: 4px !important;
89
+ border: 1px solid #CCCCCC !important;
90
+ }
91
+
92
+ /* Improve readability of text elements */
93
+ p, li, label, div {
94
+ color: #333333;
95
+ }
96
+
97
+ /* Info, success, warning boxes */
98
+ .stAlert {
99
+ border-radius: 4px;
100
+ }
101
+
102
+ /* Tabs styling */
103
+ .stTabs [data-baseweb="tab-list"] {
104
+ gap: 8px;
105
+ }
106
+
107
+ .stTabs [data-baseweb="tab"] {
108
+ border-radius: 4px 4px 0 0;
109
+ padding: 8px 16px;
110
+ background-color: #F0F2F6;
111
+ }
112
+
113
+ .stTabs [aria-selected="true"] {
114
+ background-color: white;
115
+ border-top: 2px solid var(--primary-color);
116
+ }
117
+
118
+ /* Code blocks */
119
+ code {
120
+ background-color: #F0F2F6;
121
+ color: #333333;
122
+ padding: 2px 4px;
123
+ border-radius: 3px;
124
+ }
125
+
126
+ /* Progress bar */
127
+ .stProgress > div > div {
128
+ background-color: var(--primary-color);
129
+ }
130
+ </style>
131
+ """, unsafe_allow_html=True)
132
+
133
+ def display_metric(title, value, container=None):
134
+ """
135
+ Display a metric in a nice formatted container
136
+
137
+ Args:
138
+ title (str): Title of the metric
139
+ value (str/int/float): Value to display
140
+ container (streamlit container, optional): Container to display the metric in
141
+ """
142
+ target = container if container else st
143
+ target.markdown(f"""
144
+ <div class="metric-container">
145
+ <div class="metric-value">{value}</div>
146
+ <div class="metric-label">{title}</div>
147
+ </div>
148
+ """, unsafe_allow_html=True)
149
+
150
+ def create_card(content, container=None):
151
+ """
152
+ Create a card-like container for content
153
+
154
+ Args:
155
+ content (callable): Function to call to render content inside the card
156
+ container (streamlit container, optional): Container to place the card in
157
+ """
158
+ target = container if container else st
159
+
160
+ with target.container():
161
+ target.markdown('<div class="stCard">', unsafe_allow_html=True)
162
+ content()
163
+ target.markdown('</div>', unsafe_allow_html=True)