Amelia-James commited on
Commit
7ae1348
·
verified ·
1 Parent(s): b80e1f8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MBartForConditionalGeneration, MBartTokenizer, MarianMTModel, MarianTokenizer
2
+ import streamlit as st
3
+
4
+ # Load multilingual summarization model and tokenizer
5
+ multilingual_summarization_model = MBartForConditionalGeneration.from_pretrained('facebook/mbart-large-50')
6
+ multilingual_summarization_tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-50')
7
+
8
+ # Dictionary of languages and their corresponding Hugging Face model codes
9
+ LANGUAGES = {
10
+ "English": "en_XX",
11
+ "French": "fr_XX",
12
+ "Spanish": "es_XX",
13
+ "German": "de_DE",
14
+ "Chinese": "zh_CN",
15
+ "Russian": "ru_RU",
16
+ "Arabic": "ar_AR",
17
+ "Portuguese": "pt_PT",
18
+ "Hindi": "hi_IN",
19
+ "Italian": "it_IT",
20
+ "Japanese": "ja_XX",
21
+ "Korean": "ko_KR",
22
+ "Dutch": "nl_NL",
23
+ "Polish": "pl_PL",
24
+ "Turkish": "tr_TR",
25
+ "Swedish": "sv_SE",
26
+ "Greek": "el_EL",
27
+ "Finnish": "fi_FI",
28
+ "Hungarian": "hu_HU",
29
+ "Danish": "da_DK",
30
+ "Norwegian": "no_NO",
31
+ "Czech": "cs_CZ",
32
+ "Romanian": "ro_RO",
33
+ "Thai": "th_TH",
34
+ "Hebrew": "he_IL",
35
+ "Vietnamese": "vi_VN",
36
+ "Indonesian": "id_ID",
37
+ "Malay": "ms_MY",
38
+ "Bengali": "bn_BD",
39
+ "Ukrainian": "uk_UA",
40
+ "Urdu": "ur_PK",
41
+ "Swahili": "sw_KE",
42
+ "Serbian": "sr_SR",
43
+ "Croatian": "hr_HR",
44
+ "Slovak": "sk_SK",
45
+ "Lithuanian": "lt_LT",
46
+ "Latvian": "lv_LV",
47
+ "Estonian": "et_EE",
48
+ "Bulgarian": "bg_BG",
49
+ "Macedonian": "mk_MK",
50
+ "Albanian": "sq_AL",
51
+ "Georgian": "ka_GE",
52
+ "Armenian": "hy_AM",
53
+ "Kazakh": "kk_KZ",
54
+ "Uzbek": "uz_UZ",
55
+ "Tajik": "tg_TJ",
56
+ "Kyrgyz": "ky_KG",
57
+ "Turkmen": "tk_TM"
58
+ }
59
+
60
+ # Function to get the appropriate translation model and tokenizer
61
+ def get_translation_model(source_lang, target_lang):
62
+ model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
63
+ model = MarianMTModel.from_pretrained(model_name)
64
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
65
+ return model, tokenizer
66
+
67
+ # Function to translate text
68
+ def translate_text(text, source_lang, target_lang):
69
+ model, tokenizer = get_translation_model(source_lang, target_lang)
70
+ inputs = tokenizer([text], return_tensors="pt", truncation=True)
71
+ translated_ids = model.generate(inputs['input_ids'], max_length=1024)
72
+ translated_text = tokenizer.decode(translated_ids[0], skip_special_tokens=True)
73
+ return translated_text
74
+
75
+ # Summarization function with multi-language support
76
+ def summarize_text(text, source_language="English", target_language="English"):
77
+ source_lang_code = LANGUAGES[source_language]
78
+ target_lang_code = LANGUAGES[target_language]
79
+
80
+ # If the input language is not English, translate to English
81
+ if source_lang_code != "en_XX":
82
+ text = translate_text(text, source_lang_code, "en_XX")
83
+
84
+ # Summarize the text using mBART
85
+ inputs = multilingual_summarization_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
86
+ summary_ids = multilingual_summarization_model.generate(inputs['input_ids'], num_beams=4, max_length=200, early_stopping=True)
87
+ summary = multilingual_summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
88
+
89
+ # Translate summary to the target language if needed
90
+ if target_lang_code != "en_XX":
91
+ summary = translate_text(summary, "en_XX", target_lang_code)
92
+
93
+ return summary
94
+
95
+ # Streamlit interface
96
+ st.title("Multi-Language Text Summarization Tool")
97
+
98
+ text = st.text_area("Input Text")
99
+ source_language = st.selectbox("Source Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English"))
100
+ target_language = st.selectbox("Target Language", options=list(LANGUAGES.keys()), index=list(LANGUAGES.keys()).index("English"))
101
+
102
+ if st.button("Summarize"):
103
+ if text:
104
+ summary = summarize_text(text, source_language, target_language)
105
+ st.subheader("Summary")
106
+ st.write(summary)
107
+ else:
108
+ st.warning("Please enter text to summarize.")