diff --git a/celext/lib.go b/celext/lib.go index 444f09b..63dfe8e 100644 --- a/celext/lib.go +++ b/celext/lib.go @@ -29,7 +29,9 @@ import ( "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/ext" + "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/dynamicpb" ) // DefaultEnv produces a cel.Env with the necessary cel.EnvOption and @@ -49,6 +51,23 @@ func DefaultEnv(useUTC bool) (*cel.Env, error) { ) } +// RequiredCELEnvOptions returns the options required to have expressions which +// rely on the provided descriptor. +func RequiredCELEnvOptions(fieldDesc protoreflect.FieldDescriptor) []cel.EnvOption { + if fieldDesc.IsMap() { + return append( + RequiredCELEnvOptions(fieldDesc.MapKey()), + RequiredCELEnvOptions(fieldDesc.MapValue())..., + ) + } + if fieldDesc.Kind() == protoreflect.MessageKind { + return []cel.EnvOption{ + cel.Types(dynamicpb.NewMessage(fieldDesc.Message())), + } + } + return nil +} + // lib is the collection of functions and settings required by protovalidate // beyond the standard definitions of the CEL Specification: // diff --git a/internal/expression/lookups.go b/celext/lookups.go similarity index 78% rename from internal/expression/lookups.go rename to celext/lookups.go index 5c26695..cb24d70 100644 --- a/internal/expression/lookups.go +++ b/celext/lookups.go @@ -12,51 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package expression +package celext import ( "github.com/google/cel-go/cel" "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/dynamicpb" ) -// ProtoKindToCELType maps a protoreflect.Kind to a compatible cel.Type. -func ProtoKindToCELType(kind protoreflect.Kind) *cel.Type { - switch kind { - case - protoreflect.FloatKind, - protoreflect.DoubleKind: - return cel.DoubleType - case - protoreflect.Int32Kind, - protoreflect.Int64Kind, - protoreflect.Sint32Kind, - protoreflect.Sint64Kind, - protoreflect.Sfixed32Kind, - protoreflect.Sfixed64Kind, - protoreflect.EnumKind: - return cel.IntType - case - protoreflect.Uint32Kind, - protoreflect.Uint64Kind, - protoreflect.Fixed32Kind, - protoreflect.Fixed64Kind: - return cel.UintType - case protoreflect.BoolKind: - return cel.BoolType - case protoreflect.StringKind: - return cel.StringType - case protoreflect.BytesKind: - return cel.BytesType - case - protoreflect.MessageKind, - protoreflect.GroupKind: - return cel.DynType - default: - return cel.DynType - } -} - // ProtoFieldToCELType resolves the CEL value type for the provided // FieldDescriptor. If generic is true, the specific subtypes of map and // repeated fields will be replaced with cel.DynType. If forItems is true, the @@ -92,22 +54,42 @@ func ProtoFieldToCELType(fieldDesc protoreflect.FieldDescriptor, generic, forIte return cel.ObjectType(string(fqn)) } } - return ProtoKindToCELType(fieldDesc.Kind()) + return protoKindToCELType(fieldDesc.Kind()) } -// RequiredCELEnvOptions returns the options required to have expressions which -// rely on the provided descriptor. -func RequiredCELEnvOptions(fieldDesc protoreflect.FieldDescriptor) []cel.EnvOption { - if fieldDesc.IsMap() { - return append( - RequiredCELEnvOptions(fieldDesc.MapKey()), - RequiredCELEnvOptions(fieldDesc.MapValue())..., - ) - } - if fieldDesc.Kind() == protoreflect.MessageKind { - return []cel.EnvOption{ - cel.Types(dynamicpb.NewMessage(fieldDesc.Message())), - } +// protoKindToCELType maps a protoreflect.Kind to a compatible cel.Type. +func protoKindToCELType(kind protoreflect.Kind) *cel.Type { + switch kind { + case + protoreflect.FloatKind, + protoreflect.DoubleKind: + return cel.DoubleType + case + protoreflect.Int32Kind, + protoreflect.Int64Kind, + protoreflect.Sint32Kind, + protoreflect.Sint64Kind, + protoreflect.Sfixed32Kind, + protoreflect.Sfixed64Kind, + protoreflect.EnumKind: + return cel.IntType + case + protoreflect.Uint32Kind, + protoreflect.Uint64Kind, + protoreflect.Fixed32Kind, + protoreflect.Fixed64Kind: + return cel.UintType + case protoreflect.BoolKind: + return cel.BoolType + case protoreflect.StringKind: + return cel.StringType + case protoreflect.BytesKind: + return cel.BytesType + case + protoreflect.MessageKind, + protoreflect.GroupKind: + return cel.DynType + default: + return cel.DynType } - return nil } diff --git a/internal/expression/lookups_test.go b/celext/lookups_test.go similarity index 70% rename from internal/expression/lookups_test.go rename to celext/lookups_test.go index f9fe452..8ee382f 100644 --- a/internal/expression/lookups_test.go +++ b/celext/lookups_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package expression +package celext import ( "testing" @@ -89,6 +89,40 @@ func TestCache_GetCELType(t *testing.T) { } } +func TestProtoKindToCELType(t *testing.T) { + t.Parallel() + + tests := map[protoreflect.Kind]*cel.Type{ + protoreflect.FloatKind: cel.DoubleType, + protoreflect.DoubleKind: cel.DoubleType, + protoreflect.Int32Kind: cel.IntType, + protoreflect.Int64Kind: cel.IntType, + protoreflect.Uint32Kind: cel.UintType, + protoreflect.Uint64Kind: cel.UintType, + protoreflect.Sint32Kind: cel.IntType, + protoreflect.Sint64Kind: cel.IntType, + protoreflect.Fixed32Kind: cel.UintType, + protoreflect.Fixed64Kind: cel.UintType, + protoreflect.Sfixed32Kind: cel.IntType, + protoreflect.Sfixed64Kind: cel.IntType, + protoreflect.BoolKind: cel.BoolType, + protoreflect.StringKind: cel.StringType, + protoreflect.BytesKind: cel.BytesType, + protoreflect.EnumKind: cel.IntType, + protoreflect.MessageKind: cel.DynType, + protoreflect.GroupKind: cel.DynType, + protoreflect.Kind(0): cel.DynType, + } + + for k, ty := range tests { + kind, typ := k, ty + t.Run(kind.String(), func(t *testing.T) { + t.Parallel() + assert.Equal(t, typ, protoKindToCELType(kind)) + }) + } +} + func getFieldDesc(t *testing.T, msg proto.Message, fld protoreflect.Name) protoreflect.FieldDescriptor { t.Helper() desc := msg.ProtoReflect().Descriptor().Fields().ByName(fld) diff --git a/internal/constraints/cache.go b/internal/constraints/cache.go index 02a3b9e..7d07123 100644 --- a/internal/constraints/cache.go +++ b/internal/constraints/cache.go @@ -17,6 +17,7 @@ package constraints import ( "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate/priv" + "github.com/bufbuild/protovalidate-go/celext" "github.com/bufbuild/protovalidate-go/internal/errors" "github.com/bufbuild/protovalidate-go/internal/expression" "github.com/google/cel-go/cel" @@ -114,7 +115,7 @@ func (c *Cache) prepareEnvironment( ) (*cel.Env, error) { env, err := env.Extend( cel.Types(rules.Interface()), - cel.Variable("this", expression.ProtoFieldToCELType(fieldDesc, true, forItems)), + cel.Variable("this", celext.ProtoFieldToCELType(fieldDesc, true, forItems)), cel.Variable("rules", cel.ObjectType(string(rules.Descriptor().FullName()))), ) diff --git a/internal/constraints/lookups_test.go b/internal/constraints/lookups_test.go index 152201d..6fa9423 100644 --- a/internal/constraints/lookups_test.go +++ b/internal/constraints/lookups_test.go @@ -17,8 +17,6 @@ package constraints import ( "testing" - "github.com/bufbuild/protovalidate-go/internal/expression" - "github.com/google/cel-go/cel" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" @@ -54,37 +52,3 @@ func TestExpectedWrapperConstraints(t *testing.T) { }) } } - -func TestProtoKindToCELType(t *testing.T) { - t.Parallel() - - tests := map[protoreflect.Kind]*cel.Type{ - protoreflect.FloatKind: cel.DoubleType, - protoreflect.DoubleKind: cel.DoubleType, - protoreflect.Int32Kind: cel.IntType, - protoreflect.Int64Kind: cel.IntType, - protoreflect.Uint32Kind: cel.UintType, - protoreflect.Uint64Kind: cel.UintType, - protoreflect.Sint32Kind: cel.IntType, - protoreflect.Sint64Kind: cel.IntType, - protoreflect.Fixed32Kind: cel.UintType, - protoreflect.Fixed64Kind: cel.UintType, - protoreflect.Sfixed32Kind: cel.IntType, - protoreflect.Sfixed64Kind: cel.IntType, - protoreflect.BoolKind: cel.BoolType, - protoreflect.StringKind: cel.StringType, - protoreflect.BytesKind: cel.BytesType, - protoreflect.EnumKind: cel.IntType, - protoreflect.MessageKind: cel.DynType, - protoreflect.GroupKind: cel.DynType, - protoreflect.Kind(0): cel.DynType, - } - - for k, ty := range tests { - kind, typ := k, ty - t.Run(kind.String(), func(t *testing.T) { - t.Parallel() - assert.Equal(t, typ, expression.ProtoKindToCELType(kind)) - }) - } -} diff --git a/internal/evaluator/builder.go b/internal/evaluator/builder.go index b525711..1057cf9 100644 --- a/internal/evaluator/builder.go +++ b/internal/evaluator/builder.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" + "github.com/bufbuild/protovalidate-go/celext" "github.com/bufbuild/protovalidate-go/internal/constraints" "github.com/bufbuild/protovalidate-go/internal/errors" "github.com/bufbuild/protovalidate-go/internal/expression" @@ -278,9 +279,9 @@ func (bldr *Builder) processFieldExpressions( return nil } - celTyp := expression.ProtoFieldToCELType(fieldDesc, false, false) + celTyp := celext.ProtoFieldToCELType(fieldDesc, false, false) opts := append( - expression.RequiredCELEnvOptions(fieldDesc), + celext.RequiredCELEnvOptions(fieldDesc), cel.Variable("this", celTyp), ) compiledExpressions, err := expression.Compile(exprs, bldr.env, opts...)