Initial Commit

This commit is contained in:
Andy
2025-07-18 00:46:05 +00:00
commit d37014f53f
94 changed files with 17458 additions and 0 deletions

244
unshackle/vaults/MySQL.py Normal file
View File

@@ -0,0 +1,244 @@
import threading
from typing import Iterator, Optional, Union
from uuid import UUID
import pymysql
from pymysql.cursors import DictCursor
from unshackle.core.services import Services
from unshackle.core.vault import Vault
class MySQL(Vault):
"""Key Vault using a remotely-accessed mysql database connection."""
def __init__(self, name: str, host: str, database: str, username: str, **kwargs):
"""
All extra arguments provided via **kwargs will be sent to pymysql.connect.
This can be used to provide more specific connection information.
"""
super().__init__(name)
self.slug = f"{host}:{database}:{username}"
self.conn_factory = ConnectionFactory(
dict(host=host, db=database, user=username, cursorclass=DictCursor, **kwargs)
)
self.permissions = self.get_permissions()
if not self.has_permission("SELECT"):
raise PermissionError(f"MySQL vault {self.slug} has no SELECT permission.")
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
if not self.has_table(service):
# no table, no key, simple
return None
if isinstance(kid, UUID):
kid = kid.hex
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
cursor.execute(
# TODO: SQL injection risk
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=%s AND `key_`!=%s",
(kid, "0" * 32),
)
cek = cursor.fetchone()
if not cek:
return None
return cek["key_"]
finally:
cursor.close()
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service):
# no table, no keys, simple
return None
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
cursor.execute(
# TODO: SQL injection risk
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=%s",
("0" * 32,),
)
for row in cursor.fetchall():
yield row["kid"], row["key_"]
finally:
cursor.close()
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.")
if not self.has_permission("INSERT", table=service):
raise PermissionError(f"MySQL vault {self.slug} has no INSERT permission.")
if not self.has_table(service):
try:
self.create_table(service)
except PermissionError:
return False
if isinstance(kid, UUID):
kid = kid.hex
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
cursor.execute(
# TODO: SQL injection risk
f"SELECT `id` FROM `{service}` WHERE `kid`=%s AND `key_`=%s",
(kid, key),
)
if cursor.fetchone():
# table already has this exact KID:KEY stored
return True
cursor.execute(
# TODO: SQL injection risk
f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)",
(kid, key),
)
finally:
conn.commit()
cursor.close()
return True
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
for kid, key in kid_keys.items():
if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.")
if not self.has_permission("INSERT", table=service):
raise PermissionError(f"MySQL vault {self.slug} has no INSERT permission.")
if not self.has_table(service):
try:
self.create_table(service)
except PermissionError:
return 0
if not isinstance(kid_keys, dict):
raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}")
if not all(isinstance(kid, (str, UUID)) and isinstance(key_, str) for kid, key_ in kid_keys.items()):
raise ValueError("Expecting dict with Key of str/UUID and value of str.")
if any(isinstance(kid, UUID) for kid, key_ in kid_keys.items()):
kid_keys = {kid.hex if isinstance(kid, UUID) else kid: key_ for kid, key_ in kid_keys.items()}
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
cursor.executemany(
# TODO: SQL injection risk
f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)",
kid_keys.items(),
)
return cursor.rowcount
finally:
conn.commit()
cursor.close()
def get_services(self) -> Iterator[str]:
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
cursor.execute("SHOW TABLES")
for table in cursor.fetchall():
# each entry has a key named `Tables_in_<db name>`
yield Services.get_tag(list(table.values())[0])
finally:
cursor.close()
def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name."""
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
cursor.execute(
"SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
(conn.db, name),
)
return list(cursor.fetchone().values())[0] == 1
finally:
cursor.close()
def create_table(self, name: str):
"""Create a Table with the specified name if not yet created."""
if self.has_table(name):
return
if not self.has_permission("CREATE"):
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
cursor.execute(
# TODO: SQL injection risk
f"""
CREATE TABLE IF NOT EXISTS {name} (
id int AUTO_INCREMENT PRIMARY KEY,
kid VARCHAR(64) NOT NULL,
key_ VARCHAR(64) NOT NULL,
UNIQUE(kid, key_)
);
"""
)
finally:
conn.commit()
cursor.close()
def get_permissions(self) -> list:
"""Get and parse Grants to a more easily usable list tuple array."""
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
cursor.execute("SHOW GRANTS")
grants = cursor.fetchall()
grants = [next(iter(x.values())) for x in grants]
grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
grants = [
(
list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
location.replace("`", "").split("."),
)
for perms, location in grants
]
return grants
finally:
conn.commit()
cursor.close()
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
"""Check if the current connection has a specific permission."""
grants = [x for x in self.permissions if x[0] == ["*"] or operation.upper() in x[0]]
if grants and database:
grants = [x for x in grants if x[1][0] in (database, "*")]
if grants and table:
grants = [x for x in grants if x[1][1] in (table, "*")]
return bool(grants)
class ConnectionFactory:
def __init__(self, con: dict):
self._con = con
self._store = threading.local()
def _create_connection(self) -> pymysql.Connection:
return pymysql.connect(**self._con)
def get(self) -> pymysql.Connection:
if not hasattr(self._store, "conn"):
self._store.conn = self._create_connection()
return self._store.conn