1 package net.avcompris.examples.users3.dao.impl;
2
3 import static com.google.common.base.Preconditions.checkNotNull;
4 import static com.google.common.collect.Sets.newHashSet;
5 import static net.avcompris.commons3.databeans.DataBeans.instantiate;
6 import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric;
7
8 import java.io.IOException;
9 import java.sql.Connection;
10 import java.sql.PreparedStatement;
11 import java.sql.ResultSet;
12 import java.sql.SQLException;
13 import java.sql.SQLIntegrityConstraintViolationException;
14 import java.util.Set;
15
16 import javax.annotation.Nullable;
17 import javax.sql.DataSource;
18
19 import org.apache.commons.codec.digest.DigestUtils;
20 import org.joda.time.DateTime;
21 import org.springframework.beans.factory.annotation.Autowired;
22 import org.springframework.beans.factory.annotation.Value;
23 import org.springframework.stereotype.Component;
24
25 import net.avcompris.commons.query.impl.SqlWhereClause;
26 import net.avcompris.commons3.api.UserSessionFiltering;
27 import net.avcompris.commons3.dao.impl.AbstractDaoInRDS;
28 import net.avcompris.commons3.utils.Clock;
29 import net.avcompris.examples.users3.dao.AuthDao;
30 import net.avcompris.examples.users3.dao.UserSessionDto;
31 import net.avcompris.examples.users3.dao.UserSessionsDto;
32 import net.avcompris.examples.users3.dao.UserSessionsDtoQuery;
33
34 @Component
35 public final class AuthDaoInRDS extends AbstractDaoInRDS implements AuthDao {
36
37 private static final int SESSION_TIMEOUT_MINUTES = 60;
38
39 private final String sessionsTableName;
40
41 private final boolean debug;
42
43 @Autowired
44 public AuthDaoInRDS(
45 @Value("#{rds.dataSource}") final DataSource dataSource,
46 @Value("#{rds.tableNames.auth}") final String tableName,
47 final Clock clock) {
48
49 super(dataSource, tableName, clock);
50
51 sessionsTableName = tableName + "_sessions";
52
53 debug = System.getProperty("debug") != null;
54 }
55
56 private static String hashPassword(final String passwordSalt, final String password) {
57
58 return DigestUtils.sha256Hex(passwordSalt + password);
59 }
60
61 @Override
62 public void setUserPassword(final String username, final String password) throws SQLException, IOException {
63
64 checkNotNull(username, "username");
65 checkNotNull(password, "password");
66
67 final String passwordSalt = randomAlphanumeric(20);
68
69 final String passwordHash = hashPassword(passwordSalt, password);
70
71 try (Connection cxn = getConnection()) {
72
73 final int updated;
74
75 try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + tableName
76 + " SET"
77 + " password_salt = ?,"
78 + " password_hash = ?"
79 + " WHERE username = ?"
80 )) {
81
82 setString(pstmt, 1, passwordSalt);
83 setString(pstmt, 2, passwordHash);
84 setString(pstmt, 3, username);
85
86 updated = pstmt.executeUpdate();
87 }
88
89 if (updated == 0) {
90
91 try (PreparedStatement pstmt = cxn.prepareStatement("INSERT INTO " + tableName
92 + " (username, password_salt, password_hash)"
93 + " VALUES (?, ?, ?)"
94 )) {
95
96 setString(pstmt, 1, username);
97 setString(pstmt, 2, passwordSalt);
98 setString(pstmt, 3, passwordHash);
99
100 pstmt.executeUpdate();
101 }
102 }
103 }
104 }
105
106 @Override
107 public void removeUserPassword(final String username) throws SQLException, IOException {
108
109 checkNotNull(username, "username");
110
111 try (Connection cxn = getConnection()) {
112
113 try (PreparedStatement pstmt = cxn.prepareStatement("DELETE FROM " + tableName
114 + " WHERE username = ?"
115 )) {
116
117 setString(pstmt, 1, username);
118
119 pstmt.executeUpdate();
120 }
121 }
122 }
123
124 @Override
125 @Nullable
126 public String getUsernameByAuthorization(final String authorization,
127 final DateTime updatedAt
128 ) throws SQLException, IOException {
129
130 checkNotNull(authorization, "authorization");
131 checkNotNull(updatedAt, "updatedAt");
132
133
134
135 return null;
136 }
137
138 @Override
139 @Nullable
140 public String getUsernameBySessionId(final String userSessionId,
141 final DateTime updatedAt
142 ) throws SQLException, IOException {
143
144 checkNotNull(userSessionId, "userSessionId");
145 checkNotNull(updatedAt, "updatedAt");
146
147 final long startMs = System.currentTimeMillis();
148
149 if (debug) {
150 System.out.println(AuthDaoInRDS.class.getSimpleName() + ".getUsernameBySessionId(), userSessionId: "
151 + userSessionId + "...");
152 }
153
154 final String username;
155
156 try (Connection cxn = getConnection()) {
157
158 try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName
159 + " SET"
160 + " updated_at = ?"
161 + " WHERE user_session_id = ?"
162 + " AND expired_at IS NULL"
163 )) {
164
165 setDateTime(pstmt, 1, updatedAt);
166 setString(pstmt, 2, userSessionId);
167
168 final int updated = pstmt.executeUpdate();
169
170 if (updated == 0) {
171 return null;
172 }
173 }
174
175 try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName
176 + " SET"
177 + " expired_at = ?"
178 + " WHERE user_session_id = ?"
179 + " AND updated_at >= expires_at"
180 )) {
181
182 setDateTime(pstmt, 1, updatedAt);
183 setString(pstmt, 2, userSessionId);
184
185 final int updated = pstmt.executeUpdate();
186
187 if (updated != 0) {
188 return null;
189 }
190 }
191
192 try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName
193 + " SET"
194 + " expires_at = ?"
195 + " WHERE user_session_id = ?"
196 + " AND expired_at IS NULL"
197 )) {
198
199 setDateTime(pstmt, 1, updatedAt.plusMinutes(SESSION_TIMEOUT_MINUTES));
200 setString(pstmt, 2, userSessionId);
201
202 final int updated = pstmt.executeUpdate();
203
204 if (updated == 0) {
205 return null;
206 }
207 }
208
209 try (PreparedStatement pstmt = cxn.prepareStatement("SELECT"
210 + " username"
211 + " FROM " + sessionsTableName
212 + " WHERE user_session_id = ?"
213 + " AND expired_at IS NULL"
214 )) {
215
216 setString(pstmt, 1, userSessionId);
217
218 try (ResultSet rs = pstmt.executeQuery()) {
219
220 if (rs.next()) {
221 username = getString(rs, 1);
222 } else {
223 return null;
224 }
225 }
226 }
227 }
228
229 final long elapsedMs = System.currentTimeMillis() - startMs;
230
231 if (debug) {
232 System.out
233 .println(AuthDaoInRDS.class.getSimpleName() + ".getUsernameBySessionId(), elapsedMs: " + elapsedMs);
234 }
235
236 return username;
237 }
238
239 @Override
240 public boolean isValidUserPassword(final String username, final String password) throws SQLException, IOException {
241
242 checkNotNull(username, "username");
243 checkNotNull(password, "password");
244
245 try (Connection cxn = getConnection()) {
246
247 final Set<String> passwordSalts = newHashSet();
248
249 try (PreparedStatement pstmt = cxn.prepareStatement("SELECT"
250 + " password_salt"
251 + " FROM " + tableName
252 + " WHERE username = ?"
253 )) {
254
255 setString(pstmt, 1, username);
256
257 try (ResultSet rs = pstmt.executeQuery()) {
258
259 while (rs.next()) {
260
261 final String passwordSalt = rs.getString(1);
262
263 passwordSalts.add(passwordSalt);
264 }
265 }
266 }
267
268 try (PreparedStatement pstmt = cxn.prepareStatement("SELECT"
269 + " 1"
270 + " FROM " + tableName
271 + " WHERE username = ?"
272 + " AND password_salt = ?"
273 + " AND password_hash = ?"
274 )) {
275
276 setString(pstmt, 1, username);
277
278 for (final String passwordSalt : passwordSalts) {
279
280 setString(pstmt, 2, passwordSalt);
281
282 final String passwordHash = hashPassword(passwordSalt, password);
283
284 setString(pstmt, 3, passwordHash);
285
286 try (ResultSet rs = pstmt.executeQuery()) {
287
288
289
290 if (rs.next()) {
291
292 return true;
293 }
294 }
295 }
296 }
297 }
298
299 return false;
300 }
301
302 @Override
303 public UserSessionDto newUserSession(final String username, final DateTime createdAt)
304 throws SQLException, IOException {
305
306 checkNotNull(username, "username");
307 checkNotNull(createdAt, "createdAt");
308
309 final DateTime expiresAt = createdAt.plusMinutes(SESSION_TIMEOUT_MINUTES);
310
311 try (Connection cxn = getConnection()) {
312
313 try (PreparedStatement pstmt = cxn.prepareStatement("INSERT INTO " + sessionsTableName
314 + " (user_session_id,"
315 + " username,"
316 + " created_at,"
317 + " updated_at,"
318 + " expires_at)"
319 + " VALUES (?, ?, ?, ?, ?)"
320 )) {
321
322 setString(pstmt, 2, username);
323 setDateTime(pstmt, 3, createdAt);
324 setDateTime(pstmt, 4, createdAt);
325 setDateTime(pstmt, 5, expiresAt);
326
327 final String userSessionId = retryUntil(4_000, 0, () -> {
328
329 final String newSessionId = "S-"
330 + System.currentTimeMillis() + "-"
331 + randomAlphanumeric(20);
332
333 try {
334
335 setString(pstmt, 1, newSessionId);
336
337 pstmt.executeUpdate();
338
339 } catch (final SQLIntegrityConstraintViolationException e) {
340
341 return null;
342
343 } catch (final SQLException e) {
344
345 if (isPSQLUniqueViolation(e)) {
346
347 return null;
348 }
349
350 throw e;
351 }
352
353 return newSessionId;
354
355 });
356
357 return instantiate(MutableUserSessionDto.class)
358 .setUserSessionId(userSessionId).setUsername(username)
359 .setCreatedAt(createdAt)
360 .setUpdatedAt(createdAt)
361 .setExpiresAt(expiresAt);
362 }
363 }
364 }
365
366 @Override
367 @Nullable
368 public UserSessionDto getUserSession(final String userSessionId,
369 final DateTime updatedAt
370 ) throws SQLException, IOException {
371
372 checkNotNull(userSessionId, "userSessionId");
373 checkNotNull(updatedAt, "updatedAt");
374
375 final UserSessionDto dto;
376
377 try (Connection cxn = getConnection()) {
378
379 cxn.setAutoCommit(false);
380
381 try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName
382 + " SET"
383 + " updated_at = ?,"
384 + " expired_at = ?"
385 + " WHERE user_session_id = ?"
386 + " AND ? >= expires_at"
387 + " AND expired_at IS NULL"
388 )) {
389
390 setDateTime(pstmt, 1, updatedAt);
391 setDateTime(pstmt, 2, updatedAt);
392 setString(pstmt, 3, userSessionId);
393 setDateTime(pstmt, 4, updatedAt);
394
395 pstmt.executeUpdate();
396 }
397
398 try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName
399 + " SET"
400 + " updated_at = ?,"
401 + " expires_at = ?"
402 + " WHERE user_session_id = ?"
403 + " AND expired_at IS NULL"
404 )) {
405
406 setDateTime(pstmt, 1, updatedAt);
407 setDateTime(pstmt, 2, updatedAt.plusMinutes(SESSION_TIMEOUT_MINUTES));
408 setString(pstmt, 3, userSessionId);
409
410 pstmt.executeUpdate();
411 }
412
413 try (PreparedStatement pstmt = cxn.prepareStatement("SELECT"
414 + " user_session_id,"
415 + " username,"
416 + " created_at,"
417 + " updated_at,"
418 + " expires_at,"
419 + " expired_at"
420
421 + " FROM " + sessionsTableName
422 + " WHERE user_session_id = ?"
423 )) {
424
425 setString(pstmt, 1, userSessionId);
426
427 try (ResultSet rs = pstmt.executeQuery()) {
428
429 if (!rs.next()) {
430 return null;
431 }
432
433 dto = resultSet2UserSessionDto(rs);
434 }
435 }
436
437 cxn.commit();
438 }
439
440 return dto;
441 }
442
443 @Override
444 public void terminateSession(final String userSessionId,
445 @Nullable final DateTime updatedAt,
446 final DateTime expiredAt
447 ) throws SQLException, IOException {
448
449 checkNotNull(userSessionId, "userSessionId");
450 checkNotNull(expiredAt, "expiredAt");
451
452 try (Connection cxn = getConnection()) {
453
454 if (updatedAt != null) {
455
456 try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName
457 + " SET"
458 + " updated_at = ?,"
459 + " expired_at = ?"
460 + " WHERE user_session_id = ?"
461 + " AND expired_at IS NULL"
462 )) {
463
464 setDateTime(pstmt, 1, updatedAt);
465 setDateTime(pstmt, 2, expiredAt);
466 setString(pstmt, 3, userSessionId);
467
468 pstmt.executeUpdate();
469 }
470
471 } else {
472
473 try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName
474 + " SET"
475 + " expired_at = ?"
476 + " WHERE user_session_id = ?"
477 + " AND expired_at IS NULL"
478 )) {
479
480 setDateTime(pstmt, 1, expiredAt);
481 setString(pstmt, 2, userSessionId);
482
483 pstmt.executeUpdate();
484 }
485 }
486 }
487 }
488
489 @Override
490 @Nullable
491 public DateTime getLastActiveAt(final String username) throws SQLException, IOException {
492
493 checkNotNull(username, "username");
494
495 try (Connection cxn = getConnection()) {
496
497 try (PreparedStatement pstmt = cxn.prepareStatement("SELECT"
498 + " MAX(last_active_at)"
499 + " FROM " + tableName
500 + " WHERE username = ?"
501 )) {
502
503 setString(pstmt, 1, username);
504
505 try (ResultSet rs = pstmt.executeQuery()) {
506
507 if (!rs.next()) {
508 return null;
509 }
510
511 return getDateTime(rs, 1);
512 }
513 }
514 }
515 }
516
517 private static UserSessionDto resultSet2UserSessionDto(final ResultSet rs) throws SQLException {
518
519 return instantiate(MutableUserSessionDto.class)
520 .setUserSessionId(getString(rs, "user_session_id"))
521 .setUsername(getString(rs, "username"))
522 .setCreatedAt(getDateTime(rs, "created_at"))
523 .setUpdatedAt(getDateTime(rs, "updated_at"))
524 .setExpiresAt(getDateTime(rs, "expires_at"))
525 .setExpiredAt(getDateTime(rs, "expired_at"));
526 }
527
528 @Override
529 public UserSessionsDto getUserSessions(final UserSessionsDtoQuery query) throws SQLException, IOException {
530
531 checkNotNull(query, "query");
532
533 final String sqlWhereClause = SqlWhereClause
534 .build(query.getFiltering(), UserSessionFiltering.Field.class)
535 .getSQL(" WHERE");
536
537 final MutableUserSessionsDto sessions = instantiate(MutableUserSessionsDto.class)
538 .setSqlWhereClause(sqlWhereClause);
539
540 final String orderDirective = toSQLOrderByDirective(query.getSortBys());
541
542 final String limitClause = toSQLLimitClause(query.getStart(), query.getLimit());
543
544 final int total;
545
546 try (Connection cxn = getConnection()) {
547
548 try (PreparedStatement pstmt = cxn.prepareStatement("SELECT"
549 + " COUNT(1)"
550 + " FROM " + sessionsTableName
551 + sqlWhereClause)) {
552
553 try (ResultSet rs = pstmt.executeQuery()) {
554
555 if (rs.next()) {
556
557 total = getInt(rs, 1);
558
559 } else {
560
561 throw new IllegalStateException();
562 }
563 }
564 }
565
566 try (PreparedStatement pstmt = cxn.prepareStatement("SELECT"
567 + " user_session_id,"
568 + " username,"
569 + " created_at,"
570 + " updated_at,"
571 + " expires_at,"
572 + " expired_at"
573
574 + " FROM " + sessionsTableName
575 + sqlWhereClause
576 + orderDirective
577 + limitClause)) {
578
579 try (ResultSet rs = pstmt.executeQuery()) {
580
581 while (rs.next()) {
582
583 final UserSessionDto session = resultSet2UserSessionDto(rs);
584
585 sessions.addToResults(session);
586 }
587 }
588 }
589 }
590
591 return sessions.setTotal(total);
592 }
593 }