Spaces:
Runtime error
Runtime error
File size: 5,044 Bytes
10b8ce5 ed0aacc 10b8ce5 ed0aacc 10b8ce5 d9581e7 ed0aacc 10b8ce5 ed0aacc 10b8ce5 ed0aacc 10b8ce5 ed0aacc 10b8ce5 b7dd971 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from flask import Flask, render_template, request
from PIL import Image
import os
import torch
import cv2
import mediapipe as mp
from transformers import SamModel, SamProcessor
from diffusers.utils import load_image
from torchvision import transforms
import tempfile
app = Flask(__name__)
# Use temporary directories for uploads and outputs
UPLOAD_FOLDER = '/tmp/uploads'
OUTPUT_FOLDER = '/tmp/outputs'
# Ensure folders exist
try:
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# Also create static directories for serving files
os.makedirs('static/uploads', exist_ok=True)
os.makedirs('static/outputs', exist_ok=True)
except PermissionError as e:
print(f"Permission denied for creating directories: {e}")
# Load model once at startup
try:
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-50")
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-50")
print("Models loaded successfully")
except Exception as e:
print(f"Error loading models: {e}")
# Pose function
def get_shoulder_coordinates(image_path):
try:
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(
static_image_mode=True,
model_complexity=2,
enable_segmentation=False,
min_detection_confidence=0.5
)
image = cv2.imread(image_path)
if image is None:
print(f"Could not load image from {image_path}")
return None
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
if results.pose_landmarks:
height, width, _ = image.shape
landmarks = results.pose_landmarks.landmark
left_shoulder = (
int(landmarks[11].x * width),
int(landmarks[11].y * height)
)
right_shoulder = (
int(landmarks[12].x * width),
int(landmarks[12].y * height)
)
print(f"Left shoulder: {left_shoulder}")
print(f"Right shoulder: {right_shoulder}")
return left_shoulder, right_shoulder
else:
print("No pose landmarks detected")
return None
except Exception as e:
print(f"Error in pose detection: {e}")
return None
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
try:
person_file = request.files.get('person_image')
tshirt_file = request.files.get('tshirt_image')
if not person_file or not tshirt_file:
return "Please upload both person and t-shirt images."
# Save files to temporary directory
person_path = os.path.join(UPLOAD_FOLDER, 'person.jpg')
tshirt_path = os.path.join(UPLOAD_FOLDER, 'tshirt.png')
person_file.save(person_path)
tshirt_file.save(tshirt_path)
# Run your model
coordinates = get_shoulder_coordinates(person_path)
if coordinates is None:
return "No pose detected. Please try with a different image where the person's shoulders are clearly visible."
img = load_image(person_path)
new_tshirt = load_image(tshirt_path)
left_shoulder, right_shoulder = coordinates
input_points = [[[left_shoulder[0], left_shoulder[1]], [right_shoulder[0], right_shoulder[1]]]]
inputs = processor(img, input_points=input_points, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
mask_tensor = masks[0][0][2].to(dtype=torch.uint8)
mask = transforms.ToPILImage()(mask_tensor * 255)
new_tshirt = new_tshirt.resize(img.size, Image.LANCZOS)
img_with_new_tshirt = Image.composite(new_tshirt, img, mask)
# Save result to both temp and static directories
result_path_temp = os.path.join(OUTPUT_FOLDER, 'result.jpg')
result_path_static = os.path.join('static/outputs', 'result.jpg')
img_with_new_tshirt.save(result_path_temp)
img_with_new_tshirt.save(result_path_static)
return render_template('index.html', result_img='outputs/result.jpg')
except Exception as e:
print(f"Error processing request: {e}")
return f"Error processing images: {str(e)}"
return render_template('index.html')
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=6000) |