diff --git a/src/compat/function/throttle.spec.ts b/src/compat/function/throttle.spec.ts index 93d05ceb3..a30d4b24b 100644 --- a/src/compat/function/throttle.spec.ts +++ b/src/compat/function/throttle.spec.ts @@ -55,6 +55,51 @@ describe('throttle', () => { expect(callCount).toBe(2); }); + it('should match lodash timing for repeated default calls', () => { + vi.useFakeTimers(); + vi.setSystemTime(0); + + try { + const calls: Array<[number, number]> = []; + const throttled = throttle((value: number) => { + calls.push([value, Date.now()]); + }, 50); + + throttled(0); + vi.advanceTimersByTime(20); + throttled(20); + vi.advanceTimersByTime(20); + throttled(40); + vi.advanceTimersByTime(20); + throttled(60); + vi.advanceTimersByTime(20); + throttled(80); + vi.advanceTimersByTime(30); + + expect(calls).toEqual([ + [0, 0], + [40, 50], + [80, 110], + ]); + + vi.advanceTimersByTime(190); + throttled(300); + vi.advanceTimersByTime(20); + throttled(320); + vi.advanceTimersByTime(30); + + expect(calls).toEqual([ + [0, 0], + [40, 50], + [80, 110], + [300, 300], + [320, 350], + ]); + } finally { + vi.useRealTimers(); + } + }); + it('should not trigger a trailing call when invoked once', async () => { let callCount = 0; const throttled = throttle(() => { diff --git a/src/compat/function/throttle.ts b/src/compat/function/throttle.ts index a41aef35e..dd1f05984 100644 --- a/src/compat/function/throttle.ts +++ b/src/compat/function/throttle.ts @@ -1,4 +1,4 @@ -import { debounce, DebouncedFunc, DebouncedFuncLeading } from './debounce.ts'; +import type { DebouncedFunc, DebouncedFuncLeading } from './debounce.ts'; interface ThrottleSettings { /** @@ -122,11 +122,126 @@ export function throttle any>( throttleMs = 0, options: ThrottleSettings = {} ): DebouncedFunc { + if (options == null || typeof options !== 'object') { + options = {}; + } + const { leading = true, trailing = true } = options; - return debounce(func, throttleMs, { - leading, - maxWait: throttleMs, - trailing, - }); + let result: ReturnType | undefined = undefined; + let lastArgs: Parameters | undefined = undefined; + let lastThis: unknown = undefined; + let lastCallTime: number | undefined = undefined; + let lastInvokeTime = 0; + let timeoutId: ReturnType | undefined = undefined; + + const invoke = (time: number) => { + const args = lastArgs; + const thisArg = lastThis; + + lastArgs = lastThis = undefined; + lastInvokeTime = time; + result = func.apply(thisArg, args!); + + return result; + }; + + const shouldInvoke = (time: number) => { + if (lastCallTime === undefined) { + return true; + } + + const timeSinceLastCall = time - lastCallTime; + const timeSinceLastInvoke = time - lastInvokeTime; + + return timeSinceLastCall >= throttleMs || timeSinceLastCall < 0 || timeSinceLastInvoke >= throttleMs; + }; + + const trailingEdge = (time: number) => { + timeoutId = undefined; + + if (trailing && lastArgs != null) { + return invoke(time); + } + + lastArgs = lastThis = undefined; + return result; + }; + + const remainingWait = (time: number) => { + const timeSinceLastCall = time - (lastCallTime ?? 0); + const timeSinceLastInvoke = time - lastInvokeTime; + + return Math.min(throttleMs - timeSinceLastCall, throttleMs - timeSinceLastInvoke); + }; + + const timerExpired = () => { + const time = Date.now(); + + if (shouldInvoke(time)) { + return trailingEdge(time); + } + + timeoutId = setTimeout(timerExpired, remainingWait(time)); + }; + + const leadingEdge = (time: number) => { + lastInvokeTime = time; + timeoutId = setTimeout(timerExpired, throttleMs); + + if (leading) { + return invoke(time); + } + + return result; + }; + + const throttled = function (this: any, ...args: Parameters) { + const time = Date.now(); + const isInvoking = shouldInvoke(time); + + lastArgs = args; + // eslint-disable-next-line @typescript-eslint/no-this-alias + lastThis = this; + lastCallTime = time; + + if (isInvoking) { + if (timeoutId === undefined) { + return leadingEdge(time); + } + + clearTimeout(timeoutId); + timeoutId = setTimeout(timerExpired, throttleMs); + + return invoke(time); + } + + if (timeoutId === undefined) { + timeoutId = setTimeout(timerExpired, throttleMs); + } + + return result; + }; + + throttled.cancel = () => { + if (timeoutId !== undefined) { + clearTimeout(timeoutId); + } + + lastInvokeTime = 0; + lastArgs = undefined; + lastCallTime = undefined; + lastThis = undefined; + timeoutId = undefined; + }; + + throttled.flush = () => { + if (timeoutId === undefined) { + return result; + } + + return trailingEdge(Date.now()); + }; + + return throttled; } diff --git a/src/function/throttle.spec.ts b/src/function/throttle.spec.ts index e8ffd7128..650b0b77b 100644 --- a/src/function/throttle.spec.ts +++ b/src/function/throttle.spec.ts @@ -14,7 +14,7 @@ describe('throttle', () => { expect(func).toHaveBeenCalledTimes(1); }); - it('should execute the function immediately if not called within the wait time', async () => { + it('should execute a pending trailing call before the next throttle period', async () => { const func = vi.fn(); const throttleMs = 500; const throttledFunc = throttle(func, throttleMs); @@ -29,9 +29,9 @@ describe('throttle', () => { expect(func).toHaveBeenCalledTimes(1); await delay(throttleMs / 2 + 1); - expect(func).toHaveBeenCalledTimes(1); + expect(func).toHaveBeenCalledTimes(2); - throttledFunc(); // should be executed + throttledFunc(); // should be scheduled for the next period expect(func).toHaveBeenCalledTimes(2); await delay(throttleMs / 2 - 1); @@ -41,9 +41,6 @@ describe('throttle', () => { expect(func).toHaveBeenCalledTimes(2); await delay(throttleMs / 2 + 1); - expect(func).toHaveBeenCalledTimes(2); - - throttledFunc(); // should be executed expect(func).toHaveBeenCalledTimes(3); }); @@ -85,6 +82,45 @@ describe('throttle', () => { expect(func).toBeCalledTimes(2); }); + it('should invoke at each throttle period during repeated calls', () => { + vi.useFakeTimers(); + vi.setSystemTime(0); + + try { + const calls: Array<[number, number]> = []; + const throttleMs = 50; + const throttled = throttle((value: number) => { + calls.push([value, Date.now()]); + }, throttleMs); + + throttled(0); + vi.advanceTimersByTime(20); + throttled(20); + vi.advanceTimersByTime(20); + throttled(40); + vi.advanceTimersByTime(10); + + expect(calls).toEqual([ + [0, 0], + [40, 50], + ]); + + vi.advanceTimersByTime(10); + throttled(60); + vi.advanceTimersByTime(20); + throttled(80); + vi.advanceTimersByTime(20); + + expect(calls).toEqual([ + [0, 0], + [40, 50], + [80, 100], + ]); + } finally { + vi.useRealTimers(); + } + }); + it('should be able to abort initial invocation', async () => { const throttleMs = 50; const func = vi.fn(); diff --git a/src/function/throttle.ts b/src/function/throttle.ts index 1cb5f991d..c2d0f6293 100644 --- a/src/function/throttle.ts +++ b/src/function/throttle.ts @@ -1,5 +1,3 @@ -import { debounce } from './debounce.ts'; - export interface ThrottleOptions { /** * An optional AbortSignal to cancel the throttled function. @@ -53,36 +51,106 @@ export function throttle void>( throttleMs: number, { signal, edges = ['leading', 'trailing'] }: ThrottleOptions = {} ): ThrottledFunction { - let pendingAt: number | null = null; + const leading = edges.includes('leading'); + const trailing = edges.includes('trailing'); + + let lastInvokeTime: number | null = null; + let timeoutId: ReturnType | null = null; + let pendingThis: unknown = undefined; + let pendingArgs: Parameters | null = null; + + const clearTimer = () => { + if (timeoutId !== null) { + clearTimeout(timeoutId); + timeoutId = null; + } + }; + + const clearPending = () => { + pendingThis = undefined; + pendingArgs = null; + }; + + const invoke = (time: number) => { + if (pendingArgs !== null) { + const args = pendingArgs; + const thisArg = pendingThis; + clearPending(); + lastInvokeTime = time; + func.apply(thisArg, args); + } + }; + + const schedule = (delay: number) => { + if (timeoutId === null) { + timeoutId = setTimeout(() => { + timeoutId = null; + invoke(Date.now()); + }, delay); + } + }; + + const cancel = () => { + clearTimer(); + clearPending(); + lastInvokeTime = null; + }; + + const flush = () => { + clearTimer(); + invoke(Date.now()); + }; - const debounced = debounce( - function (this: any, ...args: Parameters) { - pendingAt = Date.now(); - func.apply(this, args); - }, - throttleMs, - { signal, edges } - ); + const setPending = (thisArg: unknown, args: Parameters) => { + pendingThis = thisArg; + pendingArgs = args; + }; const throttled = function (this: any, ...args: Parameters) { - if (pendingAt == null) { - pendingAt = Date.now(); + if (signal?.aborted || (!leading && !trailing)) { + return; + } + + const now = Date.now(); + + if (lastInvokeTime === null) { + if (leading) { + setPending(this, args); + invoke(now); + } else { + setPending(this, args); + schedule(throttleMs); + } + + return; } - if (Date.now() - pendingAt >= throttleMs) { - pendingAt = Date.now(); - func.apply(this, args); + const remaining = throttleMs - (now - lastInvokeTime); + + if (remaining <= 0) { + clearTimer(); + + if (leading) { + setPending(this, args); + invoke(now); + } else { + setPending(this, args); + schedule(throttleMs); + } - debounced.cancel(); - debounced.schedule(); return; } - debounced.apply(this, args); + if (trailing) { + setPending(this, args); + schedule(remaining); + } }; - throttled.cancel = debounced.cancel; - throttled.flush = debounced.flush; + throttled.cancel = cancel; + throttled.flush = flush; + + signal?.addEventListener('abort', cancel, { once: true }); return throttled; }