AbstractServlet3Test.groovy 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. /*
  2. * Copyright The OpenTelemetry Authors
  3. * SPDX-License-Identifier: Apache-2.0
  4. */
  5. import io.opentelemetry.instrumentation.test.AgentTestTrait
  6. import io.opentelemetry.instrumentation.test.asserts.TraceAssert
  7. import io.opentelemetry.instrumentation.test.base.HttpServerTest
  8. import io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint
  9. import io.opentelemetry.javaagent.bootstrap.servlet.ExperimentalSnippetHolder
  10. import io.opentelemetry.testing.internal.armeria.common.AggregatedHttpRequest
  11. import javax.servlet.Servlet
  12. import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.AUTH_REQUIRED
  13. import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.CAPTURE_HEADERS
  14. import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.CAPTURE_PARAMETERS
  15. import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.ERROR
  16. import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.EXCEPTION
  17. import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.INDEXED_CHILD
  18. import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.NOT_FOUND
  19. import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.QUERY_PARAM
  20. import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.REDIRECT
  21. import static io.opentelemetry.instrumentation.testing.junit.http.ServerEndpoint.SUCCESS
  22. abstract class AbstractServlet3Test<SERVER, CONTEXT> extends HttpServerTest<SERVER> implements AgentTestTrait {
  23. @Override
  24. URI buildAddress() {
  25. return new URI("http://localhost:$port$contextPath/")
  26. }
  27. // FIXME: Add authentication tests back in...
  28. // @Shared
  29. // protected String user = "user"
  30. // @Shared
  31. // protected String pass = "password"
  32. abstract Class<Servlet> servlet()
  33. abstract void addServlet(CONTEXT context, String path, Class<Servlet> servlet)
  34. public static final ServerEndpoint HTML_PRINT_WRITER =
  35. new ServerEndpoint("HTML_PRINT_WRITER", "htmlPrintWriter",
  36. 200,
  37. "<!DOCTYPE html>\n"
  38. + "<html lang=\"en\">\n"
  39. + "<head>\n"
  40. + " <meta charset=\"UTF-8\">\n"
  41. + " <title>Title</title>\n"
  42. + "</head>\n"
  43. + "<body>\n"
  44. + "<p>test works</p>\n"
  45. + "</body>\n"
  46. + "</html>")
  47. public static final ServerEndpoint HTML_SERVLET_OUTPUT_STREAM =
  48. new ServerEndpoint("HTML_SERVLET_OUTPUT_STREAM", "htmlServletOutputStream",
  49. 200,
  50. "<!DOCTYPE html>\n"
  51. + "<html lang=\"en\">\n"
  52. + "<head>\n"
  53. + " <meta charset=\"UTF-8\">\n"
  54. + " <title>Title</title>\n"
  55. + "</head>\n"
  56. + "<body>\n"
  57. + "<p>test works</p>\n"
  58. + "</body>\n"
  59. + "</html>")
  60. protected void setupServlets(CONTEXT context) {
  61. def servlet = servlet()
  62. addServlet(context, SUCCESS.path, servlet)
  63. addServlet(context, QUERY_PARAM.path, servlet)
  64. addServlet(context, ERROR.path, servlet)
  65. addServlet(context, EXCEPTION.path, servlet)
  66. addServlet(context, REDIRECT.path, servlet)
  67. addServlet(context, AUTH_REQUIRED.path, servlet)
  68. addServlet(context, INDEXED_CHILD.path, servlet)
  69. addServlet(context, CAPTURE_HEADERS.path, servlet)
  70. addServlet(context, CAPTURE_PARAMETERS.path, servlet)
  71. addServlet(context, HTML_PRINT_WRITER.path, servlet)
  72. addServlet(context, HTML_SERVLET_OUTPUT_STREAM.path, servlet)
  73. }
  74. protected ServerEndpoint lastRequest
  75. @Override
  76. AggregatedHttpRequest request(ServerEndpoint uri, String method) {
  77. lastRequest = uri
  78. super.request(uri, method)
  79. }
  80. @Override
  81. String expectedHttpRoute(ServerEndpoint endpoint) {
  82. switch (endpoint) {
  83. case NOT_FOUND:
  84. return getContextPath() + "/*"
  85. default:
  86. return super.expectedHttpRoute(endpoint)
  87. }
  88. }
  89. @Override
  90. boolean testCapturedRequestParameters() {
  91. true
  92. }
  93. boolean errorEndpointUsesSendError() {
  94. true
  95. }
  96. @Override
  97. boolean hasResponseCustomizer(ServerEndpoint endpoint) {
  98. true
  99. }
  100. @Override
  101. boolean hasResponseSpan(ServerEndpoint endpoint) {
  102. endpoint == REDIRECT || (endpoint == ERROR && errorEndpointUsesSendError())
  103. }
  104. @Override
  105. void responseSpan(TraceAssert trace, int index, Object parent, String method, ServerEndpoint endpoint) {
  106. switch (endpoint) {
  107. case REDIRECT:
  108. redirectSpan(trace, index, parent)
  109. break
  110. case ERROR:
  111. sendErrorSpan(trace, index, parent)
  112. break
  113. }
  114. }
  115. def "snippet injection with ServletOutputStream"() {
  116. setup:
  117. ExperimentalSnippetHolder.setSnippet("\n <script type=\"text/javascript\"> Test </script>")
  118. def request = request(HTML_SERVLET_OUTPUT_STREAM, "GET")
  119. def response = client.execute(request).aggregate().join()
  120. expect:
  121. response.status().code() == HTML_SERVLET_OUTPUT_STREAM.status
  122. String result = "<!DOCTYPE html>\n" +
  123. "<html lang=\"en\">\n" +
  124. "<head>\n" +
  125. " <script type=\"text/javascript\"> Test </script>\n" +
  126. " <meta charset=\"UTF-8\">\n" +
  127. " <title>Title</title>\n" +
  128. "</head>\n" +
  129. "<body>\n" +
  130. "<p>test works</p>\n" +
  131. "</body>\n" +
  132. "</html>"
  133. response.contentUtf8() == result
  134. response.headers().contentLength() == result.length()
  135. }
  136. def "snippet injection with PrintWriter"() {
  137. setup:
  138. ExperimentalSnippetHolder.setSnippet("\n <script type=\"text/javascript\"> Test </script>")
  139. def request = request(HTML_PRINT_WRITER, "GET")
  140. def response = client.execute(request).aggregate().join()
  141. expect:
  142. response.status().code() == HTML_PRINT_WRITER.status
  143. String result = "<!DOCTYPE html>\n" +
  144. "<html lang=\"en\">\n" +
  145. "<head>\n" +
  146. " <script type=\"text/javascript\"> Test </script>\n" +
  147. " <meta charset=\"UTF-8\">\n" +
  148. " <title>Title</title>\n" +
  149. "</head>\n" +
  150. "<body>\n" +
  151. "<p>test works</p>\n" +
  152. "</body>\n" +
  153. "</html>"
  154. response.contentUtf8() == result
  155. response.headers().contentLength() == result.length()
  156. }
  157. }