Merge branch 'dev' into Config-Filenames

This commit is contained in:
CodeName393
2026-02-03 23:01:18 +09:00
committed by GitHub
50 changed files with 9370 additions and 1356 deletions

View File

@@ -0,0 +1,145 @@
"""API key tier management for remote services."""
import logging
from typing import Any, Dict, List, Optional
from aiohttp import web
log = logging.getLogger("api.keys")
def get_api_key_from_request(request: web.Request) -> Optional[str]:
"""
Extract API key from request headers.
Args:
request: aiohttp request object
Returns:
API key string or None
"""
api_key = request.headers.get("X-API-Key")
if api_key:
return api_key
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return auth_header[7:] # len("Bearer ") == 7
return None
def get_api_key_config(app: web.Application, api_key: str) -> Optional[Dict[str, Any]]:
"""
Get configuration for a specific API key.
Args:
app: aiohttp application
api_key: API key to look up
Returns:
API key configuration dict or None if not found
"""
config = app.get("config", {})
# Check new-style tiered API keys
api_keys = config.get("api_keys", [])
for key_config in api_keys:
if isinstance(key_config, dict) and key_config.get("key") == api_key:
return key_config
# Check legacy users list (backward compatibility)
users = config.get("users", [])
if api_key in users:
return {
"key": api_key,
"tier": "basic",
"allowed_cdms": []
}
return None
def is_premium_user(app: web.Application, api_key: str) -> bool:
"""
Check if an API key belongs to a premium user.
Premium users can use server-side CDM for decryption.
Args:
app: aiohttp application
api_key: API key to check
Returns:
True if premium user, False otherwise
"""
key_config = get_api_key_config(app, api_key)
if not key_config:
return False
tier = key_config.get("tier", "basic")
return tier == "premium"
def get_allowed_cdms(app: web.Application, api_key: str) -> List[str]:
"""
Get list of CDMs that an API key is allowed to use.
Args:
app: aiohttp application
api_key: API key to check
Returns:
List of allowed CDM names, or empty list if not premium
"""
key_config = get_api_key_config(app, api_key)
if not key_config:
return []
allowed_cdms = key_config.get("allowed_cdms", [])
# Handle wildcard
if allowed_cdms == "*" or allowed_cdms == ["*"]:
return ["*"]
return allowed_cdms if isinstance(allowed_cdms, list) else []
def get_default_cdm(app: web.Application, api_key: str) -> Optional[str]:
"""
Get default CDM for an API key.
Args:
app: aiohttp application
api_key: API key to check
Returns:
Default CDM name or None
"""
key_config = get_api_key_config(app, api_key)
if not key_config:
return None
return key_config.get("default_cdm")
def can_use_cdm(app: web.Application, api_key: str, cdm_name: str) -> bool:
"""
Check if an API key can use a specific CDM.
Args:
app: aiohttp application
api_key: API key to check
cdm_name: CDM name to check access for
Returns:
True if allowed, False otherwise
"""
allowed_cdms = get_allowed_cdms(app, api_key)
# Wildcard access
if "*" in allowed_cdms:
return True
# Specific CDM access
return cdm_name in allowed_cdms

View File

@@ -227,6 +227,7 @@ def _perform_download(
range_=params.get("range", ["SDR"]),
channels=params.get("channels"),
no_atmos=params.get("no_atmos", False),
split_audio=params.get("split_audio"),
wanted=params.get("wanted", []),
latest_episode=params.get("latest_episode", False),
lang=params.get("lang", ["orig"]),

View File

@@ -191,12 +191,73 @@ def serialize_title(title: Title_T) -> Dict[str, Any]:
return result
def serialize_video_track(track: Video) -> Dict[str, Any]:
def serialize_drm(drm_list) -> Optional[List[Dict[str, Any]]]:
"""Serialize DRM objects to JSON-serializable list."""
if not drm_list:
return None
if not isinstance(drm_list, list):
drm_list = [drm_list]
result = []
for drm in drm_list:
drm_info = {}
drm_class = drm.__class__.__name__
drm_info["type"] = drm_class.lower()
# Get PSSH - handle both Widevine and PlayReady
if hasattr(drm, "_pssh") and drm._pssh:
try:
pssh_obj = drm._pssh
# Try to get base64 representation
if hasattr(pssh_obj, "dumps"):
# pywidevine PSSH has dumps() method
drm_info["pssh"] = pssh_obj.dumps()
elif hasattr(pssh_obj, "__bytes__"):
# Convert to base64
import base64
drm_info["pssh"] = base64.b64encode(bytes(pssh_obj)).decode()
elif hasattr(pssh_obj, "to_base64"):
drm_info["pssh"] = pssh_obj.to_base64()
else:
# Fallback - str() works for pywidevine PSSH
pssh_str = str(pssh_obj)
# Check if it's already base64-like or an object repr
if not pssh_str.startswith("<"):
drm_info["pssh"] = pssh_str
except Exception:
pass
# Get KIDs
if hasattr(drm, "kids") and drm.kids:
drm_info["kids"] = [str(kid) for kid in drm.kids]
# Get content keys if available
if hasattr(drm, "content_keys") and drm.content_keys:
drm_info["content_keys"] = {str(k): v for k, v in drm.content_keys.items()}
# Get license URL - essential for remote licensing
if hasattr(drm, "license_url") and drm.license_url:
drm_info["license_url"] = str(drm.license_url)
elif hasattr(drm, "_license_url") and drm._license_url:
drm_info["license_url"] = str(drm._license_url)
result.append(drm_info)
return result if result else None
def serialize_video_track(track: Video, include_url: bool = False) -> Dict[str, Any]:
"""Convert video track to JSON-serializable dict."""
codec_name = track.codec.name if hasattr(track.codec, "name") else str(track.codec)
range_name = track.range.name if hasattr(track.range, "name") else str(track.range)
return {
# Get descriptor for N_m3u8DL-RE compatibility (HLS, DASH, URL, etc.)
descriptor_name = None
if hasattr(track, "descriptor") and track.descriptor:
descriptor_name = track.descriptor.name if hasattr(track.descriptor, "name") else str(track.descriptor)
result = {
"id": str(track.id),
"codec": codec_name,
"codec_display": VIDEO_CODEC_MAP.get(codec_name, codec_name),
@@ -208,15 +269,24 @@ def serialize_video_track(track: Video) -> Dict[str, Any]:
"range": range_name,
"range_display": DYNAMIC_RANGE_MAP.get(range_name, range_name),
"language": str(track.language) if track.language else None,
"drm": str(track.drm) if hasattr(track, "drm") and track.drm else None,
"drm": serialize_drm(track.drm) if hasattr(track, "drm") and track.drm else None,
"descriptor": descriptor_name,
}
if include_url and hasattr(track, "url") and track.url:
result["url"] = str(track.url)
return result
def serialize_audio_track(track: Audio) -> Dict[str, Any]:
def serialize_audio_track(track: Audio, include_url: bool = False) -> Dict[str, Any]:
"""Convert audio track to JSON-serializable dict."""
codec_name = track.codec.name if hasattr(track.codec, "name") else str(track.codec)
return {
# Get descriptor for N_m3u8DL-RE compatibility
descriptor_name = None
if hasattr(track, "descriptor") and track.descriptor:
descriptor_name = track.descriptor.name if hasattr(track.descriptor, "name") else str(track.descriptor)
result = {
"id": str(track.id),
"codec": codec_name,
"codec_display": AUDIO_CODEC_MAP.get(codec_name, codec_name),
@@ -225,20 +295,33 @@ def serialize_audio_track(track: Audio) -> Dict[str, Any]:
"language": str(track.language) if track.language else None,
"atmos": track.atmos if hasattr(track, "atmos") else False,
"descriptive": track.descriptive if hasattr(track, "descriptive") else False,
"drm": str(track.drm) if hasattr(track, "drm") and track.drm else None,
"drm": serialize_drm(track.drm) if hasattr(track, "drm") and track.drm else None,
"descriptor": descriptor_name,
}
if include_url and hasattr(track, "url") and track.url:
result["url"] = str(track.url)
return result
def serialize_subtitle_track(track: Subtitle) -> Dict[str, Any]:
def serialize_subtitle_track(track: Subtitle, include_url: bool = False) -> Dict[str, Any]:
"""Convert subtitle track to JSON-serializable dict."""
return {
# Get descriptor for compatibility
descriptor_name = None
if hasattr(track, "descriptor") and track.descriptor:
descriptor_name = track.descriptor.name if hasattr(track.descriptor, "name") else str(track.descriptor)
result = {
"id": str(track.id),
"codec": track.codec.name if hasattr(track.codec, "name") else str(track.codec),
"language": str(track.language) if track.language else None,
"forced": track.forced if hasattr(track, "forced") else False,
"sdh": track.sdh if hasattr(track, "sdh") else False,
"cc": track.cc if hasattr(track, "cc") else False,
"descriptor": descriptor_name,
}
if include_url and hasattr(track, "url") and track.url:
result["url"] = str(track.url)
return result
async def list_titles_handler(data: Dict[str, Any], request: Optional[web.Request] = None) -> web.Response:
@@ -665,9 +748,17 @@ def validate_download_parameters(data: Dict[str, Any]) -> Optional[str]:
return f"Invalid vcodec: {data['vcodec']}. Must be one of: {', '.join(valid_vcodecs)}"
if "acodec" in data and data["acodec"]:
valid_acodecs = ["AAC", "AC3", "EAC3", "OPUS", "FLAC", "ALAC", "VORBIS", "DTS"]
if data["acodec"].upper() not in valid_acodecs:
return f"Invalid acodec: {data['acodec']}. Must be one of: {', '.join(valid_acodecs)}"
valid_acodecs = ["AAC", "AC3", "EC3", "EAC3", "DD", "DD+", "AC4", "OPUS", "FLAC", "ALAC", "VORBIS", "OGG", "DTS"]
if isinstance(data["acodec"], str):
acodec_values = [v.strip() for v in data["acodec"].split(",") if v.strip()]
elif isinstance(data["acodec"], list):
acodec_values = [str(v).strip() for v in data["acodec"] if str(v).strip()]
else:
return "acodec must be a string or list"
invalid = [value for value in acodec_values if value.upper() not in valid_acodecs]
if invalid:
return f"Invalid acodec: {', '.join(invalid)}. Must be one of: {', '.join(valid_acodecs)}"
if "sub_format" in data and data["sub_format"]:
valid_sub_formats = ["SRT", "VTT", "ASS", "SSA"]

File diff suppressed because it is too large Load Diff

View File

@@ -8,6 +8,9 @@ from unshackle.core import __version__
from unshackle.core.api.errors import APIError, APIErrorCode, build_error_response, handle_api_exception
from unshackle.core.api.handlers import (cancel_download_job_handler, download_handler, get_download_job_handler,
list_download_jobs_handler, list_titles_handler, list_tracks_handler)
from unshackle.core.api.remote_handlers import (remote_decrypt, remote_get_chapters, remote_get_license,
remote_get_manifest, remote_get_titles, remote_get_tracks,
remote_list_services, remote_search)
from unshackle.core.services import Services
from unshackle.core.update_checker import UpdateChecker
@@ -413,7 +416,7 @@ async def download(request: web.Request) -> web.Response:
description: Video codec to download (e.g., H264, H265, VP9, AV1) (default - None)
acodec:
type: string
description: Audio codec to download (e.g., AAC, AC3, EAC3) (default - None)
description: Audio codec(s) to download (e.g., AAC or AAC,EC3) (default - None)
vbitrate:
type: integer
description: Video bitrate in kbps (default - None)
@@ -730,6 +733,16 @@ def setup_routes(app: web.Application) -> None:
app.router.add_get("/api/download/jobs/{job_id}", download_job_detail)
app.router.add_delete("/api/download/jobs/{job_id}", cancel_download_job)
# Remote service endpoints
app.router.add_get("/api/remote/services", remote_list_services)
app.router.add_post("/api/remote/{service}/search", remote_search)
app.router.add_post("/api/remote/{service}/titles", remote_get_titles)
app.router.add_post("/api/remote/{service}/tracks", remote_get_tracks)
app.router.add_post("/api/remote/{service}/manifest", remote_get_manifest)
app.router.add_post("/api/remote/{service}/chapters", remote_get_chapters)
app.router.add_post("/api/remote/{service}/license", remote_get_license)
app.router.add_post("/api/remote/{service}/decrypt", remote_decrypt)
def setup_swagger(app: web.Application) -> None:
"""Setup Swagger UI documentation."""
@@ -754,5 +767,14 @@ def setup_swagger(app: web.Application) -> None:
web.get("/api/download/jobs", download_jobs),
web.get("/api/download/jobs/{job_id}", download_job_detail),
web.delete("/api/download/jobs/{job_id}", cancel_download_job),
# Remote service routes
web.get("/api/remote/services", remote_list_services),
web.post("/api/remote/{service}/search", remote_search),
web.post("/api/remote/{service}/titles", remote_get_titles),
web.post("/api/remote/{service}/tracks", remote_get_tracks),
web.post("/api/remote/{service}/manifest", remote_get_manifest),
web.post("/api/remote/{service}/chapters", remote_get_chapters),
web.post("/api/remote/{service}/license", remote_get_license),
web.post("/api/remote/{service}/decrypt", remote_decrypt),
]
)

View File

@@ -0,0 +1,236 @@
"""Session serialization helpers for remote services."""
from http.cookiejar import CookieJar
from typing import Any, Dict, Optional
import requests
from unshackle.core.credential import Credential
def serialize_session(session: requests.Session) -> Dict[str, Any]:
"""
Serialize a requests.Session into a JSON-serializable dictionary.
Extracts cookies, headers, and other session data that can be
transferred to a remote client for downloading.
Args:
session: The requests.Session to serialize
Returns:
Dictionary containing serialized session data
"""
session_data = {
"cookies": {},
"headers": {},
"proxies": session.proxies.copy() if session.proxies else {},
}
# Serialize cookies
if session.cookies:
for cookie in session.cookies:
session_data["cookies"][cookie.name] = {
"value": cookie.value,
"domain": cookie.domain,
"path": cookie.path,
"secure": cookie.secure,
"expires": cookie.expires,
}
# Serialize headers (exclude proxy-authorization for security)
if session.headers:
for key, value in session.headers.items():
# Skip proxy-related headers as they're server-specific
if key.lower() not in ["proxy-authorization"]:
session_data["headers"][key] = value
return session_data
def deserialize_session(
session_data: Dict[str, Any], target_session: Optional[requests.Session] = None
) -> requests.Session:
"""
Deserialize session data into a requests.Session.
Applies cookies, headers, and other session data from a remote server
to a local session for downloading.
Args:
session_data: Dictionary containing serialized session data
target_session: Optional existing session to update (creates new if None)
Returns:
requests.Session with applied session data
"""
if target_session is None:
target_session = requests.Session()
# Apply cookies
if "cookies" in session_data:
for cookie_name, cookie_data in session_data["cookies"].items():
target_session.cookies.set(
name=cookie_name,
value=cookie_data["value"],
domain=cookie_data.get("domain"),
path=cookie_data.get("path", "/"),
secure=cookie_data.get("secure", False),
expires=cookie_data.get("expires"),
)
# Apply headers
if "headers" in session_data:
target_session.headers.update(session_data["headers"])
# Note: We don't apply proxies from remote as the local client
# should use its own proxy configuration
return target_session
def extract_session_tokens(session: requests.Session) -> Dict[str, Any]:
"""
Extract authentication tokens and similar data from a session.
Looks for common authentication patterns like Bearer tokens,
API keys in headers, etc.
Args:
session: The requests.Session to extract tokens from
Returns:
Dictionary containing extracted tokens
"""
tokens = {}
# Check for Authorization header
if "Authorization" in session.headers:
tokens["authorization"] = session.headers["Authorization"]
# Check for common API key headers
for key in ["X-API-Key", "Api-Key", "X-Auth-Token"]:
if key in session.headers:
tokens[key.lower().replace("-", "_")] = session.headers[key]
return tokens
def apply_session_tokens(tokens: Dict[str, Any], target_session: requests.Session) -> None:
"""
Apply authentication tokens to a session.
Args:
tokens: Dictionary containing tokens to apply
target_session: Session to apply tokens to
"""
# Apply Authorization header
if "authorization" in tokens:
target_session.headers["Authorization"] = tokens["authorization"]
# Apply other token headers
token_header_map = {
"x_api_key": "X-API-Key",
"api_key": "Api-Key",
"x_auth_token": "X-Auth-Token",
}
for token_key, header_name in token_header_map.items():
if token_key in tokens:
target_session.headers[header_name] = tokens[token_key]
def serialize_cookies(cookie_jar: Optional[CookieJar]) -> Dict[str, Any]:
"""
Serialize a CookieJar into a JSON-serializable dictionary.
Args:
cookie_jar: The CookieJar to serialize
Returns:
Dictionary containing serialized cookies
"""
if not cookie_jar:
return {}
cookies = {}
for cookie in cookie_jar:
cookies[cookie.name] = {
"value": cookie.value,
"domain": cookie.domain,
"path": cookie.path,
"secure": cookie.secure,
"expires": cookie.expires,
}
return cookies
def deserialize_cookies(cookies_data: Dict[str, Any]) -> CookieJar:
"""
Deserialize cookies into a CookieJar.
Args:
cookies_data: Dictionary containing serialized cookies
Returns:
CookieJar with cookies
"""
import http.cookiejar
cookie_jar = http.cookiejar.CookieJar()
for cookie_name, cookie_data in cookies_data.items():
cookie = http.cookiejar.Cookie(
version=0,
name=cookie_name,
value=cookie_data["value"],
port=None,
port_specified=False,
domain=cookie_data.get("domain", ""),
domain_specified=bool(cookie_data.get("domain")),
domain_initial_dot=cookie_data.get("domain", "").startswith("."),
path=cookie_data.get("path", "/"),
path_specified=True,
secure=cookie_data.get("secure", False),
expires=cookie_data.get("expires"),
discard=False,
comment=None,
comment_url=None,
rest={},
)
cookie_jar.set_cookie(cookie)
return cookie_jar
def serialize_credential(credential: Optional[Credential]) -> Optional[Dict[str, str]]:
"""
Serialize a Credential into a JSON-serializable dictionary.
Args:
credential: The Credential to serialize
Returns:
Dictionary containing username and password, or None
"""
if not credential:
return None
return {"username": credential.username, "password": credential.password}
def deserialize_credential(credential_data: Optional[Dict[str, str]]) -> Optional[Credential]:
"""
Deserialize credential data into a Credential object.
Args:
credential_data: Dictionary containing username and password
Returns:
Credential object or None
"""
if not credential_data:
return None
return Credential(username=credential_data["username"], password=credential_data["password"])

View File

@@ -17,6 +17,10 @@ def find(*names: str) -> Optional[Path]:
if local_binaries_dir.exists():
candidate_paths = [local_binaries_dir / f"{name}{ext}", local_binaries_dir / name / f"{name}{ext}"]
for subdir in local_binaries_dir.iterdir():
if subdir.is_dir():
candidate_paths.append(subdir / f"{name}{ext}")
for path in candidate_paths:
if path.is_file():
# On Unix-like systems, check if file is executable
@@ -52,6 +56,8 @@ Mkvpropedit = find("mkvpropedit")
DoviTool = find("dovi_tool")
HDR10PlusTool = find("hdr10plus_tool", "HDR10Plus_tool")
Mp4decrypt = find("mp4decrypt")
Docker = find("docker")
ML_Worker = find("ML-Worker")
__all__ = (
@@ -71,5 +77,7 @@ __all__ = (
"DoviTool",
"HDR10PlusTool",
"Mp4decrypt",
"Docker",
"ML_Worker",
"find",
)

View File

@@ -1,4 +1,5 @@
from .custom_remote_cdm import CustomRemoteCDM
from .decrypt_labs_remote_cdm import DecryptLabsRemoteCDM
from .monalisa import MonaLisaCDM
__all__ = ["DecryptLabsRemoteCDM", "CustomRemoteCDM"]
__all__ = ["DecryptLabsRemoteCDM", "CustomRemoteCDM", "MonaLisaCDM"]

View File

@@ -0,0 +1,3 @@
from .monalisa_cdm import MonaLisaCDM
__all__ = ["MonaLisaCDM"]

View File

@@ -0,0 +1,371 @@
"""
MonaLisa CDM - WASM-based Content Decryption Module wrapper.
This module provides key extraction from MonaLisa-protected content using
a WebAssembly module that runs locally via wasmtime.
"""
import base64
import ctypes
import json
import re
import uuid
from pathlib import Path
from typing import Dict, Optional, Union
import wasmtime
from unshackle.core import binaries
class MonaLisaCDM:
"""
MonaLisa CDM wrapper for WASM-based key extraction.
This CDM differs from Widevine/PlayReady in that it does not use a
challenge/response flow with a license server. Instead, the license
(ticket) is provided directly by the service API, and keys are extracted
locally via the WASM module.
"""
DYNAMIC_BASE = 6065008
DYNAMICTOP_PTR = 821968
LICENSE_KEY_OFFSET = 0x5C8C0C
LICENSE_KEY_LENGTH = 16
ENV_STRINGS = (
"USER=web_user",
"LOGNAME=web_user",
"PATH=/",
"PWD=/",
"HOME=/home/web_user",
"LANG=zh_CN.UTF-8",
"_=./this.program",
)
def __init__(self, device_path: Path):
"""
Initialize the MonaLisa CDM.
Args:
device_path: Path to the device file (.mld).
"""
device_path = Path(device_path)
self.device_path = device_path
self.base_dir = device_path.parent
if not self.device_path.is_file():
raise FileNotFoundError(f"Device file not found at: {self.device_path}")
try:
data = json.loads(self.device_path.read_text(encoding="utf-8", errors="replace"))
except Exception as e:
raise ValueError(f"Invalid device file (JSON): {e}")
wasm_path_str = data.get("wasm_path")
if not wasm_path_str:
raise ValueError("Device file missing 'wasm_path'")
wasm_filename = Path(wasm_path_str).name
wasm_path = self.base_dir / wasm_filename
if not wasm_path.exists():
raise FileNotFoundError(f"WASM file not found at: {wasm_path}")
try:
self.engine = wasmtime.Engine()
if wasm_path.suffix.lower() == ".wat":
self.module = wasmtime.Module.from_file(self.engine, str(wasm_path))
else:
self.module = wasmtime.Module(self.engine, wasm_path.read_bytes())
except Exception as e:
raise RuntimeError(f"Failed to load WASM module: {e}")
self.store = None
self.memory = None
self.instance = None
self.exports = {}
self.ctx = None
@staticmethod
def get_worker_path() -> Optional[Path]:
"""Get ML-Worker binary path from the unshackle binaries system."""
if binaries.ML_Worker:
return Path(binaries.ML_Worker)
return None
def open(self) -> int:
"""
Open a CDM session.
Returns:
Session ID (always 1 for MonaLisa).
Raises:
RuntimeError: If session initialization fails.
"""
try:
self.store = wasmtime.Store(self.engine)
memory_type = wasmtime.MemoryType(wasmtime.Limits(256, 256))
self.memory = wasmtime.Memory(self.store, memory_type)
self._write_i32(self.DYNAMICTOP_PTR, self.DYNAMIC_BASE)
imports = self._build_imports()
self.instance = wasmtime.Instance(self.store, self.module, imports)
ex = self.instance.exports(self.store)
self.exports = {
"___wasm_call_ctors": ex["s"],
"_monalisa_context_alloc": ex["D"],
"monalisa_set_license": ex["F"],
"_monalisa_set_canvas_id": ex["t"],
"_monalisa_version_get": ex["A"],
"monalisa_get_line_number": ex["v"],
"stackAlloc": ex["N"],
"stackSave": ex["L"],
"stackRestore": ex["M"],
}
self.exports["___wasm_call_ctors"](self.store)
self.ctx = self.exports["_monalisa_context_alloc"](self.store)
return 1
except Exception as e:
raise RuntimeError(f"Failed to initialize session: {e}")
def close(self, session_id: int = 1) -> None:
"""
Close the CDM session and release resources.
Args:
session_id: The session ID to close (unused, for API compatibility).
"""
self.store = None
self.memory = None
self.instance = None
self.exports = {}
self.ctx = None
def extract_keys(self, license_data: Union[str, bytes]) -> Dict:
"""
Extract decryption keys from license/ticket data.
Args:
license_data: The license ticket, either as base64 string or raw bytes.
Returns:
Dictionary with keys: kid (hex), key (hex), type ("CONTENT").
Raises:
RuntimeError: If session not open or license validation fails.
ValueError: If license_data is empty.
"""
if not self.instance or not self.memory or self.ctx is None:
raise RuntimeError("Session not open. Call open() first.")
if not license_data:
raise ValueError("license_data is empty")
if isinstance(license_data, bytes):
license_b64 = base64.b64encode(license_data).decode("utf-8")
else:
license_b64 = license_data
ret = self._ccall(
"monalisa_set_license",
int,
self.ctx,
license_b64,
len(license_b64),
"0",
)
if ret != 0:
raise RuntimeError(f"License validation failed with code: {ret}")
key_bytes = self._extract_license_key_bytes()
# Extract DCID from license to generate KID
try:
decoded = base64.b64decode(license_b64).decode("ascii", errors="ignore")
except Exception:
decoded = ""
m = re.search(
r"DCID-[A-Z0-9]+-[A-Z0-9]+-\d{8}-\d{6}-[A-Z0-9]+-\d{10}-[A-Z0-9]+",
decoded,
)
if m:
kid_bytes = uuid.uuid5(uuid.NAMESPACE_DNS, m.group()).bytes
else:
kid_bytes = uuid.UUID(int=0).bytes
return {"kid": kid_bytes.hex(), "key": key_bytes.hex(), "type": "CONTENT"}
def _extract_license_key_bytes(self) -> bytes:
"""Extract the 16-byte decryption key from WASM memory."""
data_ptr = self.memory.data_ptr(self.store)
data_len = self.memory.data_len(self.store)
if self.LICENSE_KEY_OFFSET + self.LICENSE_KEY_LENGTH > data_len:
raise RuntimeError("License key offset beyond memory bounds")
mem_ptr = ctypes.cast(data_ptr, ctypes.POINTER(ctypes.c_ubyte * data_len))
start = self.LICENSE_KEY_OFFSET
end = self.LICENSE_KEY_OFFSET + self.LICENSE_KEY_LENGTH
return bytes(mem_ptr.contents[start:end])
def _ccall(self, func_name: str, return_type: type, *args):
"""Call a WASM function with automatic string conversion."""
stack = 0
converted_args = []
for arg in args:
if isinstance(arg, str):
if stack == 0:
stack = self.exports["stackSave"](self.store)
max_length = (len(arg) << 2) + 1
ptr = self.exports["stackAlloc"](self.store, max_length)
self._string_to_utf8(arg, ptr, max_length)
converted_args.append(ptr)
else:
converted_args.append(arg)
result = self.exports[func_name](self.store, *converted_args)
if stack != 0:
self.exports["stackRestore"](self.store, stack)
if return_type is bool:
return bool(result)
return result
def _write_i32(self, addr: int, value: int) -> None:
"""Write a 32-bit integer to WASM memory."""
data = self.memory.data_ptr(self.store)
mem_ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_int32))
mem_ptr[addr >> 2] = value
def _string_to_utf8(self, data: str, ptr: int, max_length: int) -> int:
"""Convert string to UTF-8 and write to WASM memory."""
encoded = data.encode("utf-8")
write_length = min(len(encoded), max_length - 1)
mem_data = self.memory.data_ptr(self.store)
mem_ptr = ctypes.cast(mem_data, ctypes.POINTER(ctypes.c_ubyte))
for i in range(write_length):
mem_ptr[ptr + i] = encoded[i]
mem_ptr[ptr + write_length] = 0
return write_length
def _write_ascii_to_memory(self, string: str, buffer: int, dont_add_null: int = 0) -> None:
"""Write ASCII string to WASM memory."""
mem_data = self.memory.data_ptr(self.store)
mem_ptr = ctypes.cast(mem_data, ctypes.POINTER(ctypes.c_ubyte))
encoded = string.encode("utf-8")
for i, byte_val in enumerate(encoded):
mem_ptr[buffer + i] = byte_val
if dont_add_null == 0:
mem_ptr[buffer + len(encoded)] = 0
def _build_imports(self):
"""Build the WASM import stubs required by the MonaLisa module."""
def sys_fcntl64(a, b, c):
return 0
def fd_write(a, b, c, d):
return 0
def fd_close(a):
return 0
def sys_ioctl(a, b, c):
return 0
def sys_open(a, b, c):
return 0
def sys_rmdir(a):
return 0
def sys_unlink(a):
return 0
def clock():
return 0
def time(a):
return 0
def emscripten_run_script(a):
return None
def fd_seek(a, b, c, d, e):
return 0
def emscripten_resize_heap(a):
return 0
def fd_read(a, b, c, d):
return 0
def emscripten_run_script_string(a):
return 0
def emscripten_run_script_int(a):
return 1
def emscripten_memcpy_big(dest, src, num):
mem_data = self.memory.data_ptr(self.store)
data_len = self.memory.data_len(self.store)
if num is None:
num = data_len - 1
mem_ptr = ctypes.cast(mem_data, ctypes.POINTER(ctypes.c_ubyte))
for i in range(num):
if dest + i < data_len and src + i < data_len:
mem_ptr[dest + i] = mem_ptr[src + i]
return dest
def environ_get(environ_ptr, environ_buf):
buf_size = 0
for index, string in enumerate(self.ENV_STRINGS):
ptr = environ_buf + buf_size
self._write_i32(environ_ptr + index * 4, ptr)
self._write_ascii_to_memory(string, ptr)
buf_size += len(string) + 1
return 0
def environ_sizes_get(penviron_count, penviron_buf_size):
self._write_i32(penviron_count, len(self.ENV_STRINGS))
buf_size = sum(len(s) + 1 for s in self.ENV_STRINGS)
self._write_i32(penviron_buf_size, buf_size)
return 0
i32 = wasmtime.ValType.i32()
return [
wasmtime.Func(self.store, wasmtime.FuncType([i32, i32, i32], [i32]), sys_fcntl64),
wasmtime.Func(self.store, wasmtime.FuncType([i32, i32, i32, i32], [i32]), fd_write),
wasmtime.Func(self.store, wasmtime.FuncType([i32], [i32]), fd_close),
wasmtime.Func(self.store, wasmtime.FuncType([i32, i32, i32], [i32]), sys_ioctl),
wasmtime.Func(self.store, wasmtime.FuncType([i32, i32, i32], [i32]), sys_open),
wasmtime.Func(self.store, wasmtime.FuncType([i32], [i32]), sys_rmdir),
wasmtime.Func(self.store, wasmtime.FuncType([i32], [i32]), sys_unlink),
wasmtime.Func(self.store, wasmtime.FuncType([], [i32]), clock),
wasmtime.Func(self.store, wasmtime.FuncType([i32], [i32]), time),
wasmtime.Func(self.store, wasmtime.FuncType([i32], []), emscripten_run_script),
wasmtime.Func(self.store, wasmtime.FuncType([i32, i32, i32, i32, i32], [i32]), fd_seek),
wasmtime.Func(self.store, wasmtime.FuncType([i32, i32, i32], [i32]), emscripten_memcpy_big),
wasmtime.Func(self.store, wasmtime.FuncType([i32], [i32]), emscripten_resize_heap),
wasmtime.Func(self.store, wasmtime.FuncType([i32, i32], [i32]), environ_get),
wasmtime.Func(self.store, wasmtime.FuncType([i32, i32], [i32]), environ_sizes_get),
wasmtime.Func(self.store, wasmtime.FuncType([i32, i32, i32, i32], [i32]), fd_read),
wasmtime.Func(self.store, wasmtime.FuncType([i32], [i32]), emscripten_run_script_string),
wasmtime.Func(self.store, wasmtime.FuncType([i32], [i32]), emscripten_run_script_int),
self.memory,
]

View File

@@ -94,6 +94,7 @@ class Config:
self.update_checks: bool = kwargs.get("update_checks", True)
self.update_check_interval: int = kwargs.get("update_check_interval", 24)
self.scene_naming: bool = kwargs.get("scene_naming", True)
self.dash_naming: bool = kwargs.get("dash_naming", False)
self.series_year: bool = kwargs.get("series_year", True)
self.unicode_filenames: bool = kwargs.get("unicode_filenames", False)
self.insert_episodename_into_filenames: bool = kwargs.get("insert_episodename_into_filenames", True)

View File

@@ -1,6 +1,7 @@
import os
import subprocess
import textwrap
import threading
import time
from functools import partial
from http.cookiejar import CookieJar
@@ -49,6 +50,138 @@ def rpc(caller: Callable, secret: str, method: str, params: Optional[list[Any]]
return
class _Aria2Manager:
"""Singleton manager to run one aria2c process and enqueue downloads via RPC."""
def __init__(self) -> None:
self._proc: Optional[subprocess.Popen] = None
self._rpc_port: Optional[int] = None
self._rpc_secret: Optional[str] = None
self._rpc_uri: Optional[str] = None
self._session: Session = Session()
self._max_concurrent_downloads: int = 0
self._max_connection_per_server: int = 1
self._split_default: int = 5
self._file_allocation: str = "prealloc"
self._proxy: Optional[str] = None
self._lock: threading.Lock = threading.Lock()
def _build_args(self) -> list[str]:
args = [
"--continue=true",
f"--max-concurrent-downloads={self._max_concurrent_downloads}",
f"--max-connection-per-server={self._max_connection_per_server}",
f"--split={self._split_default}",
"--max-file-not-found=5",
"--max-tries=5",
"--retry-wait=2",
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--console-log-level=warn",
"--download-result=default",
f"--file-allocation={self._file_allocation}",
"--summary-interval=0",
"--enable-rpc=true",
f"--rpc-listen-port={self._rpc_port}",
f"--rpc-secret={self._rpc_secret}",
]
if self._proxy:
args.extend(["--all-proxy", self._proxy])
return args
def ensure_started(
self,
proxy: Optional[str],
max_workers: Optional[int],
) -> None:
with self._lock:
if self._proc and self._proc.poll() is None:
return
if not binaries.Aria2:
debug_logger = get_debug_logger()
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_binary_missing",
message="Aria2c executable not found in PATH or local binaries directory",
context={"searched_names": ["aria2c", "aria2"]},
)
raise EnvironmentError("Aria2c executable not found...")
if not max_workers:
max_workers = min(32, (os.cpu_count() or 1) + 4)
elif not isinstance(max_workers, int):
raise TypeError(f"Expected max_workers to be {int}, not {type(max_workers)}")
self._rpc_port = get_free_port()
self._rpc_secret = get_random_bytes(16).hex()
self._rpc_uri = f"http://127.0.0.1:{self._rpc_port}/jsonrpc"
self._max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers))
self._max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1))
self._split_default = int(config.aria2c.get("split", 5))
self._file_allocation = config.aria2c.get("file_allocation", "prealloc")
self._proxy = proxy or None
args = self._build_args()
self._proc = subprocess.Popen(
[binaries.Aria2, *args], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
# Give aria2c a moment to start up and bind to the RPC port
time.sleep(0.5)
@property
def rpc_uri(self) -> str:
assert self._rpc_uri
return self._rpc_uri
@property
def rpc_secret(self) -> str:
assert self._rpc_secret
return self._rpc_secret
@property
def session(self) -> Session:
return self._session
def add_uris(self, uris: list[str], options: dict[str, Any]) -> str:
"""Add a single download with multiple URIs via RPC."""
gid = rpc(
caller=partial(self._session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.addUri",
params=[uris, options],
)
return gid or ""
def get_global_stat(self) -> dict[str, Any]:
return rpc(
caller=partial(self.session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.getGlobalStat",
) or {}
def tell_status(self, gid: str) -> Optional[dict[str, Any]]:
return rpc(
caller=partial(self.session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.tellStatus",
params=[gid, ["status", "errorCode", "errorMessage", "files", "completedLength", "totalLength"]],
)
def remove(self, gid: str) -> None:
rpc(
caller=partial(self.session.post, url=self.rpc_uri),
secret=self.rpc_secret,
method="aria2.forceRemove",
params=[gid],
)
_manager = _Aria2Manager()
def download(
urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]],
output_dir: Path,
@@ -58,6 +191,7 @@ def download(
proxy: Optional[str] = None,
max_workers: Optional[int] = None,
) -> Generator[dict[str, Any], None, None]:
"""Enqueue downloads to the singleton aria2c instance via stdin and track per-call progress via RPC."""
debug_logger = get_debug_logger()
if not urls:
@@ -92,102 +226,10 @@ def download(
if not isinstance(urls, list):
urls = [urls]
if not binaries.Aria2:
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_binary_missing",
message="Aria2c executable not found in PATH or local binaries directory",
context={"searched_names": ["aria2c", "aria2"]},
)
raise EnvironmentError("Aria2c executable not found...")
if proxy and not proxy.lower().startswith("http://"):
raise ValueError("Only HTTP proxies are supported by aria2(c)")
if cookies and not isinstance(cookies, CookieJar):
cookies = cookiejar_from_dict(cookies)
url_files = []
for i, url in enumerate(urls):
if isinstance(url, str):
url_data = {"url": url}
else:
url_data: dict[str, Any] = url
url_filename = filename.format(i=i, ext=get_extension(url_data["url"]))
url_text = url_data["url"]
url_text += f"\n\tdir={output_dir}"
url_text += f"\n\tout={url_filename}"
if cookies:
mock_request = requests.Request(url=url_data["url"])
cookie_header = get_cookie_header(cookies, mock_request)
if cookie_header:
url_text += f"\n\theader=Cookie: {cookie_header}"
for key, value in url_data.items():
if key == "url":
continue
if key == "headers":
for header_name, header_value in value.items():
url_text += f"\n\theader={header_name}: {header_value}"
else:
url_text += f"\n\t{key}={value}"
url_files.append(url_text)
url_file = "\n".join(url_files)
rpc_port = get_free_port()
rpc_secret = get_random_bytes(16).hex()
rpc_uri = f"http://127.0.0.1:{rpc_port}/jsonrpc"
rpc_session = Session()
max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers))
max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1))
split = int(config.aria2c.get("split", 5))
file_allocation = config.aria2c.get("file_allocation", "prealloc")
if len(urls) > 1:
split = 1
file_allocation = "none"
arguments = [
# [Basic Options]
"--input-file",
"-",
"--all-proxy",
proxy or "",
"--continue=true",
# [Connection Options]
f"--max-concurrent-downloads={max_concurrent_downloads}",
f"--max-connection-per-server={max_connection_per_server}",
f"--split={split}", # each split uses their own connection
"--max-file-not-found=5", # counted towards --max-tries
"--max-tries=5",
"--retry-wait=2",
# [Advanced Options]
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--console-log-level=warn",
"--download-result=default",
f"--file-allocation={file_allocation}",
"--summary-interval=0",
# [RPC Options]
"--enable-rpc=true",
f"--rpc-listen-port={rpc_port}",
f"--rpc-secret={rpc_secret}",
]
for header, value in (headers or {}).items():
if header.lower() == "cookie":
raise ValueError("You cannot set Cookies as a header manually, please use the `cookies` param.")
if header.lower() == "accept-encoding":
# we cannot set an allowed encoding, or it will return compressed
# and the code is not set up to uncompress the data
continue
if header.lower() == "referer":
arguments.extend(["--referer", value])
continue
if header.lower() == "user-agent":
arguments.extend(["--user-agent", value])
continue
arguments.extend(["--header", f"{header}: {value}"])
_manager.ensure_started(proxy=proxy, max_workers=max_workers)
if debug_logger:
first_url = urls[0] if isinstance(urls[0], str) else urls[0].get("url", "")
@@ -202,128 +244,151 @@ def download(
"first_url": url_display,
"output_dir": str(output_dir),
"filename": filename,
"max_concurrent_downloads": max_concurrent_downloads,
"max_connection_per_server": max_connection_per_server,
"split": split,
"file_allocation": file_allocation,
"has_proxy": bool(proxy),
"rpc_port": rpc_port,
},
)
yield dict(total=len(urls))
# Build options for each URI and add via RPC
gids: list[str] = []
for i, url in enumerate(urls):
if isinstance(url, str):
url_data = {"url": url}
else:
url_data: dict[str, Any] = url
url_filename = filename.format(i=i, ext=get_extension(url_data["url"]))
opts: dict[str, Any] = {
"dir": str(output_dir),
"out": url_filename,
"split": str(1 if len(urls) > 1 else int(config.aria2c.get("split", 5))),
}
# Cookies as header
if cookies:
mock_request = requests.Request(url=url_data["url"])
cookie_header = get_cookie_header(cookies, mock_request)
if cookie_header:
opts.setdefault("header", []).append(f"Cookie: {cookie_header}")
# Global headers
for header, value in (headers or {}).items():
if header.lower() == "cookie":
raise ValueError("You cannot set Cookies as a header manually, please use the `cookies` param.")
if header.lower() == "accept-encoding":
continue
if header.lower() == "referer":
opts["referer"] = str(value)
continue
if header.lower() == "user-agent":
opts["user-agent"] = str(value)
continue
opts.setdefault("header", []).append(f"{header}: {value}")
# Per-url extra args
for key, value in url_data.items():
if key == "url":
continue
if key == "headers":
for header_name, header_value in value.items():
opts.setdefault("header", []).append(f"{header_name}: {header_value}")
else:
opts[key] = str(value)
# Add via RPC
gid = _manager.add_uris([url_data["url"]], opts)
if gid:
gids.append(gid)
yield dict(total=len(gids))
completed: set[str] = set()
try:
p = subprocess.Popen([binaries.Aria2, *arguments], stdin=subprocess.PIPE, stdout=subprocess.DEVNULL)
while len(completed) < len(gids):
if DOWNLOAD_CANCELLED.is_set():
# Remove tracked downloads on cancel
for gid in gids:
if gid not in completed:
_manager.remove(gid)
yield dict(downloaded="[yellow]CANCELLED")
raise KeyboardInterrupt()
p.stdin.write(url_file.encode())
p.stdin.close()
stats = _manager.get_global_stat()
dl_speed = int(stats.get("downloadSpeed", -1))
while p.poll() is None:
global_stats: dict[str, Any] = (
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.getGlobalStat")
or {}
)
# Aggregate progress across all GIDs for this call
total_completed = 0
total_size = 0
number_stopped = int(global_stats.get("numStoppedTotal", 0))
download_speed = int(global_stats.get("downloadSpeed", -1))
# Check each tracked GID
for gid in gids:
if gid in completed:
continue
if number_stopped:
yield dict(completed=number_stopped)
if download_speed != -1:
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
status = _manager.tell_status(gid)
if not status:
continue
stopped_downloads: list[dict[str, Any]] = (
rpc(
caller=partial(rpc_session.post, url=rpc_uri),
secret=rpc_secret,
method="aria2.tellStopped",
params=[0, 999999],
)
or []
)
completed_length = int(status.get("completedLength", 0))
total_length = int(status.get("totalLength", 0))
total_completed += completed_length
total_size += total_length
for dl in stopped_downloads:
if dl["status"] == "error":
used_uri = next(
uri["uri"]
for file in dl["files"]
if file["selected"] == "true"
for uri in file["uris"]
if uri["status"] == "used"
)
error = f"Download Error (#{dl['gid']}): {dl['errorMessage']} ({dl['errorCode']}), {used_uri}"
error_pretty = "\n ".join(
textwrap.wrap(error, width=console.width - 20, initial_indent="")
)
console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty))
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_download_error",
message=f"Aria2c download failed: {dl['errorMessage']}",
context={
"gid": dl["gid"],
"error_code": dl["errorCode"],
"error_message": dl["errorMessage"],
"used_uri": used_uri[:200] + "..." if len(used_uri) > 200 else used_uri,
"completed_length": dl.get("completedLength"),
"total_length": dl.get("totalLength"),
},
)
raise ValueError(error)
state = status.get("status")
if state in ("complete", "error"):
completed.add(gid)
yield dict(completed=len(completed))
if number_stopped == len(urls):
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.shutdown")
break
if state == "error":
used_uri = None
try:
used_uri = next(
uri["uri"]
for file in status.get("files", [])
for uri in file.get("uris", [])
if uri.get("status") == "used"
)
except Exception:
used_uri = "unknown"
error = f"Download Error (#{gid}): {status.get('errorMessage')} ({status.get('errorCode')}), {used_uri}"
error_pretty = "\n ".join(textwrap.wrap(error, width=console.width - 20, initial_indent=""))
console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty))
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_download_error",
message=f"Aria2c download failed: {status.get('errorMessage')}",
context={
"gid": gid,
"error_code": status.get("errorCode"),
"error_message": status.get("errorMessage"),
"used_uri": used_uri[:200] + "..." if used_uri and len(used_uri) > 200 else used_uri,
"completed_length": status.get("completedLength"),
"total_length": status.get("totalLength"),
},
)
raise ValueError(error)
# Yield aggregate progress for this call's downloads
if total_size > 0:
# Yield both advance (bytes downloaded this iteration) and total for rich progress
if dl_speed != -1:
yield dict(downloaded=f"{filesize.decimal(dl_speed)}/s", advance=0, completed=total_completed, total=total_size)
else:
yield dict(advance=0, completed=total_completed, total=total_size)
elif dl_speed != -1:
yield dict(downloaded=f"{filesize.decimal(dl_speed)}/s")
time.sleep(1)
p.wait()
if p.returncode != 0:
if debug_logger:
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_failed",
message=f"Aria2c exited with code {p.returncode}",
context={
"returncode": p.returncode,
"url_count": len(urls),
"output_dir": str(output_dir),
},
)
raise subprocess.CalledProcessError(p.returncode, arguments)
if debug_logger:
debug_logger.log(
level="DEBUG",
operation="downloader_aria2c_complete",
message="Aria2c download completed successfully",
context={
"url_count": len(urls),
"output_dir": str(output_dir),
"filename": filename,
},
)
except ConnectionResetError:
# interrupted while passing URI to download
raise KeyboardInterrupt()
except subprocess.CalledProcessError as e:
if e.returncode in (7, 0xC000013A):
# 7 is when Aria2(c) handled the CTRL+C
# 0xC000013A is when it never got the chance to
raise KeyboardInterrupt()
raise
except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
yield dict(downloaded="[yellow]CANCELLED")
DOWNLOAD_CANCELLED.set()
raise
except Exception as e:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
DOWNLOAD_CANCELLED.set()
yield dict(downloaded="[red]FAILED")
if debug_logger and not isinstance(e, (subprocess.CalledProcessError, ValueError)):
if debug_logger and not isinstance(e, ValueError):
debug_logger.log(
level="ERROR",
operation="downloader_aria2c_exception",
@@ -335,8 +400,6 @@ def download(
},
)
raise
finally:
rpc(caller=partial(rpc_session.post, url=rpc_uri), secret=rpc_secret, method="aria2.shutdown")
def aria2c(

View File

@@ -10,6 +10,7 @@ import requests
from requests.cookies import cookiejar_from_dict, get_cookie_header
from unshackle.core import binaries
from unshackle.core.binaries import FFMPEG, Mp4decrypt, ShakaPackager
from unshackle.core.config import config
from unshackle.core.console import console
from unshackle.core.constants import DOWNLOAD_CANCELLED
@@ -19,7 +20,7 @@ PERCENT_RE = re.compile(r"(\d+\.\d+%)")
SPEED_RE = re.compile(r"(\d+\.\d+(?:MB|KB)ps)")
SIZE_RE = re.compile(r"(\d+\.\d+(?:MB|GB|KB)/\d+\.\d+(?:MB|GB|KB))")
WARN_RE = re.compile(r"(WARN : Response.*|WARN : One or more errors occurred.*)")
ERROR_RE = re.compile(r"(ERROR.*)")
ERROR_RE = re.compile(r"(\bERROR\b.*|\bFAILED\b.*|\bException\b.*)")
DECRYPTION_ENGINE = {
"shaka": "SHAKA_PACKAGER",
@@ -181,17 +182,33 @@ def build_download_args(
"--tmp-dir": output_dir,
"--thread-count": thread_count,
"--download-retry-count": retry_count,
"--write-meta-json": False,
}
if FFMPEG:
args["--ffmpeg-binary-path"] = str(FFMPEG)
if proxy:
args["--custom-proxy"] = proxy
if skip_merge:
args["--skip-merge"] = skip_merge
if ad_keyword:
args["--ad-keyword"] = ad_keyword
if content_keys:
args["--key"] = next((f"{kid.hex}:{key.lower()}" for kid, key in content_keys.items()), None)
args["--decryption-engine"] = DECRYPTION_ENGINE.get(config.decryption.lower()) or "SHAKA_PACKAGER"
decryption_config = config.decryption.lower()
engine_name = DECRYPTION_ENGINE.get(decryption_config) or "SHAKA_PACKAGER"
args["--decryption-engine"] = engine_name
binary_path = None
if engine_name == "SHAKA_PACKAGER":
if ShakaPackager:
binary_path = str(ShakaPackager)
elif engine_name == "MP4DECRYPT":
if Mp4decrypt:
binary_path = str(Mp4decrypt)
if binary_path:
args["--decryption-binary-path"] = binary_path
if custom_args:
args.update(custom_args)
@@ -288,7 +305,10 @@ def download(
log_file_path: Path | None = None
if debug_logger:
log_file_path = output_dir / f".n_m3u8dl_re_{filename}.log"
arguments.extend(["--log-file-path", str(log_file_path)])
arguments.extend([
"--log-file-path", str(log_file_path),
"--log-level", "DEBUG",
])
track_url_display = track.url[:200] + "..." if len(track.url) > 200 else track.url
debug_logger.log(
@@ -376,6 +396,14 @@ def download(
raise subprocess.CalledProcessError(process.returncode, arguments)
if debug_logger:
output_dir_exists = output_dir.exists()
output_files = []
if output_dir_exists:
try:
output_files = [f.name for f in output_dir.iterdir() if f.is_file()][:20]
except Exception:
output_files = ["<error listing files>"]
debug_logger.log(
level="DEBUG",
operation="downloader_n_m3u8dl_re_complete",
@@ -384,10 +412,38 @@ def download(
"track_id": getattr(track, "id", None),
"track_type": track.__class__.__name__,
"output_dir": str(output_dir),
"output_dir_exists": output_dir_exists,
"output_files_count": len(output_files),
"output_files": output_files,
"filename": filename,
},
)
# Warn if no output was produced - include N_m3u8DL-RE's logs for diagnosis
if not output_dir_exists or not output_files:
# Read N_m3u8DL-RE's log file for debugging
n_m3u8dl_log = ""
if log_file_path and log_file_path.exists():
try:
n_m3u8dl_log = log_file_path.read_text(encoding="utf-8", errors="replace")
except Exception:
n_m3u8dl_log = "<failed to read log file>"
debug_logger.log(
level="WARNING",
operation="downloader_n_m3u8dl_re_no_output",
message="N_m3u8DL-RE exited successfully but produced no output files",
context={
"track_id": getattr(track, "id", None),
"track_type": track.__class__.__name__,
"output_dir": str(output_dir),
"output_dir_exists": output_dir_exists,
"selection_args": selection_args,
"track_url": track.url[:200] + "..." if len(track.url) > 200 else track.url,
"n_m3u8dl_re_log": n_m3u8dl_log,
},
)
except ConnectionResetError:
# interrupted while passing URI to download
raise KeyboardInterrupt()
@@ -419,6 +475,7 @@ def download(
)
raise
finally:
# Clean up temporary debug files
if log_file_path and log_file_path.exists():
try:
log_file_path.unlink()

View File

@@ -122,7 +122,7 @@ def download(
last_speed_refresh = now
download_sizes.clear()
if content_length and written < content_length:
if not segmented and content_length and written < content_length:
raise IOError(f"Failed to read {content_length} bytes from the track URI.")
yield dict(file_downloaded=save_path, written=written)
@@ -264,7 +264,7 @@ def requests(
try:
with ThreadPoolExecutor(max_workers=max_workers) as pool:
for future in as_completed(pool.submit(download, session=session, segmented=False, **url) for url in urls):
for future in as_completed(pool.submit(download, session=session, segmented=True, **url) for url in urls):
try:
yield from future.result()
except KeyboardInterrupt:

View File

@@ -1,10 +1,11 @@
from typing import Union
from unshackle.core.drm.clearkey import ClearKey
from unshackle.core.drm.monalisa import MonaLisa
from unshackle.core.drm.playready import PlayReady
from unshackle.core.drm.widevine import Widevine
DRM_T = Union[ClearKey, Widevine, PlayReady]
DRM_T = Union[ClearKey, Widevine, PlayReady, MonaLisa]
__all__ = ("ClearKey", "Widevine", "PlayReady", "DRM_T")
__all__ = ("ClearKey", "Widevine", "PlayReady", "MonaLisa", "DRM_T")

View File

@@ -0,0 +1,280 @@
"""
MonaLisa DRM System.
A WASM-based DRM system that uses local key extraction and two-stage
segment decryption (ML-Worker binary + AES-ECB).
"""
from __future__ import annotations
import os
import subprocess
import sys
from pathlib import Path
from typing import Any, Optional, Union
from uuid import UUID
from Cryptodome.Cipher import AES
from Cryptodome.Util.Padding import unpad
class MonaLisa:
"""
MonaLisa DRM System.
Unlike Widevine/PlayReady, MonaLisa does not use a challenge/response flow
with a license server. Instead, the PSSH value (ticket) is provided directly
by the service API, and keys are extracted locally via a WASM module.
Decryption is performed in two stages:
1. ML-Worker binary: Removes MonaLisa encryption layer (bbts -> ents)
2. AES-ECB decryption: Final decryption with service-provided key
"""
class Exceptions:
class TicketNotFound(Exception):
"""Raised when no PSSH/ticket data is provided."""
class KeyExtractionFailed(Exception):
"""Raised when key extraction from the ticket fails."""
class WorkerNotFound(Exception):
"""Raised when the ML-Worker binary is not found."""
class DecryptionFailed(Exception):
"""Raised when segment decryption fails."""
def __init__(
self,
ticket: Union[str, bytes],
aes_key: Union[str, bytes],
device_path: Path,
**kwargs: Any,
):
"""
Initialize MonaLisa DRM.
Args:
ticket: PSSH value from service API (base64 string or raw bytes).
aes_key: AES-ECB key for second-stage decryption (hex string or bytes).
device_path: Path to the CDM device file (.mld).
**kwargs: Additional metadata stored in self.data.
Raises:
TicketNotFound: If ticket/PSSH is empty.
KeyExtractionFailed: If key extraction fails.
"""
if not ticket:
raise MonaLisa.Exceptions.TicketNotFound("No PSSH/ticket data provided.")
self._ticket = ticket
# Store AES key for second-stage decryption
if isinstance(aes_key, str):
self._aes_key = bytes.fromhex(aes_key)
else:
self._aes_key = aes_key
self._device_path = device_path
self._kid: Optional[UUID] = None
self._key: Optional[str] = None
self.data: dict = kwargs or {}
# Extract keys immediately
self._extract_keys()
def _extract_keys(self) -> None:
"""Extract keys from the ticket using the MonaLisa CDM."""
# Import here to avoid circular import
from unshackle.core.cdm.monalisa import MonaLisaCDM
try:
cdm = MonaLisaCDM(device_path=self._device_path)
session_id = cdm.open()
try:
keys = cdm.extract_keys(self._ticket)
if keys:
kid_hex = keys.get("kid")
if kid_hex:
self._kid = UUID(hex=kid_hex)
self._key = keys.get("key")
finally:
cdm.close(session_id)
except Exception as e:
raise MonaLisa.Exceptions.KeyExtractionFailed(f"Failed to extract keys: {e}")
@classmethod
def from_ticket(
cls,
ticket: Union[str, bytes],
aes_key: Union[str, bytes],
device_path: Path,
) -> MonaLisa:
"""
Create a MonaLisa DRM instance from a PSSH/ticket.
Args:
ticket: PSSH value from service API.
aes_key: AES-ECB key for second-stage decryption.
device_path: Path to the CDM device file (.mld).
Returns:
MonaLisa DRM instance with extracted keys.
"""
return cls(ticket=ticket, aes_key=aes_key, device_path=device_path)
@property
def kid(self) -> Optional[UUID]:
"""Get the Key ID."""
return self._kid
@property
def key(self) -> Optional[str]:
"""Get the content key as hex string."""
return self._key
@property
def pssh(self) -> str:
"""
Get the raw PSSH/ticket value as a string.
Returns:
The raw PSSH value as a base64 string.
"""
if isinstance(self._ticket, bytes):
return self._ticket.decode("utf-8")
return self._ticket
@property
def content_id(self) -> Optional[str]:
"""
Extract the Content ID from the PSSH for display.
The PSSH contains an embedded Content ID at bytes 21-75 with format:
H5DCID-V3-P1-YYYYMMDD-HHMMSS-MEDIAID-TIMESTAMP-SUFFIX
Returns:
The Content ID string if extractable, None otherwise.
"""
import base64
try:
# Decode base64 PSSH to get raw bytes
if isinstance(self._ticket, bytes):
data = self._ticket
else:
data = base64.b64decode(self._ticket)
# Content ID is at bytes 21-75 (55 bytes)
if len(data) >= 76:
content_id = data[21:76].decode("ascii")
# Validate it looks like a content ID
if content_id.startswith("H5DCID-"):
return content_id
except Exception:
pass
return None
@property
def content_keys(self) -> dict[UUID, str]:
"""
Get content keys in the same format as Widevine/PlayReady.
Returns:
Dictionary mapping KID to key hex string.
"""
if self._kid and self._key:
return {self._kid: self._key}
return {}
def decrypt_segment(self, segment_path: Path) -> None:
"""
Decrypt a single segment using two-stage decryption.
Stage 1: ML-Worker binary (bbts -> ents)
Stage 2: AES-ECB decryption (ents -> ts)
Args:
segment_path: Path to the encrypted segment file.
Raises:
WorkerNotFound: If ML-Worker binary is not available.
DecryptionFailed: If decryption fails at any stage.
"""
if not self._key:
return
# Import here to avoid circular import
from unshackle.core.cdm.monalisa import MonaLisaCDM
worker_path = MonaLisaCDM.get_worker_path()
if not worker_path or not worker_path.exists():
raise MonaLisa.Exceptions.WorkerNotFound("ML-Worker not found.")
bbts_path = segment_path.with_suffix(".bbts")
ents_path = segment_path.with_suffix(".ents")
try:
if segment_path.exists():
segment_path.replace(bbts_path)
else:
raise MonaLisa.Exceptions.DecryptionFailed(f"Segment file does not exist: {segment_path}")
# Stage 1: ML-Worker decryption
cmd = [str(worker_path), self._key, str(bbts_path), str(ents_path)]
startupinfo = None
if sys.platform == "win32":
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
process = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
startupinfo=startupinfo,
)
if process.returncode != 0:
raise MonaLisa.Exceptions.DecryptionFailed(
f"ML-Worker failed for {segment_path.name}: {process.stderr}"
)
if not ents_path.exists():
raise MonaLisa.Exceptions.DecryptionFailed(
f"Decrypted .ents file was not created for {segment_path.name}"
)
# Stage 2: AES-ECB decryption
with open(ents_path, "rb") as f:
ents_data = f.read()
crypto = AES.new(self._aes_key, AES.MODE_ECB)
decrypted_data = unpad(crypto.decrypt(ents_data), AES.block_size)
# Write decrypted segment back to original path
with open(segment_path, "wb") as f:
f.write(decrypted_data)
except MonaLisa.Exceptions.DecryptionFailed:
raise
except Exception as e:
raise MonaLisa.Exceptions.DecryptionFailed(f"Failed to decrypt segment {segment_path.name}: {e}")
finally:
if ents_path.exists():
os.remove(ents_path)
if bbts_path != segment_path and bbts_path.exists():
os.remove(bbts_path)
def decrypt(self, _path: Path) -> None:
"""
MonaLisa uses per-segment decryption during download via the
on_segment_downloaded callback. By the time this method is called,
the content has already been decrypted and muxed into a container.
Args:
path: Path to the file (ignored).
"""
pass

View File

@@ -151,6 +151,11 @@ class DASH:
if not track_fps and segment_base is not None:
track_fps = segment_base.get("timescale")
scan_type = None
scan_type_str = get("scanType")
if scan_type_str and scan_type_str.lower() == "interlaced":
scan_type = Video.ScanType.INTERLACED
track_args = dict(
range_=self.get_video_range(
codecs, findall("SupplementalProperty"), findall("EssentialProperty")
@@ -159,6 +164,7 @@ class DASH:
width=get("width") or 0,
height=get("height") or 0,
fps=track_fps or None,
scan_type=scan_type,
)
elif content_type == "audio":
track_type = Audio
@@ -366,6 +372,9 @@ class DASH:
if not end_number:
end_number = len(segment_durations)
# Handle high startNumber in DVR/catch-up manifests where startNumber > segment count
if start_number > end_number:
end_number = start_number + len(segment_durations) - 1
for t, n in zip(segment_durations, range(start_number, end_number + 1)):
segments.append(
@@ -467,8 +476,9 @@ class DASH:
track.data["dash"]["timescale"] = int(segment_timescale)
track.data["dash"]["segment_durations"] = segment_durations
if init_data and isinstance(track, (Video, Audio)):
if isinstance(cdm, PlayReadyCdm):
if not track.drm and init_data and isinstance(track, (Video, Audio)):
prefers_playready = isinstance(cdm, PlayReadyCdm) or (hasattr(cdm, "is_playready") and cdm.is_playready)
if prefers_playready:
try:
track.drm = [PlayReady.from_init_data(init_data)]
except PlayReady.Exceptions.PSSHNotFound:
@@ -572,8 +582,64 @@ class DASH:
for control_file in save_dir.glob("*.aria2__temp"):
control_file.unlink()
# Verify output directory exists and contains files
if not save_dir.exists():
error_msg = f"Output directory does not exist: {save_dir}"
if debug_logger:
debug_logger.log(
level="ERROR",
operation="manifest_dash_download_output_missing",
message=error_msg,
context={
"track_id": getattr(track, "id", None),
"track_type": track.__class__.__name__,
"save_dir": str(save_dir),
"save_path": str(save_path),
"downloader": downloader.__name__,
"skip_merge": skip_merge,
},
)
raise FileNotFoundError(error_msg)
segments_to_merge = [x for x in sorted(save_dir.iterdir()) if x.is_file()]
if debug_logger:
debug_logger.log(
level="DEBUG",
operation="manifest_dash_download_complete",
message="DASH download complete, preparing to merge",
context={
"track_id": getattr(track, "id", None),
"track_type": track.__class__.__name__,
"save_dir": str(save_dir),
"save_dir_exists": save_dir.exists(),
"segments_found": len(segments_to_merge),
"segment_files": [f.name for f in segments_to_merge[:10]], # Limit to first 10
"downloader": downloader.__name__,
"skip_merge": skip_merge,
},
)
if not segments_to_merge:
error_msg = f"No segment files found in output directory: {save_dir}"
if debug_logger:
# List all contents of the directory for debugging
all_contents = list(save_dir.iterdir()) if save_dir.exists() else []
debug_logger.log(
level="ERROR",
operation="manifest_dash_download_no_segments",
message=error_msg,
context={
"track_id": getattr(track, "id", None),
"track_type": track.__class__.__name__,
"save_dir": str(save_dir),
"directory_contents": [str(p) for p in all_contents],
"downloader": downloader.__name__,
"skip_merge": skip_merge,
},
)
raise FileNotFoundError(error_msg)
if skip_merge:
# N_m3u8DL-RE handles merging and decryption internally
shutil.move(segments_to_merge[0], save_path)
@@ -800,7 +866,7 @@ class DASH:
urn = (protection.get("schemeIdUri") or "").lower()
if urn == WidevineCdm.urn:
pssh_text = protection.findtext("pssh")
pssh_text = protection.findtext("pssh") or protection.findtext("{urn:mpeg:cenc:2013}pssh")
if not pssh_text:
continue
pssh = PSSH(pssh_text)
@@ -831,6 +897,7 @@ class DASH:
elif urn in ("urn:uuid:9a04f079-9840-4286-ab92-e65be0885f95", "urn:microsoft:playready"):
pr_pssh_b64 = (
protection.findtext("pssh")
or protection.findtext("{urn:mpeg:cenc:2013}pssh")
or protection.findtext("pro")
or protection.findtext("{urn:microsoft:playready}pro")
)

View File

@@ -30,7 +30,7 @@ from requests import Session
from unshackle.core import binaries
from unshackle.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack
from unshackle.core.downloaders import requests as requests_downloader
from unshackle.core.drm import DRM_T, ClearKey, PlayReady, Widevine
from unshackle.core.drm import DRM_T, ClearKey, MonaLisa, PlayReady, Widevine
from unshackle.core.events import events
from unshackle.core.tracks import Audio, Subtitle, Tracks, Video
from unshackle.core.utilities import get_debug_logger, get_extension, is_close_match, try_ensure_utf8
@@ -316,6 +316,10 @@ class HLS:
progress(downloaded="[red]FAILED")
raise
if not initial_drm_licensed and session_drm and isinstance(session_drm, MonaLisa):
if license_widevine:
license_widevine(session_drm)
if DOWNLOAD_LICENCE_ONLY.is_set():
progress(downloaded="[yellow]SKIPPED")
return
@@ -591,7 +595,11 @@ class HLS:
segment_keys = getattr(segment, "keys", None)
if segment_keys:
key = HLS.get_supported_key(segment_keys)
if cdm:
cdm_segment_keys = HLS.filter_keys_for_cdm(segment_keys, cdm)
key = HLS.get_supported_key(cdm_segment_keys) if cdm_segment_keys else HLS.get_supported_key(segment_keys)
else:
key = HLS.get_supported_key(segment_keys)
if encryption_data and encryption_data[0] != key and i != 0 and segment not in unwanted_segments:
decrypt(include_this_segment=False)
@@ -650,6 +658,44 @@ class HLS:
# finally merge all the discontinuity save files together to the final path
segments_to_merge = find_segments_recursively(save_dir)
if debug_logger:
debug_logger.log(
level="DEBUG",
operation="manifest_hls_download_complete",
message="HLS download complete, preparing to merge",
context={
"track_id": getattr(track, "id", None),
"track_type": track.__class__.__name__,
"save_dir": str(save_dir),
"save_dir_exists": save_dir.exists(),
"segments_found": len(segments_to_merge),
"segment_files": [f.name for f in segments_to_merge[:10]], # Limit to first 10
"downloader": downloader.__name__,
"skip_merge": skip_merge,
},
)
if not segments_to_merge:
error_msg = f"No segment files found in output directory: {save_dir}"
if debug_logger:
all_contents = list(save_dir.iterdir()) if save_dir.exists() else []
debug_logger.log(
level="ERROR",
operation="manifest_hls_download_no_segments",
message=error_msg,
context={
"track_id": getattr(track, "id", None),
"track_type": track.__class__.__name__,
"save_dir": str(save_dir),
"save_dir_exists": save_dir.exists(),
"directory_contents": [str(p) for p in all_contents],
"downloader": downloader.__name__,
"skip_merge": skip_merge,
},
)
raise FileNotFoundError(error_msg)
if len(segments_to_merge) == 1:
shutil.move(segments_to_merge[0], save_path)
else:
@@ -889,7 +935,8 @@ class HLS:
elif key.keyformat and key.keyformat.lower() == WidevineCdm.urn:
return key
elif key.keyformat and key.keyformat.lower() in {
f"urn:uuid:{PR_PSSH.SYSTEM_ID}", "com.microsoft.playready"
f"urn:uuid:{PR_PSSH.SYSTEM_ID}",
"com.microsoft.playready",
}:
return key
else:
@@ -927,9 +974,7 @@ class HLS:
pssh=WV_PSSH(key.uri.split(",")[-1]),
**key._extra_params, # noqa
)
elif key.keyformat and key.keyformat.lower() in {
f"urn:uuid:{PR_PSSH.SYSTEM_ID}", "com.microsoft.playready"
}:
elif key.keyformat and key.keyformat.lower() in {f"urn:uuid:{PR_PSSH.SYSTEM_ID}", "com.microsoft.playready"}:
drm = PlayReady(
pssh=PR_PSSH(key.uri.split(",")[-1]),
pssh_b64=key.uri.split(",")[-1],

View File

@@ -314,8 +314,63 @@ class ISM:
for control_file in save_dir.glob("*.aria2__temp"):
control_file.unlink()
# Verify output directory exists and contains files
if not save_dir.exists():
error_msg = f"Output directory does not exist: {save_dir}"
if debug_logger:
debug_logger.log(
level="ERROR",
operation="manifest_ism_download_output_missing",
message=error_msg,
context={
"track_id": getattr(track, "id", None),
"track_type": track.__class__.__name__,
"save_dir": str(save_dir),
"save_path": str(save_path),
"downloader": downloader.__name__,
"skip_merge": skip_merge,
},
)
raise FileNotFoundError(error_msg)
segments_to_merge = [x for x in sorted(save_dir.iterdir()) if x.is_file()]
if debug_logger:
debug_logger.log(
level="DEBUG",
operation="manifest_ism_download_complete",
message="ISM download complete, preparing to merge",
context={
"track_id": getattr(track, "id", None),
"track_type": track.__class__.__name__,
"save_dir": str(save_dir),
"save_dir_exists": save_dir.exists(),
"segments_found": len(segments_to_merge),
"segment_files": [f.name for f in segments_to_merge[:10]], # Limit to first 10
"downloader": downloader.__name__,
"skip_merge": skip_merge,
},
)
if not segments_to_merge:
error_msg = f"No segment files found in output directory: {save_dir}"
if debug_logger:
all_contents = list(save_dir.iterdir()) if save_dir.exists() else []
debug_logger.log(
level="ERROR",
operation="manifest_ism_download_no_segments",
message=error_msg,
context={
"track_id": getattr(track, "id", None),
"track_type": track.__class__.__name__,
"save_dir": str(save_dir),
"directory_contents": [str(p) for p in all_contents],
"downloader": downloader.__name__,
"skip_merge": skip_merge,
},
)
raise FileNotFoundError(error_msg)
if skip_merge:
shutil.move(segments_to_merge[0], save_path)
else:

View File

@@ -1,7 +1,8 @@
from .basic import Basic
from .gluetun import Gluetun
from .hola import Hola
from .nordvpn import NordVPN
from .surfsharkvpn import SurfsharkVPN
from .windscribevpn import WindscribeVPN
__all__ = ("Basic", "Hola", "NordVPN", "SurfsharkVPN", "WindscribeVPN")
__all__ = ("Basic", "Gluetun", "Hola", "NordVPN", "SurfsharkVPN", "WindscribeVPN")

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,5 @@
import json
import random
import re
from typing import Optional
@@ -46,8 +47,21 @@ class NordVPN(Proxy):
HTTP proxies under port 80 were disabled on the 15th of Feb, 2021:
https://nordvpn.com/blog/removing-http-proxies
Supports:
- Country code: "us", "ca", "gb"
- Country ID: "228"
- Specific server: "us1234"
- City selection: "us:seattle", "ca:calgary"
"""
query = query.lower()
city = None
# Check if query includes city specification (e.g., "ca:calgary")
if ":" in query:
query, city = query.split(":", maxsplit=1)
city = city.strip()
if re.match(r"^[a-z]{2}\d+$", query):
# country and nordvpn server id, e.g., us1, fr1234
hostname = f"{query}.nordvpn.com"
@@ -64,7 +78,12 @@ class NordVPN(Proxy):
# NordVPN doesnt have servers in this region
return
server_mapping = self.server_map.get(country["code"].lower())
# Check server_map for pinned servers (can include city)
server_map_key = f"{country['code'].lower()}:{city}" if city else country["code"].lower()
server_mapping = self.server_map.get(server_map_key) or (
self.server_map.get(country["code"].lower()) if not city else None
)
if server_mapping:
# country was set to a specific server ID in config
hostname = f"{country['code'].lower()}{server_mapping}.nordvpn.com"
@@ -76,7 +95,19 @@ class NordVPN(Proxy):
f"The NordVPN Country {query} currently has no recommended servers. "
"Try again later. If the issue persists, double-check the query."
)
hostname = recommended_servers[0]["hostname"]
# Filter by city if specified
if city:
city_servers = self.filter_servers_by_city(recommended_servers, city)
if not city_servers:
raise ValueError(
f"No servers found in city '{city}' for country '{country['name']}'. "
"Try a different city or check the city name spelling."
)
recommended_servers = city_servers
# Pick a random server from the filtered list
hostname = random.choice(recommended_servers)["hostname"]
if hostname.startswith("gb"):
# NordVPN uses the alpha2 of 'GB' in API responses, but 'UK' in the hostname
@@ -95,6 +126,41 @@ class NordVPN(Proxy):
):
return country
@staticmethod
def filter_servers_by_city(servers: list[dict], city: str) -> list[dict]:
"""
Filter servers by city name.
The API returns servers with location data that includes city information.
This method filters servers to only those in the specified city.
Args:
servers: List of server dictionaries from the NordVPN API
city: City name to filter by (case-insensitive)
Returns:
List of servers in the specified city
"""
city_lower = city.lower()
filtered = []
for server in servers:
# Each server has a 'locations' list with location data
locations = server.get("locations", [])
for location in locations:
# City data can be in different formats:
# - {"city": {"name": "Seattle", ...}}
# - {"city": "Seattle"}
city_data = location.get("city")
if city_data:
# Handle both dict and string formats
city_name = city_data.get("name") if isinstance(city_data, dict) else city_data
if city_name and city_name.lower() == city_lower:
filtered.append(server)
break # Found a match, no need to check other locations for this server
return filtered
@staticmethod
def get_recommended_servers(country_id: int) -> list[dict]:
"""

View File

@@ -44,8 +44,21 @@ class SurfsharkVPN(Proxy):
def get_proxy(self, query: str) -> Optional[str]:
"""
Get an HTTP(SSL) proxy URI for a SurfsharkVPN server.
Supports:
- Country code: "us", "ca", "gb"
- Country ID: "228"
- Specific server: "us-bos" (Boston)
- City selection: "us:seattle", "ca:toronto"
"""
query = query.lower()
city = None
# Check if query includes city specification (e.g., "us:seattle")
if ":" in query:
query, city = query.split(":", maxsplit=1)
city = city.strip()
if re.match(r"^[a-z]{2}\d+$", query):
# country and surfsharkvpn server id, e.g., au-per, be-anr, us-bos
hostname = f"{query}.prod.surfshark.com"
@@ -62,13 +75,18 @@ class SurfsharkVPN(Proxy):
# SurfsharkVPN doesnt have servers in this region
return
server_mapping = self.server_map.get(country["countryCode"].lower())
# Check server_map for pinned servers (can include city)
server_map_key = f"{country['countryCode'].lower()}:{city}" if city else country["countryCode"].lower()
server_mapping = self.server_map.get(server_map_key) or (
self.server_map.get(country["countryCode"].lower()) if not city else None
)
if server_mapping:
# country was set to a specific server ID in config
hostname = f"{country['code'].lower()}{server_mapping}.prod.surfshark.com"
else:
# get the random server ID
random_server = self.get_random_server(country["countryCode"])
random_server = self.get_random_server(country["countryCode"], city)
if not random_server:
raise ValueError(
f"The SurfsharkVPN Country {query} currently has no random servers. "
@@ -92,18 +110,44 @@ class SurfsharkVPN(Proxy):
):
return country
def get_random_server(self, country_id: str):
def get_random_server(self, country_id: str, city: Optional[str] = None):
"""
Get the list of random Server for a Country.
Get a random server for a Country, optionally filtered by city.
Note: There may not always be more than one recommended server.
Args:
country_id: The country code (e.g., "US", "CA")
city: Optional city name to filter by (case-insensitive)
Note: The API may include a 'location' field with city information.
If not available, this will return any server from the country.
"""
country = [x["connectionName"] for x in self.countries if x["countryCode"].lower() == country_id.lower()]
servers = [x for x in self.countries if x["countryCode"].lower() == country_id.lower()]
# Filter by city if specified
if city:
city_lower = city.lower()
# Check if servers have a 'location' field for city filtering
city_servers = [
x
for x in servers
if x.get("location", "").lower() == city_lower or x.get("city", "").lower() == city_lower
]
if city_servers:
servers = city_servers
else:
raise ValueError(
f"No servers found in city '{city}' for country '{country_id}'. "
"Try a different city or check the city name spelling."
)
# Get connection names from filtered servers
connection_names = [x["connectionName"] for x in servers]
try:
country = random.choice(country)
return country
except Exception:
raise ValueError("Could not get random countrycode from the countries list.")
return random.choice(connection_names)
except (IndexError, KeyError):
raise ValueError(f"Could not get random server for country '{country_id}'.")
@staticmethod
def get_countries() -> list[dict]:

View File

@@ -45,22 +45,27 @@ class WindscribeVPN(Proxy):
"""
Get an HTTPS proxy URI for a WindscribeVPN server.
Note: Windscribe's static OpenVPN credentials work reliably on US, AU, and NZ servers.
Supports:
- Country code: "us", "ca", "gb"
- City selection: "us:seattle", "ca:toronto"
"""
query = query.lower()
supported_regions = {"us", "au", "nz"}
city = None
if query not in supported_regions and query not in self.server_map:
raise ValueError(
f"Windscribe proxy does not currently support the '{query.upper()}' region. "
f"Supported regions with reliable credentials: {', '.join(sorted(supported_regions))}. "
)
# Check if query includes city specification (e.g., "ca:toronto")
if ":" in query:
query, city = query.split(":", maxsplit=1)
city = city.strip()
if query in self.server_map:
# Check server_map for pinned servers (can include city)
server_map_key = f"{query}:{city}" if city else query
if server_map_key in self.server_map:
hostname = self.server_map[server_map_key]
elif query in self.server_map and not city:
hostname = self.server_map[query]
else:
if re.match(r"^[a-z]+$", query):
hostname = self.get_random_server(query)
hostname = self.get_random_server(query, city)
else:
raise ValueError(f"The query provided is unsupported and unrecognized: {query}")
@@ -70,22 +75,42 @@ class WindscribeVPN(Proxy):
hostname = hostname.split(':')[0]
return f"https://{self.username}:{self.password}@{hostname}:443"
def get_random_server(self, country_code: str) -> Optional[str]:
def get_random_server(self, country_code: str, city: Optional[str] = None) -> Optional[str]:
"""
Get a random server hostname for a country.
Get a random server hostname for a country, optionally filtered by city.
Returns None if no servers are available for the country.
Args:
country_code: The country code (e.g., "us", "ca")
city: Optional city name to filter by (case-insensitive)
Returns:
A random hostname from matching servers, or None if none available.
"""
hostnames = []
# Collect hostnames from ALL locations matching the country code
for location in self.countries:
if location.get("country_code", "").lower() == country_code.lower():
hostnames = []
for group in location.get("groups", []):
# Filter by city if specified
if city:
group_city = group.get("city", "")
if group_city.lower() != city.lower():
continue
# Collect hostnames from this group
for host in group.get("hosts", []):
if hostname := host.get("hostname"):
hostnames.append(hostname)
if hostnames:
return random.choice(hostnames)
if hostnames:
return random.choice(hostnames)
elif city:
# No servers found for the specified city
raise ValueError(
f"No servers found in city '{city}' for country code '{country_code}'. "
"Try a different city or check the city name spelling."
)
return None

View File

@@ -53,8 +53,55 @@ class Service(metaclass=ABCMeta):
if not ctx.parent or not ctx.parent.params.get("no_proxy"):
if ctx.parent:
proxy = ctx.parent.params["proxy"]
proxy_query = ctx.parent.params.get("proxy_query")
proxy_provider_name = ctx.parent.params.get("proxy_provider")
else:
proxy = None
proxy_query = None
proxy_provider_name = None
# Check for service-specific proxy mapping
service_name = self.__class__.__name__
service_config_dict = config.services.get(service_name, {})
proxy_map = service_config_dict.get("proxy_map", {})
if proxy_map and proxy_query:
# Build the full proxy query key (e.g., "nordvpn:ca" or "us")
if proxy_provider_name:
full_proxy_key = f"{proxy_provider_name}:{proxy_query}"
else:
full_proxy_key = proxy_query
# Check if there's a mapping for this query
mapped_value = proxy_map.get(full_proxy_key)
if mapped_value:
self.log.info(f"Found service-specific proxy mapping: {full_proxy_key} -> {mapped_value}")
# Query the proxy provider with the mapped value
if proxy_provider_name:
# Specific provider requested
proxy_provider = next(
(x for x in ctx.obj.proxy_providers if x.__class__.__name__.lower() == proxy_provider_name),
None,
)
if proxy_provider:
mapped_proxy_uri = proxy_provider.get_proxy(mapped_value)
if mapped_proxy_uri:
proxy = mapped_proxy_uri
self.log.info(f"Using mapped proxy from {proxy_provider.__class__.__name__}: {proxy}")
else:
self.log.warning(f"Failed to get proxy for mapped value '{mapped_value}', using default")
else:
self.log.warning(f"Proxy provider '{proxy_provider_name}' not found, using default proxy")
else:
# No specific provider, try all providers
for proxy_provider in ctx.obj.proxy_providers:
mapped_proxy_uri = proxy_provider.get_proxy(mapped_value)
if mapped_proxy_uri:
proxy = mapped_proxy_uri
self.log.info(f"Using mapped proxy from {proxy_provider.__class__.__name__}: {proxy}")
break
else:
self.log.warning(f"No provider could resolve mapped value '{mapped_value}', using default")
if not proxy:
# don't override the explicit proxy set by the user, even if they may be geoblocked

View File

@@ -58,6 +58,7 @@ class Services(click.MultiCommand):
def get_path(name: str) -> Path:
"""Get the directory path of a command."""
tag = Services.get_tag(name)
for service in _SERVICES:
if service.parent.stem == tag:
return service.parent
@@ -72,19 +73,22 @@ class Services(click.MultiCommand):
"""
original_value = value
value = value.lower()
for path in _SERVICES:
tag = path.parent.stem
if value in (tag.lower(), *_ALIASES.get(tag, [])):
return tag
return original_value
@staticmethod
def load(tag: str) -> Service:
"""Load a Service module by Service tag."""
module = _MODULES.get(tag)
if not module:
raise KeyError(f"There is no Service added by the Tag '{tag}'")
return module
if module:
return module
raise KeyError(f"There is no Service added by the Tag '{tag}'")
__all__ = ("Services",)

View File

@@ -47,6 +47,8 @@ class Movie(Title):
def __str__(self) -> str:
if self.year:
if config.dash_naming:
return f"{self.name} - {self.year}"
return f"{self.name} ({self.year})"
return self.name
@@ -86,11 +88,21 @@ class Movie(Title):
# likely a movie or HD source, so it's most likely widescreen so
# 16:9 canvas makes the most sense.
resolution = int(primary_video_track.width * (9 / 16))
name += f" {resolution}p"
# Determine scan type suffix - default to "p", use "i" only if explicitly interlaced
scan_suffix = "p"
scan_type = getattr(primary_video_track, 'scan_type', None)
if scan_type and str(scan_type).lower() == "interlaced":
scan_suffix = "i"
name += f" {resolution}{scan_suffix}"
# Service
# Service (use track source if available)
if show_service:
name += f" {self.service.__name__}"
source_name = None
if self.tracks:
first_track = next(iter(self.tracks), None)
if first_track and hasattr(first_track, "source") and first_track.source:
source_name = first_track.source
name += f" {source_name or self.service.__name__}"
# 'WEB-DL'
name += " WEB-DL"

View File

@@ -101,9 +101,14 @@ class Song(Title):
name = str(self).split(" / ")[1]
if config.scene_naming:
# Service
# Service (use track source if available)
if show_service:
name += f" {self.service.__name__}"
source_name = None
if self.tracks:
first_track = next(iter(self.tracks), None)
if first_track and hasattr(first_track, "source") and first_track.source:
source_name = first_track.source
name += f" {source_name or self.service.__name__}"
# 'WEB-DL'
name += " WEB-DL"

View File

@@ -8,7 +8,7 @@ from pathlib import Path
from rich.padding import Padding
from rich.rule import Rule
from unshackle.core.binaries import DoviTool, HDR10PlusTool
from unshackle.core.binaries import FFMPEG, DoviTool, HDR10PlusTool
from unshackle.core.config import config
from unshackle.core.console import console
@@ -109,7 +109,7 @@ class Hybrid:
"""Simple ffmpeg execution without progress tracking"""
p = subprocess.run(
[
"ffmpeg",
str(FFMPEG) if FFMPEG else "ffmpeg",
"-nostdin",
"-i",
str(save_path),

View File

@@ -314,6 +314,7 @@ class Tracks:
progress: Optional[partial] = None,
audio_expected: bool = True,
title_language: Optional[Language] = None,
skip_subtitles: bool = False,
) -> tuple[Path, int, list[str]]:
"""
Multiplex all the Tracks into a Matroska Container file.
@@ -328,6 +329,7 @@ class Tracks:
if embedded audio metadata should be added.
title_language: The title's intended language. Used to select the best video track
for audio metadata when multiple video tracks exist.
skip_subtitles: Skip muxing subtitle tracks into the container.
"""
if self.videos and not self.audio and audio_expected:
video_track = None
@@ -439,34 +441,35 @@ class Tracks:
]
)
for st in self.subtitles:
if not st.path or not st.path.exists():
raise ValueError("Text Track must be downloaded before muxing...")
events.emit(events.Types.TRACK_MULTIPLEX, track=st)
default = bool(self.audio and is_close_match(st.language, [self.audio[0].language]) and st.forced)
cl.extend(
[
"--track-name",
f"0:{st.get_track_name() or ''}",
"--language",
f"0:{st.language}",
"--sub-charset",
"0:UTF-8",
"--forced-track",
f"0:{st.forced}",
"--default-track",
f"0:{default}",
"--hearing-impaired-flag",
f"0:{st.sdh}",
"--original-flag",
f"0:{st.is_original_lang}",
"--compression",
"0:none", # disable extra compression (probably zlib)
"(",
str(st.path),
")",
]
)
if not skip_subtitles:
for st in self.subtitles:
if not st.path or not st.path.exists():
raise ValueError("Text Track must be downloaded before muxing...")
events.emit(events.Types.TRACK_MULTIPLEX, track=st)
default = bool(self.audio and is_close_match(st.language, [self.audio[0].language]) and st.forced)
cl.extend(
[
"--track-name",
f"0:{st.get_track_name() or ''}",
"--language",
f"0:{st.language}",
"--sub-charset",
"0:UTF-8",
"--forced-track",
f"0:{st.forced}",
"--default-track",
f"0:{default}",
"--hearing-impaired-flag",
f"0:{st.sdh}",
"--original-flag",
f"0:{st.is_original_lang}",
"--compression",
"0:none", # disable extra compression (probably zlib)
"(",
str(st.path),
")",
]
)
if self.chapters:
chapters_path = config.directories.temp / config.filenames.chapters.format(

View File

@@ -186,6 +186,10 @@ class Video(Track):
# for some reason there's no Dolby Vision info tag
raise ValueError(f"The M3U Range Tag '{tag}' is not a supported Video Range")
class ScanType(str, Enum):
PROGRESSIVE = "progressive"
INTERLACED = "interlaced"
def __init__(
self,
*args: Any,
@@ -195,6 +199,7 @@ class Video(Track):
width: Optional[int] = None,
height: Optional[int] = None,
fps: Optional[Union[str, int, float]] = None,
scan_type: Optional[Video.ScanType] = None,
**kwargs: Any,
) -> None:
"""
@@ -232,6 +237,8 @@ class Video(Track):
raise TypeError(f"Expected height to be a {int}, not {height!r}")
if not isinstance(fps, (str, int, float, type(None))):
raise TypeError(f"Expected fps to be a {str}, {int}, or {float}, not {fps!r}")
if not isinstance(scan_type, (Video.ScanType, type(None))):
raise TypeError(f"Expected scan_type to be a {Video.ScanType}, not {scan_type!r}")
self.codec = codec
self.range = range_ or Video.Range.SDR
@@ -256,6 +263,7 @@ class Video(Track):
except Exception as e:
raise ValueError("Expected fps to be a number, float, or a string as numerator/denominator form, " + str(e))
self.scan_type = scan_type
self.needs_duration_fix = False
def __str__(self) -> str:

View File

@@ -19,6 +19,7 @@ from urllib.parse import ParseResult, urlparse
from uuid import uuid4
import chardet
import pycountry
import requests
from construct import ValidationError
from fontTools import ttLib
@@ -277,6 +278,80 @@ def ap_case(text: str, keep_spaces: bool = False, stop_words: tuple[str] = None)
)
# Common country code aliases that differ from ISO 3166-1 alpha-2
COUNTRY_CODE_ALIASES = {
"uk": "gb", # United Kingdom -> Great Britain
}
def get_country_name(code: str) -> Optional[str]:
"""
Convert a 2-letter country code to full country name.
Args:
code: ISO 3166-1 alpha-2 country code (e.g., 'ca', 'us', 'gb', 'uk')
Returns:
Full country name (e.g., 'Canada', 'United States', 'United Kingdom') or None if not found
Examples:
>>> get_country_name('ca')
'Canada'
>>> get_country_name('US')
'United States'
>>> get_country_name('uk')
'United Kingdom'
"""
# Handle common aliases
code = COUNTRY_CODE_ALIASES.get(code.lower(), code.lower())
try:
country = pycountry.countries.get(alpha_2=code.upper())
if country:
return country.name
except (KeyError, LookupError):
pass
return None
def get_country_code(name: str) -> Optional[str]:
"""
Convert a country name to its 2-letter ISO 3166-1 alpha-2 code.
Args:
name: Full country name (e.g., 'Canada', 'United States', 'United Kingdom')
Returns:
2-letter country code in uppercase (e.g., 'CA', 'US', 'GB') or None if not found
Examples:
>>> get_country_code('Canada')
'CA'
>>> get_country_code('united states')
'US'
>>> get_country_code('United Kingdom')
'GB'
"""
try:
# Try exact name match first
country = pycountry.countries.get(name=name.title())
if country:
return country.alpha_2.upper()
# Try common name (e.g., "Bolivia" vs "Bolivia, Plurinational State of")
country = pycountry.countries.get(common_name=name.title())
if country:
return country.alpha_2.upper()
# Try fuzzy search as fallback
results = pycountry.countries.search_fuzzy(name)
if results:
return results[0].alpha_2.upper()
except (KeyError, LookupError):
pass
return None
def get_ip_info(session: Optional[requests.Session] = None) -> dict:
"""
Use ipinfo.io to get IP location information.

View File

@@ -5,6 +5,8 @@ import click
from click.shell_completion import CompletionItem
from pywidevine.cdm import Cdm as WidevineCdm
from unshackle.core.tracks.audio import Audio
class VideoCodecChoice(click.Choice):
"""
@@ -241,6 +243,52 @@ class QualityList(click.ParamType):
return sorted(resolutions, reverse=True)
class AudioCodecList(click.ParamType):
"""Parses comma-separated audio codecs like 'AAC,EC3'."""
name = "audio_codec_list"
def __init__(self, codec_enum):
self.codec_enum = codec_enum
self._name_to_codec: dict[str, Audio.Codec] = {}
for codec in codec_enum:
self._name_to_codec[codec.name.lower()] = codec
self._name_to_codec[codec.value.lower()] = codec
aliases = {
"eac3": "EC3",
"ddp": "EC3",
"vorbis": "OGG",
}
for alias, target in aliases.items():
if target in codec_enum.__members__:
self._name_to_codec[alias] = codec_enum[target]
def convert(self, value: Any, param: Optional[click.Parameter] = None, ctx: Optional[click.Context] = None) -> list:
if not value:
return []
if isinstance(value, self.codec_enum):
return [value]
if isinstance(value, list):
if all(isinstance(v, self.codec_enum) for v in value):
return value
values = [str(v).strip() for v in value]
else:
values = [v.strip() for v in str(value).split(",")]
codecs = []
for val in values:
if not val:
continue
key = val.lower()
if key in self._name_to_codec:
codecs.append(self._name_to_codec[key])
else:
valid = sorted(set(self._name_to_codec.keys()))
self.fail(f"'{val}' is not valid. Choices: {', '.join(valid)}", param, ctx)
return list(dict.fromkeys(codecs)) # Remove duplicates, preserve order
class MultipleChoice(click.Choice):
"""
The multiple choice type allows multiple values to be checked against
@@ -288,5 +336,6 @@ class MultipleChoice(click.Choice):
SEASON_RANGE = SeasonRange()
LANGUAGE_RANGE = LanguageRange()
QUALITY_LIST = QualityList()
AUDIO_CODEC_LIST = AudioCodecList(Audio.Codec)
# VIDEO_CODEC_CHOICE will be created dynamically when imported