Spaces:
Sleeping
Sleeping
File size: 5,115 Bytes
f709e5e 4a3df21 f709e5e 4a3df21 f709e5e 4a3df21 f709e5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from transformers import Pipeline, AutoModelForTokenClassification
import numpy as np
from eval import retrieve_predictions, align_tokens_labels_from_wordids
from reading import read_dataset
from utils import read_config
def write_sentences_to_format(sentences: list[str], filename: str):
"""
Écrit une phrase dans un fichier, un mot par ligne, avec le format :
index<TAB>mot<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>Seg=...
"""
if not sentences:
return ""
if isinstance(sentences, str):
sentences=[sentences]
import sys
sys.stderr.write("Warning: only one sentence provided as a string instead of a list of sentences.\n")
full="# newdoc_id = GUM_academic_discrimination\n"
for sentence in sentences:
words = sentence.strip().split()
for i, word in enumerate(words, start=1):
# Le premier mot → B-seg, sinon O
seg_label = "B-seg" if i == 1 or word[0].isupper() else "O"
line = f"{i}\t{word}\t_\t_\t_\t_\t_\t_\t_\tSeg={seg_label}\n"
full+=line
if filename:
with open(filename, "w", encoding="utf-8") as f:
f.write(full)
return full
class DiscoursePipeline(Pipeline):
def __init__(self, model_id, tokenizer, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs):
auto_model = AutoModelForTokenClassification.from_pretrained(model_id)
super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs)
self.config = {"model_checkpoint": model_id, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{
"padding":"max_length",
"truncation":True,
"max_length": 512
}}
self.model = model_id
self.output_folder = output_folder
def _sanitize_parameters(self, **kwargs):
# Permet de passer des paramètres optionnels comme add_lang_token etc.
preprocess_params = {}
forward_params = {}
postprocess_params = {}
return preprocess_params, forward_params, postprocess_params
def preprocess(self, text:str):
self.original_text=text
formatted_text=write_sentences_to_format(text.split("\n"), filename=None)
dataset, _ = read_dataset(
formatted_text,
output_path=self.output_folder,
config=self.config,
add_lang_token=True,
add_frame_token=True,
)
return {"dataset": dataset}
def _forward(self, inputs):
dataset = inputs["dataset"]
preds_from_model, label_ids, _ = retrieve_predictions(
self.model, dataset, self.output_folder, self.tokenizer, self.config
)
return {"preds": preds_from_model, "labels": label_ids, "dataset": dataset}
def postprocess(self, outputs):
preds = np.argmax(outputs["preds"], axis=-1)
predictions = align_tokens_labels_from_wordids(preds, outputs["dataset"], self.tokenizer)
edus=text_to_edus(self.original_text, predictions)
return edus
def get_plain_text_from_format(formatted_text:str) -> str:
"""
Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères.
"""
formatted_text=formatted_text.split("\n")
s=""
for line in formatted_text:
if not line.startswith("#"):
if len(line.split("\t"))>1:
s+=line.split("\t")[1]+" "
return s.strip()
def get_preds_from_format(formatted_text:str) -> str:
"""
Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères.
"""
formatted_text=formatted_text.split("\n")
s=""
for line in formatted_text:
if not line.startswith("#"):
if len(line.split("\t"))>1:
s+=line.split("\t")[-1]+" "
return s.strip()
def text_to_edus(text: str, labels: list[str]) -> list[str]:
"""
Découpe un texte brut en EDUs à partir d'une séquence de labels BIO.
Args:
text (str): Le texte brut (séquence de mots séparés par des espaces).
labels (list[str]): La séquence de labels BIO (B, I, O),
de même longueur que le nombre de tokens du texte.
Returns:
list[str]: La liste des EDUs (chaque EDU est une sous-chaîne du texte).
"""
words = text.strip().split()
if len(words) != len(labels):
raise ValueError(f"Longueur mismatch: {len(words)} mots vs {len(labels)} labels")
edus = []
current_edu = []
for word, label in zip(words, labels):
if label == "Conn=O" or label == "Seg=O":
current_edu.append(word)
elif label == "Conn=B-conn" or label == "Seg=B-seg":
# Finir l'EDU courant si ouvert
if current_edu:
edus.append(" ".join(current_edu))
current_edu = []
current_edu.append(word)
# Si un EDU est resté ouvert, on le ferme
if current_edu:
edus.append(" ".join(current_edu))
return edus
|