Initial Commit

This commit is contained in:
Andy
2025-07-18 00:46:05 +00:00
commit d37014f53f
94 changed files with 17458 additions and 0 deletions

View File

View File

@@ -0,0 +1,169 @@
import re
from typing import Any, Optional, Union
import click
from click.shell_completion import CompletionItem
from pywidevine.cdm import Cdm as WidevineCdm
class ContextData:
def __init__(self, config: dict, cdm: WidevineCdm, proxy_providers: list, profile: Optional[str] = None):
self.config = config
self.cdm = cdm
self.proxy_providers = proxy_providers
self.profile = profile
class SeasonRange(click.ParamType):
name = "ep_range"
MIN_EPISODE = 0
MAX_EPISODE = 999
def parse_tokens(self, *tokens: str) -> list[str]:
"""
Parse multiple tokens or ranged tokens as '{s}x{e}' strings.
Supports exclusioning by putting a `-` before the token.
Example:
>>> sr = SeasonRange()
>>> sr.parse_tokens("S01E01")
["1x1"]
>>> sr.parse_tokens("S02E01", "S02E03-S02E05")
["2x1", "2x3", "2x4", "2x5"]
>>> sr.parse_tokens("S01-S05", "-S03", "-S02E01")
["1x0", "1x1", ..., "2x0", (...), "2x2", (...), "4x0", ..., "5x0", ...]
"""
if len(tokens) == 0:
return []
computed: list = []
exclusions: list = []
for token in tokens:
exclude = token.startswith("-")
if exclude:
token = token[1:]
parsed = [
re.match(r"^S(?P<season>\d+)(E(?P<episode>\d+))?$", x, re.IGNORECASE) for x in re.split(r"[:-]", token)
]
if len(parsed) > 2:
self.fail(f"Invalid token, only a left and right range is acceptable: {token}")
if len(parsed) == 1:
parsed.append(parsed[0])
if any(x is None for x in parsed):
self.fail(f"Invalid token, syntax error occurred: {token}")
from_season, from_episode = [
int(v) if v is not None else self.MIN_EPISODE
for k, v in parsed[0].groupdict().items()
if parsed[0] # type: ignore[union-attr]
]
to_season, to_episode = [
int(v) if v is not None else self.MAX_EPISODE
for k, v in parsed[1].groupdict().items()
if parsed[1] # type: ignore[union-attr]
]
if from_season > to_season:
self.fail(f"Invalid range, left side season cannot be bigger than right side season: {token}")
if from_season == to_season and from_episode > to_episode:
self.fail(f"Invalid range, left side episode cannot be bigger than right side episode: {token}")
for s in range(from_season, to_season + 1):
for e in range(
from_episode if s == from_season else 0, (self.MAX_EPISODE if s < to_season else to_episode) + 1
):
(computed if not exclude else exclusions).append(f"{s}x{e}")
for exclusion in exclusions:
if exclusion in computed:
computed.remove(exclusion)
return list(set(computed))
def convert(
self, value: str, param: Optional[click.Parameter] = None, ctx: Optional[click.Context] = None
) -> list[str]:
return self.parse_tokens(*re.split(r"\s*[,;]\s*", value))
class LanguageRange(click.ParamType):
name = "lang_range"
def convert(
self, value: Union[str, list], param: Optional[click.Parameter] = None, ctx: Optional[click.Context] = None
) -> list[str]:
if isinstance(value, list):
return value
if not value:
return []
return re.split(r"\s*[,;]\s*", value)
class QualityList(click.ParamType):
name = "quality_list"
def convert(
self, value: Union[str, list[str]], param: Optional[click.Parameter] = None, ctx: Optional[click.Context] = None
) -> list[int]:
if not value:
return []
if not isinstance(value, list):
value = value.split(",")
resolutions = []
for resolution in value:
try:
resolutions.append(int(resolution.lower().rstrip("p")))
except TypeError:
self.fail(
f"Expected string for int() conversion, got {resolution!r} of type {type(resolution).__name__}",
param,
ctx,
)
except ValueError:
self.fail(f"{resolution!r} is not a valid integer", param, ctx)
return sorted(resolutions, reverse=True)
class MultipleChoice(click.Choice):
"""
The multiple choice type allows multiple values to be checked against
a fixed set of supported values.
It internally uses and is based off of click.Choice.
"""
name = "multiple_choice"
def __repr__(self) -> str:
return f"MultipleChoice({list(self.choices)})"
def convert(
self, value: Any, param: Optional[click.Parameter] = None, ctx: Optional[click.Context] = None
) -> list[Any]:
if not value:
return []
if isinstance(value, str):
values = value.split(",")
elif isinstance(value, list):
values = value
else:
self.fail(f"{value!r} is not a supported value.", param, ctx)
chosen_values: list[Any] = []
for value in values:
chosen_values.append(super().convert(value, param, ctx))
return chosen_values
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]:
"""
Complete choices that start with the incomplete value.
Parameters:
ctx: Invocation context for this command.
param: The parameter that is requesting completion.
incomplete: Value being completed. May be empty.
"""
incomplete = incomplete.rsplit(",")[-1]
return super(self).shell_complete(ctx, param, incomplete)
SEASON_RANGE = SeasonRange()
LANGUAGE_RANGE = LanguageRange()
QUALITY_LIST = QualityList()

View File

@@ -0,0 +1,51 @@
import itertools
from typing import Any, Iterable, Iterator, Sequence, Tuple, Type, Union
def as_lists(*args: Any) -> Iterator[Any]:
"""Converts any input objects to list objects."""
for item in args:
yield item if isinstance(item, list) else [item]
def as_list(*args: Any) -> list:
"""
Convert any input objects to a single merged list object.
Example:
>>> as_list('foo', ['buzz', 'bizz'], 'bazz', 'bozz', ['bar'], ['bur'])
['foo', 'buzz', 'bizz', 'bazz', 'bozz', 'bar', 'bur']
"""
return list(itertools.chain.from_iterable(as_lists(*args)))
def flatten(items: Any, ignore_types: Union[Type, Tuple[Type, ...]] = str) -> Iterator:
"""
Flattens items recursively.
Example:
>>> list(flatten(["foo", [["bar", ["buzz", [""]], "bee"]]]))
['foo', 'bar', 'buzz', '', 'bee']
>>> list(flatten("foo"))
['foo']
>>> list(flatten({1}, set))
[{1}]
"""
if isinstance(items, (Iterable, Sequence)) and not isinstance(items, ignore_types):
for i in items:
yield from flatten(i, ignore_types)
else:
yield items
def merge_dict(source: dict, destination: dict) -> None:
"""Recursively merge Source into Destination in-place."""
if not source:
return
for key, value in source.items():
if isinstance(value, dict):
# get node or create one
node = destination.setdefault(key, {})
merge_dict(value, node)
else:
destination[key] = value

View File

@@ -0,0 +1,30 @@
import logging
import os
import random
from datetime import datetime, timedelta
log = logging.getLogger("NF-ESN")
def chrome_esn_generator():
ESN_GEN = "".join(random.choice("0123456789ABCDEF") for _ in range(30))
esn_file = ".esn"
def gen_file():
with open(esn_file, "w") as file:
file.write(f"NFCDIE-03-{ESN_GEN}")
if not os.path.isfile(esn_file):
log.warning("Generating a new Chrome ESN")
gen_file()
file_datetime = datetime.fromtimestamp(os.path.getmtime(esn_file))
time_diff = datetime.now() - file_datetime
if time_diff > timedelta(hours=6):
log.warning("Old ESN detected, Generating a new Chrome ESN")
gen_file()
with open(esn_file, "r") as f:
esn = f.read()
return esn

View File

@@ -0,0 +1,24 @@
import platform
def get_os_arch(name: str) -> str:
"""Builds a name-os-arch based on the input name, system, architecture."""
os_name = platform.system().lower()
os_arch = platform.machine().lower()
# Map platform.system() output to desired OS name
if os_name == "windows":
os_name = "win"
elif os_name == "darwin":
os_name = "osx"
else:
os_name = "linux"
# Map platform.machine() output to desired architecture
if os_arch in ["x86_64", "amd64"]:
os_arch = "x64"
elif os_arch == "arm64":
os_arch = "arm64"
# Construct the dependency name in the desired format using the input name
return f"{name}-{os_name}-{os_arch}"

View File

@@ -0,0 +1,77 @@
import ssl
from typing import Optional
from requests.adapters import HTTPAdapter
class SSLCiphers(HTTPAdapter):
"""
Custom HTTP Adapter to change the TLS Cipher set and security requirements.
Security Level may optionally be provided. A level above 0 must be used at all times.
A list of Security Levels and their security is listed below. Usually 2 is used by default.
Do not set the Security level via @SECLEVEL in the cipher list.
Level 0:
Everything is permitted. This retains compatibility with previous versions of OpenSSL.
Level 1:
The security level corresponds to a minimum of 80 bits of security. Any parameters
offering below 80 bits of security are excluded. As a result RSA, DSA and DH keys
shorter than 1024 bits and ECC keys shorter than 160 bits are prohibited. All export
cipher suites are prohibited since they all offer less than 80 bits of security. SSL
version 2 is prohibited. Any cipher suite using MD5 for the MAC is also prohibited.
Level 2:
Security level set to 112 bits of security. As a result RSA, DSA and DH keys shorter
than 2048 bits and ECC keys shorter than 224 bits are prohibited. In addition to the
level 1 exclusions any cipher suite using RC4 is also prohibited. SSL version 3 is
also not allowed. Compression is disabled.
Level 3:
Security level set to 128 bits of security. As a result RSA, DSA and DH keys shorter
than 3072 bits and ECC keys shorter than 256 bits are prohibited. In addition to the
level 2 exclusions cipher suites not offering forward secrecy are prohibited. TLS
versions below 1.1 are not permitted. Session tickets are disabled.
Level 4:
Security level set to 192 bits of security. As a result RSA, DSA and DH keys shorter
than 7680 bits and ECC keys shorter than 384 bits are prohibited. Cipher suites using
SHA1 for the MAC are prohibited. TLS versions below 1.2 are not permitted.
Level 5:
Security level set to 256 bits of security. As a result RSA, DSA and DH keys shorter
than 15360 bits and ECC keys shorter than 512 bits are prohibited.
"""
def __init__(self, cipher_list: Optional[str] = None, security_level: int = 0, *args, **kwargs):
if cipher_list:
if not isinstance(cipher_list, str):
raise TypeError(f"Expected cipher_list to be a str, not {cipher_list!r}")
if "@SECLEVEL" in cipher_list:
raise ValueError("You must not specify the Security Level manually in the cipher list.")
if not isinstance(security_level, int):
raise TypeError(f"Expected security_level to be an int, not {security_level!r}")
if security_level not in range(6):
raise ValueError(f"The security_level must be a value between 0 and 5, not {security_level}")
if not cipher_list:
# cpython's default cipher list differs to Python-requests cipher list
cipher_list = "DEFAULT"
cipher_list += f":@SECLEVEL={security_level}"
ctx = ssl.create_default_context()
ctx.check_hostname = False # For some reason this is needed to avoid a verification error
ctx.set_ciphers(cipher_list)
self._ssl_context = ctx
super().__init__(*args, **kwargs)
def init_poolmanager(self, *args, **kwargs):
kwargs["ssl_context"] = self._ssl_context
return super().init_poolmanager(*args, **kwargs)
def proxy_manager_for(self, *args, **kwargs):
kwargs["ssl_context"] = self._ssl_context
return super().proxy_manager_for(*args, **kwargs)

View File

@@ -0,0 +1,25 @@
import json
import subprocess
from pathlib import Path
from typing import Union
from unshackle.core import binaries
def ffprobe(uri: Union[bytes, Path]) -> dict:
"""Use ffprobe on the provided data to get stream information."""
if not binaries.FFProbe:
raise EnvironmentError('FFProbe executable "ffprobe" not found but is required.')
args = [binaries.FFProbe, "-v", "quiet", "-of", "json", "-show_streams"]
if isinstance(uri, Path):
args.extend(
["-f", "lavfi", "-i", "movie={}[out+subcc]".format(str(uri).replace("\\", "/").replace(":", "\\\\:"))]
)
elif isinstance(uri, bytes):
args.append("pipe:")
try:
ff = subprocess.run(args, input=uri if isinstance(uri, bytes) else None, check=True, capture_output=True)
except subprocess.CalledProcessError:
return {}
return json.loads(ff.stdout.decode("utf8"))

View File

@@ -0,0 +1,279 @@
from __future__ import annotations
import logging
import os
import re
import shutil
import subprocess
import tempfile
from difflib import SequenceMatcher
from pathlib import Path
from typing import Optional, Tuple
import requests
from unshackle.core.config import config
from unshackle.core.titles.episode import Episode
from unshackle.core.titles.movie import Movie
from unshackle.core.titles.title import Title
STRIP_RE = re.compile(r"[^a-z0-9]+", re.I)
YEAR_RE = re.compile(r"\s*\(?[12][0-9]{3}\)?$")
HEADERS = {"User-Agent": "unshackle-tags/1.0"}
log = logging.getLogger("TAGS")
def _api_key() -> Optional[str]:
return config.tmdb_api_key or os.getenv("TMDB_API_KEY")
def _clean(s: str) -> str:
return STRIP_RE.sub("", s).lower()
def _strip_year(s: str) -> str:
return YEAR_RE.sub("", s).strip()
def fuzzy_match(a: str, b: str, threshold: float = 0.8) -> bool:
"""Return True if ``a`` and ``b`` are a close match."""
ratio = SequenceMatcher(None, _clean(a), _clean(b)).ratio()
return ratio >= threshold
def search_tmdb(title: str, year: Optional[int], kind: str) -> Tuple[Optional[int], Optional[str]]:
api_key = _api_key()
if not api_key:
return None, None
search_title = _strip_year(title)
log.debug("Searching TMDB for %r (%s, %s)", search_title, kind, year)
params = {"api_key": api_key, "query": search_title}
if year is not None:
params["year" if kind == "movie" else "first_air_date_year"] = year
r = requests.get(
f"https://api.themoviedb.org/3/search/{kind}",
params=params,
headers=HEADERS,
timeout=30,
)
r.raise_for_status()
js = r.json()
results = js.get("results") or []
log.debug("TMDB returned %d results", len(results))
if not results:
return None, None
best_ratio = 0.0
best_id: Optional[int] = None
best_title: Optional[str] = None
for result in results:
candidates = [
result.get("title"),
result.get("name"),
result.get("original_title"),
result.get("original_name"),
]
candidates = [c for c in candidates if c] # Filter out None/empty values
if not candidates:
continue
# Find the best matching candidate from all available titles
for candidate in candidates:
ratio = SequenceMatcher(None, _clean(search_title), _clean(candidate)).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_id = result.get("id")
best_title = candidate
log.debug(
"Best candidate ratio %.2f for %r (ID %s)",
best_ratio,
best_title,
best_id,
)
if best_id is not None:
return best_id, best_title
first = results[0]
return first.get("id"), first.get("title") or first.get("name")
def get_title(tmdb_id: int, kind: str) -> Optional[str]:
"""Fetch the name/title of a TMDB entry by ID."""
api_key = _api_key()
if not api_key:
return None
try:
r = requests.get(
f"https://api.themoviedb.org/3/{kind}/{tmdb_id}",
params={"api_key": api_key},
headers=HEADERS,
timeout=30,
)
r.raise_for_status()
except requests.RequestException as exc:
log.debug("Failed to fetch TMDB title: %s", exc)
return None
js = r.json()
return js.get("title") or js.get("name")
def get_year(tmdb_id: int, kind: str) -> Optional[int]:
"""Fetch the release year of a TMDB entry by ID."""
api_key = _api_key()
if not api_key:
return None
try:
r = requests.get(
f"https://api.themoviedb.org/3/{kind}/{tmdb_id}",
params={"api_key": api_key},
headers=HEADERS,
timeout=30,
)
r.raise_for_status()
except requests.RequestException as exc:
log.debug("Failed to fetch TMDB year: %s", exc)
return None
js = r.json()
date = js.get("release_date") or js.get("first_air_date")
if date and len(date) >= 4 and date[:4].isdigit():
return int(date[:4])
return None
def external_ids(tmdb_id: int, kind: str) -> dict:
api_key = _api_key()
if not api_key:
return {}
url = f"https://api.themoviedb.org/3/{kind}/{tmdb_id}/external_ids"
log.debug("Fetching external IDs for %s %s", kind, tmdb_id)
r = requests.get(
url,
params={"api_key": api_key},
headers=HEADERS,
timeout=30,
)
r.raise_for_status()
js = r.json()
log.debug("External IDs response: %s", js)
return js
def _apply_tags(path: Path, tags: dict[str, str]) -> None:
if not tags:
return
mkvpropedit = shutil.which("mkvpropedit")
if not mkvpropedit:
log.debug("mkvpropedit not found on PATH; skipping tags")
return
log.debug("Applying tags to %s: %s", path, tags)
xml_lines = ["<?xml version='1.0' encoding='UTF-8'?>", "<Tags>", " <Tag>", " <Targets/>"]
for name, value in tags.items():
xml_lines.append(f" <Simple><Name>{name}</Name><String>{value}</String></Simple>")
xml_lines.extend([" </Tag>", "</Tags>"])
with tempfile.NamedTemporaryFile("w", suffix=".xml", delete=False) as f:
f.write("\n".join(xml_lines))
tmp_path = Path(f.name)
try:
subprocess.run(
[mkvpropedit, str(path), "--tags", f"global:{tmp_path}"],
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
log.debug("Tags applied via mkvpropedit")
finally:
tmp_path.unlink(missing_ok=True)
def tag_file(path: Path, title: Title, tmdb_id: Optional[int] | None = None) -> None:
log.debug("Tagging file %s with title %r", path, title)
standard_tags: dict[str, str] = {}
custom_tags: dict[str, str] = {}
# To add custom information to the tags
# custom_tags["Text to the left side"] = "Text to the right side"
if config.tag:
custom_tags["Group"] = config.tag
description = getattr(title, "description", None)
if description:
if len(description) > 255:
truncated = description[:255]
if " " in truncated:
truncated = truncated.rsplit(" ", 1)[0]
description = truncated + "..."
custom_tags["Description"] = description
api_key = _api_key()
if not api_key:
log.debug("No TMDB API key set; applying basic tags only")
_apply_tags(path, custom_tags)
return
if isinstance(title, Movie):
kind = "movie"
name = title.name
year = title.year
elif isinstance(title, Episode):
kind = "tv"
name = title.title
year = title.year
else:
_apply_tags(path, custom_tags)
return
tmdb_title: Optional[str] = None
if tmdb_id is None:
tmdb_id, tmdb_title = search_tmdb(name, year, kind)
log.debug("Search result: %r (ID %s)", tmdb_title, tmdb_id)
if not tmdb_id or not tmdb_title or not fuzzy_match(tmdb_title, name):
log.debug("TMDB search did not match; skipping external ID lookup")
_apply_tags(path, custom_tags)
return
tmdb_url = f"https://www.themoviedb.org/{'movie' if kind == 'movie' else 'tv'}/{tmdb_id}"
standard_tags["TMDB"] = tmdb_url
try:
ids = external_ids(tmdb_id, kind)
except requests.RequestException as exc:
log.debug("Failed to fetch external IDs: %s", exc)
ids = {}
else:
log.debug("External IDs found: %s", ids)
imdb_id = ids.get("imdb_id")
if imdb_id:
standard_tags["IMDB"] = f"https://www.imdb.com/title/{imdb_id}"
tvdb_id = ids.get("tvdb_id")
if tvdb_id:
tvdb_prefix = "movies" if kind == "movie" else "series"
standard_tags["TVDB"] = f"https://thetvdb.com/dereferrer/{tvdb_prefix}/{tvdb_id}"
merged_tags = {
**custom_tags,
**standard_tags,
}
_apply_tags(path, merged_tags)
__all__ = [
"search_tmdb",
"get_title",
"get_year",
"external_ids",
"tag_file",
"fuzzy_match",
]

View File

@@ -0,0 +1,192 @@
import re
import sys
import typing
from typing import Optional
from pycaption import Caption, CaptionList, CaptionNode, CaptionReadError, WebVTTReader, WebVTTWriter
class CaptionListExt(CaptionList):
@typing.no_type_check
def __init__(self, iterable=None, layout_info=None):
self.first_segment_mpegts = 0
super().__init__(iterable, layout_info)
class CaptionExt(Caption):
@typing.no_type_check
def __init__(self, start, end, nodes, style=None, layout_info=None, segment_index=0, mpegts=0, cue_time=0.0):
style = style or {}
self.segment_index: int = segment_index
self.mpegts: float = mpegts
self.cue_time: float = cue_time
super().__init__(start, end, nodes, style, layout_info)
class WebVTTReaderExt(WebVTTReader):
# HLS extension support <https://datatracker.ietf.org/doc/html/rfc8216#section-3.5>
RE_TIMESTAMP_MAP = re.compile(r"X-TIMESTAMP-MAP.*")
RE_MPEGTS = re.compile(r"MPEGTS:(\d+)")
RE_LOCAL = re.compile(r"LOCAL:((?:(\d{1,}):)?(\d{2}):(\d{2})\.(\d{3}))")
def _parse(self, lines: list[str]) -> CaptionList:
captions = CaptionListExt()
start = None
end = None
nodes: list[CaptionNode] = []
layout_info = None
found_timing = False
segment_index = -1
mpegts = 0
cue_time = 0.0
# The first segment MPEGTS is needed to calculate the rest. It is possible that
# the first segment contains no cue and is ignored by pycaption, this acts as a fallback.
captions.first_segment_mpegts = 0
for i, line in enumerate(lines):
if "-->" in line:
found_timing = True
timing_line = i
last_start_time = captions[-1].start if captions else 0
try:
start, end, layout_info = self._parse_timing_line(line, last_start_time)
except CaptionReadError as e:
new_msg = f"{e.args[0]} (line {timing_line})"
tb = sys.exc_info()[2]
raise type(e)(new_msg).with_traceback(tb) from None
elif "" == line:
if found_timing and nodes:
found_timing = False
caption = CaptionExt(
start,
end,
nodes,
layout_info=layout_info,
segment_index=segment_index,
mpegts=mpegts,
cue_time=cue_time,
)
captions.append(caption)
nodes = []
elif "WEBVTT" in line:
# Merged segmented VTT doesn't have index information, track manually.
segment_index += 1
mpegts = 0
cue_time = 0.0
elif m := self.RE_TIMESTAMP_MAP.match(line):
if r := self.RE_MPEGTS.search(m.group()):
mpegts = int(r.group(1))
cue_time = self._parse_local(m.group())
# Early assignment in case the first segment contains no cue.
if segment_index == 0:
captions.first_segment_mpegts = mpegts
else:
if found_timing:
if nodes:
nodes.append(CaptionNode.create_break())
nodes.append(CaptionNode.create_text(self._decode(line)))
else:
# it's a comment or some metadata; ignore it
pass
# Add a last caption if there are remaining nodes
if nodes:
caption = CaptionExt(start, end, nodes, layout_info=layout_info, segment_index=segment_index, mpegts=mpegts)
captions.append(caption)
return captions
@staticmethod
def _parse_local(string: str) -> float:
"""
Parse WebVTT LOCAL time and convert it to seconds.
"""
m = WebVTTReaderExt.RE_LOCAL.search(string)
if not m:
return 0
parsed = m.groups()
if not parsed:
return 0
hours = int(parsed[1])
minutes = int(parsed[2])
seconds = int(parsed[3])
milliseconds = int(parsed[4])
return (milliseconds / 1000) + seconds + (minutes * 60) + (hours * 3600)
def merge_segmented_webvtt(vtt_raw: str, segment_durations: Optional[list[int]] = None, timescale: int = 1) -> str:
"""
Merge Segmented WebVTT data.
Parameters:
vtt_raw: The concatenated WebVTT files to merge. All WebVTT headers must be
appropriately spaced apart, or it may produce unwanted effects like
considering headers as captions, timestamp lines, etc.
segment_durations: A list of each segment's duration. If not provided it will try
to get it from the X-TIMESTAMP-MAP headers, specifically the MPEGTS number.
timescale: The number of time units per second.
This parses the X-TIMESTAMP-MAP data to compute new absolute timestamps, replacing
the old start and end timestamp values. All X-TIMESTAMP-MAP header information will
be removed from the output as they are no longer of concern. Consider this function
the opposite of a WebVTT Segmenter, a WebVTT Joiner of sorts.
Algorithm borrowed from N_m3u8DL-RE and shaka-player.
"""
MPEG_TIMESCALE = 90_000
vtt = WebVTTReaderExt().read(vtt_raw)
for lang in vtt.get_languages():
prev_caption = None
duplicate_index: list[int] = []
captions = vtt.get_captions(lang)
if captions[0].segment_index == 0:
first_segment_mpegts = captions[0].mpegts
else:
first_segment_mpegts = segment_durations[0] if segment_durations else captions.first_segment_mpegts
caption: CaptionExt
for i, caption in enumerate(captions):
# DASH WebVTT doesn't have MPEGTS timestamp like HLS. Instead,
# calculate the timestamp from SegmentTemplate/SegmentList duration.
likely_dash = first_segment_mpegts == 0 and caption.mpegts == 0
if likely_dash and segment_durations:
duration = segment_durations[caption.segment_index]
caption.mpegts = MPEG_TIMESCALE * (duration / timescale)
if caption.mpegts == 0:
continue
# Commeted to fix DSNP subs being out of sync and mistimed.
# seconds = (caption.mpegts - first_segment_mpegts) / MPEG_TIMESCALE - caption.cue_time
# offset = seconds * 1_000_000 # pycaption use microseconds
# if caption.start < offset:
# caption.start += offset
# caption.end += offset
# If the difference between current and previous captions is <=1ms
# and the payload is equal then splice.
if (
prev_caption
and not caption.is_empty()
and (caption.start - prev_caption.end) <= 1000 # 1ms in microseconds
and caption.get_text() == prev_caption.get_text()
):
prev_caption.end = caption.end
duplicate_index.append(i)
prev_caption = caption
# Remove duplicate
captions[:] = [c for c_index, c in enumerate(captions) if c_index not in set(duplicate_index)]
return WebVTTWriter().write(vtt)

View File

@@ -0,0 +1,24 @@
from typing import Union
from lxml import etree
from lxml.etree import ElementTree
def load_xml(xml: Union[str, bytes]) -> ElementTree:
"""Safely parse XML data to an ElementTree, without namespaces in tags."""
if not isinstance(xml, bytes):
xml = xml.encode("utf8")
root = etree.fromstring(xml)
for elem in root.getiterator():
if not hasattr(elem.tag, "find"):
# e.g. comment elements
continue
elem.tag = etree.QName(elem).localname
for name, value in elem.attrib.items():
local_name = etree.QName(name).localname
if local_name == name:
continue
del elem.attrib[name]
elem.attrib[local_name] = value
etree.cleanup_namespaces(root)
return root