File size: 10,108 Bytes
1ec7405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
"""
Visualize training metrics from MLflow runs.

Generates plots showing:
- Loss curves (training/validation)
- Task-specific metrics over time
- Learning rate schedule
- Training speed analysis

Author: Oliver Perrin
Date: December 2025
"""

from __future__ import annotations

import json
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import mlflow
import mlflow.tracking
import seaborn as sns

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.utils.logging import configure_logging, get_logger

configure_logging()
logger = get_logger(__name__)

# Configure plotting style
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (12, 8)
plt.rcParams["figure.dpi"] = 100

OUTPUTS_DIR = PROJECT_ROOT / "outputs"
MLRUNS_DIR = PROJECT_ROOT / "mlruns"


def load_training_history() -> dict[str, object] | None:
    """Load training history from JSON if available."""
    history_path = OUTPUTS_DIR / "training_history.json"
    if history_path.exists():
        with open(history_path) as f:
            data: dict[str, object] = json.load(f)
            return data
    return None


def get_latest_run():
    """Get the most recent MLflow run."""
    mlflow.set_tracking_uri(f"file://{MLRUNS_DIR}")
    client = mlflow.tracking.MlflowClient()

    # Get the experiment (LexiMind)
    experiment = client.get_experiment_by_name("LexiMind")
    if not experiment:
        logger.error("No 'LexiMind' experiment found")
        return None

    # Get all runs, sorted by start time
    runs = client.search_runs(
        experiment_ids=[experiment.experiment_id],
        order_by=["start_time DESC"],
        max_results=1,
    )

    if not runs:
        logger.error("No runs found in experiment")
        return None

    return runs[0]


def plot_loss_curves(run):
    """Plot training and validation loss over time."""
    client = mlflow.tracking.MlflowClient()

    # Get metrics
    train_loss = client.get_metric_history(run.info.run_id, "train_total_loss")
    val_loss = client.get_metric_history(run.info.run_id, "val_total_loss")

    fig, ax = plt.subplots(figsize=(12, 6))

    if not train_loss:
        # Create placeholder plot
        ax.text(
            0.5,
            0.5,
            "No training data yet\n\nWaiting for first epoch to complete...",
            ha="center",
            va="center",
            fontsize=14,
            color="gray",
        )
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
    else:
        # Extract steps and values
        train_steps = [m.step for m in train_loss]
        train_values = [m.value for m in train_loss]

        ax.plot(train_steps, train_values, label="Training Loss", linewidth=2, alpha=0.8)

        if val_loss:
            val_steps = [m.step for m in val_loss]
            val_values = [m.value for m in val_loss]
            ax.plot(val_steps, val_values, label="Validation Loss", linewidth=2, alpha=0.8)

        ax.legend(fontsize=11)

    ax.set_xlabel("Epoch", fontsize=12)
    ax.set_ylabel("Loss", fontsize=12)
    ax.set_title("Training Progress: Total Loss", fontsize=14, fontweight="bold")
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    output_path = OUTPUTS_DIR / "training_loss_curve.png"
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    logger.info(f"✓ Saved loss curve to {output_path}")
    plt.close()


def plot_task_metrics(run):
    """Plot metrics for each task."""
    client = mlflow.tracking.MlflowClient()

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle("Task-Specific Training Metrics", fontsize=16, fontweight="bold")

    # Summarization
    ax = axes[0, 0]
    train_sum = client.get_metric_history(run.info.run_id, "train_summarization_loss")
    val_sum = client.get_metric_history(run.info.run_id, "val_summarization_loss")

    if train_sum:
        ax.plot(
            [m.step for m in train_sum], [m.value for m in train_sum], label="Train", linewidth=2
        )
    if val_sum:
        ax.plot([m.step for m in val_sum], [m.value for m in val_sum], label="Val", linewidth=2)
    ax.set_title("Summarization Loss", fontweight="bold")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Emotion
    ax = axes[0, 1]
    train_emo = client.get_metric_history(run.info.run_id, "train_emotion_loss")
    val_emo = client.get_metric_history(run.info.run_id, "val_emotion_loss")
    train_f1 = client.get_metric_history(run.info.run_id, "train_emotion_f1")
    val_f1 = client.get_metric_history(run.info.run_id, "val_emotion_f1")

    if train_emo:
        ax.plot(
            [m.step for m in train_emo],
            [m.value for m in train_emo],
            label="Train Loss",
            linewidth=2,
        )
    if val_emo:
        ax.plot(
            [m.step for m in val_emo], [m.value for m in val_emo], label="Val Loss", linewidth=2
        )

    ax2 = ax.twinx()
    if train_f1:
        ax2.plot(
            [m.step for m in train_f1],
            [m.value for m in train_f1],
            label="Train F1",
            linewidth=2,
            linestyle="--",
            alpha=0.7,
        )
    if val_f1:
        ax2.plot(
            [m.step for m in val_f1],
            [m.value for m in val_f1],
            label="Val F1",
            linewidth=2,
            linestyle="--",
            alpha=0.7,
        )

    ax.set_title("Emotion Detection", fontweight="bold")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax2.set_ylabel("F1 Score")
    ax.legend(loc="upper left")
    ax2.legend(loc="upper right")
    ax.grid(True, alpha=0.3)

    # Topic
    ax = axes[1, 0]
    train_topic = client.get_metric_history(run.info.run_id, "train_topic_loss")
    val_topic = client.get_metric_history(run.info.run_id, "val_topic_loss")
    train_acc = client.get_metric_history(run.info.run_id, "train_topic_accuracy")
    val_acc = client.get_metric_history(run.info.run_id, "val_topic_accuracy")

    if train_topic:
        ax.plot(
            [m.step for m in train_topic],
            [m.value for m in train_topic],
            label="Train Loss",
            linewidth=2,
        )
    if val_topic:
        ax.plot(
            [m.step for m in val_topic], [m.value for m in val_topic], label="Val Loss", linewidth=2
        )

    ax2 = ax.twinx()
    if train_acc:
        ax2.plot(
            [m.step for m in train_acc],
            [m.value for m in train_acc],
            label="Train Acc",
            linewidth=2,
            linestyle="--",
            alpha=0.7,
        )
    if val_acc:
        ax2.plot(
            [m.step for m in val_acc],
            [m.value for m in val_acc],
            label="Val Acc",
            linewidth=2,
            linestyle="--",
            alpha=0.7,
        )

    ax.set_title("Topic Classification", fontweight="bold")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax2.set_ylabel("Accuracy")
    ax.legend(loc="upper left")
    ax2.legend(loc="upper right")
    ax.grid(True, alpha=0.3)

    # Summary statistics
    ax = axes[1, 1]
    ax.axis("off")

    # Get final metrics
    summary_text = "Final Metrics (Last Epoch)\n" + "=" * 35 + "\n\n"

    if val_topic and val_acc:
        summary_text += f"Topic Accuracy: {val_acc[-1].value:.1%}\n"
    if val_emo and val_f1:
        summary_text += f"Emotion F1: {val_f1[-1].value:.1%}\n"
    if val_sum:
        summary_text += f"Summarization Loss: {val_sum[-1].value:.3f}\n"

    ax.text(0.1, 0.5, summary_text, fontsize=12, family="monospace", verticalalignment="center")

    plt.tight_layout()
    output_path = OUTPUTS_DIR / "task_metrics.png"
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    logger.info(f"✓ Saved task metrics to {output_path}")
    plt.close()


def plot_learning_rate(run):
    """Plot learning rate schedule if available."""
    client = mlflow.tracking.MlflowClient()
    lr_metrics = client.get_metric_history(run.info.run_id, "learning_rate")

    fig, ax = plt.subplots(figsize=(12, 5))

    if not lr_metrics:
        # Create placeholder
        ax.text(
            0.5,
            0.5,
            "No learning rate data yet\n\n(Will be logged in future training runs)",
            ha="center",
            va="center",
            fontsize=14,
            color="gray",
        )
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
    else:
        steps = [m.step for m in lr_metrics]
        values = [m.value for m in lr_metrics]

        ax.plot(steps, values, linewidth=2, color="darkblue")

        # Mark warmup region
        warmup_steps = 1000  # From config
        if warmup_steps < max(steps):
            ax.axvline(warmup_steps, color="red", linestyle="--", alpha=0.5, label="Warmup End")
            ax.legend()

    ax.set_xlabel("Step", fontsize=12)
    ax.set_ylabel("Learning Rate", fontsize=12)
    ax.set_title("Learning Rate Schedule (Cosine with Warmup)", fontsize=14, fontweight="bold")
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    output_path = OUTPUTS_DIR / "learning_rate_schedule.png"
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    logger.info(f"✓ Saved LR schedule to {output_path}")
    plt.close()


def main():
    """Generate all training visualizations."""
    logger.info("Loading MLflow data...")

    run = get_latest_run()
    if not run:
        logger.error("No training run found. Make sure training has started.")
        return

    logger.info(f"Analyzing run: {run.info.run_id}")

    OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)

    logger.info("Generating visualizations...")

    plot_loss_curves(run)
    plot_task_metrics(run)
    plot_learning_rate(run)

    logger.info("\n" + "=" * 60)
    logger.info("✓ All visualizations saved to outputs/")
    logger.info("=" * 60)
    logger.info("  - training_loss_curve.png")
    logger.info("  - task_metrics.png")
    logger.info("  - learning_rate_schedule.png")
    logger.info("=" * 60)


if __name__ == "__main__":
    main()