test-tinystoris / app.py
mahesh1209's picture
Create app.py
64a69e3 verified
import os
import torch
import torch.nn as nn
import torch.optim as optim
from flask import Flask, request, jsonify, render_template
API_KEY = os.getenv("API_KEY", "mysecretkey")
stories = [
"The sun rose over the quiet village.",
"A cat chased a butterfly through the garden.",
"Rain tapped gently on the windowpane.",
"Children laughed and played in the park.",
"The old man told stories by the fire."
]
text = " ".join(stories).lower()
tokens = text.split()
vocab = sorted(set(tokens))
word2idx = {w: i for i, w in enumerate(vocab)}
encoded = [word2idx[w] for w in tokens]
seq_len = 4
inputs, targets = [], []
for i in range(len(encoded) - seq_len):
inputs.append(encoded[i:i+seq_len])
targets.append(encoded[i+1:i+seq_len+1])
inputs = torch.tensor(inputs)
targets = torch.tensor(targets)
class SLM(nn.Module):
def __init__(self, vocab_size, embed_dim):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_dim)
self.fc = nn.Linear(embed_dim * seq_len, vocab_size)
def forward(self, x):
x = self.embed(x).view(x.size(0), -1)
return self.fc(x)
loss_fn = nn.CrossEntropyLoss()
model = SLM(len(vocab), embed_dim=10)
optimizer = optim.Adam(model.parameters(), lr=0.01)
for epoch in range(50): # reduced for speed
model.train()
optimizer.zero_grad()
output = model(inputs)
loss = loss_fn(output, targets[:, -1])
loss.backward()
optimizer.step()
def generate_sentence(phrase):
phrase = phrase.strip().lower()
phrase_words = phrase.split()
if len(phrase_words) < 2:
return "Please enter at least 2 words from a sentence."
for story in stories:
story_text = story.lower()
if all(word in story_text for word in phrase_words):
return story
return f'No match found for "{phrase}". Try phrases like: quiet village, chased butterfly, laughed played, told stories'
app = Flask(__name__, template_folder="templates")
@app.route("/", methods=["GET"])
def home():
return render_template("index.html")
@app.route("/generate", methods=["POST"])
def generate():
key = request.headers.get("x-api-key")
if key != API_KEY:
return jsonify({"error": "Unauthorized"}), 401
data = request.get_json()
phrase = data.get("phrase", "")
result = generate_sentence(phrase)
return jsonify({"input": phrase, "output": result})
# Proxy route: frontend calls this, backend injects API key
@app.route("/frontend_generate", methods=["POST"])
def frontend_generate():
data = request.get_json()
phrase = data.get("phrase", "")
# Internally call the same logic as /generate, but skip key check
result = generate_sentence(phrase)
return jsonify({"input": phrase, "output": result})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)