stash
This commit is contained in:
38
stash/config/plugins/community/LocalVisage/stashface/app.py
Normal file
38
stash/config/plugins/community/LocalVisage/stashface/app.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import os
|
||||
import sys
|
||||
# Set DeepFace home directory
|
||||
os.environ["DEEPFACE_HOME"] = "."
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Suppress TF logs
|
||||
# Add the plugins directory to sys.path
|
||||
plugins_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
if plugins_dir not in sys.path:
|
||||
sys.path.insert(0, plugins_dir)
|
||||
|
||||
|
||||
from stashapi.stashapp import StashInterface
|
||||
|
||||
|
||||
try:
|
||||
from models.data_manager import DataManager
|
||||
from web.interface import WebInterface
|
||||
except ImportError as e:
|
||||
print(f"Error importing modules: {e}")
|
||||
input("Ensure you have installed the required dependencies. Press Enter to exit.")
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the application"""
|
||||
# Initialize data manager
|
||||
data_manager = DataManager(
|
||||
voy_root_folder=os.path.abspath(os.path.join(os.path.dirname(__file__),"../voy_db")),
|
||||
)
|
||||
|
||||
# Initialize and launch web interface
|
||||
web_interface = WebInterface(data_manager, default_threshold=0.5)
|
||||
web_interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1 @@
|
||||
# models package
|
||||
@@ -0,0 +1,116 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
import numpy as np
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../py_dependencies/numpy_1.26.4")))
|
||||
|
||||
server_connection = json.loads(os.environ.get("SERVER_CONNECTION"))
|
||||
from stashapi.stashapp import StashInterface
|
||||
|
||||
class DataManager:
|
||||
def __init__(self, voy_root_folder):
|
||||
"""
|
||||
Initialize the data manager using folders of .voy files for each model.
|
||||
Parameters:
|
||||
voy_root_folder: Path to the root folder containing 'facenet' and 'arc' subfolders.
|
||||
"""
|
||||
self.voy_root_folder = voy_root_folder
|
||||
self.embeddings = {
|
||||
"facenet": {}, # Dict[str, Dict[str, Any]]
|
||||
"arc": {}
|
||||
}
|
||||
self._load_voy_files()
|
||||
self.stash = StashInterface(server_connection)
|
||||
|
||||
def _load_voy_files(self):
|
||||
"""Load all .voy files for each model into memory."""
|
||||
for model in ["facenet", "arc"]:
|
||||
folder = os.path.join(self.voy_root_folder, model)
|
||||
self.embeddings[model] = {}
|
||||
if not os.path.isdir(folder):
|
||||
continue
|
||||
for fname in os.listdir(folder):
|
||||
if fname.endswith(".voy.npy") or fname.endswith(".voy"):
|
||||
try:
|
||||
# Remove .voy or .voy.npy
|
||||
if fname.endswith(".voy.npy"):
|
||||
id_name = fname[:-8]
|
||||
else:
|
||||
id_name = fname[:-4]
|
||||
stash_id, name = id_name.split("-", 1)
|
||||
path = os.path.join(folder, fname)
|
||||
embedding = np.load(path)
|
||||
self.embeddings[model][stash_id] = {
|
||||
"name": name,
|
||||
"embedding": embedding
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error loading {fname} for {model}: {e}")
|
||||
|
||||
def get_all_ids(self, model: str = "facenet") -> List[str]:
|
||||
"""Return all performer IDs for a given model."""
|
||||
return list(self.embeddings.get(model, {}).keys())
|
||||
|
||||
def get_performer_info(self, stash_id: str, confidence: float) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get performer information from the loaded embeddings.
|
||||
Parameters:
|
||||
stash_id: Stash ID of the performer
|
||||
confidence: Confidence score (0-1)
|
||||
Returns:
|
||||
Dictionary with performer information or None if not found
|
||||
"""
|
||||
performer = self.stash.find_performer(stash_id)
|
||||
if not performer:
|
||||
# Fallback to embedding name if performer not found
|
||||
for model in self.embeddings:
|
||||
if stash_id in self.embeddings[model]:
|
||||
name = self.embeddings[model][stash_id].get("name", "Unknown")
|
||||
break
|
||||
else:
|
||||
name = "Unknown"
|
||||
return {
|
||||
'id': stash_id,
|
||||
"name": name,
|
||||
"image": None,
|
||||
"confidence": int(confidence * 100),
|
||||
}
|
||||
return {
|
||||
'id': stash_id,
|
||||
"name": performer['name'],
|
||||
"image": urlparse(performer['image_path']).path if performer.get('image_path') else None,
|
||||
"confidence": int(confidence * 100),
|
||||
'country': performer.get('country'),
|
||||
'distance': int(confidence * 100),
|
||||
'performer_url': f"/performers/{stash_id}"
|
||||
}
|
||||
|
||||
def query_index(self, model: str, embedding: np.ndarray, limit: int = 5):
|
||||
"""
|
||||
Query the loaded embeddings for the closest matches using cosine similarity for a given model.
|
||||
Parameters:
|
||||
model: 'facenet' or 'arc'
|
||||
embedding: The embedding to compare
|
||||
limit: Number of top matches to return
|
||||
Returns:
|
||||
List of (stash_id, distance) tuples, sorted by distance ascending
|
||||
"""
|
||||
results = []
|
||||
for stash_id, data in self.embeddings.get(model, {}).items():
|
||||
db_embedding = data["embedding"]
|
||||
sim = np.dot(embedding, db_embedding) / (np.linalg.norm(embedding) * np.linalg.norm(db_embedding))
|
||||
distance = 1 - sim
|
||||
results.append((stash_id, distance))
|
||||
results.sort(key=lambda x: x[1])
|
||||
return results[:limit]
|
||||
|
||||
def query_facenet_index(self, embedding: np.ndarray, limit: int = 5):
|
||||
"""Query the Facenet index."""
|
||||
return self.query_index("facenet", embedding, limit)
|
||||
|
||||
def query_arc_index(self, embedding: np.ndarray, limit: int = 5):
|
||||
"""Query the ArcFace index."""
|
||||
return self.query_index("arc", embedding, limit)
|
||||
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from deepface import DeepFace
|
||||
|
||||
class EnsembleFaceRecognition:
|
||||
def __init__(self, model_weights: Dict[str, float] = None):
|
||||
"""
|
||||
Initialize ensemble face recognition system.
|
||||
|
||||
Parameters:
|
||||
model_weights: Dictionary mapping model names to their weights.
|
||||
If None, all models are weighted equally.
|
||||
"""
|
||||
self.model_weights = model_weights or {}
|
||||
self.boost_factor = 1.8
|
||||
|
||||
def normalize_distances(self, distances: np.ndarray) -> np.ndarray:
|
||||
"""Normalize distances to [0,1] range within each model's predictions."""
|
||||
min_dist = np.min(distances)
|
||||
max_dist = np.max(distances)
|
||||
if max_dist == min_dist:
|
||||
return np.zeros_like(distances)
|
||||
return (distances - min_dist) / (max_dist - min_dist)
|
||||
|
||||
def compute_model_confidence(self, distances: np.ndarray, temperature: float = 0.1) -> np.ndarray:
|
||||
"""Convert distances to confidence scores for a single model."""
|
||||
normalized_distances = self.normalize_distances(distances)
|
||||
exp_distances = np.exp(-normalized_distances / temperature)
|
||||
return exp_distances / np.sum(exp_distances)
|
||||
|
||||
def get_face_embeddings(self, image_path: str) -> Dict[str, np.ndarray]:
|
||||
"""Get face embeddings for each model from an image path."""
|
||||
return {
|
||||
'facenet': DeepFace.represent(img_path=image_path, detector_backend='skip', model_name='Facenet512', normalization='Facenet2018', align=True)[0]['embedding'],
|
||||
'arc': DeepFace.represent(img_path=image_path, detector_backend='skip', model_name='ArcFace', align=True)[0]['embedding']
|
||||
}
|
||||
|
||||
def ensemble_prediction(
|
||||
self,
|
||||
model_predictions: Dict[str, Tuple[List[str], List[float]]],
|
||||
temperature: float = 0.1,
|
||||
min_agreement: float = 0.5
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Combine predictions from multiple models.
|
||||
|
||||
Parameters:
|
||||
model_predictions: Dictionary mapping model names to their (names, distances) predictions.
|
||||
temperature: Temperature parameter for softmax scaling.
|
||||
min_agreement: Minimum agreement threshold between models.
|
||||
|
||||
Returns:
|
||||
final_predictions: List of (name, confidence) tuples.
|
||||
"""
|
||||
vote_dict = {}
|
||||
confidence_dict = {}
|
||||
|
||||
for model_name, (names, distances) in model_predictions.items():
|
||||
model_weight = self.model_weights.get(model_name, 1.0)
|
||||
confidences = self.compute_model_confidence(np.array(distances), temperature)
|
||||
top_name = names[0]
|
||||
top_confidence = confidences[0]
|
||||
vote_dict[top_name] = vote_dict.get(top_name, 0) + model_weight
|
||||
confidence_dict[top_name] = confidence_dict.get(top_name, [])
|
||||
confidence_dict[top_name].append(top_confidence)
|
||||
|
||||
total_weight = sum(self.model_weights.values()) if self.model_weights else len(model_predictions)
|
||||
final_results = []
|
||||
for name, votes in vote_dict.items():
|
||||
normalized_votes = votes / total_weight
|
||||
if normalized_votes >= min_agreement:
|
||||
avg_confidence = np.mean(confidence_dict[name])
|
||||
final_score = normalized_votes * avg_confidence * self.boost_factor
|
||||
final_score = min(final_score, 1.0)
|
||||
final_results.append((name, final_score))
|
||||
|
||||
final_results.sort(key=lambda x: x[1], reverse=True)
|
||||
return final_results
|
||||
|
||||
def extract_faces(image_path):
|
||||
"""Extract faces from an image using DeepFace (YoloV8 backend)."""
|
||||
return DeepFace.extract_faces(img_path=image_path, detector_backend="yolov8")
|
||||
|
||||
def extract_faces_mediapipe(image_path, enforce_detection=False, align=False):
|
||||
"""Extract faces from an image using MediaPipe backend."""
|
||||
return DeepFace.extract_faces(img_path=image_path, detector_backend="mediapipe",
|
||||
enforce_detection=enforce_detection,
|
||||
align=align)
|
||||
@@ -0,0 +1,159 @@
|
||||
import io
|
||||
import base64
|
||||
import numpy as np
|
||||
from uuid import uuid4
|
||||
from PIL import Image as PILImage
|
||||
from typing import List, Dict, Any, Tuple
|
||||
import logging
|
||||
|
||||
from models.face_recognition import EnsembleFaceRecognition, extract_faces, extract_faces_mediapipe
|
||||
from models.data_manager import DataManager
|
||||
from utils.vtt_parser import parse_vtt_offsets
|
||||
|
||||
def get_face_predictions_ensemble(face, data_manager, results=3, max_distance=0.8):
|
||||
"""
|
||||
Get predictions for a single face using both Facenet and ArcFace, then ensemble.
|
||||
|
||||
Parameters:
|
||||
face: Face image array
|
||||
data_manager: DataManager instance
|
||||
results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of (stash_id, confidence) tuples
|
||||
"""
|
||||
# Get embeddings for original and flipped images, then average
|
||||
from deepface import DeepFace
|
||||
embedding_facenet_orig = DeepFace.represent(img_path=face, detector_backend='skip', model_name='Facenet512', normalization='Facenet2018', align=True)[0]['embedding']
|
||||
embedding_facenet_flip = DeepFace.represent(img_path=np.fliplr(face), detector_backend='skip', model_name='Facenet512', normalization='Facenet2018', align=True)[0]['embedding']
|
||||
embedding_facenet = np.mean([embedding_facenet_orig, embedding_facenet_flip], axis=0)
|
||||
|
||||
embedding_arc_orig = DeepFace.represent(img_path=face, detector_backend='skip', model_name='ArcFace', align=True)[0]['embedding']
|
||||
embedding_arc_flip = DeepFace.represent(img_path=np.fliplr(face), detector_backend='skip', model_name='ArcFace', align=True)[0]['embedding']
|
||||
embedding_arc = np.mean([embedding_arc_orig, embedding_arc_flip], axis=0)
|
||||
|
||||
# Query DataManager for closest matches for both models
|
||||
preds_facenet = data_manager.query_facenet_index(embedding_facenet, limit=results)
|
||||
preds_arc = data_manager.query_arc_index(embedding_arc, limit=results)
|
||||
|
||||
# Filter by distance threshold
|
||||
filtered_facenet = [(stash_id, dist) for stash_id, dist in preds_facenet if dist < max_distance]
|
||||
filtered_arc = [(stash_id, dist) for stash_id, dist in preds_arc if dist < max_distance]
|
||||
|
||||
# Prepare for ensemble
|
||||
model_predictions = {}
|
||||
if filtered_facenet:
|
||||
names_f, dists_f = zip(*filtered_facenet)
|
||||
model_predictions['facenet'] = (list(names_f), list(dists_f))
|
||||
if filtered_arc:
|
||||
names_a, dists_a = zip(*filtered_arc)
|
||||
model_predictions['arc'] = (list(names_a), list(dists_a))
|
||||
|
||||
if not model_predictions:
|
||||
return []
|
||||
|
||||
ensemble = EnsembleFaceRecognition()
|
||||
return ensemble.ensemble_prediction(model_predictions)
|
||||
|
||||
def image_search_performer(image, data_manager, threshold=0.5, results=3):
|
||||
"""
|
||||
Search for a performer in an image using both Facenet and ArcFace.
|
||||
|
||||
Parameters:
|
||||
image: PIL Image object
|
||||
data_manager: DataManager instance
|
||||
threshold: Confidence threshold
|
||||
results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of performer information dictionaries
|
||||
"""
|
||||
image_array = np.array(image)
|
||||
try:
|
||||
faces = extract_faces(image_array)
|
||||
except ValueError:
|
||||
raise ValueError("No faces found")
|
||||
|
||||
predictions = get_face_predictions_ensemble(faces[0]['face'], data_manager, results)
|
||||
logging.info(f"Predictions: {predictions}")
|
||||
response = []
|
||||
for stash_id, confidence in predictions:
|
||||
if confidence < threshold:
|
||||
continue
|
||||
performer_info = data_manager.get_performer_info(stash_id, confidence)
|
||||
if performer_info:
|
||||
response.append(performer_info)
|
||||
print(response)
|
||||
return response
|
||||
|
||||
def image_search_performers(image, data_manager, threshold=0.5, results=3):
|
||||
"""
|
||||
Search for multiple performers in an image using both Facenet and ArcFace.
|
||||
|
||||
Parameters:
|
||||
image: PIL Image object
|
||||
data_manager: DataManager instance
|
||||
threshold: Confidence threshold
|
||||
results: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of dictionaries with face image and performer information
|
||||
"""
|
||||
image_array = np.array(image)
|
||||
try:
|
||||
faces = extract_faces(image_array)
|
||||
except ValueError:
|
||||
raise ValueError("No faces found")
|
||||
|
||||
response = []
|
||||
for face in faces:
|
||||
predictions = get_face_predictions_ensemble(face['face'], data_manager, results)
|
||||
|
||||
# Crop and encode face image
|
||||
area = face['facial_area']
|
||||
cimage = image.crop((area['x'], area['y'], area['x'] + area['w'], area['y'] + area['h']))
|
||||
buf = io.BytesIO()
|
||||
cimage.save(buf, format='JPEG')
|
||||
im_b64 = base64.b64encode(buf.getvalue()).decode('ascii')
|
||||
|
||||
# Get performer information
|
||||
performers = []
|
||||
for stash_id, confidence in predictions:
|
||||
if confidence < threshold:
|
||||
continue
|
||||
performer_info = data_manager.get_performer_info(stash_id, confidence)
|
||||
if performer_info:
|
||||
performers.append(performer_info)
|
||||
|
||||
response.append({
|
||||
'image': im_b64,
|
||||
'confidence': face['confidence'],
|
||||
'performers': performers
|
||||
})
|
||||
return response
|
||||
|
||||
def find_faces_in_sprite(image, vtt_data):
|
||||
"""
|
||||
Find faces in a sprite image using VTT data
|
||||
|
||||
Parameters:
|
||||
image: PIL Image object
|
||||
vtt_data: Base64 encoded VTT data
|
||||
|
||||
Returns:
|
||||
List of dictionaries with face information
|
||||
"""
|
||||
vtt = base64.b64decode(vtt_data.replace("data:text/vtt;base64,", ""))
|
||||
sprite = PILImage.fromarray(image)
|
||||
|
||||
results = []
|
||||
for i, (left, top, right, bottom, time_seconds) in enumerate(parse_vtt_offsets(vtt)):
|
||||
cut_frame = sprite.crop((left, top, left + right, top + bottom))
|
||||
faces = extract_faces_mediapipe(np.asarray(cut_frame), enforce_detection=False, align=False)
|
||||
faces = [face for face in faces if face['confidence'] > 0.6]
|
||||
if faces:
|
||||
size = faces[0]['facial_area']['w'] * faces[0]['facial_area']['h']
|
||||
data = {'id': str(uuid4()), "offset": (left, top, right, bottom), "frame": i, "time": time_seconds, 'size': size}
|
||||
results.append(data)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1 @@
|
||||
# utils package
|
||||
@@ -0,0 +1,44 @@
|
||||
from typing import List, Tuple, Generator
|
||||
|
||||
def parse_vtt_offsets(vtt_content: bytes) -> Generator[Tuple[int, int, int, int, float], None, None]:
|
||||
"""
|
||||
Parse VTT file content and extract offsets and timestamps.
|
||||
|
||||
Parameters:
|
||||
vtt_content: Raw VTT file content as bytes
|
||||
|
||||
Returns:
|
||||
Generator yielding tuples of (left, top, right, bottom, time_seconds)
|
||||
"""
|
||||
time_seconds = 0
|
||||
left = top = right = bottom = None
|
||||
|
||||
for line in vtt_content.decode("utf-8").split("\n"):
|
||||
line = line.strip()
|
||||
|
||||
if "-->" in line:
|
||||
# grab the start time
|
||||
# 00:00:00.000 --> 00:00:41.000
|
||||
start = line.split("-->")[0].strip().split(":")
|
||||
# convert to seconds
|
||||
time_seconds = (
|
||||
int(start[0]) * 3600
|
||||
+ int(start[1]) * 60
|
||||
+ float(start[2])
|
||||
)
|
||||
left = top = right = bottom = None
|
||||
elif "xywh=" in line:
|
||||
left, top, right, bottom = line.split("xywh=")[-1].split(",")
|
||||
left, top, right, bottom = (
|
||||
int(left),
|
||||
int(top),
|
||||
int(right),
|
||||
int(bottom),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
if not left:
|
||||
continue
|
||||
|
||||
yield left, top, right, bottom, time_seconds
|
||||
@@ -0,0 +1 @@
|
||||
# web package
|
||||
@@ -0,0 +1,174 @@
|
||||
import gradio as gr
|
||||
from typing import Dict, Any
|
||||
|
||||
from models.data_manager import DataManager
|
||||
from models.image_processor import (
|
||||
image_search_performer,
|
||||
image_search_performers,
|
||||
find_faces_in_sprite
|
||||
)
|
||||
|
||||
class WebInterface:
|
||||
def __init__(self, data_manager: DataManager, default_threshold: float = 0.5):
|
||||
"""
|
||||
Initialize the web interface.
|
||||
|
||||
Parameters:
|
||||
data_manager: DataManager instance
|
||||
default_threshold: Default confidence threshold
|
||||
"""
|
||||
self.data_manager = data_manager
|
||||
self.default_threshold = default_threshold
|
||||
|
||||
def image_search(self, img, threshold, results):
|
||||
"""Wrapper for the image search function"""
|
||||
return image_search_performer(img, self.data_manager, threshold, results)
|
||||
|
||||
def multiple_image_search(self, img, threshold, results):
|
||||
"""Wrapper for the multiple image search function"""
|
||||
return image_search_performers(img, self.data_manager, threshold, results)
|
||||
|
||||
def vector_search(self, vector_json, threshold, results):
|
||||
"""Wrapper for the vector search function (deprecated)"""
|
||||
return {'status': 'not implemented'}
|
||||
|
||||
def _create_image_search_interface(self):
|
||||
"""Create the single face search interface"""
|
||||
with gr.Blocks() as interface:
|
||||
gr.Markdown("# Who is in the photo?")
|
||||
gr.Markdown("Upload an image of a person and we'll tell you who it is.")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_input = gr.Image()
|
||||
threshold = gr.Slider(
|
||||
label="threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=self.default_threshold
|
||||
)
|
||||
results_count = gr.Slider(
|
||||
label="results",
|
||||
minimum=0,
|
||||
maximum=50,
|
||||
value=3,
|
||||
step=1
|
||||
)
|
||||
search_btn = gr.Button("Search")
|
||||
|
||||
with gr.Column():
|
||||
output = gr.JSON(label="Results")
|
||||
|
||||
search_btn.click(
|
||||
fn=self.image_search,
|
||||
inputs=[img_input, threshold, results_count],
|
||||
outputs=output
|
||||
)
|
||||
|
||||
return interface
|
||||
|
||||
def _create_multiple_image_search_interface(self):
|
||||
"""Create the multiple face search interface"""
|
||||
with gr.Blocks() as interface:
|
||||
gr.Markdown("# Who is in the photo?")
|
||||
gr.Markdown("Upload an image of a person(s) and we'll tell you who it is.")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_input = gr.Image(type="pil")
|
||||
threshold = gr.Slider(
|
||||
label="threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=self.default_threshold
|
||||
)
|
||||
results_count = gr.Slider(
|
||||
label="results",
|
||||
minimum=0,
|
||||
maximum=50,
|
||||
value=3,
|
||||
step=1
|
||||
)
|
||||
search_btn = gr.Button("Search")
|
||||
|
||||
with gr.Column():
|
||||
output = gr.JSON(label="Results")
|
||||
|
||||
search_btn.click(
|
||||
fn=self.multiple_image_search,
|
||||
inputs=[img_input, threshold, results_count],
|
||||
outputs=output
|
||||
)
|
||||
|
||||
return interface
|
||||
|
||||
def _create_vector_search_interface(self):
|
||||
"""Create the vector search interface (deprecated)"""
|
||||
with gr.Blocks() as interface:
|
||||
gr.Markdown("# Vector Search (deprecated)")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
vector_input = gr.Textbox()
|
||||
threshold = gr.Slider(
|
||||
label="threshold",
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=self.default_threshold
|
||||
)
|
||||
results_count = gr.Slider(
|
||||
label="results",
|
||||
minimum=0,
|
||||
maximum=50,
|
||||
value=3,
|
||||
step=1
|
||||
)
|
||||
search_btn = gr.Button("Search")
|
||||
|
||||
with gr.Column():
|
||||
output = gr.JSON(label="Results")
|
||||
|
||||
search_btn.click(
|
||||
fn=self.vector_search,
|
||||
inputs=[vector_input, threshold, results_count],
|
||||
outputs=output
|
||||
)
|
||||
|
||||
return interface
|
||||
|
||||
def _create_faces_in_sprite_interface(self):
|
||||
"""Create the faces in sprite interface"""
|
||||
with gr.Blocks() as interface:
|
||||
gr.Markdown("# Find Faces in Sprite")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_input = gr.Image()
|
||||
vtt_input = gr.Textbox(label="VTT file")
|
||||
search_btn = gr.Button("Process")
|
||||
|
||||
with gr.Column():
|
||||
output = gr.JSON(label="Results")
|
||||
|
||||
search_btn.click(
|
||||
fn=find_faces_in_sprite,
|
||||
inputs=[img_input, vtt_input],
|
||||
outputs=output
|
||||
)
|
||||
|
||||
return interface
|
||||
|
||||
def launch(self, server_name="0.0.0.0", server_port=7860, share=True):
|
||||
"""Launch the web interface"""
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.TabItem("Single Face Search"):
|
||||
self._create_image_search_interface()
|
||||
with gr.TabItem("Multiple Face Search"):
|
||||
self._create_multiple_image_search_interface()
|
||||
with gr.TabItem("Vector Search"):
|
||||
self._create_vector_search_interface()
|
||||
with gr.TabItem("Faces in Sprite"):
|
||||
self._create_faces_in_sprite_interface()
|
||||
|
||||
demo.queue().launch(server_name=server_name,server_port=server_port,share=share, ssr_mode=False)
|
||||
Reference in New Issue
Block a user