File size: 3,301 Bytes
a4f4613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from transformers import pipeline
import tempfile
from PIL import Image

# Lade des trainiertes ViT-Modell für die Food101-Klassifikation
vit_classifier = pipeline("image-classification", model="alimoh02/vit-base-food101")
clip_detector = pipeline(model="openai/clip-vit-large-patch14", task="zero-shot-image-classification")

# Labels aus Food101 
labels_food101 = [
    "apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad", "beignets",
    "bibimbap", "bread_pudding", "breakfast_burrito", "bruschetta", "caesar_salad", "cannoli", "caprese_salad",
    "carrot_cake", "ceviche", "cheesecake", "cheese_plate", "chicken_curry", "chicken_quesadilla",
    "chicken_wings", "chocolate_cake", "chocolate_mousse", "churros", "clam_chowder", "club_sandwich",
    "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", "deviled_eggs", "donuts", "dumplings",
    "edamame", "eggs_benedict", "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras",
    "french_fries", "french_onion_soup", "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt",
    "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon", "guacamole",
    "gyoza", "hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros", "hummus", "ice_cream",
    "lasagna", "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup",
    "mussels", "nachos", "omelette", "onion_rings", "oysters", "pad_thai", "paella", "pancakes",
    "panna_cotta", "peking_duck", "pho", "pizza", "pork_chop", "poutine", "prime_rib", "pulled_pork_sandwich",
    "ramen", "ravioli", "red_velvet_cake", "risotto", "samosa", "sashimi", "scallops", "seaweed_salad",
    "shrimp_and_grits", "spaghetti_bolognese", "spaghetti_carbonara", "spring_rolls", "steak",
    "strawberry_shortcake", "sushi", "tacos", "takoyaki", "tiramisu", "tuna_tartare", "waffles"
]

# Klassifikationsfunktion
def classify_food(image: Image.Image):
    vit_results = vit_classifier(image)
    
    vit_output = {}
    for result in vit_results:
        try:
            label_index = int(result['label'])
            label_name = labels_food101[label_index]
        except:
            label_name = str(result['label'])  # fallback
        vit_output[label_name] = round(result['score'], 4)

    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
        image.save(tmp.name)
        clip_results = clip_detector(tmp.name, candidate_labels=labels_food101)
        clip_output = {str(result['label']): round(result['score'], 4) for result in clip_results}

    return {
        "ViT Classification": vit_output,
        "CLIP Zero-Shot Classification": clip_output
    }

# Beispielbilder
example_images = [
    ["Cheeseburger.jpg"],
    ["Sushi.jpg"],
    ["Brownie.jpg"],
    ["Tiramisu.jpg"],
    ["Guacamole.jpg"],
    ["Samosa.jpg"],
    ["Oysters.jpg"]
]

# UI mit Gradio
iface = gr.Interface(
    fn=classify_food,
    inputs=gr.Image(type="pil", label="Upload a food image"),
    outputs=gr.JSON(),
    title="Food Image Classification Comparison",
    description="Vergleiche ein trainiertes ViT-Modell (Food101) mit einem CLIP Zero-Shot-Modell.",
    cache_examples=False,
    examples=example_images
)

iface.launch()