Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 98 additions & 32 deletions src/IMessage/Sources/IMDatabase/Database/IMDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ public final class IMDatabase {
// file watchers for `chat.db` and `chat.db-wal`; these need to be
// dynamically populated because the WAL can be deleted and (re)created at
// any time
private var fileWatchers = [FileWatcher]()
private var fileWatchers = [String: FileWatcher]()

private let listenerLock = Protected(())
private var directoryWatcher: FSEventsWatcher?
private var debouncer: Task<Void, Never>?

public var noisy = false
Expand Down Expand Up @@ -72,10 +74,9 @@ public final class IMDatabase {

deinit {
log.debug("being deallocated, stopping watchers and listeners if necessary")
for watcher in fileWatchers {
watcher.stopListeningIfNecessary()
listenerLock.withLock { _ in
stopListeningForChangesLocked()
}
debouncer?.cancel()
}
}

Expand All @@ -94,10 +95,50 @@ private extension IMDatabase {

public extension IMDatabase {
func beginListeningForChanges() throws {
try listenerLock.withLock { _ in
guard directoryWatcher == nil else { return }
try startListeningForChangesLocked()
}
}

func stopListeningForChanges() {
listenerLock.withLock { _ in
stopListeningForChangesLocked()
}
}

private func stopListeningForChangesLocked() {
if let directoryWatcher {
directoryWatcher.stop()
directoryWatcher.invalidate()
self.directoryWatcher = nil
}

debouncer?.cancel()
debouncer = nil

for watcher in fileWatchers.values {
watcher.stopListeningIfNecessary()
}
fileWatchers.removeAll()
}

private func startListeningForChangesLocked() throws {
log.info("setting up filesystem watchers")

stopListeningForChangesLocked()

let unthrottledChanges = Topic<Void>()

do {
try setUpListeners(unthrottledChanges: unthrottledChanges)
} catch {
stopListeningForChangesLocked()
throw error
}
}

private func setUpListeners(unthrottledChanges: Topic<Void>) throws {
// listen to ~/Library/Messages itself in order to respond to the WAL
// file being (re)created/deleted
let directoryWatcher = try FSEventsWatcher(watchingPath: messagesDataDirectory.path, latency: 1.0) { [weak self] _, event in
Expand All @@ -119,13 +160,16 @@ public extension IMDatabase {
log.debug("received FSEvent [\(event.id)] \(anonymizedPath) \(event.flags)")

do {
try ensureDatabaseFileWatchers(broadcastingTo: unthrottledChanges)
try listenerLock.withLock { _ in
try self.ensureDatabaseFileWatchers(broadcastingTo: unthrottledChanges)
}
} catch {
log.error("failed to ensure database file watchers in response to WAL file event: \(error)")
}
}
directoryWatcher.setDispatchQueue(fsEventsQueue)
try directoryWatcher.start()
self.directoryWatcher = directoryWatcher

try ensureDatabaseFileWatchers(broadcastingTo: unthrottledChanges)

Expand All @@ -141,34 +185,35 @@ public extension IMDatabase {
}

private func ensureDatabaseFileWatchers(broadcastingTo topic: Topic<Void>) throws {
if !fileWatchers.isEmpty {
let allWatchersHaveLinks = fileWatchers.allSatisfy { watcher in
do {
return try watcher.hasHardLinks() == true
} catch {
log.error("couldn't check if \(watcher) has hard links, assuming it does: \(error)")
return false
}
}

guard !allWatchersHaveLinks else {
log.debug("all file watchers have hard links, leaving them alone")
return
}
let desiredWatchFiles = [
DatabaseWatchFile(url: chatDatabaseFile(in: messagesDataDirectory), required: true),
DatabaseWatchFile(url: chatDatabaseWalFile(in: messagesDataDirectory), required: false),
]
let desiredWatchPaths = Set(desiredWatchFiles.map(\.url.path))

let staleWatchPaths = fileWatchers.keys.filter { !desiredWatchPaths.contains($0) }
for path in staleWatchPaths {
log.debug("purging stale FileWatcher for \(URL(fileURLWithPath: path).lastPathComponent)")
fileWatchers.removeValue(forKey: path)?.stopListeningIfNecessary()
}

log.debug("at least one file watcher lacks hard links, purging all of em (\(fileWatchers.count))")
// TODO: watchers stop listening in deinit, so maybe this is
// unnecessary assuming we have no refcycles
for watcher in fileWatchers {
watcher.stopListeningIfNecessary()
let unlinkedWatchPaths = fileWatchers.compactMap { path, watcher -> String? in
do {
return try watcher.hasHardLinks() == true ? nil : path
} catch {
log.error("couldn't check if \(watcher) has hard links, recreating it: \(error)")
return path
}
fileWatchers.removeAll()
}
for path in unlinkedWatchPaths {
log.debug("purging unlinked FileWatcher for \(URL(fileURLWithPath: path).lastPathComponent)")
fileWatchers.removeValue(forKey: path)?.stopListeningIfNecessary()
}

func watchFile(at path: URL) throws {
log.debug("setting up FileWatcher for \(path.lastPathComponent)")
func makeWatcher(for file: URL) throws -> FileWatcher {
log.debug("setting up FileWatcher for \(file.lastPathComponent)")

let watcher = FileWatcher(watching: path) { [weak self] _, event in
let watcher = FileWatcher(watching: file) { [weak self] _, event in
guard let self else { return }

if noisy {
Expand All @@ -178,16 +223,37 @@ public extension IMDatabase {
}

try watcher.beginListening()
fileWatchers.append(watcher)
return watcher
}

// watch `.db`/`.db-wal` files for changes
try watchFile(at: chatDatabaseFile(in: messagesDataDirectory))
try watchFile(at: chatDatabaseWalFile(in: messagesDataDirectory))
var newWatchers = [String: FileWatcher]()
for file in desiredWatchFiles where fileWatchers[file.url.path] == nil {
do {
newWatchers[file.url.path] = try makeWatcher(for: file.url)
} catch {
if file.required {
log.debug("failed to set up required database file watcher, cleaning up \(newWatchers.count) new watcher(s)")
for watcher in newWatchers.values {
watcher.stopListeningIfNecessary()
}
throw error
}
log.debug("could not watch optional \(file.url.lastPathComponent); will retry on directory events: \(error)")
}
}

for (path, watcher) in newWatchers {
fileWatchers[path] = watcher
}

log.debug("watcher count after ensuring: \(fileWatchers.count)")
}

private struct DatabaseWatchFile {
var url: URL
var required: Bool
}

private func broadcastDebouncedChanges(from topic: Topic<Void>) async throws {
var broadcaster: Task<Void, any Error>?

Expand Down
178 changes: 178 additions & 0 deletions src/IMessage/Sources/IMessage/DatabaseTickWaits.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import Foundation
import IMDatabase
import IMessageCore
import PlatformSDK

private let sentMessageLinkWaitTimeout: TimeInterval = 1.5

// Re-query at least this often even without a tick: FSEvents notifications can be
// dropped or coalesced, so a missed tick costs ~1s instead of the full timeout.
private let databaseTickBackstopInterval: TimeInterval = 1.0
private let loadedAttachmentMinimumRequeryInterval: TimeInterval = 0.25

enum DatabaseTickWaits {
typealias SentMessageID = (rowID: Int, guid: String)

private enum WaitResult<T> {
case finished(T)
case waitingUntil(Date)
}

static func sentMessageIDs(
text: String?,
timeout: TimeInterval,
changes: Topic<Void>,
linkTimeout: TimeInterval = sentMessageLinkWaitTimeout,
backstopInterval: TimeInterval = databaseTickBackstopInterval,
querySentMessageIDs: @escaping @Sendable () throws -> [SentMessageID]
) async throws -> [SentMessageID] {
let startedAt = Date()
let timeoutDeadline = startedAt.addingTimeInterval(timeout)
let linkDeadline = startedAt.addingTimeInterval(linkTimeout)
let expectedNewMessageIDCount = text.map { max($0.linkCount, 1) } ?? 1

return try await waitForDatabaseResult(
changes: changes,
backstopInterval: backstopInterval,
query: {
try querySentMessageIDs()
},
evaluate: { sentMessageIDs in
let now = Date()
if sentMessageIDs.count == expectedNewMessageIDCount {
return .finished(sentMessageIDs)
}
if text != nil, !sentMessageIDs.isEmpty, now >= linkDeadline {
return .finished(sentMessageIDs)
}
if now >= timeoutDeadline {
throw ErrorMessage("timed out waiting for sent messages")
}

let wakeDeadline: Date
if text != nil, !sentMessageIDs.isEmpty {
wakeDeadline = min(timeoutDeadline, linkDeadline)
} else {
wakeDeadline = timeoutDeadline
}
return .waitingUntil(wakeDeadline)
}
)
}

static func sentThreadIDs(
timeout: TimeInterval,
changes: Topic<Void>,
backstopInterval: TimeInterval = databaseTickBackstopInterval,
querySentThreadIDs: @escaping @Sendable () throws -> [String?]
) async throws -> [String?] {
let deadline = Date().addingTimeInterval(timeout)

return try await waitForDatabaseResult(
changes: changes,
backstopInterval: backstopInterval,
query: {
try querySentThreadIDs()
},
evaluate: { threadIDs in
if !threadIDs.contains(nil) || Date() >= deadline {
return .finished(threadIDs)
}
return .waitingUntil(deadline)
}
)
}

static func loadedAttachment(
messageID: String,
timeout: TimeInterval,
changes: Topic<Void>,
backstopInterval: TimeInterval = databaseTickBackstopInterval,
minimumRequeryInterval: TimeInterval = loadedAttachmentMinimumRequeryInterval,
loadMessage: @escaping @Sendable () async throws -> PlatformSDK.Message?,
terminalAttachmentFailureState: @escaping @Sendable () async throws -> Attachment.IMFileTransferState?
) async throws -> PlatformSDK.Message {
let deadline = Date().addingTimeInterval(timeout)
var isFirstRead = true

return try await waitForDatabaseResult(
changes: changes,
backstopInterval: backstopInterval,
minimumRequeryInterval: minimumRequeryInterval,
query: {
try await loadMessage()
.orThrow(ErrorMessage("Could not find message \(messageID)"))
},
evaluate: { message in
let attachments = message.attachments ?? []
if isFirstRead {
guard !attachments.isEmpty else {
throw ErrorMessage("Message \(messageID) has no attachments")
}
isFirstRead = false
}
if !attachments.isEmpty, !attachments.contains(where: { $0.loading == true }) {
return .finished(message)
}

if let failureState = try await terminalAttachmentFailureState() {
throw ErrorMessage("Attachment in message \(messageID) failed to load (transfer state: \(failureState.rawValue))")
}

guard Date() < deadline else {
throw ErrorMessage("Timed out waiting for attachment in message \(messageID) to load")
}

return .waitingUntil(deadline)
}
)
}

private static func waitForDatabaseResult<T>(
changes: Topic<Void>,
backstopInterval: TimeInterval,
minimumRequeryInterval: TimeInterval = 0,
query: @escaping @Sendable () async throws -> T,
evaluate: (T) async throws -> WaitResult<T>
) async throws -> T {
while true {
let changeStream = changes.subscribe()
let result = try await query()
switch try await evaluate(result) {
case let .finished(value):
return value
case let .waitingUntil(deadline):
let earliestNextQuery = Date().addingTimeInterval(minimumRequeryInterval)
try await waitForChange(on: changeStream, until: deadline, backstopInterval: backstopInterval)
try await waitUntil(earliestNextQuery, cappedAt: deadline)
}
}
}
Comment on lines +131 to +150
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Subscription leak when query succeeds without waiting.

When evaluate returns .finished on the first query attempt, the subscription created at line 134 is never iterated—waitForChange is skipped. Since AsyncStream.onTermination only fires when the stream is iterated, finished, or the iterating task is cancelled, the continuation remains in Topic.subscriptions indefinitely.

Over time, these dangling subscriptions accumulate: each broadcast() will yield to orphaned continuations with .unbounded buffering, causing unbounded memory growth.

Proposed fix: Wrap stream in RAII-style cleanup

Introduce a small wrapper that ensures the stream is consumed/cancelled on scope exit:

+    private struct ScopedSubscription {
+        let stream: AsyncStream<Void>
+        private var iterator: AsyncStream<Void>.AsyncIterator?
+        
+        init(_ stream: AsyncStream<Void>) {
+            self.stream = stream
+        }
+        
+        mutating func consume() async {
+            if iterator == nil {
+                iterator = stream.makeAsyncIterator()
+            }
+            _ = await iterator?.next()
+        }
+    }
+
     private static func waitForDatabaseResult<T>(
         changes: Topic<Void>,
         backstopInterval: TimeInterval,
         query: `@escaping` `@Sendable` () async throws -> T,
         evaluate: (T) async throws -> WaitResult<T>
     ) async throws -> T {
         while true {
-            let changeStream = changes.subscribe()
+            var subscription = ScopedSubscription(changes.subscribe())
+            defer {
+                // Start iteration so onTermination fires when scope exits
+                Task { [subscription] in
+                    var sub = subscription
+                    _ = await sub.stream.makeAsyncIterator().next()
+                }
+            }
             let result = try await query()
             switch try await evaluate(result) {
             case let .finished(value):
                 return value
             case let .waitingUntil(deadline):
-                try await waitForChange(on: changeStream, until: deadline, backstopInterval: backstopInterval)
+                await subscription.consume()
+                try await waitForChange(on: subscription.stream, until: deadline, backstopInterval: backstopInterval)
             }
         }
     }

Alternatively, add explicit unsubscribe support to Topic (e.g., subscribe() -> (stream, unsubscribe: () -> Void)).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/IMessage/Sources/IMessage/DatabaseTickWaits.swift` around lines 127 -
143, waitForDatabaseResult creates a subscription via changes.subscribe() (bound
to changeStream) but returns immediately when evaluate(...) yields .finished,
leaving the AsyncStream continuation in Topic.subscriptions and leaking; fix by
introducing an RAII-style subscription guard (e.g., SubscriptionGuard) that
wraps changes.subscribe(), exposes the stream (e.g., guard.stream) and
guarantees on deinit or explicit close() that the underlying subscription is
cancelled/consumed, then use this guard in waitForDatabaseResult instead of a
raw changeStream so that when evaluate returns .finished the guard is
dropped/closed and the subscription is removed; ensure waitForChange continues
to accept guard.stream and that guard has an explicit close() called before
returning the finished value if necessary.


private static func waitForChange(on stream: AsyncStream<Void>, until deadline: Date, backstopInterval: TimeInterval) async throws {
let remainingTime = deadline.timeIntervalSinceNow
guard remainingTime > 0 else { return }

Comment thread
coderabbitai[bot] marked this conversation as resolved.
let sleepTime = min(remainingTime, backstopInterval)

try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
var iterator = stream.makeAsyncIterator()
_ = await iterator.next()
}
group.addTask {
try await Task.sleep(forTimeInterval: sleepTime)
}

defer { group.cancelAll() }
_ = try await group.next()
}
}

private static func waitUntil(_ date: Date, cappedAt deadline: Date) async throws {
let sleepUntil = min(date, deadline)
let remainingTime = sleepUntil.timeIntervalSinceNow
guard remainingTime > 0 else { return }
try await Task.sleep(forTimeInterval: remainingTime)
}
}
Loading