Explorar el Código

End grpc server span in onComplete instead of close (#11170)

Lauri Tulmin hace 10 meses
padre
commit
c92955fa2f

+ 6 - 1
instrumentation/grpc-1.6/library/src/main/java/io/opentelemetry/instrumentation/grpc/v1_6/TracingServerInterceptor.java

@@ -67,6 +67,7 @@ final class TracingServerInterceptor implements ServerInterceptor {
       extends ForwardingServerCall.SimpleForwardingServerCall<REQUEST, RESPONSE> {
     private final Context context;
     private final GrpcRequest request;
+    private Status status;
 
     // Used by MESSAGE_ID_UPDATER
     @SuppressWarnings("UnusedVariable")
@@ -101,13 +102,13 @@ final class TracingServerInterceptor implements ServerInterceptor {
 
     @Override
     public void close(Status status, Metadata trailers) {
+      this.status = status;
       try {
         delegate().close(status, trailers);
       } catch (Throwable e) {
         instrumenter.end(context, request, status, e);
         throw e;
       }
-      instrumenter.end(context, request, status, status.getCause());
     }
 
     final class TracingServerCallListener
@@ -165,6 +166,10 @@ final class TracingServerInterceptor implements ServerInterceptor {
           instrumenter.end(context, request, Status.UNKNOWN, e);
           throw e;
         }
+        if (status == null) {
+          status = Status.UNKNOWN;
+        }
+        instrumenter.end(context, request, status, status.getCause());
       }
 
       @Override

+ 90 - 0
instrumentation/grpc-1.6/testing/src/main/java/io/opentelemetry/instrumentation/grpc/v1_6/AbstractGrpcStreamingTest.java

@@ -19,7 +19,9 @@ import io.grpc.Server;
 import io.grpc.ServerBuilder;
 import io.grpc.Status;
 import io.grpc.stub.StreamObserver;
+import io.opentelemetry.api.trace.Span;
 import io.opentelemetry.api.trace.SpanKind;
+import io.opentelemetry.api.trace.Tracer;
 import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension;
 import io.opentelemetry.instrumentation.testing.util.ThrowingRunnable;
 import io.opentelemetry.sdk.testing.assertj.EventDataAssert;
@@ -33,11 +35,13 @@ import java.util.Queue;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
 import org.junitpioneer.jupiter.cartesian.CartesianTest;
 
 public abstract class AbstractGrpcStreamingTest {
@@ -264,6 +268,92 @@ public abstract class AbstractGrpcStreamingTest {
                                                     (long) Status.Code.OK.value()))))));
   }
 
+  @Test
+  void grpcServerSpanEndsAfterChildSpan() throws Exception {
+    Tracer tracer = testing().getOpenTelemetry().getTracer("test");
+    AtomicBoolean serverSpanRecording = new AtomicBoolean();
+    CountDownLatch latch = new CountDownLatch(2);
+
+    BindableService greeter =
+        new GreeterGrpc.GreeterImplBase() {
+          @Override
+          public StreamObserver<Helloworld.Response> conversation(
+              StreamObserver<Helloworld.Response> observer) {
+            return new StreamObserver<Helloworld.Response>() {
+              Span span;
+
+              @Override
+              public void onNext(Helloworld.Response value) {
+                span = tracer.spanBuilder("child").startSpan();
+                observer.onNext(value);
+              }
+
+              @Override
+              public void onError(Throwable t) {
+                observer.onError(t);
+                span.end();
+              }
+
+              @Override
+              public void onCompleted() {
+                observer.onCompleted();
+                serverSpanRecording.set(Span.current().isRecording());
+                span.end();
+                latch.countDown();
+              }
+            };
+          }
+        };
+
+    Server server = configureServer(ServerBuilder.forPort(0).addService(greeter)).build().start();
+    ManagedChannel channel = createChannel(server);
+    closer.add(() -> channel.shutdownNow().awaitTermination(10, TimeUnit.SECONDS));
+    closer.add(() -> server.shutdownNow().awaitTermination());
+
+    GreeterGrpc.GreeterStub client = GreeterGrpc.newStub(channel).withWaitForReady();
+
+    StreamObserver<Helloworld.Response> observer2 =
+        client.conversation(
+            new StreamObserver<Helloworld.Response>() {
+              @Override
+              public void onNext(Helloworld.Response value) {}
+
+              @Override
+              public void onError(Throwable t) {}
+
+              @Override
+              public void onCompleted() {
+                latch.countDown();
+              }
+            });
+
+    Helloworld.Response message = Helloworld.Response.newBuilder().setMessage("message").build();
+    observer2.onNext(message);
+    observer2.onCompleted();
+
+    latch.await(10, TimeUnit.SECONDS);
+
+    // server span should end after child span
+    assertThat(serverSpanRecording).isTrue();
+
+    testing()
+        .waitAndAssertTraces(
+            trace ->
+                trace.hasSpansSatisfyingExactly(
+                    span ->
+                        span.hasName("example.Greeter/Conversation")
+                            .hasKind(SpanKind.CLIENT)
+                            .hasNoParent(),
+                    span ->
+                        span.hasName("example.Greeter/Conversation")
+                            .hasKind(SpanKind.SERVER)
+                            .hasParent(trace.getSpan(0)),
+                    span ->
+                        span.hasName("child")
+                            .hasKind(SpanKind.INTERNAL)
+                            .hasParent(trace.getSpan(1))));
+  }
+
   private ManagedChannel createChannel(Server server) throws Exception {
     ManagedChannelBuilder<?> channelBuilder =
         configureClient(ManagedChannelBuilder.forAddress("localhost", server.getPort()));