1 package net.avcompris.commons3.web;
2
3 import static com.google.common.base.Preconditions.checkNotNull;
4 import static org.springframework.http.HttpHeaders.AUTHORIZATION;
5
6 import java.io.IOException;
7 import java.lang.reflect.Method;
8 import java.lang.reflect.Modifier;
9
10 import javax.annotation.Nullable;
11 import javax.servlet.ServletException;
12 import javax.servlet.http.Cookie;
13 import javax.servlet.http.HttpServletRequest;
14 import javax.servlet.http.HttpServletResponse;
15
16 import org.apache.commons.lang3.NotImplementedException;
17 import org.apache.commons.lang3.tuple.Pair;
18 import org.apache.commons.logging.Log;
19 import org.springframework.http.HttpHeaders;
20 import org.springframework.http.ResponseEntity;
21 import org.springframework.web.bind.annotation.RequestMapping;
22
23 import net.avcompris.commons3.api.User;
24 import net.avcompris.commons3.api.exception.ServiceException;
25 import net.avcompris.commons3.api.exception.UnauthenticatedException;
26 import net.avcompris.commons3.client.SessionPropagator;
27 import net.avcompris.commons3.core.AuthService;
28 import net.avcompris.commons3.core.CorrelationService;
29 import net.avcompris.commons3.utils.Clock;
30 import net.avcompris.commons3.utils.LogFactory;
31
32 public abstract class AbstractController {
33
34 public static final String CORRELATION_ID_ATTRIBUTE_NAME = "Correlation-ID";
35 public static final String USER_SESSION_ID_ATTRIBUTE_NAME = "user_session_id";
36
37 private static final Log logger = LogFactory.getLog(AbstractController.class);
38
39 protected final CorrelationService correlationService;
40 protected final Clock clock;
41
42 @Nullable
43 private final SessionPropagator sessionPropagator;
44
45 protected AbstractController(final CorrelationService correlationService, final SessionPropagator sessionPropagator,
46 final Clock clock) {
47
48 this.correlationService = checkNotNull(correlationService, "correlationService");
49 this.sessionPropagator = checkNotNull(sessionPropagator, "sessionPropagator");
50 this.clock = checkNotNull(clock, "clock");
51 }
52
53 @Nullable
54 private static String getCookie(final HttpServletRequest request, final String cookieName) {
55
56 checkNotNull(request, "request");
57 checkNotNull(cookieName, "cookieName");
58
59 final Cookie[] cookies = request.getCookies();
60
61 if (cookies == null) {
62 return null;
63 }
64
65 for (final Cookie cookie : cookies) {
66
67 if (cookieName.contentEquals(cookie.getName())) {
68
69 return cookie.getValue();
70 }
71 }
72
73 return null;
74 }
75
76 @Nullable
77 private final String getAuthorization(final HttpServletRequest request) {
78
79 final String authorizationHeader = request.getHeader(AUTHORIZATION);
80
81 if (authorizationHeader != null) {
82
83 return authorizationHeader;
84 }
85
86 final String authorizationCookie = getCookie(request, AUTHORIZATION);
87
88 if (authorizationCookie == null) {
89
90 return null;
91 }
92
93 final String userSessionId = getUserSessionId(request);
94
95 return userSessionId == null
96
97 ? authorizationCookie
98
99 : null;
100 }
101
102 @Nullable
103 protected final String getUserSessionId(final HttpServletRequest request) {
104
105 final String userSessionIdParam = request.getParameter(USER_SESSION_ID_ATTRIBUTE_NAME);
106 final String userSessionIdHeader = request.getHeader(USER_SESSION_ID_ATTRIBUTE_NAME);
107 final String userSessionIdCookie = getCookie(request, USER_SESSION_ID_ATTRIBUTE_NAME);
108
109 if (userSessionIdParam != null) {
110
111 return userSessionIdParam;
112
113 } else if (userSessionIdHeader != null) {
114
115 return userSessionIdHeader;
116
117 } else if (userSessionIdCookie != null) {
118
119 return userSessionIdCookie;
120
121 } else {
122
123 return null;
124 }
125 }
126
127 protected final <T> ResponseEntity<T> wrapAuthenticated(final HttpServletRequest request,
128 final HttpServletResponse response, final AuthService authService, final AuthenticatedAction<T> action)
129 throws ServiceException {
130
131 checkNotNull(request, "request");
132 checkNotNull(action, "action");
133
134
135
136 final long startMs = System.currentTimeMillis();
137
138 final String authorization = getAuthorization(request);
139 final String userSessionId = getUserSessionId(request);
140 final String correlationId = getCorrelationId(request);
141
142 sessionPropagator.setAuthorizationHeader(authorization);
143 sessionPropagator.setUserSessionId(userSessionId);
144
145 @Nullable
146 final User user = authService.getAuthenticatedUser(authorization, userSessionId);
147
148 if (user == null) {
149
150 throw new UnauthenticatedException();
151 }
152
153 if (userSessionId != null) {
154
155
156
157 request.setAttribute(USER_SESSION_ID_ATTRIBUTE_NAME, userSessionId);
158 }
159
160 final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
161
162 final String methodName = endpoint.getRight().getName();
163
164 final Log logger = LogFactory.getLog(endpoint.getLeft());
165
166 if (logger.isInfoEnabled()) {
167 logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
168 }
169
170 final ResponseEntity<T> result;
171
172 authService.setLastActiveAt(correlationId, user);
173
174 try {
175
176 result = action.action(correlationId, user);
177
178 } catch (final ServiceException e) {
179
180 final int httpErrorCode = e.getHttpErrorCode();
181
182 final long elapsedMs = System.currentTimeMillis() - startMs;
183
184 logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
185
186 throw e;
187 }
188
189 final long elapsedMs = System.currentTimeMillis() - startMs;
190
191 if (logger.isInfoEnabled()) {
192 logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
193 }
194
195 return enrich(response, userSessionId, result, correlationId);
196 }
197
198 protected final <T> ResponseEntity<T> wrapAuthenticatedServletAction(final HttpServletRequest request,
199 final HttpServletResponse response, final AuthService authService,
200 final AuthenticatedServletAction<T> action) throws ServiceException, IOException, ServletException {
201
202 checkNotNull(request, "request");
203 checkNotNull(action, "action");
204
205
206
207 final long startMs = System.currentTimeMillis();
208
209 final String authorization = getAuthorization(request);
210 final String userSessionId = getUserSessionId(request);
211 final String correlationId = getCorrelationId(request);
212
213 sessionPropagator.setAuthorizationHeader(authorization);
214 sessionPropagator.setUserSessionId(userSessionId);
215
216 @Nullable
217 final User user = authService.getAuthenticatedUser(authorization, userSessionId);
218
219 if (user == null) {
220
221 throw new UnauthenticatedException();
222 }
223
224 if (userSessionId != null) {
225
226
227
228 request.setAttribute(USER_SESSION_ID_ATTRIBUTE_NAME, userSessionId);
229 }
230
231 final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
232
233 final String methodName = endpoint.getRight().getName();
234
235 final Log logger = LogFactory.getLog(endpoint.getLeft());
236
237 if (logger.isInfoEnabled()) {
238 logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
239 }
240
241 authService.setLastActiveAt(correlationId, user);
242
243 final ResponseEntity<T> result;
244
245 try {
246
247 result = action.action(correlationId, user);
248
249 } catch (final ServiceException e) {
250
251 final int httpErrorCode = e.getHttpErrorCode();
252
253 final long elapsedMs = System.currentTimeMillis() - startMs;
254
255 logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
256
257 throw e;
258
259 } catch (final IOException | ServletException e) {
260
261 final long elapsedMs = System.currentTimeMillis() - startMs;
262
263 logger.error(methodName + "() elapsedMs: " + elapsedMs, e);
264
265 throw e;
266 }
267
268 final long elapsedMs = System.currentTimeMillis() - startMs;
269
270 if (logger.isInfoEnabled()) {
271 logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
272 }
273
274 return enrich(response, userSessionId, result, correlationId);
275 }
276
277 protected final <T> ResponseEntity<T> wrapAuthenticatedOrNot(final HttpServletRequest request,
278 final HttpServletResponse response, final AuthService authService, final AuthenticatedAction<T> action)
279 throws ServiceException {
280
281 checkNotNull(request, "request");
282 checkNotNull(action, "action");
283
284 final long startMs = System.currentTimeMillis();
285
286 final String authorization = getAuthorization(request);
287 final String userSessionId = getUserSessionId(request);
288 final String correlationId = getCorrelationId(request);
289
290 sessionPropagator.setAuthorizationHeader(authorization);
291 sessionPropagator.setUserSessionId(userSessionId);
292
293 @Nullable
294 final User user = authService.getAuthenticatedUser(authorization, userSessionId);
295
296 final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
297
298 final String methodName = endpoint.getRight().getName();
299
300 final Log logger = LogFactory.getLog(endpoint.getLeft());
301
302 if (logger.isInfoEnabled()) {
303 logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
304 }
305
306 final ResponseEntity<T> result;
307
308 if (user != null) {
309
310 authService.setLastActiveAt(correlationId, user);
311 }
312
313 try {
314
315 result = action.action(correlationId, user);
316
317 } catch (final ServiceException e) {
318
319 final int httpErrorCode = e.getHttpErrorCode();
320
321 final long elapsedMs = System.currentTimeMillis() - startMs;
322
323 logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
324
325 throw e;
326 }
327
328 final long elapsedMs = System.currentTimeMillis() - startMs;
329
330 if (logger.isInfoEnabled()) {
331 logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
332 }
333
334 return enrich(response, userSessionId, result, correlationId);
335 }
336
337 protected final <T> ResponseEntity<T> wrapWithoutCorrelationId(final HttpServletRequest request,
338 final HttpServletResponse response, final AuthService authService, final AuthenticatedAction<T> action)
339 throws ServiceException {
340
341 checkNotNull(request, "request");
342 checkNotNull(action, "action");
343
344 final long startMs = System.currentTimeMillis();
345
346 final String authorization = getAuthorization(request);
347 final String userSessionId = getUserSessionId(request);
348
349 String fakeCorrelationId = "N/A";
350
351 try {
352
353 fakeCorrelationId = getCorrelationId(request);
354
355 } catch (final Throwable e) {
356
357 e.printStackTrace(System.out);
358
359 fakeCorrelationId = "N/A";
360 }
361
362 LogFactory.resetCorrelationId();
363
364 LogFactory.setCorrelationId(fakeCorrelationId);
365
366 sessionPropagator.setAuthorizationHeader(authorization);
367 sessionPropagator.setUserSessionId(userSessionId);
368
369 final User user = authService.getAuthenticatedUser(authorization, userSessionId);
370
371 if (user == null) {
372
373 throw new UnauthenticatedException();
374 }
375
376 final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
377
378 final String methodName = endpoint.getRight().getName();
379
380 final Log logger = LogFactory.getLog(endpoint.getLeft());
381
382 if (logger.isInfoEnabled()) {
383 logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
384 }
385
386 final ResponseEntity<T> result;
387
388 try {
389
390 result = action.action(fakeCorrelationId, user);
391
392 } catch (final ServiceException e) {
393
394 final int httpErrorCode = e.getHttpErrorCode();
395
396 final long elapsedMs = System.currentTimeMillis() - startMs;
397
398 logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
399
400 throw e;
401 }
402
403 final long elapsedMs = System.currentTimeMillis() - startMs;
404
405 if (logger.isInfoEnabled()) {
406 logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
407 }
408
409 return enrich(response, userSessionId, result, fakeCorrelationId);
410 }
411
412 private String getCorrelationId(final HttpServletRequest request) throws ServiceException {
413
414 final String correlationIdParam = request.getParameter("correlationId");
415 final String correlationIdHeader = request.getHeader(CORRELATION_ID_ATTRIBUTE_NAME);
416
417 final long startMs = System.currentTimeMillis();
418
419 final String correlationId = correlationService.getCorrelationId(correlationIdParam, correlationIdHeader);
420
421 LogFactory.resetCorrelationId();
422
423 LogFactory.setCorrelationId(correlationId);
424
425 request.setAttribute(CORRELATION_ID_ATTRIBUTE_NAME, correlationId);
426
427 final long elapsedMs = System.currentTimeMillis() - startMs;
428
429 if (logger.isDebugEnabled()) {
430 logger.debug("getCorrelationId(), elapsedMs: " + elapsedMs);
431 }
432
433 return correlationId;
434 }
435
436 protected final <T> ResponseEntity<T> wrapNonAuthenticated(final HttpServletRequest request,
437 final UnauthenticatedAction<T> action) throws ServiceException {
438
439 checkNotNull(request, "request");
440 checkNotNull(action, "action");
441
442 final long startMs = System.currentTimeMillis();
443
444 final String correlationId = getCorrelationId(request);
445
446 final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
447
448 final String methodName = endpoint.getRight().getName();
449
450 final Log logger = LogFactory.getLog(endpoint.getLeft());
451
452 if (logger.isInfoEnabled()) {
453 logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
454 }
455
456 final ResponseEntity<T> result;
457
458 try {
459
460 result = action.action(correlationId);
461
462 } catch (final ServiceException e) {
463
464 final int httpErrorCode = e.getHttpErrorCode();
465
466 final long elapsedMs = System.currentTimeMillis() - startMs;
467
468 logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
469
470 throw e;
471 }
472
473 final long elapsedMs = System.currentTimeMillis() - startMs;
474
475 if (logger.isInfoEnabled()) {
476 logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
477 }
478
479 return enrich(null, null, result, correlationId);
480 }
481
482 protected final <T> ResponseEntity<T> wrapNonAuthenticatedWithoutCorrelationId(final HttpServletRequest request,
483 final UnauthenticatedAnonymousAction<T> action) throws ServiceException {
484
485 checkNotNull(request, "request");
486 checkNotNull(action, "action");
487
488 final long startMs = System.currentTimeMillis();
489
490 final Pair<Class<?>, Method> endpoint = extractControllerCurrentEndpoint();
491
492 final String methodName = endpoint.getRight().getName();
493
494 String fakeCorrelationId = "N/A";
495
496 try {
497
498 fakeCorrelationId = getCorrelationId(request);
499
500 } catch (final Throwable e) {
501
502 e.printStackTrace(System.out);
503
504 fakeCorrelationId = "N/A";
505 }
506
507 LogFactory.resetCorrelationId();
508
509 LogFactory.setCorrelationId(fakeCorrelationId);
510
511 final Log logger = LogFactory.getLog(endpoint.getLeft());
512
513 if (logger.isInfoEnabled()) {
514 logger.info(methodName + "() started... +ms: " + (System.currentTimeMillis() - startMs));
515 }
516
517 final ResponseEntity<T> result;
518
519 try {
520
521 result = action.action();
522
523 } catch (final ServiceException e) {
524
525 final int httpErrorCode = e.getHttpErrorCode();
526
527 final long elapsedMs = System.currentTimeMillis() - startMs;
528
529 logger.error(methodName + "() ERROR. " + httpErrorCode + ". elapsedMs: " + elapsedMs, e);
530
531 throw e;
532 }
533
534 final long elapsedMs = System.currentTimeMillis() - startMs;
535
536 if (logger.isInfoEnabled()) {
537 logger.info(methodName + "() ended. " + result.getStatusCode() + ". elapsedMs: " + elapsedMs);
538 }
539
540
541
542 return result;
543 }
544
545 private <T> ResponseEntity<T> enrich(@Nullable final HttpServletResponse httpServletResponse,
546 @Nullable final String userSessionId,
547 final ResponseEntity<T> response, final String correlationId) {
548
549 checkNotNull(response, "response");
550 checkNotNull(correlationId, "correlationId");
551
552 final HttpHeaders headers = new HttpHeaders();
553
554 headers.putAll(response.getHeaders());
555
556 headers.add(CORRELATION_ID_ATTRIBUTE_NAME, correlationId);
557
558 if (httpServletResponse != null && userSessionId != null) {
559
560 setUserSessionCookie(httpServletResponse, userSessionId);
561 }
562
563 return ResponseEntity.status(response.getStatusCode())
564 .headers(headers)
565 .body(response.getBody());
566 }
567
568 private static Pair<Class<?>, Method> extractControllerCurrentEndpoint() {
569
570 for (final StackTraceElement ste : Thread.currentThread().getStackTrace()) {
571
572 final String className = ste.getClassName();
573 final String methodName = ste.getMethodName();
574
575 if (!className.endsWith("Controller")) {
576 continue;
577 }
578
579 final Class<?> controllerClass;
580
581 try {
582
583 controllerClass = Class.forName(className);
584
585 } catch (final ClassNotFoundException e) {
586
587 continue;
588 }
589
590 if (Modifier.isAbstract(controllerClass.getModifiers())) {
591 continue;
592 }
593
594 final Method method = extractDeclaredMethod(controllerClass, methodName);
595
596 if (method.getAnnotation(RequestMapping.class) == null || !Modifier.isPublic(method.getModifiers())) {
597 continue;
598 }
599
600 return Pair.of(controllerClass, method);
601 }
602
603 throw new IllegalStateException("Cannot extract controller current endpoint");
604 }
605
606 private static Method extractDeclaredMethod(final Class<?> controllerClass, final String methodName) {
607
608 for (final Method method : controllerClass.getDeclaredMethods()) {
609
610 if (methodName.contentEquals(method.getName()) || Modifier.isPublic(method.getModifiers())) {
611
612 return method;
613 }
614 }
615
616 throw new IllegalStateException(
617 "Cannot extract method: " + methodName + " from controllerClass: " + controllerClass.getName());
618 }
619
620 @FunctionalInterface
621 protected interface UnauthenticatedAnonymousAction<T> {
622
623 ResponseEntity<T> action() throws ServiceException;
624 }
625
626 @FunctionalInterface
627 protected interface UnauthenticatedAction<T> {
628
629 ResponseEntity<T> action(String correlationId) throws ServiceException;
630 }
631
632 @FunctionalInterface
633 protected interface AuthenticatedAction<T> {
634
635 ResponseEntity<T> action(String correlationId, User user) throws ServiceException;
636 }
637
638 @FunctionalInterface
639 protected interface AuthenticatedServletAction<T> {
640
641 ResponseEntity<T> action(String correlationId, User user)
642 throws ServiceException, ServletException, IOException;
643 }
644
645 protected final <T extends ResponseEntity<?>> T handleServiceException(final ServiceException e) {
646
647 checkNotNull(e, "e");
648
649 throw new NotImplementedException("");
650 }
651
652 protected static HttpHeaders headers(final String... headerNameValuePairs) {
653
654 final HttpHeaders headers = new HttpHeaders();
655
656 for (int i = 0; i < headerNameValuePairs.length / 2; ++i) {
657
658 final String headerName = headerNameValuePairs[i * 2];
659 final String headerValue = headerNameValuePairs[i * 2 + 1];
660
661 headers.set(headerName, headerValue);
662 }
663
664 return headers;
665 }
666
667 protected abstract boolean isSecure();
668
669 protected abstract boolean isHttpOnly();
670
671 protected final void setUserSessionCookie(final HttpServletResponse response, final String userSessionId) {
672
673 checkNotNull(response, "response");
674 checkNotNull(userSessionId, "userSessionId");
675
676 final Cookie cookie = new Cookie(USER_SESSION_ID_ATTRIBUTE_NAME, userSessionId);
677
678
679 cookie.setSecure(isSecure());
680 cookie.setHttpOnly(isHttpOnly());
681 cookie.setMaxAge(3600);
682 cookie.setPath("/");
683
684 response.addCookie(cookie);
685 }
686 }