Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ struct ImportDescription: Equatable, Codable {
/// would be `@_spi(Secret) import Foo`.
var spi: String? = nil

/// The access modifier to apply to the import statement.
///
/// When set to `.public` or `.package`, the modifier is prepended to the
/// import statement (e.g. `public import Foo`).
var accessModifier: AccessModifier? = nil

/// Requirements for the `@preconcurrency` attribute.
var preconcurrency: PreconcurrencyRequirement = .never

Expand Down
17 changes: 15 additions & 2 deletions Sources/_OpenAPIGeneratorCore/Renderer/TextBasedRenderer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,26 @@ struct TextBasedRenderer: RendererProtocol {

/// Renders a single import statement.
func renderImport(_ description: ImportDescription) {
let accessModifierPrefix: String
switch description.accessModifier {
case .public:
accessModifierPrefix = renderedAccessModifier(.public) + " "
case .package:
accessModifierPrefix = renderedAccessModifier(.package) + " "
default:
accessModifierPrefix = ""
}

func render(preconcurrency: Bool) {
let spiPrefix = description.spi.map { "@_spi(\($0)) " } ?? ""
let preconcurrencyPrefix = preconcurrency ? "@preconcurrency " : ""
let attributePrefix = "\(preconcurrencyPrefix)\(spiPrefix)"
if let moduleTypes = description.moduleTypes {
for type in moduleTypes { writer.writeLine("\(preconcurrencyPrefix)\(spiPrefix)import \(type)") }
for type in moduleTypes {
writer.writeLine("\(attributePrefix)\(accessModifierPrefix)import \(type)")
}
} else {
writer.writeLine("\(preconcurrencyPrefix)\(spiPrefix)import \(description.moduleName)")
writer.writeLine("\(attributePrefix)\(accessModifierPrefix)import \(description.moduleName)")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ struct ClientFileTranslator: FileTranslator {

let topComment = self.topComment

let imports =
Constants.File.clientServerImports + config.additionalImports.map { ImportDescription(moduleName: $0) }
let imports = importDescriptions(adding: Constants.File.clientServerImports)

let clientMethodDecls = try OperationDescription.all(from: doc.paths, in: components, context: context)
.map(translateClientMethod(_:))
Expand Down
27 changes: 27 additions & 0 deletions Sources/_OpenAPIGeneratorCore/Translator/FileTranslator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,33 @@ extension FileTranslator {
var topComment: Comment {
.inline(([Constants.File.topComment] + config.additionalFileComments).joined(separator: "\n"))
}

/// Returns the imports for the generated file, with access modifiers applied.
///
/// The configured access modifier is propagated to the built-in imports so
/// that, under Swift 6's `InternalImportsByDefault` flag, generated
/// declarations can re-export the symbols they depend on.
/// `additionalImports` from the configuration are also given the same
/// access modifier so they are visible to consumers of the generated code.
/// - Parameter baseImports: the base set of imports for the file (e.g.
/// ``Constants/File/imports`` or ``Constants/File/clientServerImports``).
/// - Returns: An array of ``ImportDescription`` values with appropriate
/// access modifier set.
func importDescriptions(adding baseImports: [ImportDescription]) -> [ImportDescription] {
let accessModifier: AccessModifier?
switch config.access {
case .public, .package:
accessModifier = config.access
default:
accessModifier = nil
}
let allImports: [ImportDescription] = baseImports + config.additionalImports.map { ImportDescription(moduleName: $0) }
return allImports.map { original in
var description = original
description.accessModifier = accessModifier
return description
}
}
}

/// A set of configuration values for concrete file translators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ struct ServerFileTranslator: FileTranslator {

let topComment = self.topComment

let imports =
Constants.File.clientServerImports + config.additionalImports.map { ImportDescription(moduleName: $0) }

let imports = importDescriptions(adding: Constants.File.clientServerImports)

let allOperations = try OperationDescription.all(from: doc.paths, in: components, context: context)

let (registerHandlersDecl, serverMethodDecls) = try translateRegisterHandlers(allOperations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct TypesFileTranslator: FileTranslator {

let topComment = self.topComment

let imports = Constants.File.imports + config.additionalImports.map { ImportDescription(moduleName: $0) }
let imports = importDescriptions(adding: Constants.File.imports)

let apiProtocol = try translateAPIProtocol(doc.paths)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,70 @@ final class Test_TextBasedRenderer: XCTestCase {
)
}

func testImportsWithPublicAccessModifier() throws {
try _test(
[ImportDescription(moduleName: "Foo", accessModifier: .public)],
renderedBy: TextBasedRenderer.renderImports,
rendersAs: #"""
public import Foo
"""#
)
}

func testImportsWithPackageAccessModifier() throws {
try _test(
[ImportDescription(moduleName: "Foo", accessModifier: .package)],
renderedBy: TextBasedRenderer.renderImports,
rendersAs: #"""
package import Foo
"""#
)
}

func testImportsWithInternalAccessModifier() throws {
try _test(
[ImportDescription(moduleName: "Foo", accessModifier: .internal)],
renderedBy: TextBasedRenderer.renderImports,
rendersAs: #"""
import Foo
"""#
)
}

func testImportsWithAccessModifierAndAttributes() throws {
try _test(
[ImportDescription(moduleName: "Foo", spi: "Secret", accessModifier: .public, preconcurrency: .always)],
renderedBy: TextBasedRenderer.renderImports,
rendersAs: #"""
@preconcurrency @_spi(Secret) public import Foo
"""#
)
}

func testImportsWithAccessModifierAndModuleTypes() throws {
try _test(
[ImportDescription(moduleName: "Foundation", moduleTypes: ["struct Foundation.URL"], accessModifier: .public)],
renderedBy: TextBasedRenderer.renderImports,
rendersAs: #"""
public import struct Foundation.URL
"""#
)
}

func testImportsWithAccessModifierAndPreconcurrencyOnOS() throws {
try _test(
[ImportDescription(moduleName: "Foo", accessModifier: .public, preconcurrency: .onOS(["Linux"]))],
renderedBy: TextBasedRenderer.renderImports,
rendersAs: #"""
#if os(Linux)
@preconcurrency public import Foo
#else
public import Foo
#endif
"""#
)
}

func testAccessModifiers() throws {
try _test(
.public,
Expand Down