File size: 5,115 Bytes
f709e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a3df21
 
f709e5e
4a3df21
f709e5e
 
 
 
4a3df21
f709e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from transformers import Pipeline, AutoModelForTokenClassification
import numpy as np
from eval import retrieve_predictions, align_tokens_labels_from_wordids
from reading import read_dataset
from utils import read_config



def write_sentences_to_format(sentences: list[str], filename: str):
    """
    Écrit une phrase dans un fichier, un mot par ligne, avec le format :
    index<TAB>mot<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>_<TAB>Seg=...
    """

    if not sentences:
        return ""
    if isinstance(sentences, str):
        sentences=[sentences]
        import sys
        sys.stderr.write("Warning: only one sentence provided as a string instead of a list of sentences.\n")
    
    full="# newdoc_id = GUM_academic_discrimination\n"
    for sentence in sentences:
        words = sentence.strip().split()
        for i, word in enumerate(words, start=1):
            # Le premier mot → B-seg, sinon O
            seg_label = "B-seg" if i == 1 or word[0].isupper() else "O"
            line = f"{i}\t{word}\t_\t_\t_\t_\t_\t_\t_\tSeg={seg_label}\n"
            full+=line
    if filename:
        with open(filename, "w", encoding="utf-8") as f:
            f.write(full)
        
    return full


class DiscoursePipeline(Pipeline):
    def __init__(self, model_id, tokenizer, output_folder="./pipe_out",sat_model:str="sat-3l", **kwargs):
        auto_model = AutoModelForTokenClassification.from_pretrained(model_id)
        super().__init__(model=auto_model, tokenizer=tokenizer, **kwargs)
        self.config = {"model_checkpoint": model_id, "sent_spliter":"sat","task":"seg","type":"tok","trace":False,"report_to":"none","sat_model":sat_model,"tok_config":{
        "padding":"max_length",
        "truncation":True,
        "max_length": 512
    }}
        self.model = model_id
        self.output_folder = output_folder

    def _sanitize_parameters(self, **kwargs):
        # Permet de passer des paramètres optionnels comme add_lang_token etc.
        preprocess_params = {}
        forward_params = {}
        postprocess_params = {}
        return preprocess_params, forward_params, postprocess_params

    def preprocess(self, text:str):
        self.original_text=text
        formatted_text=write_sentences_to_format(text.split("\n"), filename=None)
        dataset, _ = read_dataset(
            formatted_text,
            output_path=self.output_folder,
            config=self.config,
            add_lang_token=True,
            add_frame_token=True,
        )
        return {"dataset": dataset}

    def _forward(self, inputs):
        dataset = inputs["dataset"]
        preds_from_model, label_ids, _ = retrieve_predictions(
            self.model, dataset, self.output_folder, self.tokenizer, self.config
        )
        return {"preds": preds_from_model, "labels": label_ids, "dataset": dataset}

    def postprocess(self, outputs):
        preds = np.argmax(outputs["preds"], axis=-1)
        predictions = align_tokens_labels_from_wordids(preds, outputs["dataset"], self.tokenizer)
        edus=text_to_edus(self.original_text, predictions)
        return edus

def get_plain_text_from_format(formatted_text:str) -> str:
    """
    Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères.
    """
    formatted_text=formatted_text.split("\n")
    s=""
    for line in formatted_text:
        if not line.startswith("#"):
            if len(line.split("\t"))>1:
                s+=line.split("\t")[1]+" "
    return s.strip()


def get_preds_from_format(formatted_text:str) -> str:
    """
    Lit un fichier conllu ou tok et retourne son contenu sous forme de chaîne de caractères.
    """
    formatted_text=formatted_text.split("\n")
    s=""
    for line in formatted_text:
        if not line.startswith("#"):
            if len(line.split("\t"))>1:
                s+=line.split("\t")[-1]+" "
    return s.strip()


def text_to_edus(text: str, labels: list[str]) -> list[str]:
    """
    Découpe un texte brut en EDUs à partir d'une séquence de labels BIO.

    Args:
        text (str): Le texte brut (séquence de mots séparés par des espaces).
        labels (list[str]): La séquence de labels BIO (B, I, O),
                            de même longueur que le nombre de tokens du texte.

    Returns:
        list[str]: La liste des EDUs (chaque EDU est une sous-chaîne du texte).
    """
    words = text.strip().split()
    if len(words) != len(labels):
        raise ValueError(f"Longueur mismatch: {len(words)} mots vs {len(labels)} labels")

    edus = []
    current_edu = []

    for word, label in zip(words, labels):
        if label == "Conn=O" or label == "Seg=O":
            current_edu.append(word)

        elif label == "Conn=B-conn" or label == "Seg=B-seg":
            # Finir l'EDU courant si ouvert
            if current_edu:
                
                edus.append(" ".join(current_edu))
                current_edu = []
            current_edu.append(word)

    # Si un EDU est resté ouvert, on le ferme
    if current_edu:
        edus.append(" ".join(current_edu))

    return edus