1 package net.avcompris.commons3.core.tests;
2
3 import static com.google.common.base.Preconditions.checkNotNull;
4 import static com.google.common.collect.Sets.newHashSet;
5 import static net.avcompris.commons3.databeans.DataBeans.validate;
6 import static org.apache.commons.lang3.RandomStringUtils.randomAlphanumeric;
7 import static org.apache.commons.lang3.StringUtils.join;
8 import static org.apache.commons.lang3.StringUtils.uncapitalize;
9 import static org.mockito.Mockito.mock;
10 import static org.mockito.Mockito.when;
11
12 import java.lang.annotation.Annotation;
13 import java.lang.reflect.Array;
14 import java.lang.reflect.Constructor;
15 import java.lang.reflect.InvocationTargetException;
16 import java.lang.reflect.Method;
17 import java.util.ArrayList;
18 import java.util.List;
19 import java.util.Set;
20
21 import javax.annotation.Nullable;
22 import javax.servlet.http.HttpServletRequest;
23 import javax.servlet.http.HttpServletResponse;
24
25 import org.apache.commons.lang3.ClassUtils;
26 import org.apache.commons.lang3.NotImplementedException;
27 import org.springframework.context.ApplicationContext;
28 import org.springframework.http.ResponseEntity;
29 import org.springframework.web.bind.annotation.PathVariable;
30 import org.springframework.web.bind.annotation.RequestBody;
31 import org.springframework.web.bind.annotation.RequestMapping;
32 import org.springframework.web.bind.annotation.RequestParam;
33
34 import com.google.common.collect.ClassToInstanceMap;
35 import com.google.common.collect.MutableClassToInstanceMap;
36
37 import net.avcompris.commons3.api.User;
38 import net.avcompris.commons3.api.exception.ServiceException;
39 import net.avcompris.commons3.api.tests.AbstractApiTest;
40 import net.avcompris.commons3.api.tests.ControllerContext;
41 import net.avcompris.commons3.api.tests.ControllerContextUtils;
42 import net.avcompris.commons3.api.tests.TestsSpec.Data;
43 import net.avcompris.commons3.api.tests.TestsSpec.HttpMethod;
44 import net.avcompris.commons3.api.tests.TestsSpec.TestSpec;
45 import net.avcompris.commons3.client.SessionPropagator;
46 import net.avcompris.commons3.core.CorrelationService;
47 import net.avcompris.commons3.utils.Clock;
48 import net.avcompris.commons3.utils.ClockImpl;
49 import net.avcompris.commons3.web.AbstractController;
50
51 public abstract class AbstractServiceApiTest<T> extends AbstractApiTest {
52
53 private final ClassToInstanceMap<Object> implementations = MutableClassToInstanceMap.create();
54
55 private final List<Class<? extends AbstractController>> controllerClasses = new ArrayList<>();
56
57 protected AbstractServiceApiTest(final TestSpec spec,
58 @Nullable final String superadminAuthorization,
59 @SuppressWarnings("unchecked") final Class<? extends AbstractController>... controllerClasses) {
60
61 super(spec, superadminAuthorization);
62
63 for (final Class<? extends AbstractController> controllerClass : controllerClasses) {
64
65 this.controllerClasses.add(controllerClass);
66 }
67 }
68
69 protected abstract T getBeans(Clock clock) throws Exception;
70
71 protected final <U> void inject(final Class<U> serviceClass, final U serviceImpl) {
72
73 checkNotNull(serviceClass, "serviceClass");
74 checkNotNull(serviceImpl, "serviceImpl");
75
76 implementations.put(serviceClass, serviceImpl);
77 }
78
79 @Override
80 protected final StepExecutionResult execute(final int stepIndex, final StepExecution step) throws Exception {
81
82 @Nullable
83 final ControllerContext context = extractControllerContext(step.getHttpMethod(), step.getPath());
84
85 int statusCode = -1;
86
87 Object result = null;
88
89 if (context == null) {
90
91 statusCode = 404;
92
93 result = null;
94
95 } else {
96
97 final Class<? extends AbstractController> controllerClass = context.controllerClass;
98 final Method controllerMethod = context.controllerMethod;
99
100 final Constructor<?> constructor = extractConstructor(controllerClass);
101
102 final Class<?>[] constructorParameterTypes = constructor.getParameterTypes();
103
104 final Object[] constructorArgs = new Object[constructorParameterTypes.length];
105
106 for (int i = 0; i < constructorArgs.length; ++i) {
107
108 constructorArgs[i] = instantiate(context, step, constructorParameterTypes[i]);
109 }
110
111 final Object controller = constructor.newInstance(constructorArgs);
112
113 final Class<?>[] parameterTypes = controllerMethod.getParameterTypes();
114
115 final Object[] args = new Object[parameterTypes.length];
116
117 try {
118
119 for (int i = 0; i < args.length; ++i) {
120
121 args[i] = instantiate(context, step, parameterTypes[i],
122 controllerMethod.getParameterAnnotations()[i]);
123 }
124
125 } catch (final MalformedDataBeanException e) {
126
127 e.printStackTrace();
128
129 statusCode = e.getHttpErrorCode();
130
131 result = null;
132 }
133
134 try {
135
136 if (statusCode == -1) {
137
138 result = controllerMethod.invoke(controller, args);
139
140 statusCode = 200;
141 }
142
143 } catch (final InvocationTargetException e) {
144
145 if (e.getTargetException() instanceof ServiceException) {
146
147 final ServiceException serviceException = (ServiceException) e.getTargetException();
148
149 statusCode = serviceException.getHttpErrorCode();
150
151 } else {
152
153 throw e;
154 }
155 }
156
157 if (result != null && result instanceof ResponseEntity) {
158
159 final ResponseEntity<?> responseEntity = (ResponseEntity<?>) result;
160
161 statusCode = responseEntity.getStatusCodeValue();
162
163 result = responseEntity.getBody();
164 }
165 }
166
167 if (result != null) {
168
169 try {
170
171 validate(result);
172
173 } catch (final IllegalStateException e) {
174
175 throw e;
176
177 } catch (final IllegalArgumentException e) {
178
179
180 }
181 }
182
183 return new StepExecutionResult(statusCode, result);
184 }
185
186 @Nullable
187 private ControllerContext extractControllerContext(final HttpMethod httpMethod, final String path) {
188
189
190
191 for (final Class<? extends AbstractController> controllerClass : controllerClasses) {
192
193 for (final Method method : controllerClass.getMethods()) {
194
195 final RequestMapping requestMapping = method.getAnnotation(RequestMapping.class);
196
197 if (requestMapping == null) {
198 continue;
199 } else if (!httpMethod.name().contentEquals(requestMapping.method()[0].name())) {
200 continue;
201 }
202
203 final ControllerContext context = ControllerContextUtils
204 .extractControllerContext(httpMethod, path, controllerClass);
205
206 if (context != null) {
207
208 return context;
209 }
210 }
211 }
212
213 return null;
214 }
215
216 private static Constructor<?> extractConstructor(final Class<? extends AbstractController> controllerClass) {
217
218 for (final Constructor<?> constructor : controllerClass.getConstructors()) {
219
220 return constructor;
221 }
222
223 throw new IllegalStateException();
224 }
225
226 @Nullable
227 private static <T> T extractAnnotation(final Annotation[] annotations, final Class<T> annotationClass) {
228
229 for (final Annotation annotation : annotations) {
230
231 if (annotationClass.isInstance(annotation)) {
232
233 return annotationClass.cast(annotation);
234 }
235 }
236
237 return null;
238 }
239
240 private Object instantiate(final ControllerContext context, final StepExecution step, final Class<?> type,
241 final Annotation... annotations) throws MalformedDataBeanException {
242
243 checkNotNull(context, "context");
244 checkNotNull(type, "type");
245
246 @Nullable
247 final PathVariable pathVariable = extractAnnotation(annotations, PathVariable.class);
248
249 @Nullable
250 final RequestParam requestParam = extractAnnotation(annotations, RequestParam.class);
251
252 @Nullable
253 final RequestBody requestBody = extractAnnotation(annotations, RequestBody.class);
254
255 if (implementations.containsKey(type)) {
256
257 return implementations.get(type);
258
259 } else if (CorrelationService.class.equals(type)) {
260
261 return new CorrelationService() {
262
263 @Override
264 public String getCorrelationId(@Nullable final String correlationIdParam,
265 @Nullable final String correlationIdHeader) throws ServiceException {
266
267 return randomAlphanumeric(20);
268 }
269
270 @Override
271 public void purgeOlderThanSec(final String correlationId, final User user, final int seconds)
272 throws ServiceException {
273
274 throw new NotImplementedException("");
275 }
276 };
277
278 } else if (SessionPropagator.class.equals(type)) {
279
280 return new SessionPropagator();
281
282 } else if (Clock.class.equals(type)) {
283
284 return new ClockImpl();
285
286 } else if (String.class.equals(type)) {
287
288 if (pathVariable != null) {
289
290 return context.getVariable(pathVariable.name());
291
292 } else if (requestParam != null) {
293
294 if (requestParam.required() || context.hasVariable(requestParam.name())) {
295
296 return context.getVariable(requestParam.name());
297 }
298 }
299
300 return null;
301
302 } else if (Integer.class.equals(type)) {
303
304 return null;
305
306 } else if (HttpServletRequest.class.equals(type)) {
307
308 final HttpServletRequest request = mock(HttpServletRequest.class);
309
310 if (spec.getAuthentication() != null) {
311
312 when(request.getHeader("Authorization")).thenReturn(superadminAuthorization);
313 }
314
315 return request;
316
317 } else if (HttpServletResponse.class.equals(type)) {
318
319 final HttpServletResponse response = mock(HttpServletResponse.class);
320
321 return response;
322
323 } else if (requestBody != null) {
324
325 final Object dataBean = mock(type);
326
327 final Set<String> nonNullableGetterNames = newHashSet();
328
329 for (final Method method : type.getMethods()) {
330
331 if (method.getParameterCount() != 0
332 || method.getReturnType() == null
333 || method.getAnnotation(Nullable.class) != null) {
334 continue;
335 }
336
337 final String methodName = method.getName();
338
339 if ("getClass".contentEquals(methodName)
340 || "hashCode".contentEquals(methodName)
341 || "toString".contentEquals(methodName)) {
342 continue;
343 }
344
345 nonNullableGetterNames.add(methodName);
346 }
347
348 for (final Data data : step.getData()) {
349
350 final Method getter = extractGetter(type, data.getName());
351
352 nonNullableGetterNames.remove(getter.getName());
353
354 final Class<?> propertyType = getter.getReturnType();
355
356 final Class<?> wrapperType = propertyType.isPrimitive()
357 ? ClassUtils.primitiveToWrapper(propertyType)
358 : propertyType;
359
360 when(invoke(getter, dataBean)).thenReturn(cast(wrapperType, data.getValue()));
361 }
362
363 for (final String getterName : newHashSet(nonNullableGetterNames)) {
364
365 final Method getter = extractGetter(type, extractPropertyName(getterName));
366
367 final Class<?> propertyType = getter.getReturnType();
368
369 if (propertyType.isArray()) {
370
371 nonNullableGetterNames.remove(getter.getName());
372
373 when(invoke(getter, dataBean)).thenReturn(Array.newInstance(propertyType.getComponentType(), 0));
374 }
375 }
376
377 if (!nonNullableGetterNames.isEmpty()) {
378 throw new MalformedDataBeanException(
379 "There are non-nullable getters that were not invoked: " + join(nonNullableGetterNames, ", "));
380 }
381
382 return dataBean;
383
384 } else if (ApplicationContext.class.equals(type)) {
385
386 return mock(type);
387
388 } else if (type.isInterface()) {
389
390 return mock(type);
391
392 } else {
393
394 throw new NotImplementedException("type: " + type.getName());
395 }
396 }
397
398 private static String extractPropertyName(final String methodName) {
399
400 checkNotNull(methodName, "methodName");
401
402 if (methodName.startsWith("get")) {
403
404 return uncapitalize(methodName.substring(3));
405
406 } else if (methodName.startsWith("is")) {
407
408 return uncapitalize(methodName.substring(2));
409
410 } else {
411
412 throw new NotImplementedException("methodName: " + methodName);
413 }
414
415 }
416
417 private static <T> T cast(final Class<T> propertyType, final Object propertyValue) {
418
419 if (String.class.equals(propertyType)) {
420
421 return propertyType.cast(propertyValue);
422
423 } else if (Integer.class.equals(propertyType)) {
424
425 return propertyType.cast(Integer.valueOf(propertyValue.toString()));
426
427 } else if (Long.class.equals(propertyType)) {
428
429 return propertyType.cast(Long.valueOf(propertyValue.toString()));
430
431 } else if (Boolean.class.equals(propertyType)) {
432
433 return propertyType.cast(Boolean.valueOf(propertyValue.toString()));
434
435 } else if (propertyType.isEnum()) {
436
437 for (final T enumConstant : propertyType.getEnumConstants()) {
438
439 if (((Enum<?>) enumConstant).name().contentEquals(propertyValue.toString())) {
440
441 return enumConstant;
442 }
443 }
444
445 throw new IllegalArgumentException(
446 "propertyValue: " + propertyValue + ", for enum class: " + propertyType.getName());
447
448 } else if (String[].class.equals(propertyType)) {
449
450 return propertyType.cast(propertyValue);
451
452 } else {
453
454 throw new NotImplementedException("propertyType: " + propertyType);
455 }
456 }
457 }