Stop API key save() from writing other providers' keys as plaintext (#1944)

save() called load(), which DECRYPTS every stored key, then re-encrypted
only the key being saved and wrote the whole dict back. The other
providers' keys were thus persisted in plaintext; on the next load()
Fernet raised InvalidToken on them and they were silently dropped.

Add _load_raw() that returns the still-encrypted on-disk dict (reusing the
existing missing/corrupt-file guards) and have save() build on that, so
untouched providers keep their ciphertext. load() now also goes through
_load_raw(), keeping its behavior identical.

Fixes #1914

Co-authored-by: EkaTantra Dev <dev@ekatantra.com>
This commit is contained in:
Sushanth Reddy
2026-06-04 09:17:13 +05:30
committed by GitHub
parent 09fe308720
commit eee2167502

View File

@@ -37,15 +37,12 @@ class APIKeyManager:
f = Fernet(self.get_or_create_key())
return f.decrypt(encrypted_key.encode()).decode()
def save(self, provider: str, api_key: str):
"""Save encrypted API key to file"""
keys = self.load()
keys[provider] = self.encrypt_api_key(api_key)
with open(self.api_keys_file, 'w', encoding="utf-8") as f:
json.dump(keys, f)
def load(self) -> Dict[str, str]:
"""Load and decrypt API keys"""
def _load_raw(self) -> Dict[str, str]:
"""Load the raw, still-encrypted keys dict from disk.
Tolerates a missing/corrupt/wrong-shaped file by returning {} — the
same robustness load() relies on at startup.
"""
if not os.path.exists(self.api_keys_file):
return {}
try:
@@ -60,7 +57,24 @@ class APIKeyManager:
# Legacy/wrong shape (e.g. a list) — .items() would raise. Ignore it.
logger.warning("API keys file has unexpected shape (%s); ignoring", type(encrypted_keys).__name__)
return {}
return encrypted_keys
def save(self, provider: str, api_key: str):
"""Save encrypted API key to file.
Operates on the raw (still-encrypted) on-disk dict so other providers'
keys stay encrypted. Loading via load() first would decrypt them and
write them back as plaintext, which then fails to decrypt on the next
load() and silently drops those providers.
"""
keys = self._load_raw()
keys[provider] = self.encrypt_api_key(api_key)
with open(self.api_keys_file, 'w', encoding="utf-8") as f:
json.dump(keys, f)
def load(self) -> Dict[str, str]:
"""Load and decrypt API keys"""
encrypted_keys = self._load_raw()
decrypted = {}
for provider, key in encrypted_keys.items():
try: