diff --git a/tracing/opentracing/endpoint.go b/tracing/opentracing/endpoint.go index ef7fa63..a193ef8 100644 --- a/tracing/opentracing/endpoint.go +++ b/tracing/opentracing/endpoint.go @@ -36,11 +36,15 @@ func TraceClient(tracer opentracing.Tracer, operationName string) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - parentSpan := opentracing.SpanFromContext(ctx) - clientSpan := tracer.StartSpan( - operationName, - opentracing.ChildOf(parentSpan.Context()), - ) + var clientSpan opentracing.Span + if parentSpan := opentracing.SpanFromContext(ctx); parentSpan != nil { + clientSpan = tracer.StartSpan( + operationName, + opentracing.ChildOf(parentSpan.Context()), + ) + } else { + clientSpan = tracer.StartSpan(operationName) + } defer clientSpan.Finish() otext.SpanKind.Set(clientSpan, otext.SpanKindRPCClient) ctx = opentracing.ContextWithSpan(ctx, clientSpan) diff --git a/tracing/opentracing/endpoint_test.go b/tracing/opentracing/endpoint_test.go index f7727b4..302717e 100644 --- a/tracing/opentracing/endpoint_test.go +++ b/tracing/opentracing/endpoint_test.go @@ -94,3 +94,24 @@ t.Errorf("Want ParentID %q, have %q", want, have) } } + +func TestTraceClientNoContextSpan(t *testing.T) { + tracer := mocktracer.New() + + // Empty/background context. + tracedEndpoint := kitot.TraceClient(tracer, "testOp")(endpoint.Nop) + if _, err := tracedEndpoint(context.Background(), struct{}{}); err != nil { + t.Fatal(err) + } + + // tracedEndpoint created a new Span. + finishedSpans := tracer.GetFinishedSpans() + if want, have := 1, len(finishedSpans); want != have { + t.Fatalf("Want %v span(s), found %v", want, have) + } + + endpointSpan := finishedSpans[0] + if want, have := "testOp", endpointSpan.OperationName; want != have { + t.Fatalf("Want %q, have %q", want, have) + } +}