mirror of
https://github.com/gusaul/grpcox.git
synced 2024-12-27 19:30:10 +00:00
390 lines
12 KiB
Go
390 lines
12 KiB
Go
|
package grpcurl
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
|
||
|
"github.com/golang/protobuf/jsonpb"
|
||
|
"github.com/golang/protobuf/proto"
|
||
|
"github.com/jhump/protoreflect/desc"
|
||
|
"github.com/jhump/protoreflect/dynamic"
|
||
|
"github.com/jhump/protoreflect/dynamic/grpcdynamic"
|
||
|
"github.com/jhump/protoreflect/grpcreflect"
|
||
|
"golang.org/x/net/context"
|
||
|
"google.golang.org/grpc"
|
||
|
"google.golang.org/grpc/codes"
|
||
|
"google.golang.org/grpc/metadata"
|
||
|
"google.golang.org/grpc/status"
|
||
|
)
|
||
|
|
||
|
// InvocationEventHandler is a bag of callbacks for handling events that occur in the course
|
||
|
// of invoking an RPC. The handler also provides request data that is sent. The callbacks are
|
||
|
// generally called in the order they are listed below.
|
||
|
type InvocationEventHandler interface {
|
||
|
// OnResolveMethod is called with a descriptor of the method that is being invoked.
|
||
|
OnResolveMethod(*desc.MethodDescriptor)
|
||
|
// OnSendHeaders is called with the request metadata that is being sent.
|
||
|
OnSendHeaders(metadata.MD)
|
||
|
// OnReceiveHeaders is called when response headers have been received.
|
||
|
OnReceiveHeaders(metadata.MD)
|
||
|
// OnReceiveResponse is called for each response message received.
|
||
|
OnReceiveResponse(proto.Message)
|
||
|
// OnReceiveTrailers is called when response trailers and final RPC status have been received.
|
||
|
OnReceiveTrailers(*status.Status, metadata.MD)
|
||
|
}
|
||
|
|
||
|
// RequestMessageSupplier is a function that is called to retrieve request
|
||
|
// messages for a GRPC operation. This type is deprecated and will be removed in
|
||
|
// a future release.
|
||
|
//
|
||
|
// Deprecated: This is only used with the deprecated InvokeRpc. Instead, use
|
||
|
// RequestSupplier with InvokeRPC.
|
||
|
type RequestMessageSupplier func() ([]byte, error)
|
||
|
|
||
|
// InvokeRpc uses the given gRPC connection to invoke the given method. This function is deprecated
|
||
|
// and will be removed in a future release. It just delegates to the similarly named InvokeRPC
|
||
|
// method, whose signature is only slightly different.
|
||
|
//
|
||
|
// Deprecated: use InvokeRPC instead.
|
||
|
func InvokeRpc(ctx context.Context, source DescriptorSource, cc *grpc.ClientConn, methodName string,
|
||
|
headers []string, handler InvocationEventHandler, requestData RequestMessageSupplier) error {
|
||
|
|
||
|
return InvokeRPC(ctx, source, cc, methodName, headers, handler, func(m proto.Message) error {
|
||
|
// New function is almost identical, but the request supplier function works differently.
|
||
|
// So we adapt the logic here to maintain compatibility.
|
||
|
data, err := requestData()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return jsonpb.Unmarshal(bytes.NewReader(data), m)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// RequestSupplier is a function that is called to populate messages for a gRPC operation. The
|
||
|
// function should populate the given message or return a non-nil error. If the supplier has no
|
||
|
// more messages, it should return io.EOF. When it returns io.EOF, it should not in any way
|
||
|
// modify the given message argument.
|
||
|
type RequestSupplier func(proto.Message) error
|
||
|
|
||
|
// InvokeRPC uses the given gRPC channel to invoke the given method. The given descriptor source
|
||
|
// is used to determine the type of method and the type of request and response message. The given
|
||
|
// headers are sent as request metadata. Methods on the given event handler are called as the
|
||
|
// invocation proceeds.
|
||
|
//
|
||
|
// The given requestData function supplies the actual data to send. It should return io.EOF when
|
||
|
// there is no more request data. If the method being invoked is a unary or server-streaming RPC
|
||
|
// (e.g. exactly one request message) and there is no request data (e.g. the first invocation of
|
||
|
// the function returns io.EOF), then an empty request message is sent.
|
||
|
//
|
||
|
// If the requestData function and the given event handler coordinate or share any state, they should
|
||
|
// be thread-safe. This is because the requestData function may be called from a different goroutine
|
||
|
// than the one invoking event callbacks. (This only happens for bi-directional streaming RPCs, where
|
||
|
// one goroutine sends request messages and another consumes the response messages).
|
||
|
func InvokeRPC(ctx context.Context, source DescriptorSource, ch grpcdynamic.Channel, methodName string,
|
||
|
headers []string, handler InvocationEventHandler, requestData RequestSupplier) error {
|
||
|
|
||
|
md := MetadataFromHeaders(headers)
|
||
|
|
||
|
svc, mth := parseSymbol(methodName)
|
||
|
if svc == "" || mth == "" {
|
||
|
return fmt.Errorf("given method name %q is not in expected format: 'service/method' or 'service.method'", methodName)
|
||
|
}
|
||
|
dsc, err := source.FindSymbol(svc)
|
||
|
if err != nil {
|
||
|
if isNotFoundError(err) {
|
||
|
return fmt.Errorf("target server does not expose service %q", svc)
|
||
|
}
|
||
|
return fmt.Errorf("failed to query for service descriptor %q: %v", svc, err)
|
||
|
}
|
||
|
sd, ok := dsc.(*desc.ServiceDescriptor)
|
||
|
if !ok {
|
||
|
return fmt.Errorf("target server does not expose service %q", svc)
|
||
|
}
|
||
|
mtd := sd.FindMethodByName(mth)
|
||
|
if mtd == nil {
|
||
|
return fmt.Errorf("service %q does not include a method named %q", svc, mth)
|
||
|
}
|
||
|
|
||
|
handler.OnResolveMethod(mtd)
|
||
|
|
||
|
// we also download any applicable extensions so we can provide full support for parsing user-provided data
|
||
|
var ext dynamic.ExtensionRegistry
|
||
|
alreadyFetched := map[string]bool{}
|
||
|
if err = fetchAllExtensions(source, &ext, mtd.GetInputType(), alreadyFetched); err != nil {
|
||
|
return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetInputType().GetFullyQualifiedName(), err)
|
||
|
}
|
||
|
if err = fetchAllExtensions(source, &ext, mtd.GetOutputType(), alreadyFetched); err != nil {
|
||
|
return fmt.Errorf("error resolving server extensions for message %s: %v", mtd.GetOutputType().GetFullyQualifiedName(), err)
|
||
|
}
|
||
|
|
||
|
msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext)
|
||
|
req := msgFactory.NewMessage(mtd.GetInputType())
|
||
|
|
||
|
handler.OnSendHeaders(md)
|
||
|
ctx = metadata.NewOutgoingContext(ctx, md)
|
||
|
|
||
|
stub := grpcdynamic.NewStubWithMessageFactory(ch, msgFactory)
|
||
|
ctx, cancel := context.WithCancel(ctx)
|
||
|
defer cancel()
|
||
|
|
||
|
if mtd.IsClientStreaming() && mtd.IsServerStreaming() {
|
||
|
return invokeBidi(ctx, stub, mtd, handler, requestData, req)
|
||
|
} else if mtd.IsClientStreaming() {
|
||
|
return invokeClientStream(ctx, stub, mtd, handler, requestData, req)
|
||
|
} else if mtd.IsServerStreaming() {
|
||
|
return invokeServerStream(ctx, stub, mtd, handler, requestData, req)
|
||
|
} else {
|
||
|
return invokeUnary(ctx, stub, mtd, handler, requestData, req)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func invokeUnary(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
||
|
requestData RequestSupplier, req proto.Message) error {
|
||
|
|
||
|
err := requestData(req)
|
||
|
if err != nil && err != io.EOF {
|
||
|
return fmt.Errorf("error getting request data: %v", err)
|
||
|
}
|
||
|
if err != io.EOF {
|
||
|
// verify there is no second message, which is a usage error
|
||
|
err := requestData(req)
|
||
|
if err == nil {
|
||
|
return fmt.Errorf("method %q is a unary RPC, but request data contained more than 1 message", md.GetFullyQualifiedName())
|
||
|
} else if err != io.EOF {
|
||
|
return fmt.Errorf("error getting request data: %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Now we can actually invoke the RPC!
|
||
|
var respHeaders metadata.MD
|
||
|
var respTrailers metadata.MD
|
||
|
resp, err := stub.InvokeRpc(ctx, md, req, grpc.Trailer(&respTrailers), grpc.Header(&respHeaders))
|
||
|
|
||
|
stat, ok := status.FromError(err)
|
||
|
if !ok {
|
||
|
// Error codes sent from the server will get printed differently below.
|
||
|
// So just bail for other kinds of errors here.
|
||
|
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
||
|
}
|
||
|
|
||
|
handler.OnReceiveHeaders(respHeaders)
|
||
|
|
||
|
if stat.Code() == codes.OK {
|
||
|
handler.OnReceiveResponse(resp)
|
||
|
}
|
||
|
|
||
|
handler.OnReceiveTrailers(stat, respTrailers)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func invokeClientStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
||
|
requestData RequestSupplier, req proto.Message) error {
|
||
|
|
||
|
// invoke the RPC!
|
||
|
str, err := stub.InvokeRpcClientStream(ctx, md)
|
||
|
|
||
|
// Upload each request message in the stream
|
||
|
var resp proto.Message
|
||
|
for err == nil {
|
||
|
err = requestData(req)
|
||
|
if err == io.EOF {
|
||
|
resp, err = str.CloseAndReceive()
|
||
|
break
|
||
|
}
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("error getting request data: %v", err)
|
||
|
}
|
||
|
|
||
|
err = str.SendMsg(req)
|
||
|
if err == io.EOF {
|
||
|
// We get EOF on send if the server says "go away"
|
||
|
// We have to use CloseAndReceive to get the actual code
|
||
|
resp, err = str.CloseAndReceive()
|
||
|
break
|
||
|
}
|
||
|
|
||
|
req.Reset()
|
||
|
}
|
||
|
|
||
|
// finally, process response data
|
||
|
stat, ok := status.FromError(err)
|
||
|
if !ok {
|
||
|
// Error codes sent from the server will get printed differently below.
|
||
|
// So just bail for other kinds of errors here.
|
||
|
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
||
|
}
|
||
|
|
||
|
if respHeaders, err := str.Header(); err == nil {
|
||
|
handler.OnReceiveHeaders(respHeaders)
|
||
|
}
|
||
|
|
||
|
if stat.Code() == codes.OK {
|
||
|
handler.OnReceiveResponse(resp)
|
||
|
}
|
||
|
|
||
|
handler.OnReceiveTrailers(stat, str.Trailer())
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func invokeServerStream(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
||
|
requestData RequestSupplier, req proto.Message) error {
|
||
|
|
||
|
err := requestData(req)
|
||
|
if err != nil && err != io.EOF {
|
||
|
return fmt.Errorf("error getting request data: %v", err)
|
||
|
}
|
||
|
if err != io.EOF {
|
||
|
// verify there is no second message, which is a usage error
|
||
|
err := requestData(req)
|
||
|
if err == nil {
|
||
|
return fmt.Errorf("method %q is a server-streaming RPC, but request data contained more than 1 message", md.GetFullyQualifiedName())
|
||
|
} else if err != io.EOF {
|
||
|
return fmt.Errorf("error getting request data: %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Now we can actually invoke the RPC!
|
||
|
str, err := stub.InvokeRpcServerStream(ctx, md, req)
|
||
|
|
||
|
if respHeaders, err := str.Header(); err == nil {
|
||
|
handler.OnReceiveHeaders(respHeaders)
|
||
|
}
|
||
|
|
||
|
// Download each response message
|
||
|
for err == nil {
|
||
|
var resp proto.Message
|
||
|
resp, err = str.RecvMsg()
|
||
|
if err != nil {
|
||
|
if err == io.EOF {
|
||
|
err = nil
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
handler.OnReceiveResponse(resp)
|
||
|
}
|
||
|
|
||
|
stat, ok := status.FromError(err)
|
||
|
if !ok {
|
||
|
// Error codes sent from the server will get printed differently below.
|
||
|
// So just bail for other kinds of errors here.
|
||
|
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
||
|
}
|
||
|
|
||
|
handler.OnReceiveTrailers(stat, str.Trailer())
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func invokeBidi(ctx context.Context, stub grpcdynamic.Stub, md *desc.MethodDescriptor, handler InvocationEventHandler,
|
||
|
requestData RequestSupplier, req proto.Message) error {
|
||
|
|
||
|
ctx, cancel := context.WithCancel(ctx)
|
||
|
defer cancel()
|
||
|
|
||
|
// invoke the RPC!
|
||
|
str, err := stub.InvokeRpcBidiStream(ctx, md)
|
||
|
|
||
|
var wg sync.WaitGroup
|
||
|
var sendErr atomic.Value
|
||
|
|
||
|
defer wg.Wait()
|
||
|
|
||
|
if err == nil {
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
|
||
|
// Concurrently upload each request message in the stream
|
||
|
var err error
|
||
|
for err == nil {
|
||
|
err = requestData(req)
|
||
|
|
||
|
if err == io.EOF {
|
||
|
err = str.CloseSend()
|
||
|
break
|
||
|
}
|
||
|
if err != nil {
|
||
|
err = fmt.Errorf("error getting request data: %v", err)
|
||
|
cancel()
|
||
|
break
|
||
|
}
|
||
|
|
||
|
err = str.SendMsg(req)
|
||
|
|
||
|
req.Reset()
|
||
|
}
|
||
|
|
||
|
if err != nil {
|
||
|
sendErr.Store(err)
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
if respHeaders, err := str.Header(); err == nil {
|
||
|
handler.OnReceiveHeaders(respHeaders)
|
||
|
}
|
||
|
|
||
|
// Download each response message
|
||
|
for err == nil {
|
||
|
var resp proto.Message
|
||
|
resp, err = str.RecvMsg()
|
||
|
if err != nil {
|
||
|
if err == io.EOF {
|
||
|
err = nil
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
handler.OnReceiveResponse(resp)
|
||
|
}
|
||
|
|
||
|
if se, ok := sendErr.Load().(error); ok && se != io.EOF {
|
||
|
err = se
|
||
|
}
|
||
|
|
||
|
stat, ok := status.FromError(err)
|
||
|
if !ok {
|
||
|
// Error codes sent from the server will get printed differently below.
|
||
|
// So just bail for other kinds of errors here.
|
||
|
return fmt.Errorf("grpc call for %q failed: %v", md.GetFullyQualifiedName(), err)
|
||
|
}
|
||
|
|
||
|
handler.OnReceiveTrailers(stat, str.Trailer())
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type notFoundError string
|
||
|
|
||
|
func notFound(kind, name string) error {
|
||
|
return notFoundError(fmt.Sprintf("%s not found: %s", kind, name))
|
||
|
}
|
||
|
|
||
|
func (e notFoundError) Error() string {
|
||
|
return string(e)
|
||
|
}
|
||
|
|
||
|
func isNotFoundError(err error) bool {
|
||
|
if grpcreflect.IsElementNotFoundError(err) {
|
||
|
return true
|
||
|
}
|
||
|
_, ok := err.(notFoundError)
|
||
|
return ok
|
||
|
}
|
||
|
|
||
|
func parseSymbol(svcAndMethod string) (string, string) {
|
||
|
pos := strings.LastIndex(svcAndMethod, "/")
|
||
|
if pos < 0 {
|
||
|
pos = strings.LastIndex(svcAndMethod, ".")
|
||
|
if pos < 0 {
|
||
|
return "", ""
|
||
|
}
|
||
|
}
|
||
|
return svcAndMethod[:pos], svcAndMethod[pos+1:]
|
||
|
}
|