diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs index 75cd23026da..0e695f9e7f8 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs @@ -424,6 +424,60 @@ private IReadOnlyList GetClientParameters() protected override string BuildName() => _inputClient.IsExactName ? _inputClient.Name : _inputClient.Name.ToIdentifierName(); + protected override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new List(); + foreach (var method in Methods.OfType()) + { + if (method.BodyStatements == null) + { + continue; + } + + if (method.CollectionDefinition != null) + { + dependencies.Add(method.CollectionDefinition.Type); + } + + if (method.ServiceMethod == null) + { + continue; + } + + AddInputTypeDependency(dependencies, method.ServiceMethod.Response.Type); + AddInputTypeDependency(dependencies, method.ServiceMethod.Exception?.Type); + foreach (var parameter in method.ServiceMethod.Parameters) + { + AddInputTypeDependency(dependencies, parameter.Type); + } + + foreach (var parameter in method.ServiceMethod.Operation.Parameters) + { + AddInputTypeDependency(dependencies, parameter.Type); + } + + foreach (var response in method.ServiceMethod.Operation.Responses) + { + AddInputTypeDependency(dependencies, response.BodyType); + foreach (var header in response.Headers) + { + AddInputTypeDependency(dependencies, header.Type); + } + } + } + + return dependencies; + } + + private static void AddInputTypeDependency(List dependencies, InputType? inputType) + { + var type = inputType == null ? null : ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputType); + if (type != null) + { + dependencies.Add(type); + } + } + protected override FieldProvider[] BuildFields() { List fields = [EndpointField]; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs index ae617957bf5..590eaf2b935 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs @@ -217,6 +217,22 @@ private bool HasPagingOperationNameCollision(string operationName) protected override TypeSignatureModifiers BuildDeclarationModifiers() => TypeSignatureModifiers.Internal | TypeSignatureModifiers.Partial | TypeSignatureModifiers.Class; + protected override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new List { Client.Type, ResponseModelType, NextPagePropertyType }; + if (ItemModelType != null) + { + dependencies.Add(ItemModelType); + } + + foreach (var field in RequestFields) + { + dependencies.Add(field.Type); + } + + return dependencies; + } + protected override FieldProvider[] BuildFields() => [ClientField, .. RequestFields]; protected override CSharpType[] BuildImplements() => diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 0d02ecba187..fbcb895c72d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -117,6 +117,10 @@ public MrwSerializationTypeDefinition(InputModelType inputModel, ModelProvider m protected override CSharpType? BuildBaseType() => _model.BaseType; + protected override IReadOnlyList BuildHelperDependencyNames() => _rawDataField != null || _additionalProperties.Value.Length > 0 + ? ["ChangeTrackingDictionary"] + : []; + protected override SuppressionStatement[] BuildDisabledFileWarnings() { if (_model.CanonicalView.Properties.Any(p => ScmModelProvider.IsFileBinaryContentType(p.Type))) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs index 88cb97b16e7..e390ab24707 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs @@ -51,6 +51,12 @@ protected override string BuildRelativeFilePath() return Path.Combine("src", "Generated", "Models", $"{Name}.Serialization.Multipart.cs"); } + protected override IReadOnlyList BuildHelperDependencyNames() => _model.Properties.Any( + prop => prop.WireInfo != null && !prop.WireInfo.IsRequired && + (prop.Type is { IsCollection: true, IsReadOnlyMemory: false } || prop.Type.IsDictionary)) + ? ["Optional"] + : []; + protected override SuppressionStatement[] BuildDisabledFileWarnings() => [new SuppressionStatement(null, Literal(ScmModelProvider.FileBinaryContentDiagnosticId), ScmModelProvider.ScmEvaluationTypeSuppressionJustification)]; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs index ec53be226f3..8cc5fea2368 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs @@ -78,6 +78,33 @@ protected override FieldProvider[] BuildFields() return [.. pipelineMessage20xClassifiersFields]; } + protected override IReadOnlyList BuildHelperDependencyNames() + { + var dependencies = new HashSet(StringComparer.Ordinal); + foreach (var serviceMethod in _inputClient.Methods) + { + foreach (var parameter in serviceMethod.Operation.Parameters) + { + if (parameter is not InputHeaderParameter and not InputQueryParameter) + { + continue; + } + + var type = ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(parameter.Type); + if (type?.IsDictionary == true) + { + dependencies.Add("ChangeTrackingDictionary"); + } + else if (type?.IsCollection == true) + { + dependencies.Add("ChangeTrackingList"); + } + } + } + + return [.. dependencies]; + } + protected override ScmMethodProvider[] BuildMethods() { List methods = new List(); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs new file mode 100644 index 00000000000..9647a678c28 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs @@ -0,0 +1,234 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.TypeSpec.Generator.Input; +using Microsoft.TypeSpec.Generator.Tests.Common; +using NUnit.Framework; + +namespace Microsoft.TypeSpec.Generator.ClientModel.Tests.PostProcessing +{ + public class ClientBodyDependencyPostProcessingTests + { + [Test] + public async Task OperationBodyParameterModelDoesNotBecomePublic() + { + var requestModel = InputFactory.Model("RequestBody"); + var parameter = InputFactory.BodyParameter("body", requestModel, isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Create", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertInternalModels([requestModel], [client], ["RequestBody"]); + } + + [Test] + public async Task OperationResponseBodyModelRemainsPublicAsRootOutputModel() + { + var responseModel = InputFactory.Model("ResponseBody"); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertPublicModels([responseModel], [client], ["ResponseBody"]); + } + + [Test] + public async Task NestedBodyModelGraphDoesNotBecomePublic() + { + var nestedModel = InputFactory.Model("NestedToolParameter"); + var toolModel = InputFactory.Model( + "ToolConfig", + properties: [InputFactory.Property("Parameter", nestedModel)]); + var parameter = InputFactory.BodyParameter("tool", toolModel, isRequired: true); + var operation = InputFactory.Operation("Configure", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Configure", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertInternalModels([toolModel, nestedModel], [client], ["ToolConfig", "NestedToolParameter"]); + } + + [Test] + public async Task NonDiscriminatorDerivedBodyModelDoesNotBecomePublicFromPublicBase() + { + var baseTool = InputFactory.Model("BaseTool"); + var concreteTool = InputFactory.Model( + "ConcreteTool", + properties: [InputFactory.Property("Name", InputPrimitiveType.String)], + baseModel: baseTool); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: baseTool)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(baseTool, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertMixedModels( + [baseTool, concreteTool], + [client], + publicModelNames: ["BaseTool"], + internalModelNames: ["ConcreteTool"]); + } + + [Test] + public async Task CustomizedEnumSerializationProviderIsKeptWhenModelSerializationUsesEnum() + { + var statusEnum = InputFactory.StringEnum( + "Status", + [("Succeeded", "succeeded"), ("Failed", "failed")], + clientNamespace: "Sample"); + var resultModel = InputFactory.Model( + "OperationResult", + properties: [InputFactory.Property("Status", statusEnum, isRequired: true)], + @namespace: "Sample"); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: resultModel)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(resultModel, [])); + var client = InputFactory.Client("TestClient", methods: [method], clientNamespace: "Sample"); + + await GenerateAndAssertFiles( + enums: [statusEnum], + models: [resultModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "Custom", "Status.cs"), """ + namespace Sample; + + [CodeGenType("Status")] + public enum Status + { + Succeeded, + Failed + } + """) + ], + expectedFiles: [Path.Combine("src", "Generated", "Models", "Status.Serialization.cs")]); + } + + private static async Task GenerateAndAssertInternalModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames) + => await GenerateAndAssertModels(models, clients, modelNames, shouldBePublic: false); + + private static async Task GenerateAndAssertPublicModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames) + => await GenerateAndAssertModels(models, clients, modelNames, shouldBePublic: true); + + private static async Task GenerateAndAssertMixedModels( + InputModelType[] models, + InputClient[] clients, + string[] publicModelNames, + string[] internalModelNames) + => await GenerateAndAssertModels(models, clients, publicModelNames, internalModelNames); + + private static async Task GenerateAndAssertModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames, + bool shouldBePublic) + => await GenerateAndAssertModels( + models, + clients, + shouldBePublic ? modelNames : [], + shouldBePublic ? [] : modelNames); + + private static async Task GenerateAndAssertModels( + InputModelType[] models, + InputClient[] clients, + string[] publicModelNames, + string[] internalModelNames) + { + await GenerateAndAssertFiles( + enums: [], + models: models, + clients: clients, + customFiles: [], + publicModelNames: publicModelNames, + internalModelNames: internalModelNames, + expectedFiles: []); + } + + private static async Task GenerateAndAssertFiles( + InputEnumType[] enums, + InputModelType[] models, + InputClient[] clients, + (string Path, string Content)[] customFiles, + string[] expectedFiles, + string[] publicModelNames = null!, + string[] internalModelNames = null!) + { + publicModelNames ??= []; + internalModelNames ??= []; + + var outputPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + Directory.CreateDirectory(outputPath); + try + { + foreach (var customFile in customFiles) + { + var customPath = Path.Combine(outputPath, customFile.Path); + Directory.CreateDirectory(Path.GetDirectoryName(customPath)!); + File.WriteAllText(customPath, customFile.Content); + } + + await MockHelpers.LoadMockGeneratorAsync( + inputEnums: () => enums, + inputModels: () => models, + clients: () => clients, + configuration: "{\"package-name\": \"Sample\", \"disable-xml-docs\": true}", + outputPath: outputPath); + + await new CSharpGen().ExecuteAsync(); + + foreach (var modelName in publicModelNames) + { + var modelPath = Path.Combine(outputPath, "src", "Generated", "Models", $"{modelName}.cs"); + Assert.IsTrue(File.Exists(modelPath), $"Expected generated model file '{modelPath}'."); + var text = File.ReadAllText(modelPath); + StringAssert.Contains($"public partial class {modelName}", text, $"{modelName} should be public."); + } + + foreach (var modelName in internalModelNames) + { + var modelPath = Path.Combine(outputPath, "src", "Generated", "Models", $"{modelName}.cs"); + Assert.IsTrue(File.Exists(modelPath), $"Expected generated model file '{modelPath}'."); + var text = File.ReadAllText(modelPath); + StringAssert.Contains($"internal partial class {modelName}", text, $"{modelName} should be internal."); + StringAssert.DoesNotContain($"public partial class {modelName}", text, $"{modelName} should not be public."); + } + + var modelFactoryPath = Path.Combine(outputPath, "src", "Generated", "SampleModelFactory.cs"); + if (File.Exists(modelFactoryPath)) + { + var modelFactoryText = File.ReadAllText(modelFactoryPath); + foreach (var modelName in publicModelNames) + { + StringAssert.Contains($" {modelName}(", modelFactoryText, $"Model factory method for {modelName} should be generated."); + } + + foreach (var modelName in internalModelNames) + { + StringAssert.DoesNotContain($" {modelName}(", modelFactoryText, $"Model factory method for {modelName} should not be generated."); + } + } + + foreach (var expectedFile in expectedFiles) + { + var filePath = Path.Combine(outputPath, expectedFile); + Assert.IsTrue(File.Exists(filePath), $"Expected generated file '{filePath}'."); + } + } + finally + { + if (Directory.Exists(outputPath)) + { + Directory.Delete(outputPath, recursive: true); + } + } + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs index 9148f659e43..ee97df96d29 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs @@ -34,7 +34,8 @@ public static async Task> LoadMockGeneratorAsync( Func>? apiVersions = null, string? configuration = null, Func? createCSharpTypeCore = null, - Func? createCSharpTypeCoreFallback = null) + Func? createCSharpTypeCoreFallback = null, + string? outputPath = null) { var mockGenerator = LoadMockGenerator( inputLiterals: inputLiterals, @@ -44,7 +45,8 @@ public static async Task> LoadMockGeneratorAsync( apiVersions: apiVersions, configuration: configuration, createCSharpTypeCore: createCSharpTypeCore, - createCSharpTypeCoreFallback: createCSharpTypeCoreFallback); + createCSharpTypeCoreFallback: createCSharpTypeCoreFallback, + outputPath: outputPath); var compilationResult = compilation == null ? null : await compilation(); var lastContractCompilationResult = lastContractCompilation == null ? null : await lastContractCompilation(); @@ -76,7 +78,8 @@ public static Mock LoadMockGenerator( Func? createOutputLibrary = null, bool includeXmlDocs = false, Func? createCSharpTypeCoreFallback = null, - Func? createModelCore = null) + Func? createModelCore = null, + string? outputPath = null) { IReadOnlyList inputNsApiVersions = apiVersions?.Invoke() ?? []; IReadOnlyList inputNsLiterals = inputLiterals?.Invoke() ?? []; @@ -150,7 +153,7 @@ public static Mock LoadMockGenerator( { configuration = "{\"disable-xml-docs\": false, \"package-name\": \"Sample.Namespace\"}"; } - object?[] parameters = [_configFilePath, configuration]; + object?[] parameters = [outputPath ?? _configFilePath, configuration]; var config = loadMethod?.Invoke(null, parameters); var mockGeneratorContext = new Mock(config!); var mockGeneratorInstance = new Mock(mockGeneratorContext.Object) { CallBase = true }; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs index 9948fcff594..8ae87be33d1 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs @@ -103,6 +103,8 @@ await customCodeWorkspace.GetCompilationAsync(), // Add all the generated files to the workspace await Task.WhenAll(generateFilesTasks); + generatedCodeWorkspace.AnalyzeProviderReferenceMap(output.TypeProviders); + LoggingHelpers.LogElapsedTime("All generated types have been written into memory"); // Delete any old generated files diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index 4588b3c4839..b999674e9cf 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -83,6 +83,11 @@ public async Task AddInMemoryFile(TypeProvider type) await UpdateProject(document); } + internal void AnalyzeProviderReferenceMap(IReadOnlyList providers) + { + ProviderReferenceMapAnalyzer.Analyze(providers, _project); + } + private async Task UpdateProject(Document document) { var root = await document.GetSyntaxRootAsync(); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs index be96e11df59..e23db1086fb 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs @@ -10,6 +10,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Simplification; +using Microsoft.CodeAnalysis.Text; namespace Microsoft.TypeSpec.Generator { @@ -131,16 +132,21 @@ public async Task InternalizeAsync(Project project) // first get all the declared symbols var definitions = await GetTypeSymbolsAsync(compilation, project, true); - // build the reference map - var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( + IEnumerable symbolsToInternalize; + if (ProviderReferenceMapAnalyzer.LatestResult is { } referenceMapResult && referenceMapResult.ProjectId == project.Id) + { + // ProviderReferenceMapAnalyzer replaces Roslyn reference-map construction for generated code. + // It still uses Roslyn-discovered roots for custom/shared code before this point. + symbolsToInternalize = GetSymbolsByName(definitions.DeclaredSymbols, referenceMapResult.InternalizeCandidates).ToArray(); + } + else + { + var referenceMap = await new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( definitions.DeclaredSymbols, definitions.DeclaredNodesCache); - // get the root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); - // traverse all the root and recursively add all the things we met - var publicSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); - - var symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); + var rootSymbols = await GetRootSymbolsAsync(project, definitions); + var publicSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); + symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); + } var nodesToInternalize = new Dictionary(); foreach (var symbol in symbolsToInternalize) @@ -238,23 +244,30 @@ public async Task RemoveAsync(Project project) // find all the declarations, including non-public declared var definitions = await GetTypeSymbolsAsync(compilation, project, false); - // build reference map - var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( + IEnumerable symbolsToRemove; + HashSet referencedSet; + if (ProviderReferenceMapAnalyzer.LatestResult is { } referenceMapResult && referenceMapResult.ProjectId == project.Id) + { + // The remove pass uses the same precomputed hybrid map to avoid scanning all generated + // documents with Roslyn while preserving custom-code references as roots. + symbolsToRemove = GetSymbolsByName(definitions.DeclaredSymbols, referenceMapResult.RemoveCandidates).ToArray(); + referencedSet = new HashSet(definitions.DeclaredSymbols.Except(symbolsToRemove), SymbolEqualityComparer.Default); + } + else + { + var referenceMap = await new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( definitions.DeclaredSymbols, definitions.DocumentsCache); - // get root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); - // include model factory as a root symbol when doing the remove pass so that we are sure to include any internal - // helpers that are required by the model factory. - if (_modelFactorySymbol != null) - rootSymbols.Add(_modelFactorySymbol); - // traverse the map to determine the declarations that we are about to remove, starting from root nodes - var referencedSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); - - referencedSymbols = AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols); - var referencedSet = new HashSet(referencedSymbols, SymbolEqualityComparer.Default); + var rootSymbols = await GetRootSymbolsAsync(project, definitions); + if (_modelFactorySymbol != null) + { + rootSymbols.Add(_modelFactorySymbol); + } - var symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + var referencedSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); + referencedSymbols = AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols); + referencedSet = new HashSet(referencedSymbols, SymbolEqualityComparer.Default); + symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + } var nodesToRemove = new List(); foreach (var symbol in symbolsToRemove) @@ -334,6 +347,17 @@ private static IEnumerable GetReferencedTypes(T definition, return Enumerable.Empty(); } + private static IEnumerable GetSymbolsByName(IEnumerable symbols, HashSet names) + { + foreach (var symbol in symbols) + { + if (names.Contains(symbol.GetFullyQualifiedName())) + { + yield return symbol; + } + } + } + private Project MarkInternal(Project project, BaseTypeDeclarationSyntax declarationNode, DocumentId documentId) { var newNode = ChangeModifier(declarationNode, SyntaxKind.PublicKeyword, SyntaxKind.InternalKeyword); @@ -429,6 +453,12 @@ private async Task RemoveInvalidRefs(Project project) solution = await RemoveInvalidAttributes(solution, documentId); } + // Process each document for invalid XML doc cref attributes (with fresh semantic models) + foreach (var documentId in project.DocumentIds) + { + solution = await RemoveInvalidDocCrefs(solution, documentId); + } + return solution.GetProject(project.Id)!; } @@ -452,7 +482,14 @@ private async Task RemoveInvalidUsings(Solution solution, DocumentId d if (invalidUsings.Count > 0) { + var leadingTrivia = invalidUsings[0].GetLeadingTrivia(); cu = cu.RemoveNodes(invalidUsings, SyntaxRemoveOptions.KeepNoTrivia)!; + if (leadingTrivia.Count > 0) + { + var firstToken = cu.GetFirstToken(includeZeroWidth: true); + cu = cu.ReplaceToken(firstToken, firstToken.WithLeadingTrivia(leadingTrivia.AddRange(firstToken.LeadingTrivia))); + } + solution = solution.WithDocumentSyntaxRoot(documentId, cu); } @@ -533,6 +570,37 @@ arg.Expression is TypeOfExpressionSyntax typeOfExpr && return solution; } + private async Task RemoveInvalidDocCrefs(Solution solution, DocumentId documentId) + { + var document = solution.GetDocument(documentId)!; + var root = await document.GetSyntaxRootAsync(); + var model = await document.GetSemanticModelAsync(); + + if (root == null || model == null) + return solution; + + var invalidSeeElements = root.DescendantTrivia(descendIntoTrivia: true) + .SelectMany(static trivia => trivia.GetStructure()?.DescendantNodes().OfType() ?? []) + .Where(element => string.Equals(element.Name.LocalName.ValueText, "see", StringComparison.Ordinal)) + .Where(element => element.Attributes.OfType().Any(attribute => model.GetSymbolInfo(attribute.Cref).Symbol == null)) + .ToArray(); + + if (invalidSeeElements.Length == 0) + return solution; + + var text = await document.GetTextAsync(); + var source = text.ToString(); + foreach (var element in invalidSeeElements) + { + var cref = element.Attributes.OfType().First().Cref.ToString(); + var colonIndex = cref.IndexOf(':'); + var replacement = colonIndex >= 0 ? cref.Substring(colonIndex + 1) : cref; + source = source.Replace(element.ToFullString(), replacement, StringComparison.Ordinal); + } + + return solution.WithDocumentText(documentId, SourceText.From(source, text.Encoding)); + } + private async Task> GetRootSymbolsAsync(Project project, TypeSymbols modelSymbols) { var result = new HashSet(SymbolEqualityComparer.Default); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs new file mode 100644 index 00000000000..b8dc6c56f9e --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs @@ -0,0 +1,819 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.FindSymbols; + +namespace Microsoft.TypeSpec.Generator +{ + internal static class ProviderReferenceMapAnalyzer + { + private static ProviderReferenceMapResult? _latestResult; + + public static ProviderReferenceMapResult? LatestResult => _latestResult; + + public static void Analyze(IReadOnlyList providers, Project project) + { + var graph = BuildGraph(providers); + var publicGraph = BuildGraph(providers, publicOnly: true); + + // Generated-code dependencies come from providers. Custom code still needs Roslyn + // because arbitrary user C# can reference generated types in ways providers cannot see. + var customRoots = GetCustomCodeGeneratedTypeRoots(project, graph.Nodes); + + // Helper types are rooted after an initial reachability pass so unused infrastructure + // such as change-tracking dictionaries can still be removed when no reachable type needs them. + var internalizeReferences = CloneReferences(publicGraph.References); + var internalizeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: false); + var generatedPublicReachable = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, generatedPublicReachable); + internalizeRoots.UnionWith(customRoots); + var internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + var internalizeHelperRoots = GetHelperRootNames(providers, graph.Nodes, internalizeReachableWithoutHelpers); + internalizeRoots.UnionWith(internalizeHelperRoots); + var internalizeReachable = GetReachableTypes(internalizeRoots, internalizeReferences); + var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(providers, graph.Nodes, publicOnly: true); + var internalizeCandidates = internalizeDeclaredNodes.Except(internalizeReachable, StringComparer.Ordinal).OrderBy(static name => name, StringComparer.Ordinal).ToArray(); + + // Body-only generated dependencies are needed to avoid deleting helper files, but they do + // not contribute to public API reachability for internalization. + AddGeneratedBodyReferences(project, providers, graph); + + var removeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: true); + removeRoots.UnionWith(customRoots); + var removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + var removeHelperRoots = GetHelperRootNames(providers, graph.Nodes, removeReachableWithoutHelpers); + removeRoots.UnionWith(removeHelperRoots); + var removeReachable = GetReachableTypes(removeRoots, graph.References); + var removeDeclaredNodes = GetPostProcessorDeclaredNodes(providers, graph.Nodes, publicOnly: false); + var removeCandidates = removeDeclaredNodes.Except(removeReachable, StringComparer.Ordinal).OrderBy(static name => name, StringComparer.Ordinal).ToArray(); + + var helperRoots = internalizeHelperRoots.Concat(removeHelperRoots).ToHashSet(StringComparer.Ordinal); + + _latestResult = new ProviderReferenceMapResult( + project.Id, + internalizeCandidates.ToHashSet(StringComparer.Ordinal), + removeCandidates.ToHashSet(StringComparer.Ordinal)); + } + + private static HashSet GetCustomCodeGeneratedTypeRoots(Project project, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + var compilation = project.GetCompilationAsync().GetAwaiter().GetResult(); + if (compilation == null) + { + return roots; + } + + foreach (var document in project.Documents) + { + if (GeneratedCodeWorkspace.IsGeneratedDocument(document) || GeneratedCodeWorkspace.IsGeneratedTestDocument(document)) + { + continue; + } + + var root = document.GetSyntaxRootAsync().GetAwaiter().GetResult(); + if (root == null) + { + continue; + } + + var model = compilation.GetSemanticModel(root.SyntaxTree); + foreach (var declaration in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetDeclaredSymbol(declaration) as ITypeSymbol, generatedTypeNames); + } + + foreach (var typeSyntax in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetTypeInfo(typeSyntax).Type, generatedTypeNames); + } + + foreach (var objectCreation in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetSymbolInfo(objectCreation).Symbol?.ContainingType, generatedTypeNames); + } + + foreach (var invocation in root.DescendantNodes().OfType()) + { + AddSymbolRoot(roots, model.GetSymbolInfo(invocation).Symbol?.ContainingType, generatedTypeNames); + } + } + + return roots; + } + + private static void AddSymbolRoot(HashSet roots, ITypeSymbol? symbol, HashSet generatedTypeNames) + { + if (symbol is not INamedTypeSymbol namedType) + { + return; + } + + AddMatchingName(roots, namedType.GetFullyQualifiedName(), generatedTypeNames); + foreach (var typeArgument in namedType.TypeArguments) + { + AddSymbolRoot(roots, typeArgument, generatedTypeNames); + } + } + + private static ProviderReferenceGraph BuildGraph(IReadOnlyList providers, bool publicOnly = false) + { + var generatedProviders = GetGeneratedProviders(providers); + var serializationProviderNamesByType = providers + .Where(static provider => provider.SerializationProviders.Count > 0) + .ToDictionary( + static provider => GetProviderTypeName(provider.Type), + static provider => provider.SerializationProviders + .Select(static serializationProvider => GetProviderTypeName(serializationProvider.Type)) + .ToArray(), + StringComparer.Ordinal); + var nodes = generatedProviders + .Select(static provider => GetProviderTypeName(provider.Type)) + .ToHashSet(StringComparer.Ordinal); + var references = nodes.ToDictionary(static name => name, _ => new HashSet(StringComparer.Ordinal), StringComparer.Ordinal); + + foreach (var provider in generatedProviders) + { + var current = GetProviderTypeName(provider.Type); + AddTypeReference(references[current], provider.Type, nodes, serializationProviderNamesByType); + AddTypeReference(references[current], provider.BaseType, nodes, serializationProviderNamesByType); + AddTypeReference(references[current], provider.DeclaringTypeProvider?.Type, nodes, serializationProviderNamesByType); + + if (IsKept(provider.Type, CodeModelGenerator.Instance.NonRootTypes, nodes)) + { + continue; + } + + // Model factory signatures mention many models. The existing Roslyn post-processor + // removes factory methods for unreachable models, so model factory should only + // contribute helper dependencies, not model reachability edges. + if (IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var implementedType in provider.Implements) + { + AddTypeReference(references[current], implementedType, nodes, serializationProviderNamesByType); + } + + foreach (var nestedType in provider.NestedTypes) + { + AddTypeReference(references[current], nestedType.Type, nodes, serializationProviderNamesByType); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + AddTypeReference(references[current], serializationProvider.Type, nodes, serializationProviderNamesByType); + } + + foreach (var property in provider.Properties) + { + if (publicOnly && !IsPublic(property.Modifiers)) + { + continue; + } + + AddTypeReference(references[current], property.Type, nodes, serializationProviderNamesByType); + AddTypeReference(references[current], property.ExplicitInterface, nodes, serializationProviderNamesByType); + AddAttributes(references[current], property.Attributes, nodes, serializationProviderNamesByType); + } + + foreach (var field in provider.Fields) + { + if (publicOnly && !field.Modifiers.HasFlag(FieldModifiers.Public)) + { + continue; + } + + AddTypeReference(references[current], field.Type, nodes, serializationProviderNamesByType); + AddAttributes(references[current], field.Attributes, nodes, serializationProviderNamesByType); + } + + foreach (var constructor in provider.Constructors) + { + if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], constructor.Signature, nodes, serializationProviderNamesByType); + } + + foreach (var method in provider.Methods) + { + if (publicOnly && !IsPublic(method.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], method.Signature, nodes, serializationProviderNamesByType); + AddTypeReference(references[current], GetCollectionDefinitionType(method), nodes, serializationProviderNamesByType); + } + } + + return new ProviderReferenceGraph(nodes, references); + } + + private static CSharpType? GetCollectionDefinitionType(MethodProvider method) + { + var property = method.GetType().GetProperty("CollectionDefinition"); + return property?.GetValue(method) is TypeProvider collectionDefinition + ? collectionDefinition.Type + : null; + } + + private static bool IsPublic(MethodSignatureModifiers modifiers) => modifiers.HasFlag(MethodSignatureModifiers.Public); + + private static Dictionary> CloneReferences(IReadOnlyDictionary> references) + { + return references.ToDictionary( + static item => item.Key, + static item => item.Value.ToHashSet(StringComparer.Ordinal), + StringComparer.Ordinal); + } + + private static void AddDerivedModelReferences( + IReadOnlyList providers, + HashSet nodes, + Dictionary> references, + HashSet publicBaseModels) + { + var addedReference = true; + while (addedReference) + { + addedReference = false; + foreach (var provider in providers.OfType()) + { + if (provider.DiscriminatorProperty == null) + { + continue; + } + + if (!provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + if (!publicBaseModels.Contains(providerName)) + { + continue; + } + + foreach (var derivedModel in provider.DerivedModels) + { + var before = references[providerName].Count; + AddTypeReference(references[providerName], derivedModel.Type, nodes); + var derivedName = GetProviderTypeName(derivedModel.Type); + if (nodes.Contains(derivedName) && publicBaseModels.Add(derivedName) || references[providerName].Count != before) + { + addedReference = true; + } + } + } + } + } + + private static IReadOnlyList GetGeneratedProviders(IReadOnlyList providers) + { + var generatedProviders = new List(); + foreach (var provider in providers) + { + generatedProviders.Add(provider); + generatedProviders.AddRange(provider.SerializationProviders); + } + + return generatedProviders; + } + + private static void AddGeneratedBodyReferences(Project project, IReadOnlyList providers, ProviderReferenceGraph graph) + { + var compilation = project.GetCompilationAsync().GetAwaiter().GetResult(); + if (compilation == null) + { + return; + } + + foreach (var provider in GetBodyReferenceProviders(providers)) + { + if (IsModelFactoryProvider(provider)) + { + continue; + } + + if (!IsGeneratedBodyReferenceCandidate(provider)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!graph.Nodes.Contains(providerName)) + { + continue; + } + + AddProviderBodyDependencyTypes(graph.References[providerName], provider.BodyDependencyTypes, graph.Nodes); + + if (provider.BodyDependencyTypes.Count > 0) + { + continue; + } + + var symbol = compilation.GetTypeByMetadataName(providerName); + if (symbol == null) + { + continue; + } + + if (!IsSerializationProvider(provider)) + { + AddGeneratedReferencesToHelper(project, compilation, graph, providerName, symbol); + if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + foreach (var method in symbol.GetMembers().OfType()) + { + if (method.IsExtensionMethod) + { + AddGeneratedReferencesToHelper(project, compilation, graph, providerName, method); + } + } + } + } + + AddGeneratedBodyTypeReferences(project, compilation, graph, providerName, symbol); + } + } + + private static void AddProviderBodyDependencyTypes(HashSet references, IReadOnlyList dependencies, HashSet nodes) + { + foreach (var dependency in dependencies) + { + AddTypeReference(references, dependency, nodes); + } + } + + private static IReadOnlyList GetBodyReferenceProviders(IReadOnlyList providers) + { + var bodyReferenceProviders = new List(); + foreach (var provider in providers) + { + bodyReferenceProviders.Add(provider); + bodyReferenceProviders.AddRange(provider.SerializationProviders); + } + + return bodyReferenceProviders; + } + + private static bool IsGeneratedBodyReferenceCandidate(TypeProvider provider) + { + if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + return true; + } + + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return IsSerializationProvider(provider) || + relativePath.EndsWith("/Internal/ClientUriBuilder.cs", StringComparison.Ordinal) || + provider.BodyDependencyTypes.Count > 0; + } + + private static void AddGeneratedBodyTypeReferences(Project project, Compilation compilation, ProviderReferenceGraph graph, string ownerName, INamedTypeSymbol ownerSymbol) + { + foreach (var syntaxReference in ownerSymbol.DeclaringSyntaxReferences) + { + var document = project.GetDocument(syntaxReference.SyntaxTree); + if (document == null || !GeneratedCodeWorkspace.IsGeneratedDocument(document)) + { + continue; + } + + var root = syntaxReference.SyntaxTree.GetRoot(); + var semanticModel = compilation.GetSemanticModel(syntaxReference.SyntaxTree); + foreach (var typeSyntax in root.DescendantNodes().OfType()) + { + // Declaration names are the owner itself. The old Roslyn map captures references, + // not a declaration making itself reachable. + if (typeSyntax.Parent is BaseTypeDeclarationSyntax baseTypeDeclaration && baseTypeDeclaration.Identifier.Span == typeSyntax.Span) + { + continue; + } + + AddBodyTypeReference(graph.References[ownerName], semanticModel.GetTypeInfo(typeSyntax).Type, graph.Nodes); + } + } + } + + private static void AddBodyTypeReference(HashSet references, ITypeSymbol? symbol, HashSet nodes) + { + if (symbol is not INamedTypeSymbol namedType || namedType.TypeKind == TypeKind.Error) + { + return; + } + + AddMatchingName(references, namedType.GetFullyQualifiedName(), nodes); + if (namedType.TypeKind == TypeKind.Enum) + { + AddMatchingName(references, $"{namedType.Name}Extensions", nodes); + } + + foreach (var typeArgument in namedType.TypeArguments) + { + AddBodyTypeReference(references, typeArgument, nodes); + } + } + + private static void AddGeneratedReferencesToHelper(Project project, Compilation compilation, ProviderReferenceGraph graph, string helperName, ISymbol symbol) + { + foreach (var reference in SymbolFinder.FindReferencesAsync(symbol, project.Solution).GetAwaiter().GetResult()) + { + foreach (var location in reference.Locations) + { + var document = location.Document; + if (!GeneratedCodeWorkspace.IsGeneratedDocument(document)) + { + continue; + } + + var root = document.GetSyntaxRootAsync().GetAwaiter().GetResult(); + if (root == null) + { + continue; + } + + var node = root.FindNode(location.Location.SourceSpan); + var owner = node.AncestorsAndSelf().OfType().FirstOrDefault(); + if (owner == null) + { + continue; + } + + var semanticModel = compilation.GetSemanticModel(owner.SyntaxTree); + if (semanticModel.GetDeclaredSymbol(owner) is not INamedTypeSymbol ownerSymbol) + { + continue; + } + + var ownerName = ownerSymbol.GetFullyQualifiedName(); + if (graph.Nodes.Contains(ownerName)) + { + graph.References[ownerName].Add(helperName); + } + } + } + } + + private static HashSet GetRootNames(IReadOnlyList providers, HashSet nodes, HashSet helperRoots, bool includeModelFactory) + { + var generator = CodeModelGenerator.Instance; + var roots = new HashSet(StringComparer.Ordinal); + var modelFactoryName = GetProviderTypeName(generator.OutputLibrary.ModelFactory.Value.Type); + + foreach (var provider in providers) + { + var name = GetProviderTypeName(provider.Type); + if (IsClientProviderRoot(provider) || + IsKept(provider.Type, generator.AdditionalRootTypes, nodes) || + includeModelFactory && string.Equals(name, modelFactoryName, StringComparison.Ordinal) || + includeModelFactory && helperRoots.Contains(name)) + { + roots.Add(name); + } + } + + foreach (var root in generator.TypeFactory.UnionVariantTypesToKeep) + { + AddMatchingName(roots, root, nodes); + } + + foreach (var root in generator.AdditionalRootTypes) + { + AddMatchingName(roots, root, nodes); + } + + return roots; + } + + private static HashSet GetPostProcessorDeclaredNodes(IReadOnlyList providers, HashSet nodes, bool publicOnly) + { + var generator = CodeModelGenerator.Instance; + var excludedNames = generator.NonRootTypes; + return GetGeneratedProviders(providers) + .Where(provider => !IsModelFactoryProvider(provider)) + .Where(provider => !publicOnly || provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + .Select(provider => GetProviderTypeName(provider.Type)) + .Where(name => nodes.Contains(name)) + .Where(name => !excludedNames.Contains(name) && !excludedNames.Contains(GetSimpleName(name))) + .ToHashSet(StringComparer.Ordinal); + } + + private static bool IsKept(CSharpType type, HashSet roots, HashSet nodes) => + roots.Contains(type.Name) || roots.Contains(GetProviderTypeName(type)) && nodes.Contains(GetProviderTypeName(type)); + + private static bool IsClientProviderRoot(TypeProvider provider) => + provider.RelativeFilePath.EndsWith("Client.cs", StringComparison.Ordinal); + + private static bool IsModelFactoryProvider(TypeProvider provider) + { + if (provider is ModelFactoryProvider) + { + return true; + } + + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return relativePath.EndsWith("ModelFactory.cs", StringComparison.Ordinal); + } + + private static HashSet GetHelperRootNames(IReadOnlyList providers, HashSet nodes, HashSet reachableTypes) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + var isModelFactory = IsModelFactoryProvider(provider); + if (!reachableTypes.Contains(providerName) && !isModelFactory) + { + continue; + } + + AddHelperDependencies(roots, provider.HelperDependencyNames, nodes); + + foreach (var property in provider.Properties) + { + AddInitializationHelperRoot(roots, property.Type, nodes); + AddParameterValidationHelperRoot(roots, property.AsParameter, nodes); + } + + foreach (var field in provider.Fields) + { + AddParameterValidationHelperRoot(roots, field.AsParameter, nodes); + } + + foreach (var constructor in provider.Constructors) + { + foreach (var parameter in constructor.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + } + } + + foreach (var method in provider.Methods) + { + // Only factory methods for reachable models can instantiate collection helpers. + if (isModelFactory && + (method.Signature.ReturnType == null || !reachableTypes.Contains(GetProviderTypeName(method.Signature.ReturnType)))) + { + continue; + } + + foreach (var parameter in method.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + if (isModelFactory) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, parameter.Type, nodes); + } + } + } + } + + return roots; + } + + private static void AddParameterValidationHelperRoot(HashSet roots, ParameterProvider parameter, HashSet nodes) + { + if (parameter.Validation != ParameterValidationType.None) + { + AddMatchingName(roots, "Argument", nodes); + } + } + + private static void AddHelperDependencies(HashSet roots, IReadOnlyList dependencies, HashSet nodes) + { + foreach (var dependency in dependencies) + { + AddMatchingName(roots, dependency, nodes); + } + } + + private static bool IsSerializationProvider(TypeProvider provider) + { + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return relativePath.EndsWith(".Serialization.cs", StringComparison.Ordinal) || + relativePath.EndsWith(".Serialization.Multipart.cs", StringComparison.Ordinal); + } + + private static void AddInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + var initializationType = type.PropertyInitializationType; + if (!string.Equals(initializationType.FullyQualifiedName, type.FullyQualifiedName, StringComparison.Ordinal)) + { + AddMatchingName(roots, initializationType.Name, nodes); + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddMatchingName(roots, "ChangeTrackingList", nodes); + } + + foreach (var argument in type.Arguments) + { + AddInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddModelFactoryCollectionInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddMatchingName(roots, "ChangeTrackingList", nodes); + } + + if (type.IsDictionary) + { + AddMatchingName(roots, "ChangeTrackingDictionary", nodes); + } + + foreach (var argument in type.Arguments) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddMatchingName(HashSet target, string name, HashSet nodes) + { + if (nodes.Contains(name)) + { + target.Add(name); + return; + } + + foreach (var node in nodes) + { + if (string.Equals(StripGenericArity(GetSimpleName(node)), name, StringComparison.Ordinal)) + { + target.Add(node); + } + } + } + + private static HashSet GetReachableTypes(HashSet roots, IReadOnlyDictionary> references) + { + var reachable = new HashSet(StringComparer.Ordinal); + var queue = new Queue(roots); + while (queue.Count > 0) + { + var current = queue.Dequeue(); + if (!reachable.Add(current)) + { + continue; + } + + if (!references.TryGetValue(current, out var children)) + { + continue; + } + + foreach (var child in children) + { + queue.Enqueue(child); + } + } + + return reachable; + } + + private static void AddSignatureReferences( + HashSet references, + MethodSignatureBase signature, + HashSet nodes, + IReadOnlyDictionary serializationProviderNamesByType) + { + AddTypeReference(references, signature.ReturnType, nodes, serializationProviderNamesByType); + AddAttributes(references, signature.Attributes, nodes, serializationProviderNamesByType); + + foreach (var parameter in signature.Parameters) + { + AddTypeReference(references, parameter.Type, nodes, serializationProviderNamesByType); + AddAttributes(references, parameter.Attributes, nodes, serializationProviderNamesByType); + } + + if (signature is MethodSignature methodSignature) + { + AddTypeReference(references, methodSignature.ExplicitInterface, nodes, serializationProviderNamesByType); + if (methodSignature.GenericArguments != null) + { + foreach (var genericArgument in methodSignature.GenericArguments) + { + AddTypeReference(references, genericArgument, nodes, serializationProviderNamesByType); + } + } + + if (methodSignature.GenericParameterConstraints != null) + { + foreach (var constraint in methodSignature.GenericParameterConstraints) + { + AddTypeReference(references, constraint.Type, nodes, serializationProviderNamesByType); + } + } + } + + if (signature is ConstructorSignature constructorSignature) + { + AddTypeReference(references, constructorSignature.Type, nodes, serializationProviderNamesByType); + } + } + + private static void AddAttributes( + HashSet references, + IReadOnlyList attributes, + HashSet nodes, + IReadOnlyDictionary serializationProviderNamesByType) + { + foreach (var attribute in attributes) + { + AddTypeReference(references, attribute.Type, nodes, serializationProviderNamesByType); + } + } + + private static void AddTypeReference( + HashSet references, + CSharpType? type, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType = null) + { + if (type == null) + { + return; + } + + var providerTypeName = GetProviderTypeName(type); + if (nodes.Contains(providerTypeName)) + { + references.Add(providerTypeName); + if (serializationProviderNamesByType != null && serializationProviderNamesByType.TryGetValue(providerTypeName, out var serializationProviderNames)) + { + foreach (var serializationProviderName in serializationProviderNames) + { + references.Add(serializationProviderName); + } + } + } + + AddTypeReference(references, type.BaseType, nodes, serializationProviderNamesByType); + AddTypeReference(references, type.DeclaringType, nodes, serializationProviderNamesByType); + foreach (var argument in type.Arguments) + { + AddTypeReference(references, argument, nodes, serializationProviderNamesByType); + } + } + + private static string GetSimpleName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? fullyQualifiedName : fullyQualifiedName.Substring(lastDot + 1); + } + + private static string GetProviderTypeName(CSharpType type) + { + var name = type.Arguments.Count > 0 && !type.Name.Contains('`', StringComparison.Ordinal) + ? $"{type.Name}`{type.Arguments.Count}" + : type.Name; + return string.IsNullOrEmpty(type.Namespace) ? name : $"{type.Namespace}.{name}"; + } + + private static string StripGenericArity(string name) + { + var tick = name.IndexOf('`'); + return tick < 0 ? name : name.Substring(0, tick); + } + + private sealed record ProviderReferenceGraph( + HashSet Nodes, + Dictionary> References); + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs new file mode 100644 index 00000000000..e5623b1e7d2 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; +using Microsoft.CodeAnalysis; + +namespace Microsoft.TypeSpec.Generator +{ + internal sealed record ProviderReferenceMapResult( + ProjectId ProjectId, + HashSet InternalizeCandidates, + HashSet RemoveCandidates) + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs index 7a3047886b6..9762dfbf311 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs @@ -272,6 +272,12 @@ private IReadOnlyList ApplyCustomizationFilter(IEnumerable SerializationProviders => _serializationProviders ??= BuildSerializationProviders(); + internal IReadOnlyList HelperDependencyNames => BuildHelperDependencyNames(); + protected internal virtual IReadOnlyList BuildHelperDependencyNames() => []; + + internal IReadOnlyList BodyDependencyTypes => BuildBodyDependencyTypes(); + protected internal virtual IReadOnlyList BuildBodyDependencyTypes() => []; + private IReadOnlyList? _attributes; public IReadOnlyList Attributes diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs index c207c3e130b..120fafb9c57 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs @@ -192,11 +192,6 @@ protected internal TypeFactory() if (modelProvider != null) { - if (model.Access == "public") - { - CodeModelGenerator.Instance.AddTypeToKeep(modelProvider); - } - CSharpTypeMap[modelProvider.Type] = modelProvider; TypeProvidersByName[modelProvider.Type.Name] = modelProvider; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs index 28981148a4d..4100756ff15 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs @@ -60,6 +60,52 @@ public async Task RemovesInvalidUsings() CollectionAssert.Contains(usings, "System"); } + [Test] + public async Task RemovesInvalidUsingsKeepsFileHeader() + { + MockHelpers.LoadMockGenerator(); + var workspace = new AdhocWorkspace(); + var projectInfo = ProjectInfo.Create( + ProjectId.CreateNewId(), + VersionStamp.Create(), + name: "TestProj", + assemblyName: "TestProj", + language: LanguageNames.CSharp) + .WithMetadataReferences(new[] + { + MetadataReference.CreateFromFile(typeof(object).Assembly.Location) + }); + + var project = workspace.AddProject(projectInfo); + project = AddGeneratedDocument(project, "RootClient.cs", "src", "Generated", "RootClient.cs", """ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// + +#nullable disable + +using Missing.Namespace; + +namespace Sample +{ + public partial class RootClient + { + } +} +"""); + var postProcessor = new TestPostProcessor("RootClient.cs"); + + var resultProject = await postProcessor.RemoveAsync(project); + var doc = resultProject.Documents.Single(d => d.Name == "RootClient.cs"); + var text = (await doc.GetTextAsync()).ToString(); + + StringAssert.StartsWith("// Copyright (c) Microsoft Corporation. All rights reserved.", text); + StringAssert.Contains("// ", text); + StringAssert.Contains("#nullable disable", text); + StringAssert.DoesNotContain("using Missing.Namespace;", text); + } + [Test] public async Task DoesNotRemoveValidUsings() @@ -289,11 +335,249 @@ public async Task DoesNotRemoveValidAttributes() Assert.AreEqual(Helpers.GetExpectedFromFile().TrimEnd(), output, "The output should match the expected content."); } + [Test] + public async Task ModelFactoryMethodDoesNotKeepModelPublic() + { + MockHelpers.LoadMockGenerator(); + var workspace = new AdhocWorkspace(); + var projectInfo = ProjectInfo.Create( + ProjectId.CreateNewId(), + VersionStamp.Create(), + name: "TestProj", + assemblyName: "TestProj", + language: LanguageNames.CSharp) + .WithMetadataReferences(new[] + { + MetadataReference.CreateFromFile(typeof(object).Assembly.Location) + }); + + var project = workspace.AddProject(projectInfo); + project = AddGeneratedDocument(project, "RootClient.cs", "src", "Generated", "RootClient.cs", """ +namespace Sample +{ + public partial class RootClient + { + } +} +"""); + project = AddGeneratedDocument(project, "UnusedModel.cs", "src", "Generated", "Models", "UnusedModel.cs", """ +namespace Sample.Models +{ + public partial class UnusedModel + { + } +} +"""); + project = AddGeneratedDocument(project, "SampleModelFactory.cs", "src", "Generated", "SampleModelFactory.cs", """ +namespace Sample +{ + public static partial class SampleModelFactory + { + public static Sample.Models.UnusedModel UnusedModel() => new Sample.Models.UnusedModel(); + } +} +"""); + + var postProcessor = new TestPostProcessor("RootClient.cs", modelFactoryFullName: "Sample.SampleModelFactory"); + + var resultProject = await postProcessor.InternalizeAsync(project); + var modelDoc = resultProject.Documents.Single(d => d.Name == "UnusedModel.cs"); + var modelRoot = await modelDoc.GetSyntaxRootAsync(); + var modelDeclaration = modelRoot!.DescendantNodes().OfType().Single(); + Assert.IsTrue(modelDeclaration.Modifiers.Any(m => m.Text == "internal")); + Assert.IsFalse(modelDeclaration.Modifiers.Any(m => m.Text == "public")); + + var modelFactoryDoc = resultProject.Documents.SingleOrDefault(d => d.Name == "SampleModelFactory.cs"); + if (modelFactoryDoc != null) + { + var modelFactoryRoot = await modelFactoryDoc.GetSyntaxRootAsync(); + Assert.IsFalse(modelFactoryRoot!.DescendantNodes().OfType().Any(), + "Model factory methods for internalized models should be removed."); + } + } + + [TestCase("RequestBodyModel")] + [TestCase("ResponseBodyModel")] + [TestCase("NestedToolConfig")] + [TestCase("BodyOnlyValueKind")] + public async Task BodyOnlyGeneratedTypesDoNotBecomePublic(string bodyOnlyTypeName) + { + var project = CreateProjectWithRootClient(); + var typeDeclaration = bodyOnlyTypeName == "BodyOnlyValueKind" + ? $$""" +namespace Sample.Models +{ + public enum {{bodyOnlyTypeName}} + { + Default + } +} +""" + : $$""" +namespace Sample.Models +{ + public partial class {{bodyOnlyTypeName}} + { + } +} +"""; + project = AddGeneratedDocument(project, $"{bodyOnlyTypeName}.cs", "src", "Generated", "Models", $"{bodyOnlyTypeName}.cs", typeDeclaration); + var factoryExpression = bodyOnlyTypeName == "BodyOnlyValueKind" + ? $"Sample.Models.{bodyOnlyTypeName}.Default" + : $"new Sample.Models.{bodyOnlyTypeName}()"; + project = AddGeneratedDocument(project, "SampleModelFactory.cs", "src", "Generated", "SampleModelFactory.cs", $$""" +namespace Sample +{ + public static partial class SampleModelFactory + { + public static Sample.Models.{{bodyOnlyTypeName}} {{bodyOnlyTypeName}}() => {{factoryExpression}}; + } +} +"""); + + var resultProject = await new TestPostProcessor("RootClient.cs", modelFactoryFullName: "Sample.SampleModelFactory").InternalizeAsync(project); + + await AssertGeneratedTypeInternal(resultProject, $"{bodyOnlyTypeName}.cs", bodyOnlyTypeName); + await AssertModelFactoryMethodRemoved(resultProject, bodyOnlyTypeName); + } + + [Test] + public async Task BodyOnlyNestedGraphDoesNotBecomePublic() + { + var project = CreateProjectWithRootClient(); + project = AddGeneratedDocument(project, "ToolConfig.cs", "src", "Generated", "Models", "ToolConfig.cs", """ +namespace Sample.Models +{ + public partial class ToolConfig + { + public ToolParameter Parameter { get; set; } + } +} +"""); + project = AddGeneratedDocument(project, "ToolParameter.cs", "src", "Generated", "Models", "ToolParameter.cs", """ +namespace Sample.Models +{ + public partial class ToolParameter + { + } +} +"""); + project = AddGeneratedDocument(project, "SampleModelFactory.cs", "src", "Generated", "SampleModelFactory.cs", """ +namespace Sample +{ + public static partial class SampleModelFactory + { + public static Sample.Models.ToolConfig ToolConfig() => new Sample.Models.ToolConfig(); + public static Sample.Models.ToolParameter ToolParameter() => new Sample.Models.ToolParameter(); + } +} +"""); + + var resultProject = await new TestPostProcessor("RootClient.cs", modelFactoryFullName: "Sample.SampleModelFactory").InternalizeAsync(project); + + await AssertGeneratedTypeInternal(resultProject, "ToolConfig.cs", "ToolConfig"); + await AssertGeneratedTypeInternal(resultProject, "ToolParameter.cs", "ToolParameter"); + await AssertModelFactoryMethodRemoved(resultProject, "ToolConfig"); + await AssertModelFactoryMethodRemoved(resultProject, "ToolParameter"); + } + + [Test] + public async Task ClientOptionsNestedEnumDoesNotBecomePublicWhenOptionsIsNotRooted() + { + var project = CreateProjectWithRootClient("public SampleClientOptions Options { get; }"); + project = AddGeneratedDocument(project, "SampleClientOptions.cs", "src", "Generated", "SampleClientOptions.cs", """ +namespace Sample +{ + public partial class SampleClientOptions + { + public enum ServiceVersion + { + V1 + } + } +} +"""); + + var resultProject = await new TestPostProcessor("RootClient.cs").InternalizeAsync(project); + + await AssertGeneratedTypePublic(resultProject, "SampleClientOptions.cs", "SampleClientOptions"); + var optionsDoc = resultProject.Documents.Single(d => d.Name == "SampleClientOptions.cs"); + var root = await optionsDoc.GetSyntaxRootAsync(); + var nestedEnum = root!.DescendantNodes().OfType().Single(); + Assert.IsFalse(nestedEnum.Modifiers.Any(m => m.Text == "public")); + } + + private static Project AddGeneratedDocument(Project project, string name, string folder1, string folder2, string fileName, string text) + => project.AddDocument(name, text, folders: [folder1, folder2], filePath: Path.Join(folder1, folder2, fileName)).Project; + + private static Project AddGeneratedDocument(Project project, string name, string folder1, string folder2, string folder3, string fileName, string text) + => project.AddDocument(name, text, folders: [folder1, folder2, folder3], filePath: Path.Join(folder1, folder2, folder3, fileName)).Project; + + private static Project CreateProjectWithRootClient(string members = "") + { + MockHelpers.LoadMockGenerator(); + var workspace = new AdhocWorkspace(); + var projectInfo = ProjectInfo.Create( + ProjectId.CreateNewId(), + VersionStamp.Create(), + name: "TestProj", + assemblyName: "TestProj", + language: LanguageNames.CSharp) + .WithMetadataReferences(new[] + { + MetadataReference.CreateFromFile(typeof(object).Assembly.Location) + }); + + var project = workspace.AddProject(projectInfo); + return AddGeneratedDocument(project, "RootClient.cs", "src", "Generated", "RootClient.cs", $$""" +namespace Sample +{ + public partial class RootClient + { + {{members}} + } +} +"""); + } + + private static async Task AssertGeneratedTypeInternal(Project project, string documentName, string typeName) + { + var declaration = await GetTypeDeclaration(project, documentName, typeName); + Assert.IsTrue(declaration.Modifiers.Any(m => m.Text == "internal"), $"{typeName} should be internal."); + Assert.IsFalse(declaration.Modifiers.Any(m => m.Text == "public"), $"{typeName} should not be public."); + } + + private static async Task AssertGeneratedTypePublic(Project project, string documentName, string typeName) + { + var declaration = await GetTypeDeclaration(project, documentName, typeName); + Assert.IsTrue(declaration.Modifiers.Any(m => m.Text == "public"), $"{typeName} should be public."); + } + + private static async Task GetTypeDeclaration(Project project, string documentName, string typeName) + { + var document = project.Documents.Single(d => d.Name == documentName); + var root = await document.GetSyntaxRootAsync(); + return root!.DescendantNodes().OfType().Single(t => t.Identifier.Text == typeName); + } + + private static async Task AssertModelFactoryMethodRemoved(Project project, string methodName) + { + var modelFactoryDoc = project.Documents.SingleOrDefault(d => d.Name == "SampleModelFactory.cs"); + if (modelFactoryDoc == null) + { + return; + } + + var modelFactoryRoot = await modelFactoryDoc.GetSyntaxRootAsync(); + Assert.IsFalse(modelFactoryRoot!.DescendantNodes().OfType().Any(m => m.Identifier.Text == methodName), + $"Model factory method {methodName} should be removed."); + } + private class TestPostProcessor : PostProcessor { private readonly string _rootFile; - public TestPostProcessor(string rootFile, IEnumerable? nonRootTypes = null) : base([], additionalNonRootTypeNames: nonRootTypes) + public TestPostProcessor(string rootFile, IEnumerable? additionalRootTypeNames = null, IEnumerable? nonRootTypes = null, string? modelFactoryFullName = null) : base((additionalRootTypeNames ?? []).ToHashSet(), modelFactoryFullName: modelFactoryFullName, additionalNonRootTypeNames: nonRootTypes) { _rootFile = rootFile; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs index 1bdf4020167..c35f1968d76 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs @@ -177,6 +177,32 @@ public async Task OmitsModelFactoryMethodIfParamTypeInternal() Assert.IsNull(modelFactory); } + // This test validates that a derived model customized to be internal does not get a + // public model factory method just because its base model remains public. + [Test] + public async Task OmitsModelFactoryMethodIfDerivedModelTypeInternal() + { + var baseModel = InputFactory.Model( + "baseModel", + properties: [InputFactory.Property("BaseProp", InputPrimitiveType.String)]); + var derivedModel = InputFactory.Model( + "derivedModel", + properties: [InputFactory.Property("DerivedProp", InputPrimitiveType.String)], + baseModel: baseModel); + + var mockGenerator = await MockHelpers.LoadMockGeneratorAsync( + inputModelTypes: [baseModel, derivedModel], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + var csharpGen = new CSharpGen(); + + await csharpGen.ExecuteAsync(); + + var modelFactory = mockGenerator.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider); + Assert.IsNotNull(modelFactory); + CollectionAssert.Contains(modelFactory!.Methods.Select(m => m.Signature.Name), "BaseModel"); + CollectionAssert.DoesNotContain(modelFactory.Methods.Select(m => m.Signature.Name), "DerivedModel"); + } + [TestCase(true)] [TestCase(false)] public async Task CanCustomizeModelFullConstructor(bool extraParameters) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs new file mode 100644 index 00000000000..bdb2034f5f0 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Sample.Models +{ + internal partial class DerivedModel + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs index 5937838822e..4b82c28bf77 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs @@ -1449,7 +1449,7 @@ await MockHelpers.LoadMockGeneratorAsync( } [Test] - public void PublicModelsAreIncludedInAdditionalRootTypes() + public void PublicModelsAreNotIncludedInAdditionalRootTypes() { var inputModel = InputFactory.Model( "MockInputModel", @@ -1462,7 +1462,7 @@ public void PublicModelsAreIncludedInAdditionalRootTypes() Assert.IsNotNull(modelProvider); var rootTypes = CodeModelGenerator.Instance.AdditionalRootTypes; - Assert.IsTrue(rootTypes.Contains("Sample.Models.MockInputModel")); + Assert.IsFalse(rootTypes.Contains("Sample.Models.MockInputModel")); } [Test]