koalpaca-search-demo / src /streamlit_app.py
sdfsdh's picture
Update src/streamlit_app.py
21f55b7 verified
import os
import sys
import gc
import tempfile
from pathlib import Path
# κΆŒν•œ 문제 해결을 μœ„ν•œ ν™˜κ²½ λ³€μˆ˜ μ„€μ • (μ΅œμƒλ‹¨μ— μœ„μΉ˜)
temp_dir = tempfile.gettempdir()
os.environ["STREAMLIT_HOME"] = temp_dir
os.environ["STREAMLIT_CONFIG_DIR"] = os.path.join(temp_dir, ".streamlit")
os.environ["STREAMLIT_SERVER_HEADLESS"] = "true"
os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# μΊμ‹œ 디렉토리도 temp둜 μ„€μ •
os.environ["TRANSFORMERS_CACHE"] = os.path.join(temp_dir, "transformers_cache")
os.environ["HF_HOME"] = os.path.join(temp_dir, "huggingface")
# PyTorch 클래슀 경둜 좩돌 ν•΄κ²°
try:
import torch
import importlib.util
torch_classes_path = os.path.join(os.path.dirname(importlib.util.find_spec("torch").origin), "classes")
if hasattr(torch, "classes"):
torch.classes.__path__ = [torch_classes_path]
except Exception:
pass
import streamlit as st
# transformers 라이브러리 import 및 μƒνƒœ 체크
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
TRANSFORMERS_AVAILABLE = True
except ImportError as e:
TRANSFORMERS_AVAILABLE = False
st.error(f"Transformers 라이브러리λ₯Ό 뢈러올 수 μ—†μŠ΅λ‹ˆλ‹€: {e}")
# νŽ˜μ΄μ§€ μ„€μ •
st.set_page_config(
page_title="TinyLlama Demo",
page_icon="πŸ¦™",
layout="wide",
initial_sidebar_state="expanded"
)
st.title("πŸ¦™ TinyLlama 1.1B (CPU μ „μš©) 데λͺ¨")
if not TRANSFORMERS_AVAILABLE:
st.error("ν•„μš”ν•œ 라이브러리λ₯Ό μ„€μΉ˜ν•΄μ£Όμ„Έμš”:")
st.code("""
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install transformers
pip install streamlit
""", language="bash")
st.stop()
@st.cache_resource(show_spinner=False)
def load_tinyllama_model():
"""TinyLlama 1.1B λͺ¨λΈ λ‘œλ“œ (CPU Only)"""
try:
# μ—¬λŸ¬ κ°€λŠ₯ν•œ λͺ¨λΈ 이름 μ‹œλ„
model_options = [
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"microsoft/DialoGPT-small" # λ°±μ—… μ˜΅μ…˜
]
for model_name in model_options:
try:
st.info(f"λͺ¨λΈ μ‹œλ„ 쀑: {model_name}")
# ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir=os.environ.get("TRANSFORMERS_CACHE")
)
# λͺ¨λΈ λ‘œλ“œ
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
trust_remote_code=True,
cache_dir=os.environ.get("TRANSFORMERS_CACHE"),
device_map="cpu"
)
# CPU둜 λͺ…μ‹œμ  이동 및 평가 λͺ¨λ“œ
model = model.to("cpu")
model.eval()
# ν† ν¬λ‚˜μ΄μ € μ„€μ •
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# λ©”λͺ¨λ¦¬ 정리
gc.collect()
return model, tokenizer, f"βœ… {model_name} λ‘œλ“œ 성곡!"
except Exception as model_error:
st.warning(f"{model_name} λ‘œλ“œ μ‹€νŒ¨: {str(model_error)}")
continue
return None, None, "❌ λͺ¨λ“  λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨"
except Exception as e:
return None, None, f"❌ 전체 λ‘œλ“œ μ‹€νŒ¨: {str(e)}"
def generate_text(model, tokenizer, prompt, max_new_tokens=150, temperature=0.7):
"""μ•ˆμ „ν•œ ν…μŠ€νŠΈ 생성 ν•¨μˆ˜"""
try:
# μž…λ ₯ 길이 μ œν•œ
max_input_length = 400
# 토큰화
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=max_input_length,
padding=False
)
# CPU둜 이동
inputs = {k: v.to("cpu") for k, v in inputs.items()}
input_length = inputs['input_ids'].shape[1]
# μ•ˆμ „ν•œ 생성 길이 계산
safe_max_tokens = min(max_new_tokens, 800 - input_length)
if safe_max_tokens < 20:
safe_max_tokens = 20
# 생성 μ„€μ •
generation_kwargs = {
"max_new_tokens": safe_max_tokens,
"temperature": temperature,
"do_sample": True,
"top_p": 0.9,
"top_k": 50,
"repetition_penalty": 1.1,
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"use_cache": True,
"early_stopping": True
}
# λ©”λͺ¨λ¦¬ 정리
gc.collect()
# 생성 μ‹€ν–‰
with st.spinner("ν…μŠ€νŠΈ 생성 쀑..."):
with torch.no_grad():
outputs = model.generate(
**inputs,
**generation_kwargs
)
# μƒˆλ‘œ μƒμ„±λœ λΆ€λΆ„λ§Œ μΆ”μΆœ
new_tokens = outputs[0][input_length:]
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
return generated_text.strip()
except Exception as e:
raise Exception(f"생성 쀑 였λ₯˜: {str(e)}")
def main():
# λͺ¨λΈ λ‘œλ“œ
with st.spinner("TinyLlama λͺ¨λΈ λ‘œλ”© 쀑... (처음 μ‹€ν–‰ μ‹œ λ‹€μš΄λ‘œλ“œλ‘œ 인해 μ‹œκ°„μ΄ 걸릴 수 μžˆμŠ΅λ‹ˆλ‹€)"):
model, tokenizer, status = load_tinyllama_model()
st.info(status)
if not (model and tokenizer):
st.error("λͺ¨λΈ λ‘œλ“œμ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. 인터넷 연결을 ν™•μΈν•˜κ³  λ‹€μ‹œ μ‹œλ„ν•΄μ£Όμ„Έμš”.")
return
# μ‚¬μ΄λ“œλ°” μ„€μ •
st.sidebar.header("βš™οΈ 생성 μ„€μ •")
max_new_tokens = st.sidebar.slider("μ΅œλŒ€ μƒˆ 토큰 수", 20, 200, 100)
temperature = st.sidebar.slider("Temperature (μ°½μ˜μ„±)", 0.1, 1.0, 0.7, 0.1)
# 도움말
st.sidebar.header("πŸ“– μ‚¬μš© κ°€μ΄λ“œ")
st.sidebar.info("""
**Tips:**
- ν”„λ‘¬ν”„νŠΈλŠ” λͺ…ν™•ν•˜κ³  κ°„κ²°ν•˜κ²Œ
- CPU μ „μš©μ΄λ―€λ‘œ 생성에 μ‹œκ°„μ΄ κ±Έλ¦½λ‹ˆλ‹€
- 첫 μ‹€ν–‰ μ‹œ λͺ¨λΈ λ‹€μš΄λ‘œλ“œλ‘œ μ‹œκ°„μ΄ 더 κ±Έλ¦½λ‹ˆλ‹€
""")
# 메인 μΈν„°νŽ˜μ΄μŠ€
st.header("πŸ’¬ ν…μŠ€νŠΈ 생성")
# 예제 ν”„λ‘¬ν”„νŠΈ
col1, col2 = st.columns([2, 1])
with col1:
example_prompts = [
"μ‚¬μš©μž μ •μ˜ μž…λ ₯",
"The future of artificial intelligence is",
"Once upon a time in a magical forest,",
"Python is a programming language that",
"Climate change is an important issue because",
"The benefits of reading books include"
]
selected_prompt = st.selectbox("예제 ν”„λ‘¬ν”„νŠΈ 선택:", example_prompts)
with col2:
st.write("") # 곡간 확보
st.write("") # 곡간 확보
if st.button("🎲 랜덀 예제", help="λžœλ€ν•œ 예제 ν”„λ‘¬ν”„νŠΈ 선택"):
import random
random_prompt = random.choice(example_prompts[1:]) # 첫 번째 μ œμ™Έ
st.session_state.random_prompt = random_prompt
# ν”„λ‘¬ν”„νŠΈ μž…λ ₯
if selected_prompt == "μ‚¬μš©μž μ •μ˜ μž…λ ₯":
default_prompt = st.session_state.get('random_prompt', '')
prompt = st.text_area(
"ν”„λ‘¬ν”„νŠΈλ₯Ό μž…λ ₯ν•˜μ„Έμš”:",
value=default_prompt,
height=100,
placeholder="여기에 ν…μŠ€νŠΈλ₯Ό μž…λ ₯ν•˜μ„Έμš”..."
)
else:
prompt = st.text_area(
"ν”„λ‘¬ν”„νŠΈ:",
value=selected_prompt,
height=100
)
# 토큰 수 ν‘œμ‹œ
if prompt and tokenizer:
try:
token_count = len(tokenizer.encode(prompt))
st.caption(f"ν˜„μž¬ ν”„λ‘¬ν”„νŠΈ 토큰 수: {token_count}")
if token_count > 400:
st.warning("⚠️ ν”„λ‘¬ν”„νŠΈκ°€ λ„ˆλ¬΄ κΉλ‹ˆλ‹€. 400 ν† ν°μœΌλ‘œ μžλ™ μž˜λ¦Όλ©λ‹ˆλ‹€.")
except:
pass
# 생성 λ²„νŠΌ
col1, col2, col3 = st.columns([1, 1, 2])
with col1:
generate_btn = st.button("πŸš€ 생성 μ‹œμž‘", type="primary", use_container_width=True)
with col2:
clear_btn = st.button("πŸ—‘οΈ κ²°κ³Ό μ§€μš°κΈ°", use_container_width=True)
# κ²°κ³Ό μ§€μš°κΈ°
if clear_btn:
if 'generated_result' in st.session_state:
del st.session_state['generated_result']
st.rerun()
# ν…μŠ€νŠΈ 생성
if generate_btn and prompt.strip():
try:
# 생성 μ§„ν–‰λ₯  ν‘œμ‹œ
progress_container = st.container()
with progress_container:
progress_bar = st.progress(0)
status_text = st.empty()
status_text.text("토큰화 쀑...")
progress_bar.progress(20)
status_text.text("ν…μŠ€νŠΈ 생성 쀑... (CPUμ—μ„œ μ‹€ν–‰λ˜λ―€λ‘œ μ‹œκ°„μ΄ κ±Έλ¦½λ‹ˆλ‹€)")
progress_bar.progress(40)
# μ‹€μ œ 생성
generated_text = generate_text(
model, tokenizer, prompt.strip(),
max_new_tokens, temperature
)
progress_bar.progress(80)
status_text.text("κ²°κ³Ό 처리 쀑...")
# κ²°κ³Ό μ €μž₯
full_result = prompt + generated_text
st.session_state['generated_result'] = {
'prompt': prompt,
'generated': generated_text,
'full_text': full_result
}
progress_bar.progress(100)
status_text.text("μ™„λ£Œ!")
# μ§„ν–‰λ₯  ν‘œμ‹œ 제거
progress_bar.empty()
status_text.empty()
except Exception as e:
st.error(f"생성 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {str(e)}")
st.info("πŸ’‘ λ‹€μ‹œ μ‹œλ„ν•˜κ±°λ‚˜ 더 짧은 ν”„λ‘¬ν”„νŠΈλ₯Ό μ‚¬μš©ν•΄λ³΄μ„Έμš”.")
elif generate_btn:
st.warning("⚠️ ν”„λ‘¬ν”„νŠΈλ₯Ό μž…λ ₯ν•΄μ£Όμ„Έμš”.")
# κ²°κ³Ό ν‘œμ‹œ
if 'generated_result' in st.session_state:
result = st.session_state['generated_result']
st.header("πŸ“ 생성 κ²°κ³Ό")
# νƒ­μœΌλ‘œ ꡬ뢄
tab1, tab2 = st.tabs(["🎯 μƒμ„±λœ ν…μŠ€νŠΈλ§Œ", "πŸ“„ 전체 ν…μŠ€νŠΈ"])
with tab1:
st.markdown("**μƒˆλ‘œ μƒμ„±λœ λΆ€λΆ„:**")
st.markdown(f'<div style="background-color: #f0f2f6; padding: 15px; border-radius: 10px; border-left: 4px solid #4CAF50;">{result["generated"]}</div>', unsafe_allow_html=True)
with tab2:
st.markdown("**전체 ν…μŠ€νŠΈ (ν”„λ‘¬ν”„νŠΈ + 생성):**")
st.text_area(
"전체 κ²°κ³Ό:",
value=result['full_text'],
height=200,
disabled=True
)
# λ‹€μš΄λ‘œλ“œ λ²„νŠΌ
st.download_button(
label="πŸ’Ύ ν…μŠ€νŠΈ 파일둜 μ €μž₯",
data=result['full_text'],
file_name=f"tinyllama_output_{len(result['full_text'])}.txt",
mime="text/plain",
use_container_width=True
)
# μ‹œμŠ€ν…œ 정보 μ‚¬μ΄λ“œλ°”
st.sidebar.header("πŸ’» μ‹œμŠ€ν…œ 정보")
st.sidebar.write(f"**Python:** {sys.version.split()[0]}")
if TRANSFORMERS_AVAILABLE:
st.sidebar.write(f"**PyTorch:** {torch.__version__}")
st.sidebar.write(f"**CUDA μ‚¬μš© κ°€λŠ₯:** {'βœ…' if torch.cuda.is_available() else '❌'}")
st.sidebar.write(f"**μ‹€ν–‰ λͺ¨λ“œ:** CPU μ „μš©")
# μ„±λŠ₯ 팁
with st.sidebar.expander("πŸš€ μ„±λŠ₯ μ΅œμ ν™” 팁"):
st.markdown("""
**속도 ν–₯상:**
- ν”„λ‘¬ν”„νŠΈλ₯Ό 100단어 μ΄ν•˜λ‘œ μœ μ§€
- 토큰 수λ₯Ό 150개 μ΄ν•˜λ‘œ μ œν•œ
- μ—¬λŸ¬ νƒ­μ—μ„œ λ™μ‹œ μ‹€ν–‰ν•˜μ§€ μ•ŠκΈ°
**λ©”λͺ¨λ¦¬ μ ˆμ•½:**
- λ‹€λ₯Έ 무거운 μ• ν”Œλ¦¬μΌ€μ΄μ…˜ μ’…λ£Œ
- λΈŒλΌμš°μ € νƒ­ μ΅œμ†Œν™”
""")
if __name__ == "__main__":
main()