diff --git a/pkg/go/transformer/dsltojson.go b/pkg/go/transformer/dsltojson.go index 44b930dc..4bacb1bb 100644 --- a/pkg/go/transformer/dsltojson.go +++ b/pkg/go/transformer/dsltojson.go @@ -307,7 +307,9 @@ func (l *OpenFgaDslListener) EnterRelationDefDirectAssignment(_ *parser.Relation } func (l *OpenFgaDslListener) ExitRelationDefDirectAssignment(_ *parser.RelationDefDirectAssignmentContext) { - partialRewrite := &openfgav1.Userset{Userset: &openfgav1.Userset_This{}} + partialRewrite := &openfgav1.Userset{ + Userset: &openfgav1.Userset_This{This: &openfgav1.DirectUserset{}}, + } l.currentRelation.Rewrites = append(l.currentRelation.Rewrites, partialRewrite) } diff --git a/pkg/go/transformer/jsontodsl.go b/pkg/go/transformer/jsontodsl.go index e8b8bd5d..2c1c6f9f 100644 --- a/pkg/go/transformer/jsontodsl.go +++ b/pkg/go/transformer/jsontodsl.go @@ -26,14 +26,24 @@ func (v *DirectAssignmentValidator) occurrences() int { return v.occurred } +func isDirectAssignment(userset *openfgav1.Userset) bool { + if userset == nil { + return false + } + + _, ok := userset.GetUserset().(*openfgav1.Userset_This) + + return ok +} + func (v *DirectAssignmentValidator) isFirstPosition(userset *openfgav1.Userset) bool { //nolint:cyclop - if userset.GetThis() != nil { + if isDirectAssignment(userset) { return true } switch { case userset.GetDifference() != nil && userset.GetDifference().GetBase() != nil: - if userset.GetDifference().GetBase().GetThis() != nil { + if isDirectAssignment(userset.GetDifference().GetBase()) { return true } @@ -45,7 +55,7 @@ func (v *DirectAssignmentValidator) isFirstPosition(userset *openfgav1.Userset) // so even if it is not in the first position here, we're fine children := userset.GetIntersection().GetChild() for _, child := range children { - if child.GetThis() != nil { + if isDirectAssignment(child) { return true } } @@ -56,7 +66,7 @@ func (v *DirectAssignmentValidator) isFirstPosition(userset *openfgav1.Userset) children := userset.GetUnion().GetChild() if len(children) > 0 { for _, child := range children { - if child.GetThis() != nil { + if isDirectAssignment(child) { return true } } @@ -204,7 +214,7 @@ func parseSubRelation( typeRestrictions []*openfgav1.RelationReference, validator *DirectAssignmentValidator, ) (string, error) { - if relationDefinition.GetThis() != nil { + if isDirectAssignment(relationDefinition) { // Make sure we have no more than 1 reference for direct assignment in a given relation validator.incr() @@ -298,7 +308,7 @@ func prioritizeDirectAssignment(usersets []*openfgav1.Userset) []*openfgav1.User thisPosition := -1 for index, userset := range usersets { - if userset.GetThis() != nil { + if isDirectAssignment(userset) { thisPosition = index break diff --git a/pkg/go/transformer/jsontodsl_test.go b/pkg/go/transformer/jsontodsl_test.go index e0387327..ea000b07 100644 --- a/pkg/go/transformer/jsontodsl_test.go +++ b/pkg/go/transformer/jsontodsl_test.go @@ -8,6 +8,33 @@ import ( language "github.com/openfga/language/pkg/go/transformer" ) +func TestTransformJSONProtoToDSL(t *testing.T) { + t.Parallel() + + testCases, err := loadModuleTestCases() + require.NoError(t, err) + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + t.Parallel() + if len(testCase.Modules) == 0 || testCase.Skip || len(testCase.ExpectedErrors) > 0 { + t.Skip() + } + + model, err := language.TransformModuleFilesToModel(testCase.Modules, "1.2") + require.NoError(t, err) + + dsl, err := language.TransformJSONProtoToDSL(model, language.WithIncludeSourceInformation(false)) + require.NoError(t, err) + require.Equal(t, testCase.DSL, dsl) + + dslWithSrc, err := language.TransformJSONProtoToDSL(model, language.WithIncludeSourceInformation(true)) + require.NoError(t, err) + require.Equal(t, testCase.DSLWithSourceInfo, dslWithSrc) + }) + + } +} func TestJSONToDSLTransformer(t *testing.T) { t.Parallel()