Source code for formulation_bench.download

from __future__ import annotations

import os
import tarfile
import urllib.request
from pathlib import Path

#: GitHub repo containing the dataset releases.
REPO = "henryrobbins/flare"

#: Name of the dataset tarball asset in the releases.
ASSET_NAME = "dataset.tar.gz"

#: The snapshot version of the dataset this package was built against.
#: The package is compatible with all dataset versions sharing the same major version.
DEFAULT_DATASET_VERSION = "dataset-v0.2.0"


def _default_cache_dir() -> Path:
    base = os.environ.get("FORMULATION_BENCH_CACHE")
    if base:
        return Path(base)
    xdg = os.environ.get("XDG_CACHE_HOME")
    root = Path(xdg) if xdg else Path.home() / ".cache"
    return root / "formulation_bench"


def _release_url(version: str) -> str:
    return f"https://github.com/{REPO}/releases/download/{version}/{ASSET_NAME}"


[docs] def download_dataset( version: str | None = None, cache_dir: str | Path | None = None, force: bool = False ) -> Path: """Download the FormulationBench dataset. A tarball is fetched from the GitHub release tagged ``version`` and extracted under ``<cache_dir>/<version>/``. Subsequent calls with the same ``version`` reuse the cached copy unless ``force=True``. Also see :doc:`/user_guide/download`. Parameters ---------- version : str, optional Release tag, e.g. ``"dataset-v0.2.0"``. Defaults to :data:`DEFAULT_DATASET_VERSION`, the snapshot version this package was built against. cache_dir : str or pathlib.Path, optional Cache root. Defaults to ``$FORMULATION_BENCH_CACHE`` or ``$XDG_CACHE_HOME/formulation_bench`` (``~/.cache/formulation_bench``). force : bool, default False Re-download and overwrite the cached copy. Returns ------- root : pathlib.Path Path to the extracted dataset root. Load the dataset with ``Dataset(root)``. Examples -------- Download the default version of the dataset (or load from cache):: >>> from formulation_bench import download_dataset >>> path = download_dataset() >>> path PosixPath('.../.cache/formulation_bench/dataset-v0.2.0/dataset') >>> from formulation_bench import Dataset >>> ds = Dataset(path) >>> sorted(ds.problems)[:5] [1, 2, 3, 4, 5] Reload the dataset from cache:: >>> path = download_dataset() >>> path PosixPath('.../.cache/formulation_bench/dataset-v0.2.0/dataset') Force re-download and overwrite the cached copy:: >>> path = download_dataset(force=True) >>> path PosixPath('.../.cache/formulation_bench/dataset-v0.2.0/dataset') Provide a custom cache directory:: >>> path = download_dataset(cache_dir="./custom_cache") >>> path PosixPath('custom_cache/dataset-v0.2.0/dataset') """ version = version or DEFAULT_DATASET_VERSION cache_root = Path(cache_dir) if cache_dir else _default_cache_dir() version_dir = cache_root / version extracted = version_dir / "dataset" if extracted.exists() and not force: return extracted version_dir.mkdir(parents=True, exist_ok=True) archive = version_dir / ASSET_NAME url = _release_url(version) urllib.request.urlretrieve(url, archive) # noqa: S310 with tarfile.open(archive, "r:gz") as tf: tf.extractall(version_dir, filter="data") archive.unlink(missing_ok=True) if not extracted.exists(): raise RuntimeError( f"expected {extracted} after extracting {url}; " "tarball layout may be wrong (top-level dir must be 'dataset/')" ) return extracted