1use aes_gcm::{
8 Aes256Gcm, KeyInit, Nonce,
9 aead::{Aead, AeadCore, OsRng},
10};
11use argon2::{Argon2, password_hash::SaltString};
12use zeroize::Zeroize;
13
14use super::errors::UserError;
15use crate::{Result, auth::crypto::PrivateKey};
16
17pub const SALT_LENGTH: usize = 22;
19
20pub const NONCE_LENGTH: usize = 12;
22
23pub const KEY_LENGTH: usize = 32;
25
26pub fn generate_salt() -> String {
31 use argon2::password_hash::rand_core;
32 SaltString::generate(&mut rand_core::OsRng)
33 .as_str()
34 .to_string()
35}
36
37pub fn derive_encryption_key(password: impl AsRef<str>, salt: impl AsRef<str>) -> Result<Vec<u8>> {
46 let salt_str = salt.as_ref();
47 if salt_str.len() != SALT_LENGTH {
48 return Err(UserError::InvalidSaltLength {
49 expected: SALT_LENGTH,
50 actual: salt_str.len(),
51 }
52 .into());
53 }
54
55 let salt = SaltString::from_b64(salt_str).map_err(|e| UserError::EncryptionFailed {
56 reason: format!("Invalid salt format: {e}"),
57 })?;
58
59 let argon2 = Argon2::default();
60
61 let mut key = vec![0u8; KEY_LENGTH];
62 argon2
63 .hash_password_into(
64 password.as_ref().as_bytes(),
65 salt.as_str().as_bytes(),
66 &mut key,
67 )
68 .map_err(|e| UserError::EncryptionFailed {
69 reason: format!("Key derivation failed: {e}"),
70 })?;
71
72 Ok(key)
73}
74
75pub fn encrypt_private_key(
86 private_key: &PrivateKey,
87 encryption_key: impl AsRef<[u8]>,
88) -> Result<(Vec<u8>, Vec<u8>)> {
89 let encryption_key = encryption_key.as_ref();
90 if encryption_key.len() != KEY_LENGTH {
91 return Err(UserError::EncryptionFailed {
92 reason: format!(
93 "Invalid key length: expected {}, got {}",
94 KEY_LENGTH,
95 encryption_key.len()
96 ),
97 }
98 .into());
99 }
100
101 let serialized = private_key.to_prefixed_string();
104
105 let cipher =
107 Aes256Gcm::new_from_slice(encryption_key).map_err(|e| UserError::EncryptionFailed {
108 reason: format!("Failed to create cipher: {e}"),
109 })?;
110
111 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
113
114 let ciphertext =
116 cipher
117 .encrypt(&nonce, serialized.as_bytes())
118 .map_err(|e| UserError::EncryptionFailed {
119 reason: format!("Encryption failed: {e}"),
120 })?;
121
122 Ok((ciphertext, nonce.to_vec()))
123}
124
125pub fn decrypt_private_key(
135 ciphertext: impl AsRef<[u8]>,
136 nonce: impl AsRef<[u8]>,
137 encryption_key: impl AsRef<[u8]>,
138) -> Result<PrivateKey> {
139 let encryption_key = encryption_key.as_ref();
140 let nonce_bytes = nonce.as_ref();
141 let ciphertext = ciphertext.as_ref();
142
143 if encryption_key.len() != KEY_LENGTH {
144 return Err(UserError::DecryptionFailed {
145 reason: format!(
146 "Invalid key length: expected {}, got {}",
147 KEY_LENGTH,
148 encryption_key.len()
149 ),
150 }
151 .into());
152 }
153
154 if nonce_bytes.len() != NONCE_LENGTH {
155 return Err(UserError::InvalidNonceLength {
156 expected: NONCE_LENGTH,
157 actual: nonce_bytes.len(),
158 }
159 .into());
160 }
161
162 let cipher =
164 Aes256Gcm::new_from_slice(encryption_key).map_err(|e| UserError::DecryptionFailed {
165 reason: format!("Failed to create cipher: {e}"),
166 })?;
167
168 let nonce_array: [u8; NONCE_LENGTH] =
170 nonce_bytes
171 .try_into()
172 .map_err(|_| UserError::InvalidNonceLength {
173 expected: NONCE_LENGTH,
174 actual: nonce_bytes.len(),
175 })?;
176 let nonce = Nonce::from(nonce_array);
177
178 let plaintext =
180 cipher
181 .decrypt(&nonce, ciphertext)
182 .map_err(|e| UserError::DecryptionFailed {
183 reason: format!("Decryption failed: {e}"),
184 })?;
185
186 let mut prefixed = String::from_utf8(plaintext).map_err(|e| {
188 let mut bytes = e.into_bytes();
189 bytes.zeroize();
190 UserError::DecryptionFailed {
191 reason: "Decrypted key is not valid UTF-8".to_string(),
192 }
193 })?;
194
195 let key = PrivateKey::from_prefixed_string(&prefixed).map_err(|e| {
196 prefixed.zeroize();
197 UserError::DecryptionFailed {
198 reason: format!("Failed to parse decrypted private key: {e}"),
199 }
200 })?;
201
202 prefixed.zeroize();
203 Ok(key)
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::auth::crypto::generate_keypair;
210
211 #[test]
212 #[cfg_attr(miri, ignore)] fn test_generate_salt() {
214 let salt1 = generate_salt();
215 let salt2 = generate_salt();
216
217 assert_eq!(salt1.len(), SALT_LENGTH);
219 assert_eq!(salt2.len(), SALT_LENGTH);
220
221 assert_ne!(salt1, salt2);
223 }
224
225 #[test]
226 #[cfg_attr(miri, ignore)] fn test_key_encryption_round_trip() {
228 let (private_key, _) = generate_keypair();
229 let password = "encryption_password";
230 let salt = generate_salt();
231
232 let encryption_key = derive_encryption_key(password, &salt).unwrap();
234
235 let (ciphertext, nonce) = encrypt_private_key(&private_key, &encryption_key).unwrap();
237
238 let decrypted_key = decrypt_private_key(&ciphertext, &nonce, &encryption_key).unwrap();
240
241 assert_eq!(private_key.to_bytes(), decrypted_key.to_bytes());
243 }
244
245 #[test]
246 #[cfg_attr(miri, ignore)] fn test_encryption_wrong_key_fails() {
248 let (private_key, _) = generate_keypair();
249 let password1 = "password1";
250 let password2 = "password2";
251 let salt = generate_salt();
252
253 let encryption_key1 = derive_encryption_key(password1, &salt).unwrap();
255 let (ciphertext, nonce) = encrypt_private_key(&private_key, &encryption_key1).unwrap();
256
257 let encryption_key2 = derive_encryption_key(password2, &salt).unwrap();
259 let result = decrypt_private_key(&ciphertext, &nonce, &encryption_key2);
260
261 assert!(result.is_err());
263 }
264
265 #[test]
266 #[cfg_attr(miri, ignore)] fn test_nonce_uniqueness() {
268 let (private_key, _) = generate_keypair();
269 let password = "password";
270 let salt = generate_salt();
271 let encryption_key = derive_encryption_key(password, &salt).unwrap();
272
273 let (_, nonce1) = encrypt_private_key(&private_key, &encryption_key).unwrap();
275 let (_, nonce2) = encrypt_private_key(&private_key, &encryption_key).unwrap();
276
277 assert_ne!(nonce1, nonce2);
279 }
280}