File size: 17,522 Bytes
f709e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
#!/usr/bin/python
# -*- coding: utf-8 -*-

import os, sys

import datasets
import transformers 

import disrpt_io
import utils

# TODO to rm when dealt with this issue of loading languages
##from ersatz import utils
##LANGUAGES = utils.MODELS.keys()
LANGUAGES = []

def read_dataset( input_path, output_path, config, add_lang_token=True,add_frame_token=True,lang_token="",frame_token="" ):
	'''
	- Read the file in input_path
	- Return a Dataset corresponding to the file

	Parameters
	----------
	input_path : str
		Path to the dataset
	output_path : str
		Path to an output directory that can be used to write new split files
	tokenizer : AutoTokenizer
		Tokenizer corresponding the checkpoint model
	add_lang_token : bool
		If True, add a special language token at the beginning of each sequence

	Returns
	-------
	Dataset
		Contain Dataset built from train_path and dev_path for train mode, 
			only dev / test pasth else
	Tokenizer
		The tokenizer used for the dataset
	'''
	model_checkpoint = config["model_checkpoint"]
	# -- Init tokenizer
	tokenizer = transformers.AutoTokenizer.from_pretrained( model_checkpoint )
	# -- Read and tokenize
	dataset = DatasetSeq( input_path, output_path, config, tokenizer, add_lang_token=add_lang_token,add_frame_token=add_frame_token,lang_token=lang_token,frame_token=frame_token )
	dataset.read_and_tokenize()
	# TODO move in class? or do elsewhere 
	LABEL_NAMES_BIO = retrieve_bio_labels( dataset ) # TODO should do it only once for all
	dataset.set_label_names_bio(LABEL_NAMES_BIO)
	return dataset, tokenizer

# --------------------------------------------------------------------------
# DatasetDict

class DatasetDisc( ):
	def __init__(self, annotations_file, output_path, config, tokenizer, dset=None ):
		""" 
		Here we save the location of our input file,
		load the data, i.e. retrieve the list of texts and associated labels,
		build the vocabulary if none is given,
		and define the pipelines used to prepare the data 
		"""
		self.annotations_file = annotations_file
		if isinstance(annotations_file, str) and not os.path.isfile(annotations_file):
			print("this is a string dataset")
			self.basename = "input"
		else:
			self.basename = os.path.basename( self.annotations_file )
			self.dset = self.basename.split(".")[2].split('_')[1]
			self.corpus_name = self.basename.split('_')[0]

		self.tokenizer = tokenizer
		self.config = config 
		# If a sentence splitter is used, the files with the new segmentation will be saved here
		self.output_path = output_path

		# Retriev info from config: TODO check against info from dir name?
		self.mode = config["type"]
		self.task = config["task"]
		self.trace = config["trace"]
		self.tok_config = config["tok_config"]
		self.sent_spliter = config["sent_spliter"]

		# Additional fields
		self.id2label, self.label2id = {}, {} 

		# -- Use disrpt_io to read the file and retrieve annotated data
		self.corpus = init_corpus( self.task ) # initialize a Corpus instance, depending on the task





	def read_and_tokenize( self ):
		print("\n-- READ FROM FILE:", self.annotations_file )
		try:
			self.read_annotations( )
		except Exception as err:
			print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
			raise
		# 	print( "Problem when reading", self.annotations_file )
		
		#print("\n-- SET LABELS")
		self.set_labels( )
		print( "self.label2id", self.label2id )
		
		#print("\n-- TOKENIZE DATASET")
		self.tokenize_dataset()
		if self.trace:
			if self.dset:
				print( "\n-- FINISHED READING", self.dset, "PRINTING TRACE --")
				self.print_trace()

	def tokenize_datasets( self ):
		# Specific to subclasses
		raise NotImplementedError
	
	def set_labels( self ):
		# Specific to subclasses 
		raise NotImplementedError
			
	# outside the class?
	# TODO use **kwags instead?
	def read_annotations( self ):
		'''
		Generate a Corpus object based on the input_file.
		Since .tok files are not segmented into sentences, a sentence splitter
		is used (here, ersatz)
		'''
		if os.path.isfile(self.annotations_file):
			self.corpus.from_file(self.annotations_file)
			lang = os.path.basename(self.annotations_file).split(".")[0]
			frame = os.path.basename(self.annotations_file).split(".")[1]
			base = os.path.basename(self.annotations_file)
		else:
			# on suppose que c’est du texte brut déjà au format attendu
			src = self.mode if self.mode in ["tok", "conllu", "split"] else "conllu"
			self.corpus.from_string(self.annotations_file,src=src)
			lang = self.lang_token 
			frame = self.frame_token 
			base = "input.text"

		
		
		#print(f"[DEBUG] lang? {lang}")
		for doc in self.corpus.docs:
			doc.lang = lang
			doc.frame = frame
		# print(corpus)
		# Split corpus into sentences using Ersatz
		if self.mode == 'tok':
			kwargs={}
			from wtpsplit import SaT
			sat_version="sat-3l"
			if "sat_model" in self.config:
				sat_version=self.config["sat_model"]
			
			sat_model = SaT(sat_version)
			kwargs["sat_model"] = sat_model
			self.corpus.sentence_split(model = self.sent_spliter, lang="default-multilingual",sat_model=sat_model)
			# Writing files with the split sentences 
			parts = base.split(".")[:-1]
			split_filename = ".".join(parts) + ".split"
			split_file = os.path.join(self.output_path, split_filename)
			self.corpus.format(file=split_file)
		# no need for sentence splitting if mode = conllu or split, no need to write files

	def print_trace( self ):
		print( "\n| Annotation_file: ", self.annotations_file )		
		print( '| Output_path:', self.output_path )
		print( '| Nb_of_instances:', len(self.dataset), "(", len(self.dataset['labels']), ")" )
	 	#		"(", len(self.dataset['tokens']), len(self.dataset['labels']), ")" )

	def print_stats( self ):
		print( "| Annotation_file: ", self.annotations_file )
		if self.dset: print( "| Data_split: ", self.dset )
		print( "| Task: ", self.task )
		print( "| Lang: ", self.lang )
		print( "| Mode: ", self.mode )
		print( "| Label_names: ", self.LABEL_NAMES)
		#print( "---Number_of_documents", len( self.corpus.docs ) )
		print( "| Number_of_instances: ", len(self.dataset) )
		# TODO : add number of docs: not computed for .rels for now 

# -------------------------------------------------------------------------------------------------
class DatasetSeq(DatasetDisc):
	def __init__( self, annotations_file, output_path, config, tokenizer, add_lang_token=True, add_frame_token=True,
			dset=None,lang_token="",frame_token="" ):
		""" 
		Class for tasks corresponding to sequence labeling problem 
			(seg, conn).
		Here we save the location of our input file,
		load the data, i.e. retrieve the list of texts and associated 
			labels,
		build the vocabulary if none is given,
		and define the pipelines used to prepare the data """
		DatasetDisc.__init__( self, annotations_file, output_path, config, 
			tokenizer )
		self.add_lang_token = add_lang_token
		self.add_frame_token=add_frame_token
		self.lang_token = lang_token
		self.frame_token=frame_token
		
		if self.mode == 'tok' and self.output_path == None:
			self.output_path = os.path.dirname( self.annotations_file )
			self.output_path = os.path.join( self.output_path, 
				self.basename.replace("."+self.mode, ".split") )

		self.sent_spliter = None
		if "sent_spliter" in self.config:
			self.sent_spliter = self.config["sent_spliter"] #only for seg 
		
		self.LABEL_NAMES_BIO = None
		# # TODO not used, really a good idea?
		# self.data_collator = transformers.DataCollatorForTokenClassification(tokenizer=self.tokenizer, 
		# 	padding=self.tok_config["padding"] )
		
	def tokenize_dataset( self ):	
		# -- Create a HuggingFace Dataset object
		if self.trace:
			print(f"\n-- Creating dataset from generator (add_lang_token={self.add_lang_token})")
		self.dataset = datasets.Dataset.from_generator(
			gen,
			gen_kwargs={"corpus": self.corpus, "label2id": self.label2id, "mode": self.mode, "add_lang_token": self.add_lang_token,"add_frame_token":self.add_frame_token},
		)
		if self.trace:
			print( self.dataset[0])
		# Keep track of the alignement between words ans subtokens, even if not ## 
		# BERT* add a tokenisation based on punctuation even if given with a list of words
		self.all_word_ids = []
		# Align labels according to tokenized subwords
		if self.trace:
			print( "\n-- Mapping dataset labels and subwords ")
		self.tokenized_datasets = self.dataset.map(
			tokenize_and_align_labels,
			fn_kwargs = {"tokenizer":self.tokenizer, 
							"id2label":self.id2label, 
							"label2id":self.label2id, 
							"all_word_ids":self.all_word_ids,
							"config":self.config},
			batched=True,
			remove_columns=self.dataset.column_names,
			)
		if self.trace:
			print( self.tokenized_datasets[0])
		
		
	def set_labels(self):
		self.LABEL_NAMES = self.corpus.LABELS
		self.id2label = {i: label for i, label in enumerate( self.LABEL_NAMES )}
		self.label2id = {v: k for k,v in self.id2label.items()}

	def set_label_names_bio( self, LABEL_NAMES_BIO ):
		self.LABEL_NAMES_BIO = LABEL_NAMES_BIO


	def print_trace( self ):
		super().print_trace()
		print( '\n--First sentence: original tokens and labels.\n')
		print( self.dataset[0]['tokens'] )
		print( self.dataset[0]['labels'] )
		print( "\n---First sentence: tokenized version:\n")
		print( self.tokenized_datasets[0] )	
		# print( '\nSource word ids:', len(self.all_word_ids) )

	# # TODO prepaper a compute_stats before printing, to allow partial printing without trace mode 
	# def print_stats( self ):
	# 	super().print_stats()
	# 	print( "| Number_of_documents", len( self.corpus.docs ) )
		

def init_corpus( task ):
	if task.strip().lower() == 'conn':
		return disrpt_io.ConnectiveCorpus()
	elif task == 'seg':
		return disrpt_io.SegmentCorpus()
	else:
		raise NotImplementedError 

def gen( corpus, label2id, mode, add_lang_token=True,add_frame_token=True ):
	# Ajout d'un token spécial langue au début de chaque séquence
	source = "split"
	if mode == 'conllu':
		source = "conllu"
	for doc in corpus.docs:
		lang = getattr(doc, 'lang', 'xx') if hasattr(doc, 'lang') else 'xx'
		lang_token = f"[LANG={lang}]"
		
		frame = getattr(doc, 'frame', 'xx') if hasattr(doc, 'lang') else 'xx'
		frame_token = f"[FRAME={frame}]"
		sent_list = doc.sentences[source] if source in doc.sentences else doc.sentences
		for sentence in sent_list:
			labels = []
			tokens = []
			if add_lang_token:
				tokens.append(lang_token)
				labels.append(-100)
			if add_frame_token:
				tokens.append(frame_token)
				labels.append(-100)
				#print(f"[DEBUG] Ajout du token frame {frame_token} pour la phrase: {' '.join([t.form for t in sentence.toks])}")
			for t in sentence.toks:
				tokens.append(t.form)
				if t.label == '_':
					if 'O' in label2id:
						labels.append(label2id['O'])
					else:
						labels.append(list(label2id.values())[0])
				else:
					labels.append(label2id[t.label])
			yield {
				"tokens": tokens,
				"labels": labels
			}


def get_tokenizer( model_checkpoint ):
	return transformers.AutoTokenizer.from_pretrained(model_checkpoint)

def tokenize_and_align_labels( dataset, tokenizer, id2label, label2id, all_word_ids, config ):
	'''
	(Done in batches)
	To preprocess our whole dataset, we need to tokenize all the inputs and
	apply align_labels_with_tokens() on all the labels.
	(with HG, we could use Dataset.map to process batches)
	The word_ids() function needs to get the index of the example we want
	the word IDs of when the inputs to the tokenizer are lists of texts
	(or in our case, list of lists of words), so we add that too:
	"tok_config"
	'''
	tokenized_inputs = tokenizer(
		dataset["tokens"], 
		truncation=config["tok_config"]['truncation'], 
		padding=config["tok_config"]['padding'],
		max_length=config["tok_config"]['max_length'],
		is_split_into_words=True
	)
	# tokenized_inputs = tokenizer(
	# 	dataset["tokens"], truncation=True, padding=True, is_split_into_words=True
	# )
	all_labels = dataset["labels"]
	new_labels = []
	#print( "tokenized_inputs.word_ids()", tokenized_inputs.word_ids() )
	#print( [tokenizer.decode(tok) for tok in tokenized_inputs['input_ids']])
	##with progressbar.ProgressBar(max_value=len(all_labels)) as bar:
	##for i in tqdm(range(len(all_labels))):
	for i, labels in enumerate(all_labels):
		word_ids = tokenized_inputs.word_ids(i)
		new_labels.append(align_labels_with_tokens(labels, word_ids, id2label, label2id, tokenizer, tokenized_inputs ))
		# Used to fill the self.word_ids field of the Dataset object, but should probably be done some<here else
		all_word_ids.append( word_ids )
		##bar.update(i)
	tokenized_inputs["labels"] = new_labels
	return tokenized_inputs

def align_labels_with_tokens(labels, word_ids, id2label, label2id, tokenizer, tokenized_inputs):
	'''
	BERT like tokenization will create new tokens, we need to align labels.
	Special tokens get a label of -100. This is because by default -100 is an
	index that is ignored in the loss function we will use (cross entropy).
	Then, each token gets the same label as the token that started the word
	it’s inside, since they are part of the same entity. For tokens inside a
	word but not at the beginning, we replace the B- with I- (since the token
	does not begin the entity). [Taken from HF website course on NER]
	'''
	count = 0
	new_labels = []
	current_word = None
	for word_id in word_ids:
		count += 1
		if word_id==0: # ou 1 peut etre
			#TODO
			#add lang token -100
			pass
		if word_id != current_word:
			# Start of a new word!
			current_word = word_id
			label = -100 if word_id is None else labels[word_id]
			new_labels.append(label)
		elif word_id is None:
			# Special token
			new_labels.append(-100)
		else:
			# Same word as previous token
			label = labels[word_id]
			# On ne cherche 'B-' que si label != -100
			if label != -100 and 'B-' in id2label[label]:
				label = -100
			new_labels.append(label)
	return new_labels


def retrieve_bio_labels( dataset ):
	'''
	Needed for compute_metrics, I think? It seems to be using a classic metrics for BIO
	scheme, thus we create a mapping to BIO labels, i.e.:
	'_' --> 'O'
	'Seg=B-Conn' --> 'B'
	'Seg=I-Conn' --> 'I'
	Should also work for segmentation TODO: check
	datasets: dict: DatasetSeq instances for train/dev/test
	Return: list: original label names
			list: label names mapped to BIO
	'''
	# need a Dataset instance to retrieve the original label sets
	task = dataset.task
	LABEL_NAMES_BIO = []
	LABEL_NAMES = dataset.LABEL_NAMES
	label2idx, idx2newl = {}, {}
	if task in ["conn", "seg"]:
		for i,l in enumerate( LABEL_NAMES ):
			label2idx[l] = i
	for l in label2idx:
		nl = ''
		if 'B' in l:
			nl = 'B'
		elif 'I' in l:
			nl = 'I'
		else:
			nl = 'O'
		idx2newl[label2idx[l]] = nl
	for i in sorted(idx2newl):
		LABEL_NAMES_BIO.append(idx2newl[i])
		#label_names = ['O', 'B', 'I']
	return LABEL_NAMES_BIO

# def _compute_distrib( dataset, id2label ):
# 	distrib = {}
# 	multi = []
# 	for inst in dataset: 
# 		label = id2label[inst['label']]
# 		if label in distrib:
# 			distrib[label] += 1
# 		else:
# 			distrib[label] = 1
# 		len_labels = len( inst["all_labels"])
# 		if len_labels > 1:
# 			#count_multi += 1
# 			multi.append( len_labels )
# 	return distrib, multi

# Defines the language code for the sentence spliter, should be done in disrpt_io?
def set_language( lang ):
	#lang = "default-multilingual" #default value
	# patch
	if lang=="sp": lang="es"
	if lang not in LANGUAGES:
		lang = "default-multilingual"
	return lang


# ------------------------------------------------------------------
if __name__=="__main__":
	import argparse, os

	parser = argparse.ArgumentParser(
		description='DISCUT: reading data from disrpt_io and converting to HuggingFace'
	)
	# TRAIN AND DEV are (list of) FILES or DIRECTORIES
	parser.add_argument("-t", "--train", 
		help="Training file. Default: data_test/eng.sample.rstdt/eng.sample.rstdt_train.conllu",
		default="data_test/eng.sample.rstdt/eng.sample.rstdt_train.conllu")
	
	parser.add_argument("-d", "--dev", 
		help="Dev file. Default: data/eng.sample.rstdt/eng.sample.rstdt_dev.conllu",
		default="data_test/eng.sample.rstdt/eng.sample.rstdt_dev.conllu")
	
	# OUTPUT DIRECTORY
	parser.add_argument("-o", "--output",
		help="Directory where models and pred will be saved. Default: /home/cbraud/experiments/expe_discut_2025/",
		default="")

	# CONFIG FILE
	parser.add_argument("-c", "--config",
		help="Config file. Default: ./config_seg.json",
		default="./config_seg.json")

	# TRACE / VERBOSITY
	parser.add_argument( '-v', '--trace', 
		action='store_true',
		default=False,
		help="Whether to print full messages. If used, it will override the value in config file.")

	args = parser.parse_args()

	train_path = args.train
	dev_path = args.dev
	print(dev_path)
	if not os.path.isfile(dev_path[0]):
		print( "ERROR with dev file:", dev_path)
	output_path = args.output
	config_file = args.config
	#eval = args.eval
	trace = args.trace

	print( '\n-[JEDIS]--PROGRAM (reader) ARGUMENTS')
	print( '| Train_path', train_path )
	print( '| Dev_path', dev_path )
	print( "| Output_path", output_path )
	print( '| Config', config_file )
	
	print( '\n-[JEDIS]--CONFIG INFO')
	config = utils.read_config( config_file )
	utils.print_config(config)
	# WE override the config file if the user says no trace in arguments
	# easier than modifying the config files each time
	if not trace:
		config['trace'] = False 

	print( "\n-[JEDIS]--READING DATASETS" )
	# dictionnary containing train (if model=='train') and/or dev (test) Dataset instance
	datasets, tokenizer = read_dataset( train_path, dev_path, config, add_lang_token=True )