"""unidep - Unified Conda and Pip requirements management.
This module provides utility functions used throughout the package.
"""
from __future__ import annotations
import ast
import codecs
import configparser
import contextlib
import importlib.util
import io
import platform
import re
import sys
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Any, Literal, NamedTuple, Sequence, cast
from packaging.markers import Marker, default_environment
from packaging.requirements import InvalidRequirement, Requirement
from packaging.utils import (
InvalidSdistFilename,
InvalidWheelFilename,
canonicalize_name,
parse_sdist_filename,
parse_wheel_filename,
)
from unidep._version import __version__
from unidep.platform_definitions import (
PEP508_MARKERS,
Platform,
Selector,
Spec,
platforms_from_selector,
validate_selector,
)
if sys.version_info >= (3, 11):
import tomllib
else: # pragma: no cover
import tomli as tomllib
[docs]
def escape_unicode(string: str) -> str:
"""Escape unicode characters."""
return codecs.decode(string, "unicode_escape")
[docs]
def is_pip_installable(folder: str | Path) -> bool: # pragma: no cover
"""Determine if the project is pip installable.
Checks for existence of setup.py or [build-system] in pyproject.toml.
If the `toml` library is available, it is used to parse the `pyproject.toml` file.
If the `toml` library is not available, the function checks for the existence of
a line starting with "[build-system]". This does not handle the case where
[build-system] is inside of a multi-line literal string.
"""
path = Path(folder)
if (path / "setup.py").exists():
return True
pyproject_path = path / "pyproject.toml"
if pyproject_path.exists():
with pyproject_path.open("rb") as file:
pyproject_data = tomllib.load(file)
return "build-system" in pyproject_data
return False
[docs]
def build_pep508_environment_marker(
platforms: Sequence[Platform | tuple[Platform, ...]],
) -> str:
"""Generate a PEP 508 selector for a list of platforms."""
sorted_platforms = tuple(sorted(platforms))
if sorted_platforms in PEP508_MARKERS:
return PEP508_MARKERS[sorted_platforms] # type: ignore[index]
environment_markers = [
PEP508_MARKERS[platform]
for platform in sorted(sorted_platforms)
if platform in PEP508_MARKERS
]
return " or ".join(environment_markers)
[docs]
class ParsedPackageStr(NamedTuple):
"""A package name and version pinning."""
name: str
pin: str | None = None
# can be of type `Selector` but also space separated string of `Selector`s
selector: str | None = None
[docs]
def parse_package_str(package_str: str) -> ParsedPackageStr:
"""Splits a string into package name, version pinning, and platform selector."""
# Regex to match package name, version pinning, and optionally platform selector
# Note: the name_pattern currently allows for paths and extras, however,
# paths cannot contain spaces or contain brackets.
name_pattern = r"[a-zA-Z0-9_.\-/]+(\[[a-zA-Z0-9_.,\-]+\])?"
version_pin_pattern = r".*?"
selector_pattern = r"[a-z0-9\s]+"
pattern = rf"({name_pattern})\s*({version_pin_pattern})?(:({selector_pattern}))?$"
match = re.match(pattern, package_str)
if match:
package_name = match.group(1).strip()
version_pin = match.group(3).strip() if match.group(3) else None
selector = match.group(5).strip() if match.group(5) else None
if selector is not None:
for s in selector.split():
validate_selector(cast("Selector", s))
return ParsedPackageStr(
package_name,
version_pin,
selector,
)
msg = f"Invalid package string: '{package_str}'"
raise ValueError(msg)
[docs]
def package_name_from_setup_cfg(file_path: Path) -> str:
"""Read the package name from ``setup.cfg`` metadata."""
config = configparser.ConfigParser()
config.read(file_path)
name = config.get("metadata", "name", fallback=None)
if name is None:
msg = "Could not find the package name in the setup.cfg file."
raise KeyError(msg)
return name
[docs]
def package_name_from_setup_py(file_path: Path) -> str:
"""Read the package name from a simple ``setup.py`` AST."""
with file_path.open() as f:
file_content = f.read()
tree = ast.parse(file_content)
def _string_literal(node: ast.expr) -> str | None:
if isinstance(node, ast.Constant) and isinstance(node.value, str):
return node.value
return None
class SetupVisitor(ast.NodeVisitor):
def __init__(self) -> None:
self.package_name: str | None = None
def visit_Call(self, node: ast.Call) -> None: # noqa: N802
if isinstance(node.func, ast.Name) and node.func.id == "setup":
for keyword in node.keywords:
if keyword.arg == "name":
self.package_name = _string_literal(keyword.value)
if self.package_name is not None:
return
visitor = SetupVisitor()
visitor.visit(tree)
if visitor.package_name is None:
msg = "Could not find the package name in the setup.py file."
raise KeyError(msg)
return visitor.package_name
[docs]
def package_name_from_pyproject_toml(file_path: Path) -> str:
"""Read project name from ``pyproject.toml`` (PEP 621 or Poetry)."""
with file_path.open("rb") as f:
data = tomllib.load(f)
with contextlib.suppress(KeyError):
return data["project"]["name"]
with contextlib.suppress(KeyError):
return data["tool"]["poetry"]["name"]
msg = f"Could not find the package name in the pyproject.toml file: {data}."
raise KeyError(msg)
def _package_name_from_archive_filename(path: Path) -> str | None:
"""Return the distribution name encoded in a wheel or sdist filename."""
if path.suffix == ".whl":
try:
name, _, _, _ = parse_wheel_filename(path.name)
except InvalidWheelFilename:
return None
return name
if path.suffix == ".zip":
try:
name, _ = parse_sdist_filename(path.name)
except InvalidSdistFilename:
return None
return name
return None
def _maybe_package_name_from_path(path: Path) -> str | None:
"""Return a package name only when UniDep can determine it from metadata."""
archive_name = _package_name_from_archive_filename(path)
if archive_name is not None:
return archive_name
pyproject_toml = path / "pyproject.toml"
if pyproject_toml.exists():
with contextlib.suppress(
KeyError,
OSError,
TypeError,
UnicodeError,
tomllib.TOMLDecodeError,
):
return package_name_from_pyproject_toml(pyproject_toml)
setup_cfg = path / "setup.cfg"
if setup_cfg.exists():
with contextlib.suppress(
KeyError,
OSError,
UnicodeError,
configparser.Error,
):
return package_name_from_setup_cfg(setup_cfg)
setup_py = path / "setup.py"
if setup_py.exists():
with contextlib.suppress(
KeyError,
OSError,
SyntaxError,
UnicodeError,
ValueError,
):
return package_name_from_setup_py(setup_py)
return None
[docs]
def package_name_from_path(path: Path) -> str:
"""Get the package name from ``pyproject.toml``, ``setup.cfg``, or ``setup.py``."""
package_name = _maybe_package_name_from_path(path)
if package_name is not None:
return package_name
return path.name
# Cover the full supported runtime floor plus near-future Python releases when
# approximating marker overlap for direct-reference validation.
_SUPPORTED_MARKER_PYTHON_VERSIONS = tuple(f"3.{minor}" for minor in range(7, 21))
_PLATFORM_MARKER_ENVIRONMENTS: dict[Platform, dict[str, str]] = {
"linux-64": {
"os_name": "posix",
"platform_machine": "x86_64",
"platform_system": "Linux",
"sys_platform": "linux",
},
"linux-aarch64": {
"os_name": "posix",
"platform_machine": "aarch64",
"platform_system": "Linux",
"sys_platform": "linux",
},
"linux-ppc64le": {
"os_name": "posix",
"platform_machine": "ppc64le",
"platform_system": "Linux",
"sys_platform": "linux",
},
"osx-64": {
"os_name": "posix",
"platform_machine": "x86_64",
"platform_system": "Darwin",
"sys_platform": "darwin",
},
"osx-arm64": {
"os_name": "posix",
"platform_machine": "arm64",
"platform_system": "Darwin",
"sys_platform": "darwin",
},
"win-64": {
"os_name": "nt",
"platform_machine": "AMD64",
"platform_system": "Windows",
"sys_platform": "win32",
},
}
def _marker_targets(marker: Marker | None) -> frozenset[tuple[Platform, str]]:
"""Return supported platform/Python targets where a requirement can apply."""
targets: set[tuple[Platform, str]] = set()
for platform_name, env_updates in _PLATFORM_MARKER_ENVIRONMENTS.items():
for python_version in _SUPPORTED_MARKER_PYTHON_VERSIONS:
environment = default_environment()
environment.update(env_updates)
environment["extra"] = ""
environment["python_version"] = python_version
environment["python_full_version"] = f"{python_version}.0"
if marker is None or marker.evaluate(environment):
targets.add((platform_name, python_version))
return frozenset(targets)
def _parsed_direct_reference(
requirement: str,
) -> tuple[str, str, str, frozenset[tuple[Platform, str]]] | None:
"""Return the canonical package name, display name, URL, and scope."""
try:
parsed = Requirement(requirement)
except InvalidRequirement:
return None
if parsed.url is None:
return None
return (
canonicalize_name(parsed.name),
parsed.name,
parsed.url,
_marker_targets(parsed.marker),
)
[docs]
def detect_conflicting_direct_references(
requirements: list[str],
*,
context: str,
) -> list[str]:
"""Deduplicate direct references and fail on conflicting sources.
This catches cases like two ``file://`` URLs for the same package name
before pip/uv reports a lower-level resolver error.
"""
deduplicated: list[str] = []
seen_exact: set[str] = set()
seen_direct_refs: dict[
str,
dict[str, list[tuple[frozenset[tuple[Platform, str]], str]]],
] = defaultdict(
lambda: defaultdict(list),
)
for requirement in requirements:
normalized = requirement.strip()
if normalized in seen_exact:
continue
seen_exact.add(normalized)
direct_reference = _parsed_direct_reference(normalized)
if direct_reference is None:
deduplicated.append(normalized)
continue
canonical_name, package_name, source_url, targets = direct_reference
existing = seen_direct_refs[canonical_name]
conflicting = [
existing_requirement
for existing_url, existing_requirements in existing.items()
if existing_url != source_url
for existing_targets, existing_requirement in existing_requirements
if existing_targets & targets
]
if conflicting:
msg = format_duplicate_package_sources_message(
package_name,
[*conflicting, normalized],
)
msg += f"\n\nWhile {context}."
raise RuntimeError(msg)
existing[source_url].append((targets, normalized))
deduplicated.append(normalized)
return deduplicated
[docs]
def detect_conflicting_direct_reference_groups(
requirement_groups: dict[str, list[str]],
*,
context: str,
) -> dict[str, list[str]]:
"""Validate direct references across multiple dependency groups."""
deduplicated_groups = {
group_name: detect_conflicting_direct_references(
requirements,
context=context,
)
for group_name, requirements in requirement_groups.items()
}
seen_direct_refs: dict[
str,
dict[str, list[tuple[frozenset[tuple[Platform, str]], str, str]]],
] = defaultdict(
lambda: defaultdict(list),
)
for group_name, requirements in deduplicated_groups.items():
for requirement in requirements:
direct_reference = _parsed_direct_reference(requirement)
if direct_reference is None:
continue
canonical_name, package_name, source_url, targets = direct_reference
existing = seen_direct_refs[canonical_name]
conflicting = [
f"{existing_group}: {existing_requirement}"
for existing_url, existing_entries in existing.items()
if existing_url != source_url
for (
existing_targets,
existing_group,
existing_requirement,
) in existing_entries
if existing_targets & targets
]
if conflicting:
msg = format_duplicate_package_sources_message(
package_name,
[*conflicting, f"{group_name}: {requirement}"],
)
msg += f"\n\nWhile {context}."
raise RuntimeError(msg)
existing[source_url].append((targets, group_name, requirement))
return deduplicated_groups
def _direct_reference_from_local_path(path: Path) -> str | None:
"""Return a direct-reference requirement string for a local package path."""
resolved = path.resolve()
package_name = _maybe_package_name_from_path(resolved)
if package_name is None:
return None
return f"{package_name} @ {resolved.as_uri()}"
[docs]
def direct_references_from_local_paths(paths: list[Path]) -> list[str]:
"""Build direct-reference strings for local paths with readable metadata."""
direct_references = []
seen = set()
for path in paths:
direct_reference = _direct_reference_from_local_path(path)
if direct_reference is None or direct_reference in seen:
continue
direct_references.append(direct_reference)
seen.add(direct_reference)
return direct_references
[docs]
def pip_requirement_strings(
requirements: dict[str, list[Spec]],
*,
platforms: list[Platform] | None = None,
) -> list[str]:
"""Return pip requirement strings filtered to the requested platforms."""
pip_requirements = []
seen = set()
for specs in requirements.values():
for spec in specs:
if spec.which != "pip":
continue
spec_platforms = spec.platforms()
matched_platforms = spec_platforms
if platforms is not None and spec_platforms is not None:
matched_platforms = cast(
"list[Platform]",
sorted(set(spec_platforms).intersection(platforms)),
)
if not matched_platforms:
continue
requirement = spec.name_with_pin(is_pip=True)
if matched_platforms is not None:
selector = build_pep508_environment_marker(matched_platforms)
requirement = f"{requirement}; {selector}"
if requirement in seen:
continue
pip_requirements.append(requirement)
seen.add(requirement)
return pip_requirements
[docs]
def detect_duplicate_local_package_paths(paths: list[Path]) -> None:
"""Raise when multiple local paths map to the same distribution name."""
name_to_paths: dict[str, list[Path]] = defaultdict(list)
for path in paths:
resolved = path.resolve()
package_name = _maybe_package_name_from_path(resolved)
if package_name is None:
continue
canonical_name = canonicalize_name(package_name)
if resolved not in name_to_paths[canonical_name]:
name_to_paths[canonical_name].append(resolved)
duplicates = {
name: local_paths
for name, local_paths in name_to_paths.items()
if len(local_paths) > 1
}
if not duplicates:
return
duplicate_lines = []
for name, local_paths in sorted(duplicates.items()):
duplicate_lines.append(f"- {name}")
duplicate_lines.extend(f" - {path}" for path in local_paths)
msg = format_cli_diagnostic(
"Multiple local packages resolve to the same distribution name.",
why=[
"pip and uv may treat these paths as conflicting sources for one"
" package and fail with duplicate file URL errors",
],
fixes=[
"keep only one local checkout for each package name",
"use `use: pypi` or `use: skip` in `local_dependencies` to exclude"
" vendored copies",
],
)
detected_paths = "\n".join(duplicate_lines)
msg += f"\n\nDetected paths:\n{detected_paths}"
raise RuntimeError(msg)
def _simple_warning_format(
message: Warning | str,
category: type[Warning], # noqa: ARG001
filename: str,
lineno: int,
line: str | None = None, # noqa: ARG001
) -> str: # pragma: no cover
"""Format warnings without code context."""
return (
f"---------------------\n"
f"⚠️ *** WARNING *** ⚠️\n"
f"{message}\n"
f"Location: {filename}:{lineno}\n"
f"---------------------\n"
)
[docs]
def warn(
message: str | Warning,
category: type[Warning] = UserWarning,
stacklevel: int = 1,
) -> None:
"""Emit a warning with a custom format specific to this package."""
original_format = warnings.formatwarning
warnings.formatwarning = _simple_warning_format
try:
warnings.warn(message, category, stacklevel=stacklevel + 1)
finally:
warnings.formatwarning = original_format
def _cli_diagnostic_sections(
*,
detected: dict[str, str] | None,
why: list[str] | None,
fixes: list[str] | None,
tips: list[str] | None,
) -> list[tuple[str, list[str]]]:
"""Build ordered diagnostic sections for CLI messages."""
sections: list[tuple[str, list[str]]] = []
if detected:
sections.append(
("Detected:", [f"{key}: {value}" for key, value in detected.items()]),
)
if why:
sections.append(("Why this matters:", why))
if fixes:
sections.append(("Do this:", fixes))
if tips:
sections.append(("Tip:", tips))
return sections
def _format_cli_diagnostic_plain(
summary: str,
sections: list[tuple[str, list[str]]],
prefix: str,
) -> str:
"""Render a diagnostic with plain text only."""
lines = [f"{prefix} {summary}"]
for heading, items in sections:
lines.extend(["", heading, *[f"- {item}" for item in items]])
return "\n".join(lines)
def _rich_available() -> bool:
"""Return whether Rich is importable."""
return importlib.util.find_spec("rich") is not None
def _format_cli_diagnostic_with_rich(
summary: str,
sections: list[tuple[str, list[str]]],
prefix: str,
) -> str:
"""Render a diagnostic with Rich while preserving string return semantics."""
from rich import box
from rich.console import Console, Group
from rich.panel import Panel
from rich.text import Text
border_style = _diagnostic_border_style(prefix)
renderables = []
summary_line = Text()
summary_line.append(f"{prefix} ", style=f"bold {border_style}")
summary_line.append(summary, style="bold")
renderables.append(summary_line)
for heading, items in sections:
renderables.append(Text())
renderables.append(Text(heading, style="bold cyan"))
for item in items:
bullet_line = Text()
bullet_line.append("• ", style=border_style)
bullet_line.append(item)
renderables.append(bullet_line)
content_lines = [f"{prefix} {summary}"]
for heading, items in sections:
content_lines.append(heading)
content_lines.extend(f"• {item}" for item in items)
console = Console(
file=io.StringIO(),
record=True,
width=max(60, max(len(line) for line in content_lines) + 4),
color_system=None,
highlight=False,
)
console.print(
Panel.fit(
Group(*renderables),
border_style=border_style,
box=box.ROUNDED,
padding=(0, 1),
),
soft_wrap=True,
)
return console.export_text(styles=False).rstrip()
def _diagnostic_border_style(prefix: str) -> str:
"""Map a diagnostic prefix to a Rich color."""
if prefix == "⚠️":
return "yellow"
if prefix == "\N{INFORMATION SOURCE}\N{VARIATION SELECTOR-16}":
return "cyan"
return "red"
[docs]
def split_path_and_extras(input_str: str | Path) -> tuple[Path, list[str]]:
"""Parse a string of the form `path/to/file[extra1,extra2]` into parts.
Returns a tuple of the `pathlib.Path` and a list of extras
"""
if isinstance(input_str, Path):
input_str = str(input_str)
if not input_str: # Check for empty string
return Path(), []
pattern = r"^(.+?)(?:\[([^\[\]]+)\])?$"
match = re.search(pattern, input_str)
if match is None: # pragma: no cover
# I don't think this is possible, but just in case
return Path(), []
path = Path(match.group(1))
extras = match.group(2)
if not extras:
return path, []
extras = [extra.strip() for extra in extras.split(",")]
return path, extras
LocalDependencyUse = Literal["local", "pypi", "skip"]
[docs]
class LocalDependency(NamedTuple):
"""A local dependency with optional PyPI alternative and `use` mode."""
local: str
pypi: str | None = None
use: LocalDependencyUse = "local"
[docs]
def parse_folder_or_filename(folder_or_file: str | Path) -> PathWithExtras:
"""Get the path to `requirements.yaml` or `pyproject.toml` file."""
folder_or_file, extras = split_path_and_extras(folder_or_file)
path = Path(folder_or_file)
if path.is_dir():
fname_yaml = path / "requirements.yaml"
if fname_yaml.exists():
return PathWithExtras(fname_yaml, extras)
fname_toml = path / "pyproject.toml"
if fname_toml.exists() and unidep_configured_in_toml(fname_toml):
return PathWithExtras(fname_toml, extras)
msg = (
f"File `{fname_yaml}` or `{fname_toml}` (with unidep configuration)"
f" not found in `{folder_or_file}`."
)
raise FileNotFoundError(msg)
if not path.exists():
msg = f"File `{path}` not found."
raise FileNotFoundError(msg)
return PathWithExtras(path, extras)
[docs]
def defaultdict_to_dict(d: defaultdict | Any) -> dict:
"""Convert (nested) defaultdict to (nested) dict."""
if isinstance(d, defaultdict):
d = {key: defaultdict_to_dict(value) for key, value in d.items()}
return d
[docs]
def get_package_version(package_name: str) -> str | None:
"""Returns the version of the given package.
Parameters
----------
package_name
The name of the package to find the version of.
Returns
-------
The version of the package, or None if the package is not found.
"""
if sys.version_info >= (3, 8):
import importlib.metadata
try:
return importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
return None
else: # pragma: no cover
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
import pkg_resources
try:
return pkg_resources.get_distribution(package_name).version
except pkg_resources.DistributionNotFound:
return None