Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ keywords = ["Kaggle", "API"]
requires-python = ">= 3.11"
dependencies = [
"bleach",
"kagglesdk >= 0.1.24, < 1.0", # sync with kagglehub
"kagglesdk >= 0.1.25, < 1.0", # sync with kagglehub
"python-slugify",
"requests",
"python-dateutil",
Expand Down
2 changes: 1 addition & 1 deletion requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jupyter-core==5.9.1
# via nbformat
jupytext==1.19.1
# via kaggle (pyproject.toml)
kagglesdk==0.1.24
kagglesdk==0.1.25
# via kaggle (pyproject.toml)
markdown-it-py==4.0.0
# via
Expand Down
47 changes: 46 additions & 1 deletion src/kaggle/api/kaggle_api_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
ApiSaveKernelResponse,
ApiKernelMetadata,
ApiDeleteKernelRequest,
ApiGetAcceleratorQuotaStatisticsRequest,
)
from kagglesdk.kernels.types.kernels_enums import KernelWorkerStatus, KernelsListSortType, KernelsListViewType
from kagglesdk.models.types.model_api_service import (
Expand Down Expand Up @@ -4099,6 +4100,50 @@ def kernels_list_cli(
else:
print("Not found")

def quota_view(self):
"""Fetches the current user's weekly GPU and TPU accelerator quota.

Returns:
An ApiGetAcceleratorQuotaStatisticsResponse with quota_refresh_time,
gpu_quota, and tpu_quota fields.
"""
with self.build_kaggle_client() as kaggle:
return kaggle.kernels.kernels_api_client.get_accelerator_quota_statistics(
ApiGetAcceleratorQuotaStatisticsRequest()
)

def quota_view_cli(self, csv_display=False):
"""A client wrapper for quota_view.

Args:
csv_display: If True, print comma-separated values instead of a table.
"""
response = self.quota_view()
refresh = response.quota_refresh_time.isoformat() if response.quota_refresh_time else ""
rows = []
for name, quota in (("GPU", response.gpu_quota), ("TPU", response.tpu_quota)):
if quota is None:
continue
used_hours = quota.time_used.total_seconds() / 3600
total_hours = quota.total_time_allowed.total_seconds() / 3600
rows.append(
SimpleNamespace(
resource=name,
used=f"{used_hours:.2f}h",
remaining=f"{max(0.0, total_hours - used_hours):.2f}h",
total=f"{total_hours:.2f}h",
refresh_at=refresh,
)
)
if not rows:
print("No quota information available")
return
fields = ["resource", "used", "remaining", "total", "refreshAt"]
if csv_display:
self.print_csv(rows, fields)
else:
self.print_table(rows, fields)

def kernels_list_files(self, kernel, page_token=None, page_size=20):
"""Lists files for a kernel.

Expand Down Expand Up @@ -7776,7 +7821,7 @@ def increment(self, length):

from pprint import pprint
from inspect import getmembers
from types import FunctionType
from types import FunctionType, SimpleNamespace


def attributes(obj):
Expand Down
10 changes: 10 additions & 0 deletions src/kaggle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def main() -> None:
parse_benchmarks(subparsers)
parse_config(subparsers)
parse_auth(subparsers)
parse_quota(subparsers)
args = parser.parse_args()
command_args = {}
command_args.update(vars(args))
Expand Down Expand Up @@ -1713,6 +1714,12 @@ def parse_auth(subparsers) -> None:
parser_auth_revoke_token.set_defaults(func=api.auth_revoke_token)


def parse_quota(subparsers) -> None:
parser_quota = subparsers.add_parser("quota", formatter_class=argparse.RawTextHelpFormatter, help=Help.group_quota)
parser_quota.add_argument("-v", "--csv", dest="csv_display", action="store_true", help=Help.param_csv)
parser_quota.set_defaults(func=api.quota_view_cli)


# ------------------------------------------------------------------
# Shared helpers for discussion topics across entity types
# ------------------------------------------------------------------
Expand Down Expand Up @@ -1825,6 +1832,7 @@ class Help(object):
"b",
"config",
"auth",
"quota",
]
competitions_choices = [
"list",
Expand Down Expand Up @@ -1924,6 +1932,7 @@ class Help(object):
+ "}"
)
kaggle += "\nauth {" + ", ".join(auth_choices) + "}"
kaggle += "\nquota"

group_competitions = "Commands related to Kaggle competitions"
group_datasets = "Commands related to Kaggle datasets"
Expand All @@ -1937,6 +1946,7 @@ class Help(object):
group_benchmarks_tasks = "Commands related to benchmark tasks"
group_config = "Configuration settings"
group_auth = "Commands related to authentication"
group_quota = "Show the current user's weekly GPU and TPU accelerator quota"

# Entity topics commands (shared across entity types)
command_entity_topics_show = "Display a topic with all its comments in tree form"
Expand Down
132 changes: 132 additions & 0 deletions tests/test_quota.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# coding=utf-8
import io
import sys
import unittest
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch

sys.path.insert(0, "..")

from kaggle.api.kaggle_api_extended import KaggleApi


def _mock_quota(used_hours, total_hours):
quota = MagicMock()
quota.time_used = timedelta(hours=used_hours)
quota.total_time_allowed = timedelta(hours=total_hours)
return quota


def _build_response(gpu=None, tpu=None, refresh_time=None):
response = MagicMock()
response.gpu_quota = gpu
response.tpu_quota = tpu
response.quota_refresh_time = refresh_time
return response


class TestQuota(unittest.TestCase):
"""Tests for the quota_view and quota_view_cli methods."""

def setUp(self):
self.api = KaggleApi.__new__(KaggleApi)

@patch.object(KaggleApi, "build_kaggle_client")
def test_quota_view_returns_response(self, mock_client):
expected = _build_response(gpu=_mock_quota(5, 30), tpu=_mock_quota(0, 20))
mock_kaggle = MagicMock()
mock_kaggle.kernels.kernels_api_client.get_accelerator_quota_statistics.return_value = expected
mock_client.return_value.__enter__ = MagicMock(return_value=mock_kaggle)
mock_client.return_value.__exit__ = MagicMock(return_value=False)

result = self.api.quota_view()
self.assertIs(result, expected)

@patch.object(KaggleApi, "quota_view")
def test_quota_view_cli_table(self, mock_view):
mock_view.return_value = _build_response(
gpu=_mock_quota(5, 30),
tpu=_mock_quota(2, 20),
refresh_time=datetime(2026, 6, 1, tzinfo=timezone.utc),
)

captured = io.StringIO()
sys.stdout = captured
try:
self.api.quota_view_cli()
finally:
sys.stdout = sys.__stdout__

output = captured.getvalue()
self.assertIn("GPU", output)
self.assertIn("TPU", output)
self.assertIn("5.00h", output)
self.assertIn("25.00h", output) # GPU remaining: 30 - 5
self.assertIn("18.00h", output) # TPU remaining: 20 - 2
self.assertIn("2026-06-01", output)

@patch.object(KaggleApi, "quota_view")
def test_quota_view_cli_csv(self, mock_view):
mock_view.return_value = _build_response(
gpu=_mock_quota(5, 30),
tpu=_mock_quota(2, 20),
)

captured = io.StringIO()
sys.stdout = captured
try:
self.api.quota_view_cli(csv_display=True)
finally:
sys.stdout = sys.__stdout__

lines = [line for line in captured.getvalue().splitlines() if line]
self.assertEqual(lines[0], "resource,used,remaining,total,refreshAt")
self.assertEqual(len(lines), 3)
self.assertTrue(lines[1].startswith("GPU,"))
self.assertTrue(lines[2].startswith("TPU,"))

@patch.object(KaggleApi, "quota_view")
def test_quota_view_cli_skips_missing_accelerator(self, mock_view):
mock_view.return_value = _build_response(gpu=_mock_quota(1, 30), tpu=None)

captured = io.StringIO()
sys.stdout = captured
try:
self.api.quota_view_cli()
finally:
sys.stdout = sys.__stdout__

output = captured.getvalue()
self.assertIn("GPU", output)
self.assertNotIn("TPU", output)

@patch.object(KaggleApi, "quota_view")
def test_quota_view_cli_no_quotas(self, mock_view):
mock_view.return_value = _build_response(gpu=None, tpu=None)

captured = io.StringIO()
sys.stdout = captured
try:
self.api.quota_view_cli()
finally:
sys.stdout = sys.__stdout__

self.assertIn("No quota information available", captured.getvalue())

@patch.object(KaggleApi, "quota_view")
def test_quota_view_cli_clamps_negative_remaining(self, mock_view):
# User over their quota — remaining should be 0, not negative.
mock_view.return_value = _build_response(gpu=_mock_quota(35, 30))

captured = io.StringIO()
sys.stdout = captured
try:
self.api.quota_view_cli()
finally:
sys.stdout = sys.__stdout__

self.assertIn("0.00h", captured.getvalue())


if __name__ == "__main__":
unittest.main()
Loading