1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.forgerock.http.servlet;
18
19 import static java.util.Collections.list;
20 import static org.forgerock.http.handler.Handlers.asDescribableHandler;
21 import static org.forgerock.http.handler.Handlers.chainOf;
22 import static org.forgerock.http.handler.Handlers.internalServerErrorHandler;
23 import static org.forgerock.http.io.IO.newBranchingInputStream;
24 import static org.forgerock.http.io.IO.newTemporaryStorage;
25 import static org.forgerock.http.protocol.Responses.newInternalServerError;
26 import static org.forgerock.http.routing.UriRouterContext.uriRouterContext;
27 import static org.forgerock.util.Utils.closeSilently;
28
29 import java.io.File;
30 import java.io.IOException;
31 import java.net.URISyntaxException;
32 import java.security.cert.X509Certificate;
33 import java.util.Arrays;
34 import java.util.Enumeration;
35 import java.util.ServiceLoader;
36
37 import jakarta.servlet.ServletConfig;
38 import jakarta.servlet.ServletContext;
39 import jakarta.servlet.ServletException;
40 import jakarta.servlet.http.HttpServlet;
41 import jakarta.servlet.http.HttpServletRequest;
42 import jakarta.servlet.http.HttpServletResponse;
43
44 import org.forgerock.http.ApiProducer;
45 import org.forgerock.http.DescribedHttpApplication;
46 import org.forgerock.http.HttpApplication;
47 import org.forgerock.http.HttpApplicationException;
48 import org.forgerock.http.filter.TransactionIdInboundFilter;
49 import org.forgerock.http.handler.DescribableHandler;
50 import org.forgerock.http.io.Buffer;
51 import org.forgerock.http.protocol.Request;
52 import org.forgerock.http.protocol.Response;
53 import org.forgerock.http.protocol.Status;
54 import org.forgerock.http.routing.UriRouterContext;
55 import org.forgerock.http.session.Session;
56 import org.forgerock.http.session.SessionContext;
57 import org.forgerock.http.util.CaseInsensitiveSet;
58 import org.forgerock.http.util.Uris;
59 import org.forgerock.services.context.AttributesContext;
60 import org.forgerock.services.context.ClientContext;
61 import org.forgerock.services.context.Context;
62 import org.forgerock.services.context.RequestAuditContext;
63 import org.forgerock.services.context.RootContext;
64 import org.forgerock.util.Factory;
65 import org.forgerock.util.promise.NeverThrowsException;
66 import org.forgerock.util.promise.Promise;
67 import org.forgerock.util.promise.ResultHandler;
68 import org.forgerock.util.promise.RuntimeExceptionHandler;
69 import org.slf4j.Logger;
70 import org.slf4j.LoggerFactory;
71
72 import io.swagger.models.Swagger;
73
74
75
76
77
78
79
80
81
82
83
84
85
86 public final class HttpFrameworkServlet extends HttpServlet {
87
88 private static final Logger logger = LoggerFactory.getLogger(HttpFrameworkServlet.class);
89 private static final long serialVersionUID = 3524182656424860912L;
90
91
92 private static final String SERVLET_REQUEST_X509_ATTRIBUTE = "jakarta.servlet.request.X509Certificate";
93
94
95 private static final CaseInsensitiveSet NON_ENTITY_METHODS = new CaseInsensitiveSet(
96 Arrays.asList("GET", "HEAD", "TRACE"));
97
98
99
100
101
102
103
104 public static final String ROUTING_BASE_INIT_PARAM_NAME = "routing-base";
105
106 private ServletVersionAdapter adapter;
107 private HttpApplication application;
108 private Factory<Buffer> storage;
109 private DescribableHandler handler;
110 private ServletRoutingBase routingBase;
111
112
113
114
115 public HttpFrameworkServlet() {
116 }
117
118
119
120
121
122
123
124 public HttpFrameworkServlet(HttpApplication application) {
125 this.application = application;
126 }
127
128 @Override
129 @SuppressWarnings("unchecked")
130 public void init() throws ServletException {
131 adapter = getAdapter(getServletContext());
132 routingBase = selectRoutingBase(getServletConfig());
133 if (application == null) {
134 HttpApplicationLoader applicationLoader = getApplicationLoader(getServletConfig());
135 application = getApplication(applicationLoader, getServletConfig());
136 }
137 storage = application.getBufferFactory();
138 if (storage == null) {
139 final File tmpDir = (File) getServletContext().getAttribute(ServletContext.TEMPDIR);
140 storage = newTemporaryStorage(tmpDir);
141 }
142 try {
143 this.handler = chainOf(asDescribableHandler(application.start()), new TransactionIdInboundFilter());
144 if (application instanceof DescribedHttpApplication) {
145 ApiProducer<Swagger> apiProducer = ((DescribedHttpApplication) application).getApiProducer();
146 this.handler.api(apiProducer);
147 }
148 } catch (HttpApplicationException e) {
149 logger.error("Error while starting the application.", e);
150 handler = asDescribableHandler(internalServerErrorHandler(e));
151 }
152 }
153
154 private ServletVersionAdapter getAdapter(ServletContext servletContext) throws ServletException {
155 switch (servletContext.getMajorVersion()) {
156 case 1:
157
158 throw new ServletException("Unsupported Servlet version "
159 + servletContext.getMajorVersion());
160 case 2:
161 return new Servlet2Adapter();
162 default:
163 return new Servlet3Adapter();
164 }
165 }
166
167 private ServletRoutingBase selectRoutingBase(ServletConfig servletConfig) throws ServletException {
168 String routingModeParam = servletConfig.getInitParameter(ROUTING_BASE_INIT_PARAM_NAME);
169 if (routingModeParam == null) {
170 return ServletRoutingBase.SERVLET_PATH;
171 }
172 try {
173 return ServletRoutingBase.valueOf(routingModeParam.toUpperCase());
174 } catch (IllegalArgumentException e) {
175 throw new ServletException("Invalid routing mode: " + routingModeParam);
176 }
177 }
178
179 private HttpApplicationLoader getApplicationLoader(ServletConfig config) throws ServletException {
180 String applicationLoaderParam = config.getInitParameter("application-loader");
181 if (applicationLoaderParam == null) {
182 return HttpApplicationLoader.SERVICE_LOADER;
183 }
184 try {
185 return HttpApplicationLoader.valueOf(applicationLoaderParam.toUpperCase());
186 } catch (IllegalArgumentException e) {
187 throw new ServletException("Invalid HTTP application loader: " + applicationLoaderParam);
188 }
189 }
190
191 private HttpApplication getApplication(HttpApplicationLoader applicationLoader, ServletConfig config)
192 throws ServletException {
193 return applicationLoader.load(config);
194 }
195
196 @Override
197 protected void service(final HttpServletRequest req, final HttpServletResponse resp)
198 throws ServletException, IOException {
199 final Session session = new ServletSession(req);
200 final SessionContext sessionContext = new SessionContext(new RootContext(), session);
201
202 final Request request;
203 final UriRouterContext uriRouterContext;
204 try {
205 request = createRequest(req);
206 uriRouterContext = createRouterContext(sessionContext, req, request);
207 } catch (URISyntaxException e) {
208 Response response = new Response(Status.BAD_REQUEST);
209 response.setEntity(e.getMessage());
210 writeResponse(response, resp, sessionContext);
211 return;
212 }
213
214 final AttributesContext attributesContext = new AttributesContext(new RequestAuditContext(uriRouterContext));
215
216
217
218
219
220 Enumeration<String> attributeNames = req.getAttributeNames();
221 while (attributeNames.hasMoreElements()) {
222 String attributeName = attributeNames.nextElement();
223 attributesContext.getAttributes().put(attributeName, req.getAttribute(attributeName));
224 }
225
226
227
228 attributesContext.getAttributes().put(HttpServletRequest.class.getName(), req);
229 attributesContext.getAttributes().put(HttpServletResponse.class.getName(), resp);
230
231 Context context = createClientContext(attributesContext, req);
232
233
234 final ServletSynchronizer sync = adapter.createServletSynchronizer(req, resp);
235 try {
236 final Promise<Response, NeverThrowsException> promise =
237 handler.handle(context, request)
238 .thenOnResult(new ResultHandler<Response>() {
239 @Override
240 public void handleResult(Response response) {
241 writeResponse(request, response, resp, sessionContext, sync);
242 }
243 })
244 .thenOnRuntimeException(new RuntimeExceptionHandler() {
245 @Override
246 public void handleRuntimeException(RuntimeException e) {
247 logger.error("RuntimeException caught", e);
248 writeResponse(request, newInternalServerError(), resp, sessionContext, sync);
249 }
250 });
251
252 sync.setAsyncListener(new Runnable() {
253 @Override
254 public void run() {
255 promise.cancel(true);
256 }
257 });
258 } catch (Throwable throwable) {
259
260
261
262
263 logger.error("Throwable caught", throwable);
264 writeResponse(request, newInternalServerError(), resp, sessionContext, sync);
265 }
266
267 try {
268 sync.awaitIfNeeded();
269 } catch (InterruptedException e) {
270 throw new ServletException("Awaiting asynchronous request was interrupted.", e);
271 }
272 }
273
274 private Request createRequest(HttpServletRequest req) throws IOException, URISyntaxException {
275
276 Request request = new Request();
277 request.setMethod(req.getMethod());
278
279
280
281 request.setUri(Uris.createNonStrict(req.getScheme(),
282 null,
283 req.getServerName(),
284 req.getServerPort(),
285 req.getRequestURI(),
286 req.getQueryString(),
287 null));
288
289
290 for (Enumeration<String> e = req.getHeaderNames(); e.hasMoreElements();) {
291 String name = e.nextElement();
292 request.getHeaders().add(name, list(req.getHeaders(name)));
293 }
294
295
296 if ((req.getContentLength() > 0 || req.getHeader("Transfer-Encoding") != null)
297 && !NON_ENTITY_METHODS.contains(request.getMethod())) {
298 request.setEntity(newBranchingInputStream(req.getInputStream(), storage));
299 }
300
301 return request;
302 }
303
304 private ClientContext createClientContext(Context parent, HttpServletRequest req) {
305 return ClientContext.buildExternalClientContext(parent)
306 .remoteUser(req.getRemoteUser())
307 .remoteAddress(req.getRemoteAddr())
308 .remotePort(req.getRemotePort())
309 .certificates((X509Certificate[]) req.getAttribute(SERVLET_REQUEST_X509_ATTRIBUTE))
310 .userAgent(req.getHeader("User-Agent"))
311 .secure("https".equalsIgnoreCase(req.getScheme()))
312 .localAddress(req.getLocalAddr())
313 .localPort(req.getLocalPort())
314 .build();
315 }
316
317 private UriRouterContext createRouterContext(Context parent, HttpServletRequest req, final Request request)
318 throws URISyntaxException {
319 String matchedUri = routingBase.extractMatchedUri(req);
320 final String requestURI = req.getRequestURI();
321 String remaining = requestURI.substring(requestURI.indexOf(matchedUri) + matchedUri.length());
322 return uriRouterContext(parent).matchedUri(matchedUri).remainingUri(remaining)
323 .originalUri(request.getUri().asURI()).build();
324 }
325
326 private void writeResponse(Request request, Response response, HttpServletResponse servletResponse,
327 SessionContext sessionContext, ServletSynchronizer synchronizer) {
328 try {
329 writeResponse(response, servletResponse, sessionContext);
330 } finally {
331 closeSilently(request);
332 synchronizer.signalAndComplete();
333 }
334 }
335
336 private void writeResponse(final Response response, final HttpServletResponse servletResponse,
337 final SessionContext sessionContext) {
338 try {
339
340
341
342
343
344 if (response != null) {
345
346 servletResponse.setStatus(response.getStatus().getCode());
347
348
349 sessionContext.getSession().save(response);
350
351
352 for (String name : response.getHeaders().keySet()) {
353 for (String value : response.getHeaders().get(name).getValues()) {
354 if (value != null && value.length() > 0) {
355 servletResponse.addHeader(name, value);
356 }
357 }
358 }
359
360
361 response.getEntity().copyRawContentTo(servletResponse.getOutputStream());
362 }
363 } catch (IOException e) {
364 logger.error("Failed to write response", e);
365 } finally {
366 closeSilently(response);
367 }
368 }
369
370 @Override
371 public void destroy() {
372 application.stop();
373 }
374 }