stash
This commit is contained in:
371
stash/config/plugins/community/LocalVisage/LocalVisage.py
Normal file
371
stash/config/plugins/community/LocalVisage/LocalVisage.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import os
|
||||
import sys
|
||||
import zipfile
|
||||
import tempfile
|
||||
from PythonDepManager import ensure_import
|
||||
# --- VENV AUTO-CREATION WITH REQUIREMENTS AND AUTO-RESTART ---
|
||||
venv_dir = os.path.join(os.path.dirname(__file__), "venv")
|
||||
requirements_path = os.path.join(os.path.dirname(__file__), "requirements.txt")
|
||||
# --- PYTHON VERSION CHECK ---
|
||||
|
||||
if not os.path.isdir(venv_dir) and not (sys.version_info.major == 3 and sys.version_info.minor == 10):
|
||||
ensure_import("stashapi:stashapp-tools>=0.2.58")
|
||||
import stashapi.log as log
|
||||
log.error("Error: Python version must be >= 3.10.X (recommanded 3.10.11) for the first installation of the plugin. Once installed you can change back your python version in stash as this plugin will run within its own venv")
|
||||
log.error(f"Current version: {sys.version}")
|
||||
log.error("Go to https://www.python.org/downloads/release/python-31011/")
|
||||
sys.exit(1)
|
||||
# --- END PYTHON VERSION CHECK ---
|
||||
|
||||
|
||||
def in_venv():
|
||||
# Checks if running inside the venv we expect
|
||||
return (
|
||||
hasattr(sys, 'real_prefix') or
|
||||
(hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix)
|
||||
) and os.path.abspath(sys.prefix) == os.path.abspath(venv_dir)
|
||||
def install_dependencies():
|
||||
"""
|
||||
Install dependencies from requirements.txt if not already installed.
|
||||
"""
|
||||
if not os.path.isfile(requirements_path):
|
||||
print("No requirements.txt found, skipping dependency installation.")
|
||||
return
|
||||
|
||||
import subprocess
|
||||
pip_exe = os.path.join(venv_dir, "Scripts", "pip.exe") if os.name == "nt" else os.path.join(venv_dir, "bin", "pip")
|
||||
py_exe = os.path.join(venv_dir, "Scripts", "python.exe") if os.name == "nt" else os.path.join(venv_dir, "bin", "python")
|
||||
subprocess.check_call([py_exe,"-m","pip", "install", "--upgrade", "pip"])
|
||||
subprocess.check_call([pip_exe, "install", "-r", requirements_path])
|
||||
|
||||
if not os.path.isdir(venv_dir):
|
||||
|
||||
ensure_import("stashapi:stashapp-tools>=0.2.58")
|
||||
import stashapi.log as log
|
||||
import subprocess
|
||||
log.info("No venv found. Creating virtual environment...")
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "venv", venv_dir])
|
||||
log.progress(0.25)
|
||||
log.info("Virtual environment created at "+ venv_dir)
|
||||
if os.path.isfile(requirements_path):
|
||||
log.info("Installing dependencies... This might take a while")
|
||||
install_dependencies()
|
||||
else:
|
||||
log.info("No requirements.txt found, skipping dependency installation.")
|
||||
|
||||
# If not running in the venv, restart the script using the venv's Python
|
||||
if not in_venv():
|
||||
py_exe = os.path.join(venv_dir, "Scripts", "python.exe") if os.name == "nt" else os.path.join(venv_dir, "bin", "python")
|
||||
print(f"Restarting script in venv: {py_exe}")
|
||||
os.execv(py_exe, [py_exe] + sys.argv)
|
||||
# --- END VENV AUTO-CREATION WITH REQUIREMENTS AND AUTO-RESTART ---
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
# Set environment variables
|
||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Suppress TF logs
|
||||
# Ensure dependencies
|
||||
|
||||
try:
|
||||
from deepface import DeepFace
|
||||
import numpy as np
|
||||
import psutil
|
||||
import stashapi.log as log
|
||||
from stashapi.stashapp import StashInterface
|
||||
except:
|
||||
install_dependencies()
|
||||
|
||||
from deepface import DeepFace
|
||||
import numpy as np
|
||||
import psutil
|
||||
import stashapi.log as log
|
||||
from stashapi.stashapp import StashInterface
|
||||
|
||||
VOY_DB_PATH = os.path.join(os.path.dirname(__file__), "voy_db")
|
||||
os.makedirs(os.path.join(VOY_DB_PATH, "facenet"), exist_ok=True)
|
||||
os.makedirs(os.path.join(VOY_DB_PATH, "arc"), exist_ok=True)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the plugin.
|
||||
"""
|
||||
global stash
|
||||
|
||||
json_input = json.loads(sys.stdin.read())
|
||||
stash = StashInterface(json_input["server_connection"])
|
||||
mode_arg = json_input["args"].get("mode")
|
||||
config = stash.get_configuration()["plugins"]
|
||||
settings = {"voyCount": 15, "sceneCount": 0, "imgCount": 0}
|
||||
if "LocalVisage" in config:
|
||||
settings.update(config["LocalVisage"])
|
||||
|
||||
if mode_arg == "spawn_server":
|
||||
spawn_server(json_input["server_connection"])
|
||||
elif mode_arg == "stop_server":
|
||||
kill_stashface_server()
|
||||
elif mode_arg == "rebuild_model":
|
||||
rebuild_model(update_only=False, settings=settings)
|
||||
elif mode_arg == "update_model":
|
||||
rebuild_model(update_only=True, settings=settings)
|
||||
|
||||
def can_read_image(image_path):
|
||||
"""
|
||||
Check if an image path can be read, handling both regular files and files inside ZIP archives.
|
||||
|
||||
Args:
|
||||
image_path (str): Path to the image file
|
||||
|
||||
Returns:
|
||||
tuple: (can_read, actual_path) where can_read is bool and actual_path is the path to use
|
||||
"""
|
||||
if os.path.exists(image_path):
|
||||
return True, image_path
|
||||
|
||||
# Check if it's inside a ZIP file
|
||||
if ".zip" in image_path.lower():
|
||||
try:
|
||||
parts = image_path.split(".zip")
|
||||
if len(parts) >= 2:
|
||||
zip_path = parts[0] + ".zip"
|
||||
internal_path = parts[1].lstrip(os.sep + "/") # Remove leading separators
|
||||
|
||||
if os.path.exists(zip_path):
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_file:
|
||||
# Check if the internal path exists in the ZIP
|
||||
if internal_path in zip_file.namelist():
|
||||
# Extract to temporary file and return temp path
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(internal_path)[1]) as tmp_file:
|
||||
tmp_file.write(zip_file.read(internal_path))
|
||||
return True, tmp_file.name
|
||||
except Exception as e:
|
||||
log.warning(f"Error reading from ZIP file {image_path}: {e}")
|
||||
|
||||
return False, image_path
|
||||
|
||||
def cleanup_temp_file(file_path):
|
||||
"""
|
||||
Clean up temporary files created for ZIP extraction.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the temporary file
|
||||
"""
|
||||
try:
|
||||
if file_path.startswith(tempfile.gettempdir()):
|
||||
os.unlink(file_path)
|
||||
except Exception as e:
|
||||
log.warning(f"Error cleaning up temporary file {file_path}: {e}")
|
||||
|
||||
|
||||
def find_performers(settings):
|
||||
"""
|
||||
Find performers with images for model building.
|
||||
"""
|
||||
query={}
|
||||
# query performers based on sceneCount and imgCount settings
|
||||
scene_count_min = settings.get("sceneCount", 0)
|
||||
img_count_min = settings.get("imgCount", 0)
|
||||
if scene_count_min>0 or img_count_min>0:
|
||||
query={
|
||||
"scene_count": {"modifier": "GREATER_THAN", "value": scene_count_min-1},
|
||||
"image_count": {"modifier": "GREATER_THAN", "value": img_count_min-1},
|
||||
}
|
||||
performers_all = stash.find_performers(f=query, fragment="id name image_path custom_fields")
|
||||
performers_without_image = stash.find_performers(f={"is_missing": "image"}, fragment="id")
|
||||
performers_without_image_ids = {p["id"] for p in performers_without_image}
|
||||
performers_to_process = [p for p in performers_all if p["id"] not in performers_without_image_ids]
|
||||
|
||||
|
||||
|
||||
performers_to_process = [
|
||||
p for p in performers_to_process
|
||||
if (p.get("scene_count", 0) >= scene_count_min and
|
||||
p.get("image_count", 0) >= img_count_min)
|
||||
]
|
||||
return enrich_performers(performers_to_process, settings)
|
||||
|
||||
def enrich_performers(performers, settings):
|
||||
"""
|
||||
Add extra images to each performer for embedding calculation.
|
||||
"""
|
||||
for progress, performer in enumerate(performers):
|
||||
performer["images"] = []
|
||||
if performer.get("image_path"):
|
||||
performer["images"].append(performer["image_path"])
|
||||
extra_images = stash.find_images(
|
||||
filter={
|
||||
"direction": "ASC",
|
||||
"page": 1,
|
||||
"per_page": settings.get("voyCount", 15) - 1,
|
||||
"q": "",
|
||||
"sort": "random_11365347"
|
||||
},
|
||||
f={
|
||||
"performer_count": {"modifier": "EQUALS", "value": 1},
|
||||
"performers": {"modifier": "INCLUDES_ALL", "value": [performer["id"]]},
|
||||
"path": {
|
||||
"modifier": "NOT_MATCHES_REGEX",
|
||||
"value": r".*\.(mp4|webm|avi|mov|mkv|flv|wmv|gif)$|.*[^\x00-\x7F].*"
|
||||
}
|
||||
}
|
||||
)
|
||||
for image in extra_images:
|
||||
if image.get("visual_files") and len(image["visual_files"]) > 0:
|
||||
image_path = image["visual_files"][0]["path"]
|
||||
can_read, actual_path = can_read_image(image_path)
|
||||
if can_read:
|
||||
performer["images"].append(actual_path)
|
||||
else:
|
||||
log.warning(f"Image path does not exist and cannot be read: {image_path}")
|
||||
else:
|
||||
log.warning(f"No visual files found for image ID: {image['id']}")
|
||||
log.progress((progress + 1) / len(performers))
|
||||
return performers
|
||||
|
||||
def rebuild_model(update_only, settings):
|
||||
"""
|
||||
Build or update the face embedding model for all performers.
|
||||
"""
|
||||
log.info("Updating model..." if update_only else "Rebuilding model...")
|
||||
performers = find_performers(settings)
|
||||
if not performers:
|
||||
log.info("No performers found for model building.")
|
||||
return
|
||||
|
||||
log.info("Database scraped, starting to rebuild model...")
|
||||
for progress, performer in enumerate(performers):
|
||||
embeddings_facenet = []
|
||||
embeddings_arc = []
|
||||
custom_fields = performer.get("custom_fields", {})
|
||||
images_used = custom_fields.get("number_of_images_used_for_voy", 0)
|
||||
if update_only and images_used >= settings["voyCount"]:
|
||||
continue
|
||||
if update_only and len(performer["images"]) <= images_used:
|
||||
continue
|
||||
|
||||
for uri in performer["images"]:
|
||||
try:
|
||||
result_facenet = DeepFace.represent(
|
||||
img_path=uri,
|
||||
model_name="Facenet512",
|
||||
detector_backend='yolov8',
|
||||
normalization='Facenet2018',
|
||||
align=True,
|
||||
enforce_detection=False
|
||||
)
|
||||
embeddings_facenet.append(result_facenet[0]['embedding'])
|
||||
result_arc = DeepFace.represent(
|
||||
img_path=uri,
|
||||
model_name="ArcFace",
|
||||
detector_backend='yolov8',
|
||||
enforce_detection=False,
|
||||
align=True
|
||||
)
|
||||
embeddings_arc.append(result_arc[0]['embedding'])
|
||||
except Exception as e:
|
||||
log.warning(f"[WARN] Skipping {uri}: {e}")
|
||||
finally:
|
||||
# Clean up temporary files created for ZIP extraction
|
||||
cleanup_temp_file(uri)
|
||||
|
||||
if embeddings_facenet and embeddings_arc:
|
||||
avg_embedding_facenet = np.mean(embeddings_facenet, axis=0).astype(np.float32)
|
||||
facenet_path = os.path.join(VOY_DB_PATH, "facenet", f"{performer['id']}-{performer['name']}.voy")
|
||||
np.save(facenet_path, avg_embedding_facenet)
|
||||
avg_embedding_arc = np.mean(embeddings_arc, axis=0).astype(np.float32)
|
||||
arc_path = os.path.join(VOY_DB_PATH, "arc", f"{performer['id']}-{performer['name']}.voy")
|
||||
np.save(arc_path, avg_embedding_arc)
|
||||
embeddings_count = max(len(embeddings_facenet), len(embeddings_arc))
|
||||
stash.update_performer({
|
||||
"id": performer["id"],
|
||||
"custom_fields": {
|
||||
"partial": {
|
||||
"number_of_images_used_for_voy": embeddings_count,
|
||||
}
|
||||
}
|
||||
})
|
||||
log.info(f"[INFO] Saved VOY for {performer['name']} with {embeddings_count} images.")
|
||||
else:
|
||||
log.warning(f"[WARN] No valid embeddings for {performer['name']}.")
|
||||
log.progress((progress + 1) / len(performers))
|
||||
log.info("Rebuilding model finished.")
|
||||
if server_running():
|
||||
kill_stashface_server()
|
||||
# Optionally, reload server with new connection info if needed
|
||||
|
||||
def server_running():
|
||||
"""
|
||||
Check if the stashface server is running.
|
||||
"""
|
||||
try:
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
name = proc.info.get('name', '').lower()
|
||||
cmdline_raw = proc.info.get('cmdline')
|
||||
if not cmdline_raw:
|
||||
continue
|
||||
cmdline = [str(arg).lower() for arg in cmdline_raw]
|
||||
if 'python' in name and any('stashface' in arg and 'app.py' in arg for arg in cmdline):
|
||||
log.debug("Stashface server is already running.")
|
||||
return True
|
||||
except psutil.NoSuchProcess:
|
||||
return False
|
||||
return False
|
||||
|
||||
def kill_stashface_server():
|
||||
"""
|
||||
Kill any running stashface server processes.
|
||||
"""
|
||||
killed = False
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
try:
|
||||
cmdline = proc.info['cmdline']
|
||||
if cmdline and any('stashface' in arg and 'app.py' in arg for arg in cmdline):
|
||||
log.debug(f"Killing process {proc.pid}: {' '.join(cmdline)}")
|
||||
proc.kill()
|
||||
killed = True
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
if killed:
|
||||
log.info("Stashface server killed.")
|
||||
|
||||
def spawn_server(server_connection=None):
|
||||
"""
|
||||
Spawn the stashface server as a subprocess.
|
||||
"""
|
||||
if server_running():
|
||||
log.info("Stashface server is already running.")
|
||||
return
|
||||
plugin_dir = os.path.dirname(__file__)
|
||||
py_exe = os.path.join(venv_dir, "Scripts", "python.exe") if os.name == "nt" else os.path.join(venv_dir, "bin", "python")
|
||||
cmd = [
|
||||
py_exe,
|
||||
os.path.abspath(os.path.join(plugin_dir, "stashface", "app.py")),
|
||||
]
|
||||
log.info("Spawning server")
|
||||
env = os.environ.copy()
|
||||
if server_connection is not None:
|
||||
env["SERVER_CONNECTION"] = json.dumps(server_connection)
|
||||
if platform.system() == "Windows":
|
||||
subprocess.Popen(
|
||||
cmd,
|
||||
creationflags=subprocess.CREATE_NEW_CONSOLE,
|
||||
close_fds=True,
|
||||
cwd=plugin_dir,
|
||||
env=env
|
||||
)
|
||||
else:
|
||||
subprocess.Popen(
|
||||
cmd,
|
||||
start_new_session=True,
|
||||
close_fds=True,
|
||||
cwd=plugin_dir,
|
||||
env=env
|
||||
)
|
||||
log.info("Server spawned successfully, you can now use the plugin.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
57
stash/config/plugins/community/LocalVisage/LocalVisage.yml
Normal file
57
stash/config/plugins/community/LocalVisage/LocalVisage.yml
Normal file
@@ -0,0 +1,57 @@
|
||||
name: Local Visage
|
||||
description: Local Performer Recognition plugin using DeepFace
|
||||
# requires: PythonDepManager
|
||||
# requires: stashUserscriptLibrary7dJx1qP
|
||||
version: 1.0.1
|
||||
exec:
|
||||
- python
|
||||
- "{pluginDir}/LocalVisage.py"
|
||||
interface: raw
|
||||
ui:
|
||||
requires:
|
||||
- stashUserscriptLibrary7dJx1qP
|
||||
javascript:
|
||||
- frontend.js
|
||||
- https://cdn.jsdelivr.net/npm/@gradio/client@1.15.3/dist/index.js
|
||||
css:
|
||||
- frontend.css
|
||||
csp:
|
||||
connect-src:
|
||||
- http://localhost:7860
|
||||
- http://192.168.1.198:7860
|
||||
- http://your-server-ip-address:7860
|
||||
script-src:
|
||||
- https://cdn.jsdelivr.net/npm/html2canvas@1.4.1/dist/html2canvas.min.js
|
||||
- https://cdn.jsdelivr.net/npm/@gradio/client@1.15.3/dist/index.js
|
||||
|
||||
tasks:
|
||||
- name: Rebuild Face Recognition Model
|
||||
description: Rebuild the face recognition model entirely
|
||||
defaultArgs:
|
||||
mode: rebuild_model
|
||||
- name: Update Face Recognition Model
|
||||
description: Update the face performers model with new images if there model was built on less than "Target image count per voy" images
|
||||
defaultArgs:
|
||||
mode: update_model
|
||||
- name: Start server
|
||||
description: Start the face recognition server (if not started) to allow the plugin to work
|
||||
defaultArgs:
|
||||
mode: spawn_server
|
||||
- name: Stop server
|
||||
description: Stop the face recognition server
|
||||
defaultArgs:
|
||||
mode: stop_server
|
||||
|
||||
settings:
|
||||
voyCount:
|
||||
displayName: Target image count per voy (default is 15)
|
||||
description: Number of images to to use to create the face recognition model (per performer)
|
||||
type: NUMBER
|
||||
imgCount:
|
||||
displayName: Minimum number of images for performer to be added to model
|
||||
description: Minimum number of images a performer must have to be included in recognition (EXCLUDING THE PERFORMER THUMBNAIL). Set to 0 for best result.
|
||||
type: NUMBER
|
||||
sceneCount:
|
||||
displayName: Minimum number of scenes for performer to be added to model
|
||||
description: Minimum number of scenes a performer must have to be included in recognition
|
||||
type: NUMBER
|
||||
136
stash/config/plugins/community/LocalVisage/frontend.css
Normal file
136
stash/config/plugins/community/LocalVisage/frontend.css
Normal file
@@ -0,0 +1,136 @@
|
||||
button.svelte-localhjf {
|
||||
background-color: var(--nav-color);
|
||||
border: 0px;
|
||||
}
|
||||
.scanner.svelte-localhjf {
|
||||
animation: svelte-localhjf-pulse 2s infinite;
|
||||
}
|
||||
@keyframes svelte-localhjf-pulse {
|
||||
0% {
|
||||
transform: scale(0.95);
|
||||
box-shadow: 0 0 0 0 var(--light);
|
||||
}
|
||||
70% {
|
||||
transform: scale(1.1);
|
||||
box-shadow: 0 0 0 10px var(--info);
|
||||
}
|
||||
100% {
|
||||
transform: scale(0.95);
|
||||
box-shadow: 0 0 0 0 var(--primary);
|
||||
}
|
||||
}
|
||||
svg.svelte-localhjf {
|
||||
fill: #ffffff;
|
||||
}
|
||||
button.svelte-localhjf {
|
||||
background-color: var(--nav-color);
|
||||
border: 0px;
|
||||
}
|
||||
.scanner.svelte-localhjf {
|
||||
animation: svelte-localhjf-pulse 2s infinite;
|
||||
}
|
||||
@keyframes svelte-localhjf-pulse {
|
||||
0% {
|
||||
transform: scale(0.95);
|
||||
box-shadow: 0 0 0 0 var(--light);
|
||||
}
|
||||
70% {
|
||||
transform: scale(1.1);
|
||||
box-shadow: 0 0 0 10px var(--info);
|
||||
}
|
||||
100% {
|
||||
transform: scale(0.95);
|
||||
box-shadow: 0 0 0 0 var(--primary);
|
||||
}
|
||||
}
|
||||
svg.svelte-localhjf {
|
||||
fill: #ffffff;
|
||||
}
|
||||
.carousel.svelte-localhja {
|
||||
display: flex;
|
||||
overflow-x: auto;
|
||||
overflow-y: auto;
|
||||
white-space: nowrap;
|
||||
overscroll-behavior-x: contain;
|
||||
overscroll-behavior-y: contain;
|
||||
scroll-snap-type: x mandatory;
|
||||
gap: 1rem;
|
||||
}
|
||||
.modal-header.svelte-localhja {
|
||||
font-size: 2.4rem;
|
||||
border-bottom: 0px;
|
||||
padding: 10px 10px 0px 10px;
|
||||
}
|
||||
.modal-footer.svelte-localhja {
|
||||
border-top: 0px;
|
||||
}
|
||||
.svelte-localhja::-webkit-scrollbar {
|
||||
width: 30px;
|
||||
}
|
||||
.svelte-localhja::-webkit-scrollbar-thumb {
|
||||
background: var(--orange);
|
||||
border-radius: 20px;
|
||||
}
|
||||
.card.svelte-localhja {
|
||||
max-width: 78%;
|
||||
}
|
||||
.performer-card.svelte-localhja {
|
||||
cursor: pointer;
|
||||
}
|
||||
.performer-card-image .svelte-localhja {
|
||||
min-width: none !important;
|
||||
aspect-ratio: 4/5;
|
||||
}
|
||||
.assigned.svelte-localhja {
|
||||
border: 5px solid var(--green);
|
||||
animation: border 1s ease-in-out;
|
||||
}
|
||||
.face-tab.svelte-localhja {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
object-fit: cover;
|
||||
}
|
||||
.selected.svelte-localhjb {
|
||||
border: 2px solid #007bff;
|
||||
}
|
||||
.face-tabs.svelte-localhjb {
|
||||
position: absolute;
|
||||
flex: 0 0 450px;
|
||||
max-width: 450px;
|
||||
min-width: 450px;
|
||||
height: 100%;
|
||||
overflow: auto;
|
||||
order: -1;
|
||||
background-color: var(--body-color);
|
||||
}
|
||||
.face-item.svelte-localhjb {
|
||||
width: 160px;
|
||||
height: 90px;
|
||||
border-radius: 5px 5px 0px 0px;
|
||||
position: relative;
|
||||
cursor: pointer;
|
||||
}
|
||||
.svelte-tabs__tab.svelte-localhjc {
|
||||
border: none;
|
||||
border-bottom: 2px solid transparent;
|
||||
color: #000000;
|
||||
cursor: pointer;
|
||||
list-style: none;
|
||||
display: inline-block;
|
||||
padding: 0.5em 0.75em;
|
||||
}
|
||||
.svelte-tabs__tab.svelte-localhjc:focus {
|
||||
outline: thin dotted;
|
||||
}
|
||||
.svelte-tabs__selected.svelte-localhjc {
|
||||
border-bottom: 2px solid #4f81e5;
|
||||
color: #4f81e5;
|
||||
}
|
||||
.svelte-tabs__tab-panel.svelte-lcocalhjd {
|
||||
margin-top: 0.5em;
|
||||
}
|
||||
.svelte-tabs__tab-list.svelte-localhje {
|
||||
border-bottom: 1px solid #cccccc;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
14579
stash/config/plugins/community/LocalVisage/frontend.js
Normal file
14579
stash/config/plugins/community/LocalVisage/frontend.js
Normal file
File diff suppressed because one or more lines are too long
25
stash/config/plugins/community/LocalVisage/manifest
Executable file
25
stash/config/plugins/community/LocalVisage/manifest
Executable file
@@ -0,0 +1,25 @@
|
||||
id: LocalVisage
|
||||
name: Local Visage
|
||||
metadata:
|
||||
description: Local Performer Recognition plugin using DeepFace
|
||||
version: 1.0.1-05b0e72
|
||||
date: "2025-07-16 04:26:21"
|
||||
requires: []
|
||||
source_repository: https://stashapp.github.io/CommunityScripts/stable/index.yml
|
||||
files:
|
||||
- stashface/utils/__init__.py
|
||||
- stashface/utils/vtt_parser.py
|
||||
- stashface/app.py
|
||||
- stashface/models/image_processor.py
|
||||
- stashface/models/data_manager.py
|
||||
- stashface/models/__init__.py
|
||||
- stashface/models/face_recognition.py
|
||||
- stashface/web/__init__.py
|
||||
- stashface/web/interface.py
|
||||
- .gitignore
|
||||
- frontend.js
|
||||
- LocalVisage.yml
|
||||
- requirements.txt
|
||||
- LocalVisage.py
|
||||
- readme.md
|
||||
- frontend.css
|
||||
69
stash/config/plugins/community/LocalVisage/readme.md
Normal file
69
stash/config/plugins/community/LocalVisage/readme.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# Local Performer Recognition
|
||||
|
||||
https://discourse.stashapp.cc/t/local-visage/2478
|
||||
|
||||
A plugin for recognizing performers from their images using [DeepFace](https://github.com/serengil/deepface). This plugin integrates seamlessly with Stash and enables automatic facial recognition by building or updating a local model trained from your existing image collection.
|
||||
|
||||
## 🔍 Features
|
||||
|
||||
- **Rebuild Face Recognition Model**
|
||||
Completely rebuild the local facial recognition model using available images per performer.
|
||||
|
||||
- **Update Face Recognition Model**
|
||||
Incrementally updates the model if performers have fewer images than the configured target count.
|
||||
|
||||
- **Automatic Server Control**
|
||||
Easily start or stop the recognition server as needed—automatically starts when an image is queried.
|
||||
|
||||
- **Identify**
|
||||
Click on the new icon next to an image to trigger performer identification.
|
||||
|
||||
## 📦 Requirements
|
||||
|
||||
- Python 3.10.11 (temporarily, see instructions below)
|
||||
- `PythonDepManager`
|
||||
- `stashUserscriptLibrary7djx1qp` (add repo https://7djx1qp.github.io/stash-plugins/
|
||||
|
||||
## ⚙️ Tasks
|
||||
|
||||
| Task | Description |
|
||||
| ---------------------------------- | --------------------------------------------------------------------- |
|
||||
| **Rebuild Face Recognition Model** | Fully rebuild the DeepFace model for all performers. |
|
||||
| **Update Face Recognition Model** | Add more images for performers with less than the target image count. |
|
||||
| **Start Server** | Start the local DeepFace server if it's not already running. |
|
||||
| **Stop Server** | Gracefully stop the running recognition server. |
|
||||
|
||||
## 🔧 Settings
|
||||
|
||||
| Setting | Description |
|
||||
| ------------------------------ | ------------------------------------------------------------------------------- |
|
||||
| **Target image count per voy** | Number of images to use per performer when training the model. Default is `15`. |
|
||||
|
||||
## 🚀 Installation & Setup
|
||||
|
||||
### 1. Set Python Path to 3.10.11
|
||||
|
||||
To ensure compatibility with DeepFace and the plugin’s dependency resolution process:
|
||||
|
||||
- Temporarily set the Python path in your system/environment to **Python 3.10.11**.
|
||||
|
||||
### 2. Rebuild the Model
|
||||
|
||||
Run the **"Rebuild Face Recognition Model"** task. This will:
|
||||
|
||||
- Set up a virtual environment
|
||||
- Install all necessary Python dependencies (DeepFace, etc.)
|
||||
- Build the recognition model
|
||||
|
||||
### 3. Restore Python Path (Optional)
|
||||
|
||||
Once setup is complete, you can revert your Python path to its original version. The plugin will continue working with the generated virtual environment.
|
||||
|
||||
## 🖼 Usage
|
||||
|
||||
1. Once the model is built, navigate to an image in your Stash UI.
|
||||
2. Click the **Performer Recognition** icon overlaying the image.
|
||||
3. The plugin will:
|
||||
- Automatically start the recognition server if it's not already running
|
||||
- Query the server to identify the performer
|
||||
- Display the matched performer from the trained database
|
||||
129
stash/config/plugins/community/LocalVisage/requirements.txt
Normal file
129
stash/config/plugins/community/LocalVisage/requirements.txt
Normal file
@@ -0,0 +1,129 @@
|
||||
#Dont install this manually. The plugin will create a venv and install the requirements automatically.
|
||||
#
|
||||
#nvidia-cublas-cu12==12.4.5.8
|
||||
#nvidia-cuda-cupti-cu12==12.4.127
|
||||
#nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
#nvidia-cuda-runtime-cu12==12.4.127
|
||||
#nvidia-cudnn-cu12==9.1.0.70
|
||||
#nvidia-cufft-cu12==11.2.1.3
|
||||
#nvidia-curand-cu12==10.3.5.147
|
||||
#nvidia-cusolver-cu12==11.6.1.9
|
||||
#nvidia-cusparse-cu12==12.3.1.170
|
||||
#nvidia-cusparselt-cu12==0.6.2
|
||||
#nvidia-nccl-cu12==2.21.5
|
||||
#nvidia-nvjitlink-cu12==12.4.127
|
||||
#nvidia-nvtx-cu12==12.4.127
|
||||
stashapp-tools>=0.2.58
|
||||
absl-py==2.2.2
|
||||
aiofiles==24.1.0
|
||||
annotated-types==0.7.0
|
||||
anyio==4.9.0
|
||||
astunparse==1.6.3
|
||||
beautifulsoup4==4.13.4
|
||||
blinker==1.9.0
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
click==8.1.8
|
||||
contourpy==1.3.2
|
||||
cycler==0.12.1
|
||||
deepface @ git+https://github.com/serengil/deepface.git@cc484b54be5188eb47faf132995af16a871d70b9
|
||||
fastapi==0.115.12
|
||||
ffmpy==0.5.0
|
||||
filelock==3.18.0
|
||||
fire==0.7.0
|
||||
flask==3.1.0
|
||||
flask-cors==5.0.1
|
||||
flatbuffers==25.2.10
|
||||
fonttools==4.57.0
|
||||
fsspec==2025.3.2
|
||||
gast==0.6.0
|
||||
gdown==5.2.0
|
||||
google-pasta==0.2.0
|
||||
gradio==5.25.2
|
||||
gradio-client==1.8.0
|
||||
groovy==0.1.2
|
||||
grpcio==1.71.0
|
||||
gunicorn==23.0.0
|
||||
h11==0.14.0
|
||||
h5py==3.13.0
|
||||
httpcore==1.0.8
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.30.2
|
||||
idna==3.10
|
||||
itsdangerous==2.2.0
|
||||
jinja2==3.1.6
|
||||
joblib==1.4.2
|
||||
keras==3.9.2
|
||||
kiwisolver==1.4.8
|
||||
libclang==18.1.1
|
||||
lz4==4.4.4
|
||||
markdown==3.8
|
||||
markdown-it-py==3.0.0
|
||||
markupsafe==3.0.2
|
||||
matplotlib==3.10.1
|
||||
mdurl==0.1.2
|
||||
ml-dtypes==0.5.1
|
||||
mpmath==1.3.0
|
||||
mtcnn==1.0.0
|
||||
namex==0.0.8
|
||||
networkx==3.4.2
|
||||
numpy==2.1.3
|
||||
opencv-python==4.11.0.86
|
||||
opt-einsum==3.4.0
|
||||
optree==0.15.0
|
||||
orjson==3.10.16
|
||||
packaging==25.0
|
||||
pandas==2.2.3
|
||||
pillow==11.2.1
|
||||
protobuf==5.29.4
|
||||
psutil==7.0.0
|
||||
py-cpuinfo==9.0.0
|
||||
pycryptodomex==3.22.0
|
||||
pydantic==2.11.3
|
||||
pydantic-core==2.33.1
|
||||
pydub==0.25.1
|
||||
pygments==2.19.1
|
||||
pyparsing==3.2.3
|
||||
pysocks==1.7.1
|
||||
python-dateutil==2.9.0.post0
|
||||
python-multipart==0.0.20
|
||||
pytz==2025.2
|
||||
pyyaml==6.0.2
|
||||
pyzipper==0.3.6
|
||||
requests==2.32.3
|
||||
retina-face==0.0.17
|
||||
rich==14.0.0
|
||||
ruff==0.11.6
|
||||
safehttpx==0.1.6
|
||||
scipy==1.15.2
|
||||
seaborn==0.13.2
|
||||
semantic-version==2.10.0
|
||||
setuptools==78.1.0
|
||||
shellingham==1.5.4
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
soupsieve==2.6
|
||||
starlette==0.46.2
|
||||
sympy==1.13.1
|
||||
tensorboard==2.19.0
|
||||
tensorboard-data-server==0.7.2
|
||||
tensorflow==2.19.0
|
||||
termcolor==3.0.1
|
||||
tf-keras==2.19.0
|
||||
tomlkit==0.13.2
|
||||
torch==2.6.0
|
||||
torchvision==0.21.0
|
||||
tqdm==4.67.1
|
||||
typer==0.15.2
|
||||
typing-extensions==4.13.2
|
||||
typing-inspection==0.4.0
|
||||
tzdata==2025.2
|
||||
ultralytics==8.3.69
|
||||
ultralytics-thop==2.0.14
|
||||
urllib3==2.4.0
|
||||
uvicorn==0.34.2
|
||||
#voyager==2.1.0
|
||||
websockets==15.0.1
|
||||
werkzeug==3.1.3
|
||||
wheel==0.45.1
|
||||
wrapt==1.17.2
|
||||
38
stash/config/plugins/community/LocalVisage/stashface/app.py
Normal file
38
stash/config/plugins/community/LocalVisage/stashface/app.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import os
|
||||
import sys
|
||||
# Set DeepFace home directory
|
||||
os.environ["DEEPFACE_HOME"] = "."
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Suppress TF logs
|
||||
# Add the plugins directory to sys.path
|
||||
plugins_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
if plugins_dir not in sys.path:
|
||||
sys.path.insert(0, plugins_dir)
|
||||
|
||||
|
||||
from stashapi.stashapp import StashInterface
|
||||
|
||||
|
||||
try:
|
||||
from models.data_manager import DataManager
|
||||
from web.interface import WebInterface
|
||||
except ImportError as e:
|
||||
print(f"Error importing modules: {e}")
|
||||
input("Ensure you have installed the required dependencies. Press Enter to exit.")
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the application"""
|
||||
# Initialize data manager
|
||||
data_manager = DataManager(
|
||||
voy_root_folder=os.path.abspath(os.path.join(os.path.dirname(__file__),"../voy_db")),
|
||||
)
|
||||
|
||||
# Initialize and launch web interface
|
||||
web_interface = WebInterface(data_manager, default_threshold=0.5)
|
||||
web_interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1 @@
|
||||
# models package
|
||||
@@ -0,0 +1,116 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
import numpy as np
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../py_dependencies/numpy_1.26.4")))
|
||||
|
||||
server_connection = json.loads(os.environ.get("SERVER_CONNECTION"))
|
||||
from stashapi.stashapp import StashInterface
|
||||
|
||||
class DataManager:
|
||||
def __init__(self, voy_root_folder):
|
||||
"""
|
||||
Initialize the data manager using folders of .voy files for each model.
|
||||
Parameters:
|
||||
voy_root_folder: Path to the root folder containing 'facenet' and 'arc' subfolders.
|
||||
"""
|
||||
self.voy_root_folder = voy_root_folder
|
||||
self.embeddings = {
|
||||
"facenet": {}, # Dict[str, Dict[str, Any]]
|
||||
"arc": {}
|
||||
}
|
||||
self._load_voy_files()
|
||||
self.stash = StashInterface(server_connection)
|
||||
|
||||
def _load_voy_files(self):
|
||||
"""Load all .voy files for each model into memory."""
|
||||
for model in ["facenet", "arc"]:
|
||||
folder = os.path.join(self.voy_root_folder, model)
|
||||
self.embeddings[model] = {}
|
||||
if not os.path.isdir(folder):
|
||||
continue
|
||||
for fname in os.listdir(folder):
|
||||
if fname.endswith(".voy.npy") or fname.endswith(".voy"):
|
||||
try:
|
||||
# Remove .voy or .voy.npy
|
||||
if fname.endswith(".voy.npy"):
|
||||
id_name = fname[:-8]
|
||||
else:
|
||||
id_name = fname[:-4]
|
||||
stash_id, name = id_name.split("-", 1)
|
||||
path = os.path.join(folder, fname)
|
||||
embedding = np.load(path)
|
||||
self.embeddings[model][stash_id] = {
|
||||
"name": name,
|
||||
"embedding": embedding
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error loading {fname} for {model}: {e}")
|
||||
|
||||
def get_all_ids(self, model: str = "facenet") -> List[str]:
|
||||
"""Return all performer IDs for a given model."""
|
||||
return list(self.embeddings.get(model, {}).keys())
|
||||
|
||||
def get_performer_info(self, stash_id: str, confidence: float) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get performer information from the loaded embeddings.
|
||||
Parameters:
|
||||
stash_id: Stash ID of the performer
|
||||
confidence: Confidence score (0-1)
|
||||
Returns:
|
||||
Dictionary with performer information or None if not found
|
||||
"""
|
||||
performer = self.stash.find_performer(stash_id)
|
||||
if not performer:
|
||||
# Fallback to embedding name if performer not found
|
||||
for model in self.embeddings:
|
||||
if stash_id in self.embeddings[model]:
|
||||
name = self.embeddings[model][stash_id].get("name", "Unknown")
|
||||
break
|
||||
else:
|
||||
name = "Unknown"
|
||||
return {
|
||||
'id': stash_id,
|
||||
"name": name,
|
||||
"image": None,
|
||||
"confidence": int(confidence * 100),
|
||||
}
|
||||
return {
|
||||
'id': stash_id,
|
||||
"name": performer['name'],
|
||||
"image": urlparse(performer['image_path']).path if performer.get('image_path') else None,
|
||||
"confidence": int(confidence * 100),
|
||||
'country': performer.get('country'),
|
||||
'distance': int(confidence * 100),
|
||||
'performer_url': f"/performers/{stash_id}"
|
||||
}
|
||||
|
||||
def query_index(self, model: str, embedding: np.ndarray, limit: int = 5):
|
||||
"""
|
||||
Query the loaded embeddings for the closest matches using cosine similarity for a given model.
|
||||
Parameters:
|
||||
model: 'facenet' or 'arc'
|
||||
embedding: The embedding to compare
|
||||
limit: Number of top matches to return
|
||||
Returns:
|
||||
List of (stash_id, distance) tuples, sorted by distance ascending
|
||||
"""
|
||||
results = []
|
||||
for stash_id, data in self.embeddings.get(model, {}).items():
|
||||
db_embedding = data["embedding"]
|
||||
sim = np.dot(embedding, db_embedding) / (np.linalg.norm(embedding) * np.linalg.norm(db_embedding))
|
||||
distance = 1 - sim
|
||||
results.append((stash_id, distance))
|
||||
results.sort(key=lambda x: x[1])
|
||||
return results[:limit]
|
||||
|
||||
def query_facenet_index(self, embedding: np.ndarray, limit: int = 5):
|
||||
"""Query the Facenet index."""
|
||||
return self.query_index("facenet", embedding, limit)
|
||||
|
||||
def query_arc_index(self, embedding: np.ndarray, limit: int = 5):
|
||||
"""Query the ArcFace index."""
|
||||
return self.query_index("arc", embedding, limit)
|
||||
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from deepface import DeepFace
|
||||
|
||||
class EnsembleFaceRecognition:
|
||||
def __init__(self, model_weights: Dict[str, float] = None):
|
||||
"""
|
||||
Initialize ensemble face recognition system.
|
||||
|
||||
Parameters:
|
||||
model_weights: Dictionary mapping model names to their weights.
|
||||
If None, all models are weighted equally.
|
||||
"""
|
||||
self.model_weights = model_weights or {}
|
||||
self.boost_factor = 1.8
|
||||
|
||||
def normalize_distances(self, distances: np.ndarray) -> np.ndarray:
|
||||
"""Normalize distances to [0,1] range within each model's predictions."""
|
||||
min_dist = np.min(distances)
|
||||
max_dist = np.max(distances)
|
||||
if max_dist == min_dist:
|
||||
return np.zeros_like(distances)
|
||||
return (distances - min_dist) / (max_dist - min_dist)
|
||||
|
||||
def compute_model_confidence(self, distances: np.ndarray, temperature: float = 0.1) -> np.ndarray:
|
||||
"""Convert distances to confidence scores for a single model."""
|
||||
normalized_distances = self.normalize_distances(distances)
|
||||
exp_distances = np.exp(-normalized_distances / temperature)
|
||||
return exp_distances / np.sum(exp_distances)
|
||||
|
||||
def get_face_embeddings(self, image_path: str) -> Dict[str, np.ndarray]:
|
||||
"""Get face embeddings for each model from an image path."""
|
||||
return {
|
||||
'facenet': DeepFace.represent(img_path=image_path, detector_backend='skip', model_name='Facenet512', normalization='Facenet2018', align=True)[0]['embedding'],
|
||||
'arc': DeepFace.represent(img_path=image_path, detector_backend='skip', model_name='ArcFace', align=True)[0]['embedding']
|
||||
}
|
||||
|
||||
def ensemble_prediction(
|
||||
self,
|
||||
model_predictions: Dict[str, Tuple[List[str], List[float]]],
|
||||
temperature: float = 0.1,
|
||||
min_agreement: float = 0.5
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Combine predictions from multiple models.
|
||||
|
||||
Parameters:
|
||||
model_predictions: Dictionary mapping model names to their (names, distances) predictions.
|
||||
temperature: Temperature parameter for softmax scaling.
|
||||
min_agreement: Minimum agreement threshold between models.
|
||||
|
||||
Returns:
|
||||
final_predictions: List of (name, confidence) tuples.
|
||||
"""
|
||||
vote_dict = {}
|
||||
confidence_dict = {}
|
||||
|
||||
for model_name, (names, distances) in model_predictions.items():
|
||||
model_weight = self.model_weights.get(model_name, 1.0)
|
||||
confidences = self.compute_model_confidence(np.array(distances), temperature)
|
||||
top_name = names[0]
|
||||
top_confidence = confidences[0]
|
||||
vote_dict[top_name] = vote_dict.get(top_name, 0) + model_weight
|
||||
confidence_dict[top_name] = confidence_dict.get(top_name, [])
|
||||
confidence_dict[top_name].append(top_confidence)
|
||||
|
||||
total_weight = sum(self.model_weights.values()) if self.model_weights else len(model_predictions)
|
||||
final_results = []
|
||||
for name, votes in vote_dict.items():
|
||||
normalized_votes = votes / total_weight
|
||||
if normalized_votes >= min_agreement:
|
||||
avg_confidence = np.mean(confidence_dict[name])
|
||||
final_score = normalized_votes * avg_confidence * self.boost_factor
|
||||
final_score = min(final_score, 1.0)
|
||||
final_results.append((name, final_score))
|
||||
|
||||
final_results.sort(key=lambda x: x[1], reverse=True)
|
||||
return final_results
|
||||
|
||||
def extract_faces(image_path):
|
||||
"""Extract faces from an image using DeepFace (YoloV8 backend)."""
|
||||
return DeepFace.extract_faces(img_path=image_path, detector_backend="yolov8")
|
||||
|
||||
def extract_faces_mediapipe(image_path, enforce_detection=False, align=False):
|
||||
"""Extract faces from an image using MediaPipe backend."""
|
||||
return DeepFace.extract_faces(img_path=image_path, detector_backend="mediapipe",
|
||||
enforce_detection=enforce_detection,
|
||||
align=align)
|
||||
@@ -0,0 +1,159 @@
|
||||
import io
|
||||
import base64
|
||||
import numpy as np
|
||||
from uuid import uuid4
|
||||
from PIL import Image as PILImage
|
||||
from typing import List, Dict, Any, Tuple
|
||||
import logging
|
||||
|
||||
from models.face_recognition import EnsembleFaceRecognition, extract_faces, extract_faces_mediapipe
|
||||
from models.data_manager import DataManager
|
||||
from utils.vtt_parser import parse_vtt_offsets
|
||||
|
||||
def get_face_predictions_ensemble(face, data_manager, results=3, max_distance=0.8):
|
||||
"""
|
||||
Get predictions for a single face using both Facenet and ArcFace, then ensemble.
|
||||
|
||||
Parameters:
|
||||
face: Face image array
|
||||
data_manager: DataManager instance
|
||||
results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of (stash_id, confidence) tuples
|
||||
"""
|
||||
# Get embeddings for original and flipped images, then average
|
||||
from deepface import DeepFace
|
||||
embedding_facenet_orig = DeepFace.represent(img_path=face, detector_backend='skip', model_name='Facenet512', normalization='Facenet2018', align=True)[0]['embedding']
|
||||
embedding_facenet_flip = DeepFace.represent(img_path=np.fliplr(face), detector_backend='skip', model_name='Facenet512', normalization='Facenet2018', align=True)[0]['embedding']
|
||||
embedding_facenet = np.mean([embedding_facenet_orig, embedding_facenet_flip], axis=0)
|
||||
|
||||
embedding_arc_orig = DeepFace.represent(img_path=face, detector_backend='skip', model_name='ArcFace', align=True)[0]['embedding']
|
||||
embedding_arc_flip = DeepFace.represent(img_path=np.fliplr(face), detector_backend='skip', model_name='ArcFace', align=True)[0]['embedding']
|
||||
embedding_arc = np.mean([embedding_arc_orig, embedding_arc_flip], axis=0)
|
||||
|
||||
# Query DataManager for closest matches for both models
|
||||
preds_facenet = data_manager.query_facenet_index(embedding_facenet, limit=results)
|
||||
preds_arc = data_manager.query_arc_index(embedding_arc, limit=results)
|
||||
|
||||
# Filter by distance threshold
|
||||
filtered_facenet = [(stash_id, dist) for stash_id, dist in preds_facenet if dist < max_distance]
|
||||
filtered_arc = [(stash_id, dist) for stash_id, dist in preds_arc if dist < max_distance]
|
||||
|
||||
# Prepare for ensemble
|
||||
model_predictions = {}
|
||||
if filtered_facenet:
|
||||
names_f, dists_f = zip(*filtered_facenet)
|
||||
model_predictions['facenet'] = (list(names_f), list(dists_f))
|
||||
if filtered_arc:
|
||||
names_a, dists_a = zip(*filtered_arc)
|
||||
model_predictions['arc'] = (list(names_a), list(dists_a))
|
||||
|
||||
if not model_predictions:
|
||||
return []
|
||||
|
||||
ensemble = EnsembleFaceRecognition()
|
||||
return ensemble.ensemble_prediction(model_predictions)
|
||||
|
||||
def image_search_performer(image, data_manager, threshold=0.5, results=3):
|
||||
"""
|
||||
Search for a performer in an image using both Facenet and ArcFace.
|
||||
|
||||
Parameters:
|
||||
image: PIL Image object
|
||||
data_manager: DataManager instance
|
||||
threshold: Confidence threshold
|
||||
results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of performer information dictionaries
|
||||
"""
|
||||
image_array = np.array(image)
|
||||
try:
|
||||
faces = extract_faces(image_array)
|
||||
except ValueError:
|
||||
raise ValueError("No faces found")
|
||||
|
||||
predictions = get_face_predictions_ensemble(faces[0]['face'], data_manager, results)
|
||||
logging.info(f"Predictions: {predictions}")
|
||||
response = []
|
||||
for stash_id, confidence in predictions:
|
||||
if confidence < threshold:
|
||||
continue
|
||||
performer_info = data_manager.get_performer_info(stash_id, confidence)
|
||||
if performer_info:
|
||||
response.append(performer_info)
|
||||
print(response)
|
||||
return response
|
||||
|
||||
def image_search_performers(image, data_manager, threshold=0.5, results=3):
|
||||
"""
|
||||
Search for multiple performers in an image using both Facenet and ArcFace.
|
||||
|
||||
Parameters:
|
||||
image: PIL Image object
|
||||
data_manager: DataManager instance
|
||||
threshold: Confidence threshold
|
||||
results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of dictionaries with face image and performer information
|
||||
"""
|
||||
image_array = np.array(image)
|
||||
try:
|
||||
faces = extract_faces(image_array)
|
||||
except ValueError:
|
||||
raise ValueError("No faces found")
|
||||
|
||||
response = []
|
||||
for face in faces:
|
||||
predictions = get_face_predictions_ensemble(face['face'], data_manager, results)
|
||||
|
||||
# Crop and encode face image
|
||||
area = face['facial_area']
|
||||
cimage = image.crop((area['x'], area['y'], area['x'] + area['w'], area['y'] + area['h']))
|
||||
buf = io.BytesIO()
|
||||
cimage.save(buf, format='JPEG')
|
||||
im_b64 = base64.b64encode(buf.getvalue()).decode('ascii')
|
||||
|
||||
# Get performer information
|
||||
performers = []
|
||||
for stash_id, confidence in predictions:
|
||||
if confidence < threshold:
|
||||
continue
|
||||
performer_info = data_manager.get_performer_info(stash_id, confidence)
|
||||
if performer_info:
|
||||
performers.append(performer_info)
|
||||
|
||||
response.append({
|
||||
'image': im_b64,
|
||||
'confidence': face['confidence'],
|
||||
'performers': performers
|
||||
})
|
||||
return response
|
||||
|
||||
def find_faces_in_sprite(image, vtt_data):
|
||||
"""
|
||||
Find faces in a sprite image using VTT data
|
||||
|
||||
Parameters:
|
||||
image: PIL Image object
|
||||
vtt_data: Base64 encoded VTT data
|
||||
|
||||
Returns:
|
||||
List of dictionaries with face information
|
||||
"""
|
||||
vtt = base64.b64decode(vtt_data.replace("data:text/vtt;base64,", ""))
|
||||
sprite = PILImage.fromarray(image)
|
||||
|
||||
results = []
|
||||
for i, (left, top, right, bottom, time_seconds) in enumerate(parse_vtt_offsets(vtt)):
|
||||
cut_frame = sprite.crop((left, top, left + right, top + bottom))
|
||||
faces = extract_faces_mediapipe(np.asarray(cut_frame), enforce_detection=False, align=False)
|
||||
faces = [face for face in faces if face['confidence'] > 0.6]
|
||||
if faces:
|
||||
size = faces[0]['facial_area']['w'] * faces[0]['facial_area']['h']
|
||||
data = {'id': str(uuid4()), "offset": (left, top, right, bottom), "frame": i, "time": time_seconds, 'size': size}
|
||||
results.append(data)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1 @@
|
||||
# utils package
|
||||
@@ -0,0 +1,44 @@
|
||||
from typing import List, Tuple, Generator
|
||||
|
||||
def parse_vtt_offsets(vtt_content: bytes) -> Generator[Tuple[int, int, int, int, float], None, None]:
|
||||
"""
|
||||
Parse VTT file content and extract offsets and timestamps.
|
||||
|
||||
Parameters:
|
||||
vtt_content: Raw VTT file content as bytes
|
||||
|
||||
Returns:
|
||||
Generator yielding tuples of (left, top, right, bottom, time_seconds)
|
||||
"""
|
||||
time_seconds = 0
|
||||
left = top = right = bottom = None
|
||||
|
||||
for line in vtt_content.decode("utf-8").split("\n"):
|
||||
line = line.strip()
|
||||
|
||||
if "-->" in line:
|
||||
# grab the start time
|
||||
# 00:00:00.000 --> 00:00:41.000
|
||||
start = line.split("-->")[0].strip().split(":")
|
||||
# convert to seconds
|
||||
time_seconds = (
|
||||
int(start[0]) * 3600
|
||||
+ int(start[1]) * 60
|
||||
+ float(start[2])
|
||||
)
|
||||
left = top = right = bottom = None
|
||||
elif "xywh=" in line:
|
||||
left, top, right, bottom = line.split("xywh=")[-1].split(",")
|
||||
left, top, right, bottom = (
|
||||
int(left),
|
||||
int(top),
|
||||
int(right),
|
||||
int(bottom),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
if not left:
|
||||
continue
|
||||
|
||||
yield left, top, right, bottom, time_seconds
|
||||
@@ -0,0 +1 @@
|
||||
# web package
|
||||
@@ -0,0 +1,174 @@
|
||||
import gradio as gr
|
||||
from typing import Dict, Any
|
||||
|
||||
from models.data_manager import DataManager
|
||||
from models.image_processor import (
|
||||
image_search_performer,
|
||||
image_search_performers,
|
||||
find_faces_in_sprite
|
||||
)
|
||||
|
||||
class WebInterface:
|
||||
def __init__(self, data_manager: DataManager, default_threshold: float = 0.5):
|
||||
"""
|
||||
Initialize the web interface.
|
||||
|
||||
Parameters:
|
||||
data_manager: DataManager instance
|
||||
default_threshold: Default confidence threshold
|
||||
"""
|
||||
self.data_manager = data_manager
|
||||
self.default_threshold = default_threshold
|
||||
|
||||
def image_search(self, img, threshold, results):
|
||||
"""Wrapper for the image search function"""
|
||||
return image_search_performer(img, self.data_manager, threshold, results)
|
||||
|
||||
def multiple_image_search(self, img, threshold, results):
|
||||
"""Wrapper for the multiple image search function"""
|
||||
return image_search_performers(img, self.data_manager, threshold, results)
|
||||
|
||||
def vector_search(self, vector_json, threshold, results):
|
||||
"""Wrapper for the vector search function (deprecated)"""
|
||||
return {'status': 'not implemented'}
|
||||
|
||||
def _create_image_search_interface(self):
|
||||
"""Create the single face search interface"""
|
||||
with gr.Blocks() as interface:
|
||||
gr.Markdown("# Who is in the photo?")
|
||||
gr.Markdown("Upload an image of a person and we'll tell you who it is.")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_input = gr.Image()
|
||||
threshold = gr.Slider(
|
||||
label="threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=self.default_threshold
|
||||
)
|
||||
results_count = gr.Slider(
|
||||
label="results",
|
||||
minimum=0,
|
||||
maximum=50,
|
||||
value=3,
|
||||
step=1
|
||||
)
|
||||
search_btn = gr.Button("Search")
|
||||
|
||||
with gr.Column():
|
||||
output = gr.JSON(label="Results")
|
||||
|
||||
search_btn.click(
|
||||
fn=self.image_search,
|
||||
inputs=[img_input, threshold, results_count],
|
||||
outputs=output
|
||||
)
|
||||
|
||||
return interface
|
||||
|
||||
def _create_multiple_image_search_interface(self):
|
||||
"""Create the multiple face search interface"""
|
||||
with gr.Blocks() as interface:
|
||||
gr.Markdown("# Who is in the photo?")
|
||||
gr.Markdown("Upload an image of a person(s) and we'll tell you who it is.")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_input = gr.Image(type="pil")
|
||||
threshold = gr.Slider(
|
||||
label="threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=self.default_threshold
|
||||
)
|
||||
results_count = gr.Slider(
|
||||
label="results",
|
||||
minimum=0,
|
||||
maximum=50,
|
||||
value=3,
|
||||
step=1
|
||||
)
|
||||
search_btn = gr.Button("Search")
|
||||
|
||||
with gr.Column():
|
||||
output = gr.JSON(label="Results")
|
||||
|
||||
search_btn.click(
|
||||
fn=self.multiple_image_search,
|
||||
inputs=[img_input, threshold, results_count],
|
||||
outputs=output
|
||||
)
|
||||
|
||||
return interface
|
||||
|
||||
def _create_vector_search_interface(self):
|
||||
"""Create the vector search interface (deprecated)"""
|
||||
with gr.Blocks() as interface:
|
||||
gr.Markdown("# Vector Search (deprecated)")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
vector_input = gr.Textbox()
|
||||
threshold = gr.Slider(
|
||||
label="threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=self.default_threshold
|
||||
)
|
||||
results_count = gr.Slider(
|
||||
label="results",
|
||||
minimum=0,
|
||||
maximum=50,
|
||||
value=3,
|
||||
step=1
|
||||
)
|
||||
search_btn = gr.Button("Search")
|
||||
|
||||
with gr.Column():
|
||||
output = gr.JSON(label="Results")
|
||||
|
||||
search_btn.click(
|
||||
fn=self.vector_search,
|
||||
inputs=[vector_input, threshold, results_count],
|
||||
outputs=output
|
||||
)
|
||||
|
||||
return interface
|
||||
|
||||
def _create_faces_in_sprite_interface(self):
|
||||
"""Create the faces in sprite interface"""
|
||||
with gr.Blocks() as interface:
|
||||
gr.Markdown("# Find Faces in Sprite")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_input = gr.Image()
|
||||
vtt_input = gr.Textbox(label="VTT file")
|
||||
search_btn = gr.Button("Process")
|
||||
|
||||
with gr.Column():
|
||||
output = gr.JSON(label="Results")
|
||||
|
||||
search_btn.click(
|
||||
fn=find_faces_in_sprite,
|
||||
inputs=[img_input, vtt_input],
|
||||
outputs=output
|
||||
)
|
||||
|
||||
return interface
|
||||
|
||||
def launch(self, server_name="0.0.0.0", server_port=7860, share=True):
|
||||
"""Launch the web interface"""
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.TabItem("Single Face Search"):
|
||||
self._create_image_search_interface()
|
||||
with gr.TabItem("Multiple Face Search"):
|
||||
self._create_multiple_image_search_interface()
|
||||
with gr.TabItem("Vector Search"):
|
||||
self._create_vector_search_interface()
|
||||
with gr.TabItem("Faces in Sprite"):
|
||||
self._create_faces_in_sprite_interface()
|
||||
|
||||
demo.queue().launch(server_name=server_name,server_port=server_port,share=share, ssr_mode=False)
|
||||
Reference in New Issue
Block a user