diff --git a/src/api_key_manager.py b/src/api_key_manager.py index 6bf3a6d..d29ac03 100644 --- a/src/api_key_manager.py +++ b/src/api_key_manager.py @@ -1,7 +1,10 @@ import os import json +import logging from typing import Dict -from cryptography.fernet import Fernet +from cryptography.fernet import Fernet, InvalidToken + +logger = logging.getLogger(__name__) class APIKeyManager: def __init__(self, data_dir: str): @@ -47,8 +50,12 @@ class APIKeyManager: return {} with open(self.api_keys_file, 'r', encoding="utf-8") as f: encrypted_keys = json.load(f) - return { - provider: self.decrypt_api_key(key) - for provider, key in encrypted_keys.items() - } + + decrypted = {} + for provider, key in encrypted_keys.items(): + try: + decrypted[provider] = self.decrypt_api_key(key) + except (InvalidToken, ValueError) as e: + logger.warning("Failed to decrypt API key for %s: %s", provider, e) + return decrypted diff --git a/tests/test_api_key_manager_resilience.py b/tests/test_api_key_manager_resilience.py new file mode 100644 index 0000000..8654a69 --- /dev/null +++ b/tests/test_api_key_manager_resilience.py @@ -0,0 +1,35 @@ +import os +import json +from src.api_key_manager import APIKeyManager +from cryptography.fernet import Fernet + +def test_api_key_manager_load_resilience(tmp_path): + mgr = APIKeyManager(str(tmp_path)) + + # Save a valid key + mgr.save("good_provider", "good_value") + + # Create another key manager/Fernet instance with a different key to produce an undecryptable token + other_key = Fernet.generate_key() + other_f = Fernet(other_key) + undecryptable_token = other_f.encrypt(b"bad_value").decode() + + # Manually edit api_keys.json to include the undecryptable token + with open(mgr.api_keys_file, "r", encoding="utf-8") as f: + keys = json.load(f) + + keys["bad_provider"] = undecryptable_token + # Also add a malformed/garbage token (causes ValueError/binascii.Error) + keys["garbage_provider"] = "not-a-valid-base64-fernet-token" + + with open(mgr.api_keys_file, "w", encoding="utf-8") as f: + json.dump(keys, f) + + # Load keys + loaded = mgr.load() + + # Assert load() returns the still-decryptable key and skips the bad ones without raising + assert "good_provider" in loaded + assert loaded["good_provider"] == "good_value" + assert "bad_provider" not in loaded + assert "garbage_provider" not in loaded