Spaces:
Running
Running
File size: 4,021 Bytes
48c27bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Noteworthy Differences:
# Classification of noteworthy differences between revisions of Wikipedia articles: an AI alignment project
# 20251114 jmd version 1
from google import genai
from google.genai import types
from pydantic import BaseModel
from dotenv import load_dotenv
import json
import os
import pandas as pd
from prompts import analyzer_prompts, judge_prompt
from retry_with_backoff import retry_with_backoff
import logfire
# Load API keys
load_dotenv()
# Setup Logfire
logfire.configure()
# This wraps Google Gen AI client calls
# to capture prompts, responses, and metadata
logfire.instrument_google_genai()
# Initialize the Gemini LLM
client = genai.Client()
@retry_with_backoff()
def classifier(old_revision, new_revision, prompt_style):
"""
Classify noteworthy differences between revisions of a Wikipedia article
Args:
old_revision: Old revision of article
new_revision: New revision of article
Returns:
noteworthy: True if the differences are noteworthy; False if not
rationale: One-sentence rational for the classification
"""
# Return None for missing revisions
if not pd.notna(old_revision) or not pd.notna(new_revision):
return {"noteworthy": None, "rationale": None}
# Get prompt template for given style
prompt_template = analyzer_prompts[prompt_style]
# Add article revisions to prompt
prompt = prompt_template.replace("{{old_revision}}", old_revision).replace(
"{{new_revision}}", new_revision
)
# Define response schema
class Response(BaseModel):
noteworthy: bool
rationale: str
# Generate response
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=Response.model_json_schema(),
),
)
return json.loads(response.text)
@retry_with_backoff()
def judge(old_revision, new_revision, rationale_1, rationale_2, mode="unaligned"):
"""
AI judge to settle disagreements between classification models
Args:
old_revision: Old revision of article
new_revision: New revision of article
rationale_1: Rationale provided by model 1 (i.e., heuristic prompt)
rationale_2: Rationale provided by model 2 (i.e., few-shot prompt)
mode: Prompt mode: unaligned, aligned-fewshot, or aligned-heuristic
Returns:
noteworthy: True if the differences are noteworthy; False if not
reasoning: One-sentence reason for the judgment
"""
prompt = judge_prompt
# Add article revisions to prompt
prompt = prompt.replace("{{old_revision}}", old_revision).replace(
"{{new_revision}}", new_revision
)
# Add rationales to prompt
prompt = prompt.replace("{{model_1_rationale}}", rationale_1).replace(
"{{model_2_rationale}}", rationale_2
)
# Optionally add alignment text to prompt
if mode == "unaligned":
alignment_text = ""
elif mode == "aligned-fewshot":
with open("data/alignment_fewshot.txt", "r") as file:
lines = file.readlines()
alignment_text = "".join(lines)
elif mode == "aligned-heuristic":
with open("data/alignment_heuristic.txt", "r") as file:
lines = file.readlines()
alignment_text = "".join(lines)
else:
raise ValueError(f"Unknown mode: {mode}")
prompt = prompt.replace("{{alignment_text}}", alignment_text)
# Define response schema
class Response(BaseModel):
noteworthy: bool
reasoning: str
# Generate response
response = client.models.generate_content(
model="gemini-2.5-flash",
contents=prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=Response.model_json_schema(),
),
)
return json.loads(response.text)
|