View Javadoc
1   package net.avcompris.examples.users3.dao.impl;
2   
3   import static com.google.common.base.Preconditions.checkNotNull;
4   import static com.google.common.collect.Sets.newHashSet;
5   import static net.avcompris.commons3.databeans.DataBeans.instantiate;
6   import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric;
7   
8   import java.io.IOException;
9   import java.sql.Connection;
10  import java.sql.PreparedStatement;
11  import java.sql.ResultSet;
12  import java.sql.SQLException;
13  import java.sql.SQLIntegrityConstraintViolationException;
14  import java.util.Set;
15  
16  import javax.annotation.Nullable;
17  import javax.sql.DataSource;
18  
19  import org.apache.commons.codec.digest.DigestUtils;
20  import org.joda.time.DateTime;
21  import org.springframework.beans.factory.annotation.Autowired;
22  import org.springframework.beans.factory.annotation.Value;
23  import org.springframework.stereotype.Component;
24  
25  import net.avcompris.commons.query.impl.SqlWhereClause;
26  import net.avcompris.commons3.api.UserSessionFiltering;
27  import net.avcompris.commons3.dao.impl.AbstractDaoInRDS;
28  import net.avcompris.commons3.utils.Clock;
29  import net.avcompris.examples.users3.dao.AuthDao;
30  import net.avcompris.examples.users3.dao.UserSessionDto;
31  import net.avcompris.examples.users3.dao.UserSessionsDto;
32  import net.avcompris.examples.users3.dao.UserSessionsDtoQuery;
33  
34  @Component
35  public final class AuthDaoInRDS extends AbstractDaoInRDS implements AuthDao {
36  
37  	private static final int SESSION_TIMEOUT_MINUTES = 60;
38  
39  	private final String sessionsTableName;
40  
41  	private final boolean debug;
42  
43  	@Autowired
44  	public AuthDaoInRDS( //
45  			@Value("#{rds.dataSource}") final DataSource dataSource, //
46  			@Value("#{rds.tableNames.auth}") final String tableName, //
47  			final Clock clock) {
48  
49  		super(dataSource, tableName, clock);
50  
51  		sessionsTableName = tableName + "_sessions";
52  
53  		debug = System.getProperty("debug") != null;
54  	}
55  
56  	private static String hashPassword(final String passwordSalt, final String password) {
57  
58  		return DigestUtils.sha256Hex(passwordSalt + password);
59  	}
60  
61  	@Override
62  	public void setUserPassword(final String username, final String password) throws SQLException, IOException {
63  
64  		checkNotNull(username, "username");
65  		checkNotNull(password, "password");
66  
67  		final String passwordSalt = randomAlphanumeric(20);
68  
69  		final String passwordHash = hashPassword(passwordSalt, password);
70  
71  		try (Connection cxn = getConnection()) {
72  
73  			final int updated;
74  
75  			try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + tableName //
76  					+ " SET" //
77  					+ " password_salt = ?," //
78  					+ " password_hash = ?" //
79  					+ " WHERE username = ?" //
80  			)) {
81  
82  				setString(pstmt, 1, passwordSalt);
83  				setString(pstmt, 2, passwordHash);
84  				setString(pstmt, 3, username);
85  
86  				updated = pstmt.executeUpdate();
87  			}
88  
89  			if (updated == 0) {
90  
91  				try (PreparedStatement pstmt = cxn.prepareStatement("INSERT INTO " + tableName //
92  						+ " (username, password_salt, password_hash)" //
93  						+ " VALUES (?, ?, ?)" //
94  				)) {
95  
96  					setString(pstmt, 1, username);
97  					setString(pstmt, 2, passwordSalt);
98  					setString(pstmt, 3, passwordHash);
99  
100 					pstmt.executeUpdate();
101 				}
102 			}
103 		}
104 	}
105 
106 	@Override
107 	public void removeUserPassword(final String username) throws SQLException, IOException {
108 
109 		checkNotNull(username, "username");
110 
111 		try (Connection cxn = getConnection()) {
112 
113 			try (PreparedStatement pstmt = cxn.prepareStatement("DELETE FROM " + tableName //
114 					+ " WHERE username = ?" //
115 			)) {
116 
117 				setString(pstmt, 1, username);
118 
119 				pstmt.executeUpdate();
120 			}
121 		}
122 	}
123 
124 	@Override
125 	@Nullable
126 	public String getUsernameByAuthorization(final String authorization, //
127 			final DateTime updatedAt //
128 	) throws SQLException, IOException {
129 
130 		checkNotNull(authorization, "authorization");
131 		checkNotNull(updatedAt, "updatedAt");
132 
133 		// throw new NotImplementedException("");
134 
135 		return null;
136 	}
137 
138 	@Override
139 	@Nullable
140 	public String getUsernameBySessionId(final String userSessionId, //
141 			final DateTime updatedAt //
142 	) throws SQLException, IOException {
143 
144 		checkNotNull(userSessionId, "userSessionId");
145 		checkNotNull(updatedAt, "updatedAt");
146 
147 		final long startMs = System.currentTimeMillis();
148 
149 		if (debug) {
150 			System.out.println(AuthDaoInRDS.class.getSimpleName() + ".getUsernameBySessionId(), userSessionId: "
151 					+ userSessionId + "...");
152 		}
153 
154 		final String username;
155 
156 		try (Connection cxn = getConnection()) {
157 
158 			try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName //
159 					+ " SET" //
160 					+ " updated_at = ?" //
161 					+ " WHERE user_session_id = ?" //
162 					+ " AND expired_at IS NULL" //
163 			)) {
164 
165 				setDateTime(pstmt, 1, updatedAt);
166 				setString(pstmt, 2, userSessionId);
167 
168 				final int updated = pstmt.executeUpdate();
169 
170 				if (updated == 0) {
171 					return null;
172 				}
173 			}
174 
175 			try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName //
176 					+ " SET" //
177 					+ " expired_at = ?" //
178 					+ " WHERE user_session_id = ?" //
179 					+ " AND updated_at >= expires_at" //
180 			)) {
181 
182 				setDateTime(pstmt, 1, updatedAt);
183 				setString(pstmt, 2, userSessionId);
184 
185 				final int updated = pstmt.executeUpdate();
186 
187 				if (updated != 0) {
188 					return null;
189 				}
190 			}
191 
192 			try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName //
193 					+ " SET" //
194 					+ " expires_at = ?" //
195 					+ " WHERE user_session_id = ?" //
196 					+ " AND expired_at IS NULL" //
197 			)) {
198 
199 				setDateTime(pstmt, 1, updatedAt.plusMinutes(SESSION_TIMEOUT_MINUTES));
200 				setString(pstmt, 2, userSessionId);
201 
202 				final int updated = pstmt.executeUpdate();
203 
204 				if (updated == 0) {
205 					return null;
206 				}
207 			}
208 
209 			try (PreparedStatement pstmt = cxn.prepareStatement("SELECT" //
210 					+ " username" //
211 					+ " FROM " + sessionsTableName //
212 					+ " WHERE user_session_id = ?" //
213 					+ " AND expired_at IS NULL" //
214 			)) {
215 
216 				setString(pstmt, 1, userSessionId);
217 
218 				try (ResultSet rs = pstmt.executeQuery()) {
219 
220 					if (rs.next()) {
221 						username = getString(rs, 1);
222 					} else {
223 						return null;
224 					}
225 				}
226 			}
227 		}
228 
229 		final long elapsedMs = System.currentTimeMillis() - startMs;
230 
231 		if (debug) {
232 			System.out
233 					.println(AuthDaoInRDS.class.getSimpleName() + ".getUsernameBySessionId(), elapsedMs: " + elapsedMs);
234 		}
235 
236 		return username;
237 	}
238 
239 	@Override
240 	public boolean isValidUserPassword(final String username, final String password) throws SQLException, IOException {
241 
242 		checkNotNull(username, "username");
243 		checkNotNull(password, "password");
244 
245 		try (Connection cxn = getConnection()) {
246 
247 			final Set<String> passwordSalts = newHashSet();
248 
249 			try (PreparedStatement pstmt = cxn.prepareStatement("SELECT" //
250 					+ " password_salt" //
251 					+ " FROM " + tableName //
252 					+ " WHERE username = ?" //
253 			)) {
254 
255 				setString(pstmt, 1, username);
256 
257 				try (ResultSet rs = pstmt.executeQuery()) {
258 
259 					while (rs.next()) {
260 
261 						final String passwordSalt = rs.getString(1);
262 
263 						passwordSalts.add(passwordSalt);
264 					}
265 				}
266 			}
267 
268 			try (PreparedStatement pstmt = cxn.prepareStatement("SELECT" //
269 					+ " 1" //
270 					+ " FROM " + tableName //
271 					+ " WHERE username = ?" //
272 					+ " AND password_salt = ?" //
273 					+ " AND password_hash = ?" //
274 			)) {
275 
276 				setString(pstmt, 1, username);
277 
278 				for (final String passwordSalt : passwordSalts) {
279 
280 					setString(pstmt, 2, passwordSalt);
281 
282 					final String passwordHash = hashPassword(passwordSalt, password);
283 
284 					setString(pstmt, 3, passwordHash);
285 
286 					try (ResultSet rs = pstmt.executeQuery()) {
287 
288 						// Return true if and only if there is at least one such record in the database.
289 
290 						if (rs.next()) {
291 
292 							return true;
293 						}
294 					}
295 				}
296 			}
297 		}
298 
299 		return false;
300 	}
301 
302 	@Override
303 	public UserSessionDto newUserSession(final String username, final DateTime createdAt)
304 			throws SQLException, IOException {
305 
306 		checkNotNull(username, "username");
307 		checkNotNull(createdAt, "createdAt");
308 
309 		final DateTime expiresAt = createdAt.plusMinutes(SESSION_TIMEOUT_MINUTES);
310 
311 		try (Connection cxn = getConnection()) {
312 
313 			try (PreparedStatement pstmt = cxn.prepareStatement("INSERT INTO " + sessionsTableName //
314 					+ " (user_session_id," //
315 					+ " username," //
316 					+ " created_at," //
317 					+ " updated_at," //
318 					+ " expires_at)" //
319 					+ " VALUES (?, ?, ?, ?, ?)" //
320 			)) {
321 
322 				setString(pstmt, 2, username);
323 				setDateTime(pstmt, 3, createdAt);
324 				setDateTime(pstmt, 4, createdAt);
325 				setDateTime(pstmt, 5, expiresAt);
326 
327 				final String userSessionId = retryUntil(4_000, 0, () -> {
328 
329 					final String newSessionId = "S-" //
330 							+ System.currentTimeMillis() + "-" //
331 							+ randomAlphanumeric(20);
332 
333 					try {
334 
335 						setString(pstmt, 1, newSessionId);
336 
337 						pstmt.executeUpdate();
338 
339 					} catch (final SQLIntegrityConstraintViolationException e) {
340 
341 						return null;
342 
343 					} catch (final SQLException e) {
344 
345 						if (isPSQLUniqueViolation(e)) {
346 
347 							return null;
348 						}
349 
350 						throw e;
351 					}
352 
353 					return newSessionId;
354 
355 				});
356 
357 				return instantiate(MutableUserSessionDto.class) //
358 						.setUserSessionId(userSessionId).setUsername(username) //
359 						.setCreatedAt(createdAt) //
360 						.setUpdatedAt(createdAt) //
361 						.setExpiresAt(expiresAt);
362 			}
363 		}
364 	}
365 
366 	@Override
367 	@Nullable
368 	public UserSessionDto getUserSession(final String userSessionId, //
369 			final DateTime updatedAt //
370 	) throws SQLException, IOException {
371 
372 		checkNotNull(userSessionId, "userSessionId");
373 		checkNotNull(updatedAt, "updatedAt");
374 
375 		final UserSessionDto dto;
376 
377 		try (Connection cxn = getConnection()) {
378 
379 			cxn.setAutoCommit(false);
380 
381 			try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName //
382 					+ " SET" //
383 					+ " updated_at = ?," //
384 					+ " expired_at = ?" //
385 					+ " WHERE user_session_id = ?" //
386 					+ " AND ? >= expires_at" //
387 					+ " AND expired_at IS NULL" //
388 			)) {
389 
390 				setDateTime(pstmt, 1, updatedAt);
391 				setDateTime(pstmt, 2, updatedAt);
392 				setString(pstmt, 3, userSessionId);
393 				setDateTime(pstmt, 4, updatedAt);
394 
395 				pstmt.executeUpdate();
396 			}
397 
398 			try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName //
399 					+ " SET" //
400 					+ " updated_at = ?," //
401 					+ " expires_at = ?" //
402 					+ " WHERE user_session_id = ?" //
403 					+ " AND expired_at IS NULL" //
404 			)) {
405 
406 				setDateTime(pstmt, 1, updatedAt);
407 				setDateTime(pstmt, 2, updatedAt.plusMinutes(SESSION_TIMEOUT_MINUTES));
408 				setString(pstmt, 3, userSessionId);
409 
410 				pstmt.executeUpdate();
411 			}
412 
413 			try (PreparedStatement pstmt = cxn.prepareStatement("SELECT" //
414 					+ " user_session_id," //
415 					+ " username," //
416 					+ " created_at," //
417 					+ " updated_at," //
418 					+ " expires_at," //
419 					+ " expired_at"
420 
421 					+ " FROM " + sessionsTableName //
422 					+ " WHERE user_session_id = ?" //
423 			)) {
424 
425 				setString(pstmt, 1, userSessionId);
426 
427 				try (ResultSet rs = pstmt.executeQuery()) {
428 
429 					if (!rs.next()) {
430 						return null;
431 					}
432 
433 					dto = resultSet2UserSessionDto(rs);
434 				}
435 			}
436 
437 			cxn.commit();
438 		}
439 
440 		return dto;
441 	}
442 
443 	@Override
444 	public void terminateSession(final String userSessionId, //
445 			@Nullable final DateTime updatedAt, //
446 			final DateTime expiredAt //
447 	) throws SQLException, IOException {
448 
449 		checkNotNull(userSessionId, "userSessionId");
450 		checkNotNull(expiredAt, "expiredAt");
451 
452 		try (Connection cxn = getConnection()) {
453 
454 			if (updatedAt != null) {
455 
456 				try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName //
457 						+ " SET" //
458 						+ " updated_at = ?," //
459 						+ " expired_at = ?" //
460 						+ " WHERE user_session_id = ?" //
461 						+ " AND expired_at IS NULL" //
462 				)) {
463 
464 					setDateTime(pstmt, 1, updatedAt);
465 					setDateTime(pstmt, 2, expiredAt);
466 					setString(pstmt, 3, userSessionId);
467 
468 					pstmt.executeUpdate();
469 				}
470 
471 			} else {
472 
473 				try (PreparedStatement pstmt = cxn.prepareStatement("UPDATE " + sessionsTableName //
474 						+ " SET" //
475 						+ " expired_at = ?" //
476 						+ " WHERE user_session_id = ?" //
477 						+ " AND expired_at IS NULL" //
478 				)) {
479 
480 					setDateTime(pstmt, 1, expiredAt);
481 					setString(pstmt, 2, userSessionId);
482 
483 					pstmt.executeUpdate();
484 				}
485 			}
486 		}
487 	}
488 
489 	@Override
490 	@Nullable
491 	public DateTime getLastActiveAt(final String username) throws SQLException, IOException {
492 
493 		checkNotNull(username, "username");
494 
495 		try (Connection cxn = getConnection()) {
496 
497 			try (PreparedStatement pstmt = cxn.prepareStatement("SELECT" //
498 					+ " MAX(last_active_at)" //
499 					+ " FROM " + tableName //
500 					+ " WHERE username = ?" //
501 			)) {
502 
503 				setString(pstmt, 1, username);
504 
505 				try (ResultSet rs = pstmt.executeQuery()) {
506 
507 					if (!rs.next()) {
508 						return null;
509 					}
510 
511 					return getDateTime(rs, 1);
512 				}
513 			}
514 		}
515 	}
516 
517 	private static UserSessionDto resultSet2UserSessionDto(final ResultSet rs) throws SQLException {
518 
519 		return instantiate(MutableUserSessionDto.class) //
520 				.setUserSessionId(getString(rs, "user_session_id")) //
521 				.setUsername(getString(rs, "username")) //
522 				.setCreatedAt(getDateTime(rs, "created_at")) //
523 				.setUpdatedAt(getDateTime(rs, "updated_at")) //
524 				.setExpiresAt(getDateTime(rs, "expires_at")) //
525 				.setExpiredAt(getDateTime(rs, "expired_at"));
526 	}
527 
528 	@Override
529 	public UserSessionsDto getUserSessions(final UserSessionsDtoQuery query) throws SQLException, IOException {
530 
531 		checkNotNull(query, "query");
532 
533 		final String sqlWhereClause = SqlWhereClause //
534 				.build(query.getFiltering(), UserSessionFiltering.Field.class) //
535 				.getSQL(" WHERE");
536 
537 		final MutableUserSessionsDto sessions = instantiate(MutableUserSessionsDto.class) //
538 				.setSqlWhereClause(sqlWhereClause);
539 
540 		final String orderDirective = toSQLOrderByDirective(query.getSortBys());
541 
542 		final String limitClause = toSQLLimitClause(query.getStart(), query.getLimit());
543 
544 		final int total;
545 
546 		try (Connection cxn = getConnection()) {
547 
548 			try (PreparedStatement pstmt = cxn.prepareStatement("SELECT" //
549 					+ " COUNT(1)" //
550 					+ " FROM " + sessionsTableName //
551 					+ sqlWhereClause)) {
552 
553 				try (ResultSet rs = pstmt.executeQuery()) {
554 
555 					if (rs.next()) {
556 
557 						total = getInt(rs, 1);
558 
559 					} else {
560 
561 						throw new IllegalStateException();
562 					}
563 				}
564 			}
565 
566 			try (PreparedStatement pstmt = cxn.prepareStatement("SELECT" //
567 					+ " user_session_id," //
568 					+ " username," //
569 					+ " created_at," //
570 					+ " updated_at," //
571 					+ " expires_at," //
572 					+ " expired_at"
573 
574 					+ " FROM " + sessionsTableName //
575 					+ sqlWhereClause //
576 					+ orderDirective //
577 					+ limitClause)) {
578 
579 				try (ResultSet rs = pstmt.executeQuery()) {
580 
581 					while (rs.next()) {
582 
583 						final UserSessionDto session = resultSet2UserSessionDto(rs);
584 
585 						sessions.addToResults(session);
586 					}
587 				}
588 			}
589 		}
590 
591 		return sessions.setTotal(total);
592 	}
593 }