Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Dict, List | |
| import numpy as np | |
| from datasets import load_dataset | |
| import evaluate | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| DataCollatorWithPadding, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| # ====================== | |
| # LABEL SCHEMA | |
| # ====================== | |
| LABELS: List[str] = [ | |
| "pre-1900", | |
| "1900-1945", | |
| "1946-1979", | |
| "1980-1999", | |
| "2000-2015", | |
| "2016-present", | |
| ] | |
| id2label: Dict[int, str] = {i: l for i, l in enumerate(LABELS)} | |
| label2id: Dict[str, int] = {l: i for i, l in enumerate(LABELS)} | |
| # Base model to fine-tune | |
| BASE_MODEL = os.environ.get("BASE_MODEL", "distilroberta-base") | |
| # Hugging Face hub repo where the fine-tuned model will be pushed | |
| HUB_MODEL_ID = "DelaliScratchwerk/time-period-classifier-bert" | |
| # ====================== | |
| # LOAD DATA | |
| # ====================== | |
| # Expect CSVs at data/train.csv and data/val.csv | |
| dataset = load_dataset( | |
| "csv", | |
| data_files={ | |
| "train": "data/train.csv", | |
| "validation": "data/val.csv", | |
| }, | |
| ) | |
| print("Raw dataset:", dataset) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| def encode_batch(batch): | |
| # tokenize texts | |
| enc = tokenizer(batch["text"], truncation=True) | |
| # map string labels -> integer ids | |
| # strip helps if there are trailing spaces in the CSV | |
| enc["labels"] = [label2id[l.strip()] for l in batch["label"]] | |
| return enc | |
| # IMPORTANT: remove original 'text' and 'label' columns so Trainer only sees tensors | |
| encoded = dataset.map( | |
| encode_batch, | |
| batched=True, | |
| remove_columns=dataset["train"].column_names, | |
| ) | |
| print(encoded) | |
| print("Encoded train sample keys:", encoded["train"][0].keys()) | |
| # should be: dict_keys(['input_ids', 'attention_mask', 'labels']) | |
| # ====================== | |
| # MODEL | |
| # ====================== | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| BASE_MODEL, | |
| num_labels=len(LABELS), | |
| id2label=id2label, | |
| label2id=label2id, | |
| ) | |
| # ====================== | |
| # METRICS | |
| # ====================== | |
| accuracy = evaluate.load("accuracy") | |
| f1_macro = evaluate.load("f1") | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| return { | |
| "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"], | |
| "f1_macro": f1_macro.compute( | |
| predictions=preds, references=labels, average="macro" | |
| )["f1"], | |
| } | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| # ====================== | |
| # TRAINING ARGS | |
| # ====================== | |
| training_args = TrainingArguments( | |
| output_dir="out", | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=32, | |
| learning_rate=5e-5, | |
| num_train_epochs=10, | |
| eval_strategy="epoch", | |
| save_strategy="no", | |
| load_best_model_at_end=False, | |
| logging_steps=50, | |
| push_to_hub=True, | |
| hub_model_id=HUB_MODEL_ID, | |
| hub_private_repo=False, | |
| ) | |
| # ====================== | |
| # TRAINER | |
| # ====================== | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=encoded["train"], | |
| eval_dataset=encoded["validation"], | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| if __name__ == "__main__": | |
| trainer.train() | |
| # push best model + tokenizer to the Hub | |
| trainer.push_to_hub() | |
| tokenizer.push_to_hub(HUB_MODEL_ID) | |