Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import nltk | |
| nltk.download('stopwords') | |
| nltk.download('punkt') | |
| from nltk.corpus import stopwords | |
| from nltk.tokenize import word_tokenize | |
| from nltk.util import ngrams | |
| import spacy | |
| # from gensim.summarization.summarizer import summarize | |
| # from gensim.summarization import keywords | |
| # Abstractive Summarisation | |
| from transformers import BartForConditionalGeneration | |
| from transformers import AutoTokenizer | |
| import torch | |
| # Keyword/Keyphrase Extraction | |
| from keybert import _highlight | |
| from keybert import KeyBERT | |
| from keyphrase_vectorizers import KeyphraseCountVectorizer, KeyphraseTfidfVectorizer | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| import time | |
| import threading | |
| from collections import defaultdict | |
| class AbstractiveSummarizer: | |
| def __init__(self): | |
| self.nlp = spacy.load('en_core_web_lg') | |
| self.summary = "" | |
| def generate_batch(self, text, tokenizer): | |
| """ | |
| Convert the text into multiple sentence parts of appropriate input size to feed to the model | |
| Arguments: | |
| text: The License text to summarise | |
| tokenizer: The tokenizer corresponding to the model used to convert the text into separate words(tokens) | |
| Returns: | |
| The text formatted into List of sentences to feed to the model | |
| """ | |
| parsed = self.nlp(text) | |
| sents = [sent.text for sent in parsed.sents] | |
| max_size = tokenizer.model_max_length | |
| batch = tokenizer(sents, return_tensors='pt', return_length=True, padding='longest') | |
| inp_batch = [] | |
| cur_batch = torch.empty((0,), dtype=torch.int64) | |
| for enc_sent, length in zip(batch['input_ids'], batch['length']): | |
| cur_size = cur_batch.shape[0] | |
| if (cur_size + length.item()) <= max_size: | |
| cur_batch = torch.cat((cur_batch,enc_sent[:length.item()])) | |
| else: | |
| inp_batch.append(torch.unsqueeze(cur_batch,0)) | |
| cur_batch = enc_sent[:length.item()] | |
| inp_batch.append(torch.unsqueeze(cur_batch,0)) | |
| return inp_batch | |
| def summarize(self, src, tokenizer, model): | |
| """ | |
| Function to use the pre-trained model to generate the summary | |
| Arguments: | |
| src: License text to summarise | |
| tokenizer: The tokenizer corresponding to the model used to convert the text into separate words(tokens) | |
| model: The pre-trained Model object used to perform the summarization | |
| Returns: | |
| summary: The summarised texts | |
| """ | |
| batch_texts = self.generate_batch(src, tokenizer) | |
| enc_summary_list = [model.generate(batch, max_length=512) for batch in batch_texts] | |
| summary_list = [tokenizer.batch_decode(enc_summ, skip_special_tokens=True) for enc_summ in enc_summary_list] | |
| # orig_list = [tokenizer.batch_decode(batch, skip_special_tokens=True) for batch in batch_texts] | |
| summary_texts = [summ[0] for summ in summary_list] | |
| summary = " ".join(summary_texts) | |
| self.summary = summary | |
| def bart(self, src): | |
| """ | |
| Initialize the facebook BART pre-trained model and call necessary functions to summarize | |
| Arguments: | |
| src: The text to summarise | |
| Returns/Set as instance variable: | |
| The summarized text | |
| """ | |
| start_time = time.time() | |
| model_name = 'facebook/bart-large-cnn' | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = BartForConditionalGeneration.from_pretrained(model_name).to(device) | |
| self.summarize(src, tokenizer, model) | |
| def get_summary(lic_txt): | |
| """ | |
| Summarize the license and return it | |
| Arguments: | |
| spdx - Id of License to summarise | |
| Returns: | |
| pos_text: The part of the License containing information for permitted use | |
| neg_text: The part of the License containing information about usage restrictions | |
| lic_txt: The full license text | |
| summary - The generated summary of the license | |
| """ | |
| print('Summarising...') | |
| absSum = AbstractiveSummarizer() | |
| # Generate summary | |
| thread = absSum.bart(lic_txt) | |
| return thread | |
| def extract_ngrams(phrase): | |
| phrase = re.sub('[^a-zA-Z0-9]',' ', phrase) | |
| tokens = word_tokenize(phrase) | |
| res = [] | |
| for num in range(len(tokens)+1): | |
| temp = ngrams(tokens, num) | |
| res += [' '.join(grams) for grams in temp] | |
| return res | |
| def get_highlight_text(text, keywords): | |
| """ | |
| Custom function to find exact position of keywords for highlighting | |
| """ | |
| text = re.sub('[-/]',' ', text) | |
| # text = re.sub('(\n *){2,}','\n',text) | |
| text = re.sub(' {2,}', ' ', text) | |
| # Group keywords by length | |
| kw_len = defaultdict(list) | |
| for kw in keywords: | |
| kw_len[len(kw)].append(kw) | |
| # Use sliding window technique to check equal strings | |
| spans = [] | |
| for length in kw_len: | |
| w_start, w_end = 0, length | |
| while w_end <= len(text): | |
| for kw in kw_len[length]: | |
| j = w_start | |
| eq = True | |
| for i in range(len(kw)): | |
| if text[j] != kw[i]: | |
| eq = False | |
| break | |
| j += 1 | |
| if eq: | |
| spans.append([w_start, w_end]) | |
| break | |
| w_start += 1 | |
| w_end += 1 | |
| if not spans: | |
| return text | |
| # merge spans | |
| spans.sort(key=lambda x: x[0]) | |
| merged = [] | |
| st, end = spans[0][0], spans[0][1] | |
| for i in range(1, len(spans)): | |
| s,e = spans[i] | |
| if st <= s <= end: | |
| end = max(e, end) | |
| else: | |
| merged.append([st, end]) | |
| st, end = s,e | |
| merged.append([st,end]) | |
| res = [] | |
| sub_start = 0 | |
| for s,e in merged: | |
| res.append(text[sub_start:s]) | |
| res.append((text[s:e], "", "#f66")) | |
| sub_start = e | |
| res.append(text[sub_start:]) | |
| return res | |
| def get_keywords(datatype, task, field, pos_text, neg_text): | |
| """ | |
| Summarize the license and generate the good and bad use tags | |
| Arguments: | |
| datafield - Type of 'data' used under the license: Eg. Model, Data, Model Derivatives, Source Code | |
| task - The type of task the model is designed to do | |
| field - Which 'field' to use the data in: Eg. Medical, Commercial, Non-Commercial, Research | |
| pos_text: The part of the License containing information for permitted use | |
| neg_text: The part of the License containing information about usage restrictions | |
| Returns: | |
| p_keywords - List of Positive(Permitted use) keywords extracted from summary | |
| n_keywords - List of Negative(Restriction) keywords extracted from summary | |
| contrd - boolean flag to show if there is any contradiction or not | |
| hl_text - the license text formatted to display in a highlighted manner | |
| """ | |
| print('Extracting keywords...') | |
| #[e.lower() for e in list_strings] | |
| datatype, task, field = datatype.lower(), task.lower(), field.lower() | |
| #datatype = [e.lower() for e in datatype] | |
| #task = [e.lower() for e in task] | |
| #field = [e.lower() for e in field] | |
| #datatype, task, field = datatype, task, str(field) | |
| stop_words = set(stopwords.words('english')) | |
| #stops = nltk.corpus.stopwords.words('english') | |
| #stop_words = set(stops) | |
| stop_words = stop_words.union({'license', 'licensing', 'licensor', 'copyright', 'copyrights', 'patent'}) | |
| pos_kw_model = KeyBERT() | |
| neg_kw_model = KeyBERT() | |
| candidates = [] | |
| for term in [datatype, task, field]: | |
| candidates += extract_ngrams(term) | |
| p_kw = pos_kw_model.extract_keywords(docs=pos_text, top_n=40, vectorizer=KeyphraseCountVectorizer(stop_words=stop_words))#, pos_pattern='<N.*>+')) | |
| n_kw = neg_kw_model.extract_keywords(docs=neg_text, top_n=40, vectorizer=KeyphraseCountVectorizer(stop_words=stop_words))#, pos_pattern='<N.*>+')) | |
| ngram_max = max([len(word_tokenize(x)) for x in [datatype, task, field]]) | |
| pc_kw = pos_kw_model.extract_keywords(docs=pos_text ,candidates=candidates, keyphrase_ngram_range=(1,ngram_max)) | |
| nc_kw = neg_kw_model.extract_keywords(docs=neg_text ,candidates=candidates, keyphrase_ngram_range=(1,ngram_max)) | |
| # Check contradiction | |
| all_cont = [kw for (kw,_) in nc_kw] | |
| cont_terms = set(all_cont).intersection(set(extract_ngrams(field))) | |
| contrd = True if len(cont_terms) > 0 else False | |
| hl_text = "" if not contrd else get_highlight_text(neg_text, all_cont) | |
| p_kw += pc_kw | |
| n_kw += nc_kw | |
| p_kw.sort(key=lambda x: x[1], reverse=True) | |
| n_kw.sort(key=lambda x: x[1], reverse=True) | |
| p_keywords = [kw for (kw,score) in p_kw if score < 0.5] | |
| n_keywords = [kw for (kw,score) in n_kw if score < 0.5] | |
| return p_keywords, n_keywords, contrd, hl_text |