Spaces:
Sleeping
Sleeping
| from transformers import Pipeline, AutoModelForTokenClassification | |
| import numpy as np | |
| from eval import retrieve_predictions, align_tokens_labels_from_wordids | |
| from reading import read_dataset | |
| from utils import read_config | |
| def write_sentences_to_format(sentences: list[str], filename: str): | |
| """ | |
| Écrit une phrase dans un fichier, un mot par ligne, avec le format : | |
| index<TAB>mot<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>Seg=... | |
| """ | |
| if not sentences: | |
| return "" | |
| if isinstance(sentences, str): | |
| sentences=[sentences] | |
| import sys | |
| sys.stderr.write("Warning: only one sentence provided as a string instead of a list of sentences.\n") | |
| full="# newdoc_id = GUM_academic_discrimination\n" | |
| for sentence in sentences: | |
| words = sentence.strip().split() | |
| for i, word in enumerate(words, start=1): | |
| # Le premier mot → B-seg, sinon O | |
| seg_label = "B-seg" if i == 1 or word[0].isupper() else "O" | |
| line = f"{i}\t{word}\t_\t_\t_\t_\t_\t_\t_\tSeg={seg_label}\n" | |
| full+=line | |
| if filename: | |
| with open(filename, "w", encoding="utf-8") as f: | |
| f.write(full) | |
| return full | |
| class DiscoursePipeline(Pipeline): | |
| def __init__(self, model_id, tokenizer, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs): | |
| auto_model = AutoModelForTokenClassification.from_pretrained(model_id) | |
| super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs) | |
| self.config = {"model_checkpoint": model_id, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{ | |
| "padding":"max_length", | |
| "truncation":True, | |
| "max_length": 512 | |
| }} | |
| self.model = model_id | |
| self.output_folder = output_folder | |
| def _sanitize_parameters(self, **kwargs): | |
| # Permet de passer des paramètres optionnels comme add_lang_token etc. | |
| preprocess_params = {} | |
| forward_params = {} | |
| postprocess_params = {} | |
| return preprocess_params, forward_params, postprocess_params | |
| def preprocess(self, text:str): | |
| self.original_text=text | |
| formatted_text=write_sentences_to_format(text.split("\n"), filename=None) | |
| dataset, _ = read_dataset( | |
| formatted_text, | |
| output_path=self.output_folder, | |
| config=self.config, | |
| add_lang_token=True, | |
| add_frame_token=True, | |
| ) | |
| return {"dataset": dataset} | |
| def _forward(self, inputs): | |
| dataset = inputs["dataset"] | |
| preds_from_model, label_ids, _ = retrieve_predictions( | |
| self.model, dataset, self.output_folder, self.tokenizer, self.config | |
| ) | |
| return {"preds": preds_from_model, "labels": label_ids, "dataset": dataset} | |
| def postprocess(self, outputs): | |
| preds = np.argmax(outputs["preds"], axis=-1) | |
| predictions = align_tokens_labels_from_wordids(preds, outputs["dataset"], self.tokenizer) | |
| edus=text_to_edus(self.original_text, predictions) | |
| return edus | |
| def get_plain_text_from_format(formatted_text:str) -> str: | |
| """ | |
| Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères. | |
| """ | |
| formatted_text=formatted_text.split("\n") | |
| s="" | |
| for line in formatted_text: | |
| if not line.startswith("#"): | |
| if len(line.split("\t"))>1: | |
| s+=line.split("\t")[1]+" " | |
| return s.strip() | |
| def get_preds_from_format(formatted_text:str) -> str: | |
| """ | |
| Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères. | |
| """ | |
| formatted_text=formatted_text.split("\n") | |
| s="" | |
| for line in formatted_text: | |
| if not line.startswith("#"): | |
| if len(line.split("\t"))>1: | |
| s+=line.split("\t")[-1]+" " | |
| return s.strip() | |
| def text_to_edus(text: str, labels: list[str]) -> list[str]: | |
| """ | |
| Découpe un texte brut en EDUs à partir d'une séquence de labels BIO. | |
| Args: | |
| text (str): Le texte brut (séquence de mots séparés par des espaces). | |
| labels (list[str]): La séquence de labels BIO (B, I, O), | |
| de même longueur que le nombre de tokens du texte. | |
| Returns: | |
| list[str]: La liste des EDUs (chaque EDU est une sous-chaîne du texte). | |
| """ | |
| words = text.strip().split() | |
| if len(words) != len(labels): | |
| raise ValueError(f"Longueur mismatch: {len(words)} mots vs {len(labels)} labels") | |
| edus = [] | |
| current_edu = [] | |
| for word, label in zip(words, labels): | |
| if label == "Conn=O" or label == "Seg=O": | |
| current_edu.append(word) | |
| elif label == "Conn=B-conn" or label == "Seg=B-seg": | |
| # Finir l'EDU courant si ouvert | |
| if current_edu: | |
| edus.append(" ".join(current_edu)) | |
| current_edu = [] | |
| current_edu.append(word) | |
| # Si un EDU est resté ouvert, on le ferme | |
| if current_edu: | |
| edus.append(" ".join(current_edu)) | |
| return edus | |