1use aes_gcm::{
8 Aes256Gcm, KeyInit, Nonce,
9 aead::{Aead, AeadCore, OsRng},
10};
11use argon2::{
12 Argon2,
13 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core},
14};
15use zeroize::Zeroize;
16
17use super::errors::UserError;
18use crate::{Result, auth::crypto::PrivateKey};
19
20pub const SALT_LENGTH: usize = 22;
22
23pub const NONCE_LENGTH: usize = 12;
25
26pub const KEY_LENGTH: usize = 32;
28
29pub fn hash_password(password: impl AsRef<str>) -> Result<(String, String)> {
39 let salt = SaltString::generate(&mut rand_core::OsRng);
40
41 let argon2 = Argon2::default();
42
43 let password_hash = argon2
44 .hash_password(password.as_ref().as_bytes(), &salt)
45 .map_err(|e| UserError::EncryptionFailed {
46 reason: format!("Password hashing failed: {e}"),
47 })?
48 .to_string();
49
50 let salt_string = salt.as_str().to_string();
51
52 Ok((password_hash, salt_string))
53}
54
55pub fn verify_password(password: impl AsRef<str>, password_hash: impl AsRef<str>) -> Result<()> {
64 let parsed_hash = PasswordHash::new(password_hash.as_ref())
65 .map_err(|_| UserError::PasswordVerificationFailed)?;
66
67 Argon2::default()
68 .verify_password(password.as_ref().as_bytes(), &parsed_hash)
69 .map_err(|_| UserError::InvalidPassword.into())
70}
71
72pub fn derive_encryption_key(password: impl AsRef<str>, salt: impl AsRef<str>) -> Result<Vec<u8>> {
81 let salt_str = salt.as_ref();
82 if salt_str.len() != SALT_LENGTH {
83 return Err(UserError::InvalidSaltLength {
84 expected: SALT_LENGTH,
85 actual: salt_str.len(),
86 }
87 .into());
88 }
89
90 let salt = SaltString::from_b64(salt_str).map_err(|e| UserError::EncryptionFailed {
91 reason: format!("Invalid salt format: {e}"),
92 })?;
93
94 let argon2 = Argon2::default();
95
96 let mut key = vec![0u8; KEY_LENGTH];
97 argon2
98 .hash_password_into(
99 password.as_ref().as_bytes(),
100 salt.as_str().as_bytes(),
101 &mut key,
102 )
103 .map_err(|e| UserError::EncryptionFailed {
104 reason: format!("Key derivation failed: {e}"),
105 })?;
106
107 Ok(key)
108}
109
110pub fn encrypt_private_key(
121 private_key: &PrivateKey,
122 encryption_key: impl AsRef<[u8]>,
123) -> Result<(Vec<u8>, Vec<u8>)> {
124 let encryption_key = encryption_key.as_ref();
125 if encryption_key.len() != KEY_LENGTH {
126 return Err(UserError::EncryptionFailed {
127 reason: format!(
128 "Invalid key length: expected {}, got {}",
129 KEY_LENGTH,
130 encryption_key.len()
131 ),
132 }
133 .into());
134 }
135
136 let serialized = private_key.to_prefixed_string();
139
140 let cipher =
142 Aes256Gcm::new_from_slice(encryption_key).map_err(|e| UserError::EncryptionFailed {
143 reason: format!("Failed to create cipher: {e}"),
144 })?;
145
146 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
148
149 let ciphertext =
151 cipher
152 .encrypt(&nonce, serialized.as_bytes())
153 .map_err(|e| UserError::EncryptionFailed {
154 reason: format!("Encryption failed: {e}"),
155 })?;
156
157 Ok((ciphertext, nonce.to_vec()))
158}
159
160pub fn decrypt_private_key(
170 ciphertext: impl AsRef<[u8]>,
171 nonce: impl AsRef<[u8]>,
172 encryption_key: impl AsRef<[u8]>,
173) -> Result<PrivateKey> {
174 let encryption_key = encryption_key.as_ref();
175 let nonce_bytes = nonce.as_ref();
176 let ciphertext = ciphertext.as_ref();
177
178 if encryption_key.len() != KEY_LENGTH {
179 return Err(UserError::DecryptionFailed {
180 reason: format!(
181 "Invalid key length: expected {}, got {}",
182 KEY_LENGTH,
183 encryption_key.len()
184 ),
185 }
186 .into());
187 }
188
189 if nonce_bytes.len() != NONCE_LENGTH {
190 return Err(UserError::InvalidNonceLength {
191 expected: NONCE_LENGTH,
192 actual: nonce_bytes.len(),
193 }
194 .into());
195 }
196
197 let cipher =
199 Aes256Gcm::new_from_slice(encryption_key).map_err(|e| UserError::DecryptionFailed {
200 reason: format!("Failed to create cipher: {e}"),
201 })?;
202
203 let nonce_array: [u8; NONCE_LENGTH] =
205 nonce_bytes
206 .try_into()
207 .map_err(|_| UserError::InvalidNonceLength {
208 expected: NONCE_LENGTH,
209 actual: nonce_bytes.len(),
210 })?;
211 let nonce = Nonce::from(nonce_array);
212
213 let plaintext =
215 cipher
216 .decrypt(&nonce, ciphertext)
217 .map_err(|e| UserError::DecryptionFailed {
218 reason: format!("Decryption failed: {e}"),
219 })?;
220
221 let mut prefixed = String::from_utf8(plaintext).map_err(|e| {
223 let mut bytes = e.into_bytes();
224 bytes.zeroize();
225 UserError::DecryptionFailed {
226 reason: "Decrypted key is not valid UTF-8".to_string(),
227 }
228 })?;
229
230 let key = PrivateKey::from_prefixed_string(&prefixed).map_err(|e| {
231 prefixed.zeroize();
232 UserError::DecryptionFailed {
233 reason: format!("Failed to parse decrypted private key: {e}"),
234 }
235 })?;
236
237 prefixed.zeroize();
238 Ok(key)
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::auth::crypto::generate_keypair;
245
246 #[test]
247 #[cfg_attr(miri, ignore)] fn test_password_hash_and_verify() {
249 let password = "test_password_123";
250
251 let (hash, _salt) = hash_password(password).unwrap();
252
253 assert!(verify_password(password, &hash).is_ok());
255
256 assert!(verify_password("wrong_password", &hash).is_err());
258 }
259
260 #[test]
261 #[cfg_attr(miri, ignore)] fn test_password_hash_unique() {
263 let password = "test_password_123";
264
265 let (hash1, _) = hash_password(password).unwrap();
266 let (hash2, _) = hash_password(password).unwrap();
267
268 assert_ne!(hash1, hash2);
270
271 assert!(verify_password(password, &hash1).is_ok());
273 assert!(verify_password(password, &hash2).is_ok());
274 }
275
276 #[test]
277 #[cfg_attr(miri, ignore)] fn test_key_encryption_round_trip() {
279 let (private_key, _) = generate_keypair();
280 let password = "encryption_password";
281 let (_, salt) = hash_password(password).unwrap();
282
283 let encryption_key = derive_encryption_key(password, &salt).unwrap();
285
286 let (ciphertext, nonce) = encrypt_private_key(&private_key, &encryption_key).unwrap();
288
289 let decrypted_key = decrypt_private_key(&ciphertext, &nonce, &encryption_key).unwrap();
291
292 assert_eq!(private_key.to_bytes(), decrypted_key.to_bytes());
294 }
295
296 #[test]
297 #[cfg_attr(miri, ignore)] fn test_encryption_wrong_key_fails() {
299 let (private_key, _) = generate_keypair();
300 let password1 = "password1";
301 let password2 = "password2";
302 let (_, salt) = hash_password(password1).unwrap();
303
304 let encryption_key1 = derive_encryption_key(password1, &salt).unwrap();
306 let (ciphertext, nonce) = encrypt_private_key(&private_key, &encryption_key1).unwrap();
307
308 let encryption_key2 = derive_encryption_key(password2, &salt).unwrap();
310 let result = decrypt_private_key(&ciphertext, &nonce, &encryption_key2);
311
312 assert!(result.is_err());
314 }
315
316 #[test]
317 #[cfg_attr(miri, ignore)] fn test_nonce_uniqueness() {
319 let (private_key, _) = generate_keypair();
320 let password = "password";
321 let (_, salt) = hash_password(password).unwrap();
322 let encryption_key = derive_encryption_key(password, &salt).unwrap();
323
324 let (_, nonce1) = encrypt_private_key(&private_key, &encryption_key).unwrap();
326 let (_, nonce2) = encrypt_private_key(&private_key, &encryption_key).unwrap();
327
328 assert_ne!(nonce1, nonce2);
330 }
331}