test_discut / pipeline.py
poyum's picture
app
4a3df21
from transformers import Pipeline, AutoModelForTokenClassification
import numpy as np
from eval import retrieve_predictions, align_tokens_labels_from_wordids
from reading import read_dataset
from utils import read_config
def write_sentences_to_format(sentences: list[str], filename: str):
"""
Écrit une phrase dans un fichier, un mot par ligne, avec le format :
index<TAB>mot<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>Seg=...
"""
if not sentences:
return ""
if isinstance(sentences, str):
sentences=[sentences]
import sys
sys.stderr.write("Warning: only one sentence provided as a string instead of a list of sentences.\n")
full="# newdoc_id = GUM_academic_discrimination\n"
for sentence in sentences:
words = sentence.strip().split()
for i, word in enumerate(words, start=1):
# Le premier mot → B-seg, sinon O
seg_label = "B-seg" if i == 1 or word[0].isupper() else "O"
line = f"{i}\t{word}\t_\t_\t_\t_\t_\t_\t_\tSeg={seg_label}\n"
full+=line
if filename:
with open(filename, "w", encoding="utf-8") as f:
f.write(full)
return full
class DiscoursePipeline(Pipeline):
def __init__(self, model_id, tokenizer, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs):
auto_model = AutoModelForTokenClassification.from_pretrained(model_id)
super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs)
self.config = {"model_checkpoint": model_id, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{
"padding":"max_length",
"truncation":True,
"max_length": 512
}}
self.model = model_id
self.output_folder = output_folder
def _sanitize_parameters(self, **kwargs):
# Permet de passer des paramètres optionnels comme add_lang_token etc.
preprocess_params = {}
forward_params = {}
postprocess_params = {}
return preprocess_params, forward_params, postprocess_params
def preprocess(self, text:str):
self.original_text=text
formatted_text=write_sentences_to_format(text.split("\n"), filename=None)
dataset, _ = read_dataset(
formatted_text,
output_path=self.output_folder,
config=self.config,
add_lang_token=True,
add_frame_token=True,
)
return {"dataset": dataset}
def _forward(self, inputs):
dataset = inputs["dataset"]
preds_from_model, label_ids, _ = retrieve_predictions(
self.model, dataset, self.output_folder, self.tokenizer, self.config
)
return {"preds": preds_from_model, "labels": label_ids, "dataset": dataset}
def postprocess(self, outputs):
preds = np.argmax(outputs["preds"], axis=-1)
predictions = align_tokens_labels_from_wordids(preds, outputs["dataset"], self.tokenizer)
edus=text_to_edus(self.original_text, predictions)
return edus
def get_plain_text_from_format(formatted_text:str) -> str:
"""
Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères.
"""
formatted_text=formatted_text.split("\n")
s=""
for line in formatted_text:
if not line.startswith("#"):
if len(line.split("\t"))>1:
s+=line.split("\t")[1]+" "
return s.strip()
def get_preds_from_format(formatted_text:str) -> str:
"""
Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères.
"""
formatted_text=formatted_text.split("\n")
s=""
for line in formatted_text:
if not line.startswith("#"):
if len(line.split("\t"))>1:
s+=line.split("\t")[-1]+" "
return s.strip()
def text_to_edus(text: str, labels: list[str]) -> list[str]:
"""
Découpe un texte brut en EDUs à partir d'une séquence de labels BIO.
Args:
text (str): Le texte brut (séquence de mots séparés par des espaces).
labels (list[str]): La séquence de labels BIO (B, I, O),
de même longueur que le nombre de tokens du texte.
Returns:
list[str]: La liste des EDUs (chaque EDU est une sous-chaîne du texte).
"""
words = text.strip().split()
if len(words) != len(labels):
raise ValueError(f"Longueur mismatch: {len(words)} mots vs {len(labels)} labels")
edus = []
current_edu = []
for word, label in zip(words, labels):
if label == "Conn=O" or label == "Seg=O":
current_edu.append(word)
elif label == "Conn=B-conn" or label == "Seg=B-seg":
# Finir l'EDU courant si ouvert
if current_edu:
edus.append(" ".join(current_edu))
current_edu = []
current_edu.append(word)
# Si un EDU est resté ouvert, on le ferme
if current_edu:
edus.append(" ".join(current_edu))
return edus