diff --git a/app/api/og/route.test.ts b/app/api/og/route.test.ts index 803be23ae..a7da6c41c 100644 --- a/app/api/og/route.test.ts +++ b/app/api/og/route.test.ts @@ -122,4 +122,17 @@ describe('OG Route', () => { expect(res).toBeDefined(); expect(res.status).toBe(200); }); + it('returns 429 when rate limit is exceeded', async () => { + const { RateLimiter } = await import('@/lib/rate-limit'); + vi.spyOn(RateLimiter.prototype, 'check').mockResolvedValueOnce(false); + + const req = new NextRequest('http://localhost/api/og?user=octocat', { + headers: { 'x-forwarded-for': '1.2.3.4' }, + }); + + const res = await GET(req); + expect(res.status).toBe(429); + const data = await res.json(); + expect(data.error).toBe('Too many requests. Please try again later.'); + }); }); diff --git a/app/api/og/route.tsx b/app/api/og/route.tsx index 0eebbb677..87d73ba7c 100644 --- a/app/api/og/route.tsx +++ b/app/api/og/route.tsx @@ -6,7 +6,10 @@ import { ogParamsSchema } from '@/lib/validations'; import { themes } from '@/lib/svg/themes'; import { fetchGitHubContributions } from '@/lib/github'; import { calculateStreak } from '@/lib/calculate'; +import { getClientIp } from '@/utils/getClientIp'; +import { RateLimiter } from '@/lib/rate-limit'; +const ogRateLimiter = new RateLimiter(30, 60_000, 1); const appUrl = process.env.NEXT_PUBLIC_SITE_URL || (process.env.VERCEL_URL ? `https://${process.env.VERCEL_URL}` : 'https://commitpulse.vercel.app'); @@ -35,6 +38,20 @@ function getLuminance(hex: string) { } export async function GET(req: NextRequest) { + const ip = getClientIp(req); + const rateLimitKey = + ip && ip !== 'unknown' ? ip : `unknown:${req.headers.get('user-agent') ?? 'no-agent'}`; + + if (!(await ogRateLimiter.check(rateLimitKey))) { + return new Response(JSON.stringify({ error: 'Too many requests. Please try again later.' }), { + status: 429, + headers: { + 'Content-Type': 'application/json', + 'Cache-Control': 'no-store', + }, + }); + } + const { searchParams } = new URL(req.url); const parseResult = ogParamsSchema.safeParse(Object.fromEntries(searchParams.entries()));