mirror of
https://github.com/gusaul/grpcox.git
synced 2025-01-26 06:14:40 +00:00
342 lines
9.6 KiB
Go
342 lines
9.6 KiB
Go
package desc
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"sync"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
|
|
|
|
"github.com/jhump/protoreflect/internal"
|
|
)
|
|
|
|
var (
|
|
cacheMu sync.RWMutex
|
|
filesCache = map[string]*FileDescriptor{}
|
|
messagesCache = map[string]*MessageDescriptor{}
|
|
enumCache = map[reflect.Type]*EnumDescriptor{}
|
|
)
|
|
|
|
// LoadFileDescriptor creates a file descriptor using the bytes returned by
|
|
// proto.FileDescriptor. Descriptors are cached so that they do not need to be
|
|
// re-processed if the same file is fetched again later.
|
|
func LoadFileDescriptor(file string) (*FileDescriptor, error) {
|
|
return loadFileDescriptor(file, nil)
|
|
}
|
|
|
|
func loadFileDescriptor(file string, r *ImportResolver) (*FileDescriptor, error) {
|
|
f := getFileFromCache(file)
|
|
if f != nil {
|
|
return f, nil
|
|
}
|
|
cacheMu.Lock()
|
|
defer cacheMu.Unlock()
|
|
return loadFileDescriptorLocked(file, r)
|
|
}
|
|
|
|
func loadFileDescriptorLocked(file string, r *ImportResolver) (*FileDescriptor, error) {
|
|
f := filesCache[file]
|
|
if f != nil {
|
|
return f, nil
|
|
}
|
|
fd, err := internal.LoadFileDescriptor(file)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
f, err = toFileDescriptorLocked(fd, r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
putCacheLocked(file, f)
|
|
return f, nil
|
|
}
|
|
|
|
func toFileDescriptorLocked(fd *dpb.FileDescriptorProto, r *ImportResolver) (*FileDescriptor, error) {
|
|
deps := make([]*FileDescriptor, len(fd.GetDependency()))
|
|
for i, dep := range fd.GetDependency() {
|
|
resolvedDep := r.ResolveImport(fd.GetName(), dep)
|
|
var err error
|
|
deps[i], err = loadFileDescriptorLocked(resolvedDep, r)
|
|
if _, ok := err.(internal.ErrNoSuchFile); ok && resolvedDep != dep {
|
|
// try original path
|
|
deps[i], err = loadFileDescriptorLocked(dep, r)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return CreateFileDescriptor(fd, deps...)
|
|
}
|
|
|
|
func getFileFromCache(file string) *FileDescriptor {
|
|
cacheMu.RLock()
|
|
defer cacheMu.RUnlock()
|
|
return filesCache[file]
|
|
}
|
|
|
|
func putCacheLocked(filename string, fd *FileDescriptor) {
|
|
filesCache[filename] = fd
|
|
putMessageCacheLocked(fd.messages)
|
|
}
|
|
|
|
func putMessageCacheLocked(mds []*MessageDescriptor) {
|
|
for _, md := range mds {
|
|
messagesCache[md.fqn] = md
|
|
putMessageCacheLocked(md.nested)
|
|
}
|
|
}
|
|
|
|
// interface implemented by generated messages, which all have a Descriptor() method in
|
|
// addition to the methods of proto.Message
|
|
type protoMessage interface {
|
|
proto.Message
|
|
Descriptor() ([]byte, []int)
|
|
}
|
|
|
|
// LoadMessageDescriptor loads descriptor using the encoded descriptor proto returned by
|
|
// Message.Descriptor() for the given message type. If the given type is not recognized,
|
|
// then a nil descriptor is returned.
|
|
func LoadMessageDescriptor(message string) (*MessageDescriptor, error) {
|
|
return loadMessageDescriptor(message, nil)
|
|
}
|
|
|
|
func loadMessageDescriptor(message string, r *ImportResolver) (*MessageDescriptor, error) {
|
|
m := getMessageFromCache(message)
|
|
if m != nil {
|
|
return m, nil
|
|
}
|
|
|
|
pt := proto.MessageType(message)
|
|
if pt == nil {
|
|
return nil, nil
|
|
}
|
|
msg, err := messageFromType(pt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cacheMu.Lock()
|
|
defer cacheMu.Unlock()
|
|
return loadMessageDescriptorForTypeLocked(message, msg, r)
|
|
}
|
|
|
|
// LoadMessageDescriptorForType loads descriptor using the encoded descriptor proto returned
|
|
// by message.Descriptor() for the given message type. If the given type is not recognized,
|
|
// then a nil descriptor is returned.
|
|
func LoadMessageDescriptorForType(messageType reflect.Type) (*MessageDescriptor, error) {
|
|
return loadMessageDescriptorForType(messageType, nil)
|
|
}
|
|
|
|
func loadMessageDescriptorForType(messageType reflect.Type, r *ImportResolver) (*MessageDescriptor, error) {
|
|
m, err := messageFromType(messageType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return loadMessageDescriptorForMessage(m, r)
|
|
}
|
|
|
|
// LoadMessageDescriptorForMessage loads descriptor using the encoded descriptor proto
|
|
// returned by message.Descriptor(). If the given type is not recognized, then a nil
|
|
// descriptor is returned.
|
|
func LoadMessageDescriptorForMessage(message proto.Message) (*MessageDescriptor, error) {
|
|
return loadMessageDescriptorForMessage(message, nil)
|
|
}
|
|
|
|
func loadMessageDescriptorForMessage(message proto.Message, r *ImportResolver) (*MessageDescriptor, error) {
|
|
// efficiently handle dynamic messages
|
|
type descriptorable interface {
|
|
GetMessageDescriptor() *MessageDescriptor
|
|
}
|
|
if d, ok := message.(descriptorable); ok {
|
|
return d.GetMessageDescriptor(), nil
|
|
}
|
|
|
|
name := proto.MessageName(message)
|
|
if name == "" {
|
|
return nil, nil
|
|
}
|
|
m := getMessageFromCache(name)
|
|
if m != nil {
|
|
return m, nil
|
|
}
|
|
|
|
cacheMu.Lock()
|
|
defer cacheMu.Unlock()
|
|
return loadMessageDescriptorForTypeLocked(name, message.(protoMessage), nil)
|
|
}
|
|
|
|
func messageFromType(mt reflect.Type) (protoMessage, error) {
|
|
if mt.Kind() != reflect.Ptr {
|
|
mt = reflect.PtrTo(mt)
|
|
}
|
|
m, ok := reflect.Zero(mt).Interface().(protoMessage)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to create message from type: %v", mt)
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
func loadMessageDescriptorForTypeLocked(name string, message protoMessage, r *ImportResolver) (*MessageDescriptor, error) {
|
|
m := messagesCache[name]
|
|
if m != nil {
|
|
return m, nil
|
|
}
|
|
|
|
fdb, _ := message.Descriptor()
|
|
fd, err := internal.DecodeFileDescriptor(name, fdb)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
f, err := toFileDescriptorLocked(fd, r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
putCacheLocked(fd.GetName(), f)
|
|
return f.FindSymbol(name).(*MessageDescriptor), nil
|
|
}
|
|
|
|
func getMessageFromCache(message string) *MessageDescriptor {
|
|
cacheMu.RLock()
|
|
defer cacheMu.RUnlock()
|
|
return messagesCache[message]
|
|
}
|
|
|
|
// interface implemented by all generated enums
|
|
type protoEnum interface {
|
|
EnumDescriptor() ([]byte, []int)
|
|
}
|
|
|
|
// NB: There is no LoadEnumDescriptor that takes a fully-qualified enum name because
|
|
// it is not useful since protoc-gen-go does not expose the name anywhere in generated
|
|
// code or register it in a way that is it accessible for reflection code. This also
|
|
// means we have to cache enum descriptors differently -- we can only cache them as
|
|
// they are requested, as opposed to caching all enum types whenever a file descriptor
|
|
// is cached. This is because we need to know the generated type of the enums, and we
|
|
// don't know that at the time of caching file descriptors.
|
|
|
|
// LoadEnumDescriptorForType loads descriptor using the encoded descriptor proto returned
|
|
// by enum.EnumDescriptor() for the given enum type.
|
|
func LoadEnumDescriptorForType(enumType reflect.Type) (*EnumDescriptor, error) {
|
|
return loadEnumDescriptorForType(enumType, nil)
|
|
}
|
|
|
|
func loadEnumDescriptorForType(enumType reflect.Type, r *ImportResolver) (*EnumDescriptor, error) {
|
|
// we cache descriptors using non-pointer type
|
|
if enumType.Kind() == reflect.Ptr {
|
|
enumType = enumType.Elem()
|
|
}
|
|
e := getEnumFromCache(enumType)
|
|
if e != nil {
|
|
return e, nil
|
|
}
|
|
enum, err := enumFromType(enumType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cacheMu.Lock()
|
|
defer cacheMu.Unlock()
|
|
return loadEnumDescriptorForTypeLocked(enumType, enum, r)
|
|
}
|
|
|
|
// LoadEnumDescriptorForEnum loads descriptor using the encoded descriptor proto
|
|
// returned by enum.EnumDescriptor().
|
|
func LoadEnumDescriptorForEnum(enum protoEnum) (*EnumDescriptor, error) {
|
|
return loadEnumDescriptorForEnum(enum, nil)
|
|
}
|
|
|
|
func loadEnumDescriptorForEnum(enum protoEnum, r *ImportResolver) (*EnumDescriptor, error) {
|
|
et := reflect.TypeOf(enum)
|
|
// we cache descriptors using non-pointer type
|
|
if et.Kind() == reflect.Ptr {
|
|
et = et.Elem()
|
|
enum = reflect.Zero(et).Interface().(protoEnum)
|
|
}
|
|
e := getEnumFromCache(et)
|
|
if e != nil {
|
|
return e, nil
|
|
}
|
|
|
|
cacheMu.Lock()
|
|
defer cacheMu.Unlock()
|
|
return loadEnumDescriptorForTypeLocked(et, enum, r)
|
|
}
|
|
|
|
func enumFromType(et reflect.Type) (protoEnum, error) {
|
|
if et.Kind() != reflect.Int32 {
|
|
et = reflect.PtrTo(et)
|
|
}
|
|
e, ok := reflect.Zero(et).Interface().(protoEnum)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to create enum from type: %v", et)
|
|
}
|
|
return e, nil
|
|
}
|
|
|
|
func loadEnumDescriptorForTypeLocked(et reflect.Type, enum protoEnum, r *ImportResolver) (*EnumDescriptor, error) {
|
|
e := enumCache[et]
|
|
if e != nil {
|
|
return e, nil
|
|
}
|
|
|
|
fdb, path := enum.EnumDescriptor()
|
|
name := fmt.Sprintf("%v", et)
|
|
fd, err := internal.DecodeFileDescriptor(name, fdb)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// see if we already have cached "rich" descriptor
|
|
f, ok := filesCache[fd.GetName()]
|
|
if !ok {
|
|
f, err = toFileDescriptorLocked(fd, r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
putCacheLocked(fd.GetName(), f)
|
|
}
|
|
|
|
ed := findEnum(f, path)
|
|
enumCache[et] = ed
|
|
return ed, nil
|
|
}
|
|
|
|
func getEnumFromCache(et reflect.Type) *EnumDescriptor {
|
|
cacheMu.RLock()
|
|
defer cacheMu.RUnlock()
|
|
return enumCache[et]
|
|
}
|
|
|
|
func findEnum(fd *FileDescriptor, path []int) *EnumDescriptor {
|
|
if len(path) == 1 {
|
|
return fd.GetEnumTypes()[path[0]]
|
|
}
|
|
md := fd.GetMessageTypes()[path[0]]
|
|
for _, i := range path[1 : len(path)-1] {
|
|
md = md.GetNestedMessageTypes()[i]
|
|
}
|
|
return md.GetNestedEnumTypes()[path[len(path)-1]]
|
|
}
|
|
|
|
// LoadFieldDescriptorForExtension loads the field descriptor that corresponds to the given
|
|
// extension description.
|
|
func LoadFieldDescriptorForExtension(ext *proto.ExtensionDesc) (*FieldDescriptor, error) {
|
|
return loadFieldDescriptorForExtension(ext, nil)
|
|
}
|
|
|
|
func loadFieldDescriptorForExtension(ext *proto.ExtensionDesc, r *ImportResolver) (*FieldDescriptor, error) {
|
|
file, err := loadFileDescriptor(ext.Filename, r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
field, ok := file.FindSymbol(ext.Name).(*FieldDescriptor)
|
|
// make sure descriptor agrees with attributes of the ExtensionDesc
|
|
if !ok || !field.IsExtension() || field.GetOwner().GetFullyQualifiedName() != proto.MessageName(ext.ExtendedType) ||
|
|
field.GetNumber() != ext.Field {
|
|
return nil, fmt.Errorf("file descriptor contained unexpected object with name %s", ext.Name)
|
|
}
|
|
return field, nil
|
|
}
|