View Javadoc
1   package net.avcompris.commons3.web;
2   
3   import static com.google.common.base.Preconditions.checkNotNull;
4   import static org.springframework.http.HttpHeaders.AUTHORIZATION;
5   
6   import java.io.IOException;
7   import java.lang.reflect.Method;
8   import java.lang.reflect.Modifier;
9   
10  import javax.annotation.Nullable;
11  import javax.servlet.ServletException;
12  import javax.servlet.http.Cookie;
13  import javax.servlet.http.HttpServletRequest;
14  import javax.servlet.http.HttpServletResponse;
15  
16  import org.apache.commons.lang3.NotImplementedException;
17  import org.apache.commons.lang3.tuple.Pair;
18  import org.apache.commons.logging.Log;
19  import org.springframework.http.HttpHeaders;
20  import org.springframework.http.ResponseEntity;
21  import org.springframework.web.bind.annotation.RequestMapping;
22  
23  import net.avcompris.commons3.api.User;
24  import net.avcompris.commons3.api.exception.ServiceException;
25  import net.avcompris.commons3.api.exception.UnauthenticatedException;
26  import net.avcompris.commons3.client.SessionPropagator;
27  import net.avcompris.commons3.core.AuthService;
28  import net.avcompris.commons3.core.CorrelationService;
29  import net.avcompris.commons3.utils.Clock;
30  import net.avcompris.commons3.utils.LogFactory;
31  
32  public abstract class AbstractController {
33  
34  	public static final String CORRELATION_ID_ATTRIBUTE_NAME = "Correlation-ID";
35  	public static final String USER_SESSION_ID_ATTRIBUTE_NAME = "user_session_id";
36  
37  	private static final Log logger = LogFactory.getLog(AbstractController.class);
38  
39  	protected final CorrelationService correlationService;
40  	protected final Clock clock;
41  
42  	@Nullable
43  	private final SessionPropagator sessionPropagator;
44  
45  	protected AbstractController(final CorrelationService correlationService, final SessionPropagator sessionPropagator,
46  			final Clock clock) {
47  
48  		this.correlationService = checkNotNull(correlationService, "correlationService");
49  		this.sessionPropagator = checkNotNull(sessionPropagator, "sessionPropagator");
50  		this.clock = checkNotNull(clock, "clock");
51  	}
52  
53  	@Nullable
54  	private static String getCookie(final HttpServletRequest request, final String cookieName) {
55  
56  		checkNotNull(request, "request");
57  		checkNotNull(cookieName, "cookieName");
58  
59  		final Cookie[] cookies = request.getCookies();
60  
61  		if (cookies == null) {
62  			return null;
63  		}
64  
65  		for (final Cookie cookie : cookies) {
66  
67  			if (cookieName.contentEquals(cookie.getName())) {
68  
69  				return cookie.getValue();
70  			}
71  		}
72  
73  		return null;
74  	}
75  
76  	@Nullable
77  	private final String getAuthorization(final HttpServletRequest request) {
78  
79  		final String authorizationHeader = request.getHeader(AUTHORIZATION);
80  
81  		if (authorizationHeader != null) {
82  
83  			return authorizationHeader;
84  		}
85  
86  		final String authorizationCookie = getCookie(request, AUTHORIZATION);
87  
88  		if (authorizationCookie == null) {
89  
90  			return null;
91  		}
92  
93  		final String userSessionId = getUserSessionId(request);
94  
95  		return userSessionId == null
96  
97  				? authorizationCookie
98  
99  				: null;
100 	}
101 
102 	@Nullable
103 	protected final String getUserSessionId(final HttpServletRequest request) {
104 
105 		final String userSessionIdParam = request.getParameter(USER_SESSION_ID_ATTRIBUTE_NAME);
106 		final String userSessionIdHeader = request.getHeader(USER_SESSION_ID_ATTRIBUTE_NAME);
107 		final String userSessionIdCookie = getCookie(request, USER_SESSION_ID_ATTRIBUTE_NAME);
108 
109 		if (userSessionIdParam != null) {
110 
111 			return userSessionIdParam;
112 
113 		} else if (userSessionIdHeader != null) {
114 
115 			return userSessionIdHeader;
116 
117 		} else if (userSessionIdCookie != null) {
118 
119 			return userSessionIdCookie;
120 
121 		} else {
122 
123 			return null;
124 		}
125 	}
126 
127 	protected final <T> ResponseEntity<T> wrapAuthenticated(final HttpServletRequest request,
128 			final HttpServletResponse response, final AuthService authService, final AuthenticatedAction<T> action)
129 			throws ServiceException {
130 
131 		checkNotNull(request, "request");
132 		checkNotNull(action, "action");
133 
134 		// TODO: How to factor this code with wrapAuthenticatedServletAction()?!
135 
136 		final long startMs = System.currentTimeMillis();
137 
138 		final String authorization = getAuthorization(request);
139 		final String userSessionId = getUserSessionId(request);
140 		final String correlationId = getCorrelationId(request);
141 
142 		sessionPropagator.setAuthorizationHeader(authorization);
143 		sessionPropagator.setUserSessionId(userSessionId);
144 
145 		@Nullable
146 		final User user = authService.getAuthenticatedUser(authorization, userSessionId);
147 
148 		if (user == null) {
149 
150 			throw new UnauthenticatedException();
151 		}
152 
153 		if (userSessionId != null) {
154 
155 			// Use this to propage userSessionId all the way to ApplicationErrorController
156 			//
157 			request.setAttribute(USER_SESSION_ID_ATTRIBUTE_NAME, userSessionId);
158 		}
159 
160 		final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
161 
162 		final String methodName = endpoint.getRight().getName();
163 
164 		final Log logger = LogFactory.getLog(endpoint.getLeft());
165 
166 		if (logger.isInfoEnabled()) {
167 			logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
168 		}
169 
170 		final ResponseEntity<T> result;
171 
172 		authService.setLastActiveAt(correlationId, user);
173 
174 		try {
175 
176 			result = action.action(correlationId, user);
177 
178 		} catch (final ServiceException e) {
179 
180 			final int httpErrorCode = e.getHttpErrorCode();
181 
182 			final long elapsedMs = System.currentTimeMillis() - startMs;
183 
184 			logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
185 
186 			throw e;
187 		}
188 
189 		final long elapsedMs = System.currentTimeMillis() - startMs;
190 
191 		if (logger.isInfoEnabled()) {
192 			logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
193 		}
194 
195 		return enrich(response, userSessionId, result, correlationId);
196 	}
197 
198 	protected final <T> ResponseEntity<T> wrapAuthenticatedServletAction(final HttpServletRequest request,
199 			final HttpServletResponse response, final AuthService authService,
200 			final AuthenticatedServletAction<T> action) throws ServiceException, IOException, ServletException {
201 
202 		checkNotNull(request, "request");
203 		checkNotNull(action, "action");
204 
205 		// TODO: How to factor this code with wrapAuthenticated()?!
206 
207 		final long startMs = System.currentTimeMillis();
208 
209 		final String authorization = getAuthorization(request);
210 		final String userSessionId = getUserSessionId(request);
211 		final String correlationId = getCorrelationId(request);
212 
213 		sessionPropagator.setAuthorizationHeader(authorization);
214 		sessionPropagator.setUserSessionId(userSessionId);
215 
216 		@Nullable
217 		final User user = authService.getAuthenticatedUser(authorization, userSessionId);
218 
219 		if (user == null) {
220 
221 			throw new UnauthenticatedException();
222 		}
223 
224 		if (userSessionId != null) {
225 
226 			// Use this to propage userSessionId all the way to ApplicationErrorController
227 			//
228 			request.setAttribute(USER_SESSION_ID_ATTRIBUTE_NAME, userSessionId);
229 		}
230 
231 		final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
232 
233 		final String methodName = endpoint.getRight().getName();
234 
235 		final Log logger = LogFactory.getLog(endpoint.getLeft());
236 
237 		if (logger.isInfoEnabled()) {
238 			logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
239 		}
240 
241 		authService.setLastActiveAt(correlationId, user);
242 
243 		final ResponseEntity<T> result;
244 
245 		try {
246 
247 			result = action.action(correlationId, user);
248 
249 		} catch (final ServiceException e) {
250 
251 			final int httpErrorCode = e.getHttpErrorCode();
252 
253 			final long elapsedMs = System.currentTimeMillis() - startMs;
254 
255 			logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
256 
257 			throw e;
258 
259 		} catch (final IOException | ServletException e) {
260 
261 			final long elapsedMs = System.currentTimeMillis() - startMs;
262 
263 			logger.error(methodName + "() elapsedMs: " + elapsedMs, e);
264 
265 			throw e;
266 		}
267 
268 		final long elapsedMs = System.currentTimeMillis() - startMs;
269 
270 		if (logger.isInfoEnabled()) {
271 			logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
272 		}
273 
274 		return enrich(response, userSessionId, result, correlationId);
275 	}
276 
277 	protected final <T> ResponseEntity<T> wrapAuthenticatedOrNot(final HttpServletRequest request,
278 			final HttpServletResponse response, final AuthService authService, final AuthenticatedAction<T> action)
279 			throws ServiceException {
280 
281 		checkNotNull(request, "request");
282 		checkNotNull(action, "action");
283 
284 		final long startMs = System.currentTimeMillis();
285 
286 		final String authorization = getAuthorization(request);
287 		final String userSessionId = getUserSessionId(request);
288 		final String correlationId = getCorrelationId(request);
289 
290 		sessionPropagator.setAuthorizationHeader(authorization);
291 		sessionPropagator.setUserSessionId(userSessionId);
292 
293 		@Nullable
294 		final User user = authService.getAuthenticatedUser(authorization, userSessionId);
295 
296 		final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
297 
298 		final String methodName = endpoint.getRight().getName();
299 
300 		final Log logger = LogFactory.getLog(endpoint.getLeft());
301 
302 		if (logger.isInfoEnabled()) {
303 			logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
304 		}
305 
306 		final ResponseEntity<T> result;
307 
308 		if (user != null) {
309 
310 			authService.setLastActiveAt(correlationId, user);
311 		}
312 
313 		try {
314 
315 			result = action.action(correlationId, user);
316 
317 		} catch (final ServiceException e) {
318 
319 			final int httpErrorCode = e.getHttpErrorCode();
320 
321 			final long elapsedMs = System.currentTimeMillis() - startMs;
322 
323 			logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
324 
325 			throw e;
326 		}
327 
328 		final long elapsedMs = System.currentTimeMillis() - startMs;
329 
330 		if (logger.isInfoEnabled()) {
331 			logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
332 		}
333 
334 		return enrich(response, userSessionId, result, correlationId);
335 	}
336 
337 	protected final <T> ResponseEntity<T> wrapWithoutCorrelationId(final HttpServletRequest request,
338 			final HttpServletResponse response, final AuthService authService, final AuthenticatedAction<T> action)
339 			throws ServiceException {
340 
341 		checkNotNull(request, "request");
342 		checkNotNull(action, "action");
343 
344 		final long startMs = System.currentTimeMillis();
345 
346 		final String authorization = getAuthorization(request);
347 		final String userSessionId = getUserSessionId(request);
348 
349 		String fakeCorrelationId = "N/A";
350 
351 		try {
352 
353 			fakeCorrelationId = getCorrelationId(request);
354 
355 		} catch (final Throwable e) {
356 
357 			e.printStackTrace(System.out);
358 
359 			fakeCorrelationId = "N/A";
360 		}
361 
362 		LogFactory.resetCorrelationId();
363 
364 		LogFactory.setCorrelationId(fakeCorrelationId);
365 
366 		sessionPropagator.setAuthorizationHeader(authorization);
367 		sessionPropagator.setUserSessionId(userSessionId);
368 
369 		final User user = authService.getAuthenticatedUser(authorization, userSessionId);
370 
371 		if (user == null) {
372 
373 			throw new UnauthenticatedException();
374 		}
375 
376 		final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
377 
378 		final String methodName = endpoint.getRight().getName();
379 
380 		final Log logger = LogFactory.getLog(endpoint.getLeft());
381 
382 		if (logger.isInfoEnabled()) {
383 			logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
384 		}
385 
386 		final ResponseEntity<T> result;
387 
388 		try {
389 
390 			result = action.action(fakeCorrelationId, user);
391 
392 		} catch (final ServiceException e) {
393 
394 			final int httpErrorCode = e.getHttpErrorCode();
395 
396 			final long elapsedMs = System.currentTimeMillis() - startMs;
397 
398 			logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
399 
400 			throw e;
401 		}
402 
403 		final long elapsedMs = System.currentTimeMillis() - startMs;
404 
405 		if (logger.isInfoEnabled()) {
406 			logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
407 		}
408 
409 		return enrich(response, userSessionId, result, fakeCorrelationId);
410 	}
411 
412 	private String getCorrelationId(final HttpServletRequest request) throws ServiceException {
413 
414 		final String correlationIdParam = request.getParameter("correlationId");
415 		final String correlationIdHeader = request.getHeader(CORRELATION_ID_ATTRIBUTE_NAME);
416 
417 		final long startMs = System.currentTimeMillis();
418 
419 		final String correlationId = correlationService.getCorrelationId(correlationIdParam, correlationIdHeader);
420 
421 		LogFactory.resetCorrelationId();
422 
423 		LogFactory.setCorrelationId(correlationId);
424 
425 		request.setAttribute(CORRELATION_ID_ATTRIBUTE_NAME, correlationId);
426 
427 		final long elapsedMs = System.currentTimeMillis() - startMs;
428 
429 		if (logger.isDebugEnabled()) {
430 			logger.debug("getCorrelationId(), elapsedMs: " + elapsedMs);
431 		}
432 
433 		return correlationId;
434 	}
435 
436 	protected final <T> ResponseEntity<T> wrapNonAuthenticated(final HttpServletRequest request,
437 			final UnauthenticatedAction<T> action) throws ServiceException {
438 
439 		checkNotNull(request, "request");
440 		checkNotNull(action, "action");
441 
442 		final long startMs = System.currentTimeMillis();
443 
444 		final String correlationId = getCorrelationId(request);
445 
446 		final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
447 
448 		final String methodName = endpoint.getRight().getName();
449 
450 		final Log logger = LogFactory.getLog(endpoint.getLeft());
451 
452 		if (logger.isInfoEnabled()) {
453 			logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
454 		}
455 
456 		final ResponseEntity<T> result;
457 
458 		try {
459 
460 			result = action.action(correlationId);
461 
462 		} catch (final ServiceException e) {
463 
464 			final int httpErrorCode = e.getHttpErrorCode();
465 
466 			final long elapsedMs = System.currentTimeMillis() - startMs;
467 
468 			logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
469 
470 			throw e;
471 		}
472 
473 		final long elapsedMs = System.currentTimeMillis() - startMs;
474 
475 		if (logger.isInfoEnabled()) {
476 			logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
477 		}
478 
479 		return enrich(null, null, result, correlationId);
480 	}
481 
482 	protected final <T> ResponseEntity<T> wrapNonAuthenticatedWithoutCorrelationId(final HttpServletRequest request,
483 			final UnauthenticatedAnonymousAction<T> action) throws ServiceException {
484 
485 		checkNotNull(request, "request");
486 		checkNotNull(action, "action");
487 
488 		final long startMs = System.currentTimeMillis();
489 
490 		final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
491 
492 		final String methodName = endpoint.getRight().getName();
493 
494 		String fakeCorrelationId = "N/A";
495 
496 		try {
497 
498 			fakeCorrelationId = getCorrelationId(request);
499 
500 		} catch (final Throwable e) {
501 
502 			e.printStackTrace(System.out);
503 
504 			fakeCorrelationId = "N/A";
505 		}
506 
507 		LogFactory.resetCorrelationId();
508 
509 		LogFactory.setCorrelationId(fakeCorrelationId);
510 
511 		final Log logger = LogFactory.getLog(endpoint.getLeft());
512 
513 		if (logger.isInfoEnabled()) {
514 			logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
515 		}
516 
517 		final ResponseEntity<T> result;
518 
519 		try {
520 
521 			result = action.action();
522 
523 		} catch (final ServiceException e) {
524 
525 			final int httpErrorCode = e.getHttpErrorCode();
526 
527 			final long elapsedMs = System.currentTimeMillis() - startMs;
528 
529 			logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
530 
531 			throw e;
532 		}
533 
534 		final long elapsedMs = System.currentTimeMillis() - startMs;
535 
536 		if (logger.isInfoEnabled()) {
537 			logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
538 		}
539 
540 		// return enrich(null, null, result, correlationId);
541 
542 		return result;
543 	}
544 
545 	private <T> ResponseEntity<T> enrich(@Nullable final HttpServletResponse httpServletResponse, //
546 			@Nullable final String userSessionId, //
547 			final ResponseEntity<T> response, final String correlationId) {
548 
549 		checkNotNull(response, "response");
550 		checkNotNull(correlationId, "correlationId");
551 
552 		final HttpHeaders headers = new HttpHeaders();
553 
554 		headers.putAll(response.getHeaders());
555 
556 		headers.add(CORRELATION_ID_ATTRIBUTE_NAME, correlationId);
557 
558 		if (httpServletResponse != null && userSessionId != null) {
559 
560 			setUserSessionCookie(httpServletResponse, userSessionId);
561 		}
562 
563 		return ResponseEntity.status(response.getStatusCode()) //
564 				.headers(headers) //
565 				.body(response.getBody());
566 	}
567 
568 	private static Pair<Class<?>, Method> extractControllerCurrentEndpoint() {
569 
570 		for (final StackTraceElement ste : Thread.currentThread().getStackTrace()) {
571 
572 			final String className = ste.getClassName();
573 			final String methodName = ste.getMethodName();
574 
575 			if (!className.endsWith("Controller")) {
576 				continue;
577 			}
578 
579 			final Class<?> controllerClass;
580 
581 			try {
582 
583 				controllerClass = Class.forName(className);
584 
585 			} catch (final ClassNotFoundException e) {
586 
587 				continue;
588 			}
589 
590 			if (Modifier.isAbstract(controllerClass.getModifiers())) {
591 				continue;
592 			}
593 
594 			final Method method = extractDeclaredMethod(controllerClass, methodName);
595 
596 			if (method.getAnnotation(RequestMapping.class) == null || !Modifier.isPublic(method.getModifiers())) {
597 				continue;
598 			}
599 
600 			return Pair.of(controllerClass, method);
601 		}
602 
603 		throw new IllegalStateException("Cannot extract controller current endpoint");
604 	}
605 
606 	private static Method extractDeclaredMethod(final Class<?> controllerClass, final String methodName) {
607 
608 		for (final Method method : controllerClass.getDeclaredMethods()) {
609 
610 			if (methodName.contentEquals(method.getName()) || Modifier.isPublic(method.getModifiers())) {
611 
612 				return method;
613 			}
614 		}
615 
616 		throw new IllegalStateException(
617 				"Cannot extract method: " + methodName + " from controllerClass: " + controllerClass.getName());
618 	}
619 
620 	@FunctionalInterface
621 	protected interface UnauthenticatedAnonymousAction<T> {
622 
623 		ResponseEntity<T> action() throws ServiceException;
624 	}
625 
626 	@FunctionalInterface
627 	protected interface UnauthenticatedAction<T> {
628 
629 		ResponseEntity<T> action(String correlationId) throws ServiceException;
630 	}
631 
632 	@FunctionalInterface
633 	protected interface AuthenticatedAction<T> {
634 
635 		ResponseEntity<T> action(String correlationId, User user) throws ServiceException;
636 	}
637 
638 	@FunctionalInterface
639 	protected interface AuthenticatedServletAction<T> {
640 
641 		ResponseEntity<T> action(String correlationId, User user)
642 				throws ServiceException, ServletException, IOException;
643 	}
644 
645 	protected final <T extends ResponseEntity<?>> T handleServiceException(final ServiceException e) {
646 
647 		checkNotNull(e, "e");
648 
649 		throw new NotImplementedException("");
650 	}
651 
652 	protected static HttpHeaders headers(final String... headerNameValuePairs) {
653 
654 		final HttpHeaders headers = new HttpHeaders();
655 
656 		for (int i = 0; i < headerNameValuePairs.length / 2; ++i) {
657 
658 			final String headerName = headerNameValuePairs[i * 2];
659 			final String headerValue = headerNameValuePairs[i * 2 + 1];
660 
661 			headers.set(headerName, headerValue);
662 		}
663 
664 		return headers;
665 	}
666 
667 	protected abstract boolean isSecure();
668 
669 	protected abstract boolean isHttpOnly();
670 
671 	protected final void setUserSessionCookie(final HttpServletResponse response, final String userSessionId) {
672 
673 		checkNotNull(response, "response");
674 		checkNotNull(userSessionId, "userSessionId");
675 
676 		final Cookie cookie = new Cookie(USER_SESSION_ID_ATTRIBUTE_NAME, userSessionId);
677 
678 		// cookie.setDomain("...");
679 		cookie.setSecure(isSecure());
680 		cookie.setHttpOnly(isHttpOnly());
681 		cookie.setMaxAge(3600);
682 		cookie.setPath("/");
683 
684 		response.addCookie(cookie);
685 	}
686 }