class CxgClient:
"""HTTP client for the CELLxGene Discover API with transparent caching."""
def __init__(
self,
base_url: str | None = None,
client: httpx.Client | None = None,
) -> None:
if base_url is None:
base_url = os.environ.get("CXG_API_BASE_URL")
if base_url is None:
from cxg.config import load_config
cfg = load_config()
if cfg.api.base_url:
base_url = cfg.api.base_url
if base_url is None:
base_url = DEFAULT_BASE_URL
self.base_url = base_url.rstrip("/")
headers = default_headers()
self._client = client or httpx.Client(headers=headers)
self._client.headers.update(headers)
self._owns_client = client is None
def close(self) -> None:
"""Close the underlying HTTP client."""
if self._owns_client:
self._client.close()
def _url(self, path: str) -> str:
return f"{self.base_url}/{path.lstrip('/')}"
def _request_json(self, path: str, *, timeout: float = DEFAULT_TIMEOUT) -> Any:
url = self._url(path)
try:
response = self._client.get(url, timeout=timeout)
except httpx.HTTPError as exc:
raise CxgError(f"Request failed for {url}: {exc}") from exc
if response.status_code == 404:
raise CxgNotFoundError(f"Resource not found at {url}")
if response.is_error:
body = response.text[:200].strip()
raise CxgError(f"API error {response.status_code} for {url}: {body}")
try:
return response.json()
except ValueError as exc:
raise CxgError(f"Malformed JSON response from {url}") from exc
def _fetch_datasets_remote(self) -> list[dict[str, Any]]:
url = self._url("datasets")
try:
response = self._client.get(url, timeout=DATASETS_TIMEOUT)
except httpx.HTTPError as exc:
raise CxgError(f"Request failed for {url}: {exc}") from exc
if response.is_error:
body = response.text[:200].strip()
raise CxgError(f"API error {response.status_code} for {url}: {body}")
try:
payload = response.json()
except ValueError as exc:
raise CxgError(f"Malformed JSON response from {url}") from exc
datasets = _extract_list_payload(
payload,
"datasets",
error_message="Malformed datasets response: expected a list",
)
valid: list[dict[str, Any]] = []
for dataset in datasets:
if not isinstance(dataset, dict):
continue
if not (dataset.get("dataset_id") or dataset.get("id")):
console.print(
"[yellow]Skipping dataset record missing dataset_id from "
"API response.[/yellow]",
highlight=False,
)
continue
valid.append(dataset)
return write_datasets_cache(valid, len(response.content))
def get_datasets(self, *, refresh: bool = False) -> list[dict[str, Any]]:
"""Fetch all datasets, using cache when available.
Args:
refresh: Bypass the cache and fetch from the API.
"""
cached = load_datasets_cache()
ttl = get_cache_ttl()
cache_usable = False
if cached and ttl != 0 and not refresh:
fetched_at = cached.get("fetched_at")
from cxg.cache import _parse_timestamp, is_expired # local import avoids export noise
cache_usable = not is_expired(_parse_timestamp(fetched_at), ttl)
if cache_usable:
datasets = cached.get("datasets", [])
if isinstance(datasets, list):
return datasets
if not cached:
status = "Fetching datasets for the first time (cached for future runs)…"
elif refresh:
status = "Refreshing datasets…"
else:
status = "Updating cached datasets…"
try:
with console.status(status, spinner="dots"):
return self._fetch_datasets_remote()
except CxgError:
if cached:
from cxg.cache import _parse_timestamp
from cxg.output import format_age
fetched_at = _parse_timestamp(cached.get("fetched_at"))
console.print(
f"[yellow]Using cached data from {format_age(fetched_at)}. "
"API is unreachable.[/yellow]",
highlight=False,
)
datasets = cached.get("datasets", [])
if isinstance(datasets, list):
return datasets
raise
def get_dataset(self, dataset_id: str, *, refresh: bool = False) -> dict[str, Any]:
"""Look up a single dataset by ID from the cached dataset list.
Args:
dataset_id: The dataset identifier.
refresh: Bypass the cache and fetch from the API.
Raises:
CxgNotFoundError: If no dataset matches the given ID.
"""
datasets = self.get_datasets(refresh=refresh)
for dataset in datasets:
if str(dataset.get("dataset_id") or dataset.get("id")) == dataset_id:
return dataset
raise CxgNotFoundError(
f"Dataset `{dataset_id}` not found. Check the ID with `cxg dataset list`."
)
def _fetch_collections_remote(self) -> list[dict[str, Any]]:
url = self._url("collections?visibility=PUBLIC")
try:
response = self._client.get(url, timeout=DEFAULT_TIMEOUT)
except httpx.HTTPError as exc:
raise CxgError(f"Request failed for {url}: {exc}") from exc
if response.is_error:
body = response.text[:200].strip()
raise CxgError(f"API error {response.status_code} for {url}: {body}")
try:
payload = response.json()
except ValueError as exc:
raise CxgError(f"Malformed JSON response from {url}") from exc
collections = _extract_list_payload(
payload,
"collections",
error_message="Malformed collections response",
)
valid = [item for item in collections if isinstance(item, dict)]
return write_collections_cache(valid, len(response.content))
def get_collections(self, *, refresh: bool = False) -> list[dict[str, Any]]:
"""Fetch all public collections, using cache when available.
Args:
refresh: Bypass the cache and fetch from the API.
"""
cached = load_collections_cache()
ttl = get_cache_ttl()
if cached and ttl != 0 and not refresh:
fetched_at = cached.get("fetched_at")
from cxg.cache import _parse_timestamp, is_expired
if not is_expired(_parse_timestamp(fetched_at), ttl):
collections = cached.get("collections", [])
if isinstance(collections, list):
return collections
if not cached:
status = "Fetching collections for the first time (cached for future runs)…"
elif refresh:
status = "Refreshing collections…"
else:
status = "Updating cached collections…"
try:
with console.status(status, spinner="dots"):
return self._fetch_collections_remote()
except CxgError:
if cached:
from cxg.cache import _parse_timestamp
from cxg.output import format_age
fetched_at = _parse_timestamp(cached.get("fetched_at"))
console.print(
f"[yellow]Using cached data from {format_age(fetched_at)}. "
"API is unreachable.[/yellow]",
highlight=False,
)
collections = cached.get("collections", [])
if isinstance(collections, list):
return collections
raise
def get_collection(self, collection_id: str) -> dict[str, Any]:
"""Fetch a single collection by ID directly from the API."""
payload = self._request_json(f"collections/{collection_id}", timeout=DEFAULT_TIMEOUT)
if not isinstance(payload, dict):
raise CxgError("Malformed collection response")
return payload