albertchristopher commited on
Commit
6b9b532
·
verified ·
1 Parent(s): 0ed32d5

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +97 -78
src/streamlit_app.py CHANGED
@@ -1,78 +1,97 @@
1
- # streamlit_app.py
2
- return out
3
- except Exception as e:
4
- st.error(f"HF Inference API error: {e}")
5
- return None
6
-
7
-
8
-
9
-
10
- if run:
11
- if not text.strip():
12
- st.warning("Please paste some text to summarize.")
13
- st.stop()
14
-
15
-
16
- if engine.startswith("HF Inference API"):
17
- if not hf_token.strip():
18
- st.error("Please provide an HF_TOKEN to use the Inference API fallback.")
19
- st.stop()
20
- with st.spinner("Calling HF Inference API…"):
21
- summary = summarize_via_hf_api(text, hf_token)
22
- if summary:
23
- st.success("Done!")
24
- st.markdown("### Summary")
25
- st.write(summary)
26
- st.stop()
27
-
28
-
29
- info_box = st.empty()
30
- info_box.info("Loading BitNet model. On CPU this can take several minutes on first run; subsequent runs are cached.")
31
-
32
-
33
- @st.cache_resource(show_spinner=False)
34
- def _load():
35
- return load_bitnet_model()
36
-
37
-
38
- tok, model = _load()
39
- info_box.empty()
40
-
41
-
42
- with st.spinner("Summarizing with BitNet (map‑reduce)…"):
43
- summary = map_reduce_summarize(
44
- text=text,
45
- tokenizer=tok,
46
- model=model,
47
- max_chunk_tokens=chunk_tokens,
48
- overlap=chunk_overlap,
49
- chunk_max_new_tokens=chunk_max_new,
50
- final_max_new_tokens=final_max_new,
51
- temperature=temperature,
52
- top_p=top_p,
53
- )
54
-
55
-
56
- st.success("Done!")
57
- st.markdown("### Summary")
58
- st.write(summary)
59
-
60
-
61
- with st.expander("Debug / details"):
62
- st.markdown(
63
- "- **Engine:** BitNet (local)
64
- "
65
- f"- **chunk size:** {chunk_tokens} tokens, **overlap:** {chunk_overlap} tokens
66
- "
67
- f"- **temperature:** {temperature}, **top_p:** {top_p}
68
- "
69
- f"- **chunk max_new_tokens:** {chunk_max_new}, **final max_new_tokens:** {final_max_new}"
70
- )
71
-
72
-
73
- st.markdown("---")
74
- st.caption(
75
- "Built with Docker + Streamlit + Transformers + Hugging Face Hub. Model: microsoft/bitnet-b1.58-2B-4T.
76
- "
77
- "Tip: Select a GPU in Space settings for faster startup."
78
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, pipeline
5
+ from huggingface_hub import InferenceClient
6
+
7
+ # Cache model loading to avoid re-download on every run
8
+ @st.cache_resource
9
+ def load_model(model_name):
10
+ """
11
+ Load the specified model and tokenizer. Returns a transformers pipeline for summarization or text generation.
12
+ """
13
+ if model_name == "microsoft/bitnet-b1.58-2B-4T":
14
+ # Load BitNet model (causal LM) and tokenizer
15
+ dtype = torch.float32 # use float32 on CPU
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
17
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map="auto")
18
+ # Create a text-generation pipeline for BitNet
19
+ gen_pipeline = pipeline(
20
+ "text-generation",
21
+ model=model, tokenizer=tokenizer,
22
+ max_new_tokens=256, # default max summary length
23
+ temperature=0.2, # a low temperature for more focused output
24
+ pad_token_id=tokenizer.eos_token_id
25
+ )
26
+ return gen_pipeline, tokenizer # return tokenizer as well for prompt preparation
27
+ else:
28
+ # For seq2seq models like T5 or BART, use summarization pipeline
29
+ summarizer = pipeline("summarization", model=model_name, tokenizer=model_name, device=-1)
30
+ return summarizer, None
31
+
32
+ # Set page configuration
33
+ st.set_page_config(page_title="Text Summarizer", page_icon="🤖")
34
+ st.title("📃 Text Summarizer")
35
+
36
+ # Model selection: local models and an option for Hugging Face API
37
+ model_options = [
38
+ "t5-small",
39
+ "facebook/bart-large-cnn",
40
+ "microsoft/bitnet-b1.58-2B-4T",
41
+ "Use Hugging Face Inference API (bart-large-cnn)"
42
+ ]
43
+ model_choice = st.selectbox("Choose a summarization model:", model_options,
44
+ help="Select a model to use for generating the summary. 'Inference API' will call a hosted model via Hugging Face.")
45
+
46
+ # Input methods: Text area and File uploader
47
+ text_input = st.text_area("Enter text to summarize (English only):", height=200)
48
+ uploaded_file = st.file_uploader("...or upload a text file", type=["txt"])
49
+ if uploaded_file is not None:
50
+ # If a file is uploaded, read it (assuming UTF-8 text file)
51
+ try:
52
+ file_content = uploaded_file.read().decode("utf-8")
53
+ except Exception:
54
+ file_content = uploaded_file.read().decode("latin-1") # fallback decoding
55
+ text_to_summarize = file_content
56
+ else:
57
+ text_to_summarize = text_input
58
+
59
+ # Button to generate summary
60
+ if st.button("Summarize"):
61
+ if not text_to_summarize or text_to_summarize.strip() == "":
62
+ st.warning("Please provide some text (or upload a file) to summarize.")
63
+ else:
64
+ st.write("Generating summary...")
65
+ # Local model inference
66
+ if model_choice != "Use Hugging Face Inference API (bart-large-cnn)":
67
+ summarizer_pipeline, tok = load_model(model_choice)
68
+ if model_choice == "microsoft/bitnet-b1.58-2B-4T":
69
+ # Prepare BitNet prompt for summarization
70
+ prompt = (
71
+ "Summarize the text below in 2-3 concise sentences focusing on key facts and implications.\n"
72
+ f"Text:\n{ {text_to_summarize} }\nSummary:"
73
+ )
74
+ # Use the text-generation pipeline to complete the prompt
75
+ result = summarizer_pipeline(prompt, max_new_tokens=200, do_sample=False)[0]["generated_text"]
76
+ # The pipeline returns the full prompt + completion; extract only after 'Summary:'
77
+ summary = result.split("Summary:")[-1].strip()
78
+ else:
79
+ # For T5 or BART, use the summarization pipeline directly
80
+ summary = summarizer_pipeline(text_to_summarize, max_length=150, min_length=30, do_sample=False)[0]["summary_text"]
81
+ st.subheader("Summary")
82
+ st.write(summary)
83
+ else:
84
+ # Use Hugging Face Inference API with a hosted model (bart-large-cnn)
85
+ hf_token = os.getenv("HF_TOKEN") or (st.secrets["HF_TOKEN"] if "HF_TOKEN" in st.secrets else None)
86
+ client = InferenceClient(model="facebook/bart-large-cnn", token=hf_token)
87
+ # Call the summarization API
88
+ try:
89
+ result = client.summarization(text_to_summarize)
90
+ # The result is an object with `summary_text` attribute (or dict with 'summary_text')
91
+ summary_text = result.summary_text if hasattr(result, "summary_text") else result["summary_text"]
92
+ except Exception as e:
93
+ st.error(f"Error using Inference API: {e}")
94
+ summary_text = None
95
+ if summary_text:
96
+ st.subheader("Summary")
97
+ st.write(summary_text)