From e364271f462d60e39b4b8591908cc3b9643388e1 Mon Sep 17 00:00:00 2001 From: "Jens H. Nielsen" Date: Mon, 23 Mar 2026 15:14:03 +0100 Subject: [PATCH 1/6] Add tests for previously uncovered utility, dataset, and metadatable modules Add 128 new tests across 12 new test files targeting modules that previously had 0% or very low test coverage (excluding instrument_drivers). New test files: - tests/utils/test_types.py: numpy type tuple definitions and composition - tests/utils/test_abstractmethod.py: qcodes_abstractmethod decorator - tests/utils/test_deprecate.py: QCoDeSDeprecationWarning class - tests/utils/test_deep_update_utils.py: recursive dict merging - tests/utils/test_path_helpers.py: QCoDeS path resolution utilities - tests/utils/test_numpy_utils.py: ragged array conversion - tests/dataset/test_snapshot_utils.py: dataset snapshot diffing - tests/dataset/test_json_exporter.py: JSON linear/heatmap export - tests/dataset/test_export_config.py: export config get/set functions - tests/dataset/test_rundescribertypes.py: TypedDict versioned schemas - tests/dataset/test_sqlite_settings_extended.py: SQLite settings/limits - tests/test_metadatable_base.py: Metadatable and MetadatableWithName Non-driver test coverage: 64.7% -> 65.4% Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/dataset/test_export_config.py | 83 ++++++++ tests/dataset/test_json_exporter.py | 135 ++++++++++++ tests/dataset/test_rundescribertypes.py | 169 +++++++++++++++ tests/dataset/test_snapshot_utils.py | 129 ++++++++++++ .../dataset/test_sqlite_settings_extended.py | 66 ++++++ tests/test_metadatable_base.py | 145 +++++++++++++ tests/utils/test_abstractmethod.py | 77 +++++++ tests/utils/test_deep_update_utils.py | 115 ++++++++++ tests/utils/test_deprecate.py | 62 ++++++ tests/utils/test_numpy_utils.py | 85 ++++++++ tests/utils/test_path_helpers.py | 92 ++++++++ tests/utils/test_types.py | 197 ++++++++++++++++++ 12 files changed, 1355 insertions(+) create mode 100644 tests/dataset/test_export_config.py create mode 100644 tests/dataset/test_json_exporter.py create mode 100644 tests/dataset/test_rundescribertypes.py create mode 100644 tests/dataset/test_snapshot_utils.py create mode 100644 tests/dataset/test_sqlite_settings_extended.py create mode 100644 tests/test_metadatable_base.py create mode 100644 tests/utils/test_abstractmethod.py create mode 100644 tests/utils/test_deep_update_utils.py create mode 100644 tests/utils/test_deprecate.py create mode 100644 tests/utils/test_numpy_utils.py create mode 100644 tests/utils/test_path_helpers.py create mode 100644 tests/utils/test_types.py diff --git a/tests/dataset/test_export_config.py b/tests/dataset/test_export_config.py new file mode 100644 index 000000000000..1032e612b594 --- /dev/null +++ b/tests/dataset/test_export_config.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from qcodes.dataset.export_config import ( + DataExportType, + get_data_export_name_elements, + get_data_export_prefix, + get_data_export_type, + set_data_export_prefix, + set_data_export_type, +) + + +def test_data_export_type_enum_members() -> None: + assert DataExportType.NETCDF.value == "nc" + assert DataExportType.CSV.value == "csv" + assert len(DataExportType) == 2 + + +def test_get_data_export_type_with_string_netcdf() -> None: + result = get_data_export_type("NETCDF") + assert result is DataExportType.NETCDF + + +def test_get_data_export_type_with_string_csv() -> None: + result = get_data_export_type("CSV") + assert result is DataExportType.CSV + + +def test_get_data_export_type_case_insensitive() -> None: + assert get_data_export_type("netcdf") is DataExportType.NETCDF + assert get_data_export_type("csv") is DataExportType.CSV + assert get_data_export_type("Csv") is DataExportType.CSV + + +def test_get_data_export_type_with_enum_input() -> None: + result = get_data_export_type(DataExportType.NETCDF) + assert result is DataExportType.NETCDF + + result = get_data_export_type(DataExportType.CSV) + assert result is DataExportType.CSV + + +def test_get_data_export_type_with_none_returns_none() -> None: + # When config export_type is also None/empty, should return None + set_data_export_type(None) # type: ignore[arg-type] + result = get_data_export_type(None) + assert result is None + + +def test_get_data_export_type_with_invalid_string_returns_none() -> None: + result = get_data_export_type("nonexistent_format") + assert result is None + + +def test_set_and_get_data_export_prefix_roundtrip() -> None: + set_data_export_prefix("my_prefix_") + assert get_data_export_prefix() == "my_prefix_" + + set_data_export_prefix("") + assert get_data_export_prefix() == "" + + +def test_get_data_export_name_elements_returns_list() -> None: + result = get_data_export_name_elements() + assert isinstance(result, list) + + +def test_set_data_export_type_valid() -> None: + set_data_export_type("netcdf") + result = get_data_export_type() + assert result is DataExportType.NETCDF + + set_data_export_type("csv") + result = get_data_export_type() + assert result is DataExportType.CSV + + +def test_set_data_export_type_invalid_does_not_change_config() -> None: + set_data_export_type("netcdf") + set_data_export_type("invalid_type") + # Config should still have the previous valid value + result = get_data_export_type() + assert result is DataExportType.NETCDF diff --git a/tests/dataset/test_json_exporter.py b/tests/dataset/test_json_exporter.py new file mode 100644 index 000000000000..f2e6229d82e0 --- /dev/null +++ b/tests/dataset/test_json_exporter.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import copy +import json +from pathlib import Path + +import numpy as np + +from qcodes.dataset.json_exporter import ( + export_data_as_json_heatmap, + export_data_as_json_linear, + json_template_heatmap, + json_template_linear, +) + + +def test_json_template_linear_structure() -> None: + assert json_template_linear["type"] == "linear" + assert "x" in json_template_linear + assert "y" in json_template_linear + assert isinstance(json_template_linear["x"], dict) + assert isinstance(json_template_linear["y"], dict) + assert "data" in json_template_linear["x"] + assert "data" in json_template_linear["y"] + assert json_template_linear["x"]["is_setpoint"] is True + assert json_template_linear["y"]["is_setpoint"] is False + + +def test_json_template_heatmap_structure() -> None: + assert json_template_heatmap["type"] == "heatmap" + assert "x" in json_template_heatmap + assert "y" in json_template_heatmap + assert "z" in json_template_heatmap + assert isinstance(json_template_heatmap["x"], dict) + assert isinstance(json_template_heatmap["y"], dict) + assert isinstance(json_template_heatmap["z"], dict) + assert json_template_heatmap["x"]["is_setpoint"] is True + assert json_template_heatmap["y"]["is_setpoint"] is True + assert json_template_heatmap["z"]["is_setpoint"] is False + + +def test_export_linear_writes_correct_json(tmp_path: Path) -> None: + location = str(tmp_path / "linear.json") + state: dict = {"json": copy.deepcopy(json_template_linear)} + data = [[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]] + + export_data_as_json_linear(data, len(data), state, location) + + with open(location) as f: + result = json.load(f) + + assert result["type"] == "linear" + assert result["x"]["data"] == [1.0, 2.0, 3.0] + assert result["y"]["data"] == [10.0, 20.0, 30.0] + + +def test_export_linear_accumulates_data(tmp_path: Path) -> None: + location = str(tmp_path / "linear.json") + state: dict = {"json": copy.deepcopy(json_template_linear)} + + export_data_as_json_linear([[1.0, 10.0]], 1, state, location) + export_data_as_json_linear([[2.0, 20.0]], 2, state, location) + + with open(location) as f: + result = json.load(f) + + assert result["x"]["data"] == [1.0, 2.0] + assert result["y"]["data"] == [10.0, 20.0] + + +def test_export_linear_does_nothing_for_empty_data(tmp_path: Path) -> None: + location = str(tmp_path / "linear.json") + state: dict = {"json": copy.deepcopy(json_template_linear)} + + export_data_as_json_linear([], 0, state, location) + + assert not Path(location).exists() + + +def test_export_heatmap_writes_correct_json(tmp_path: Path) -> None: + location = str(tmp_path / "heatmap.json") + xlen = 2 + ylen = 3 + total = xlen * ylen + + state: dict = { + "json": copy.deepcopy(json_template_heatmap), + "data": { + "x": np.zeros(total), + "y": np.zeros(total), + "z": np.zeros(total), + "location": 0, + "xlen": xlen, + "ylen": ylen, + }, + } + + # 2x3 grid: x varies slowly, y varies fast + data = [ + [0.0, 0.0, 1.0], + [0.0, 1.0, 2.0], + [0.0, 2.0, 3.0], + [1.0, 0.0, 4.0], + [1.0, 1.0, 5.0], + [1.0, 2.0, 6.0], + ] + + export_data_as_json_heatmap(data, total, state, location) + + with open(location) as f: + result = json.load(f) + + assert result["type"] == "heatmap" + assert result["x"]["data"] == [0.0, 1.0] + assert result["y"]["data"] == [0.0, 1.0, 2.0] + assert result["z"]["data"] == [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + + +def test_export_heatmap_does_nothing_for_empty_data(tmp_path: Path) -> None: + location = str(tmp_path / "heatmap.json") + state: dict = { + "json": copy.deepcopy(json_template_heatmap), + "data": { + "x": np.zeros(4), + "y": np.zeros(4), + "z": np.zeros(4), + "location": 0, + "xlen": 2, + "ylen": 2, + }, + } + + export_data_as_json_heatmap([], 0, state, location) + + assert not Path(location).exists() diff --git a/tests/dataset/test_rundescribertypes.py b/tests/dataset/test_rundescribertypes.py new file mode 100644 index 000000000000..1bccd33e1dae --- /dev/null +++ b/tests/dataset/test_rundescribertypes.py @@ -0,0 +1,169 @@ +""" +Tests for qcodes.dataset.descriptions.versioning.rundescribertypes. + +Verifies the TypedDict classes, inheritance relationships, type aliases, +and the RunDescriberDicts union. +""" + +from __future__ import annotations + +import typing + +from typing_extensions import get_annotations, get_original_bases + +from qcodes.dataset.descriptions.versioning.rundescribertypes import ( + InterDependencies_Dict, + InterDependenciesDict, + RunDescriberDicts, + RunDescriberV0Dict, + RunDescriberV1Dict, + RunDescriberV2Dict, + RunDescriberV3Dict, + Shapes, +) + +# --------------- Shapes type alias --------------- + + +def test_shapes_type_alias() -> None: + sample: Shapes = {"param": (1, 2, 3)} + assert sample["param"] == (1, 2, 3) + + +# --------------- InterDependenciesDict --------------- + + +def test_interdependencies_dict_instantiation() -> None: + d: InterDependenciesDict = {"paramspecs": ()} + assert d["paramspecs"] == () + + +# --------------- InterDependencies_Dict --------------- + + +def test_interdependencies_underscore_dict_instantiation() -> None: + d: InterDependencies_Dict = { + "parameters": {}, + "dependencies": {}, + "inferences": {}, + "standalones": [], + } + assert d["parameters"] == {} + assert d["standalones"] == [] + + +# --------------- RunDescriberV0Dict --------------- + + +def test_v0_dict_instantiation() -> None: + d: RunDescriberV0Dict = { + "version": 0, + "interdependencies": {"paramspecs": ()}, + } + assert d["version"] == 0 + + +# --------------- RunDescriberV1Dict --------------- + + +def test_v1_dict_instantiation() -> None: + d: RunDescriberV1Dict = { + "version": 1, + "interdependencies": { + "parameters": {}, + "dependencies": {}, + "inferences": {}, + "standalones": [], + }, + } + assert d["version"] == 1 + + +# --------------- RunDescriberV2Dict inherits from V0 --------------- + + +def test_v2_dict_inherits_from_v0() -> None: + # typing_extensions TypedDict flattens __bases__ to (dict,) at runtime; + # verify structural inheritance via __orig_bases__ and annotations. + assert RunDescriberV0Dict in get_original_bases(RunDescriberV2Dict) + # V2 should contain all V0 keys plus its own + v0_keys = set(get_annotations(RunDescriberV0Dict)) + v2_keys = set(get_annotations(RunDescriberV2Dict)) + assert v0_keys.issubset(v2_keys) + + +def test_v2_dict_instantiation() -> None: + d: RunDescriberV2Dict = { + "version": 2, + "interdependencies": {"paramspecs": ()}, + "interdependencies_": { + "parameters": {}, + "dependencies": {}, + "inferences": {}, + "standalones": [], + }, + } + assert d["version"] == 2 + assert "interdependencies_" in d + + +# --------------- RunDescriberV3Dict inherits from V2 --------------- + + +def test_v3_dict_inherits_from_v2() -> None: + assert RunDescriberV2Dict in get_original_bases(RunDescriberV3Dict) + v2_keys = set(get_annotations(RunDescriberV2Dict)) + v3_keys = set(get_annotations(RunDescriberV3Dict)) + assert v2_keys.issubset(v3_keys) + + +def test_v3_dict_inherits_from_v0_transitively() -> None: + # V3 inherits from V2 which inherits from V0 — all V0 keys present + v0_keys = set(get_annotations(RunDescriberV0Dict)) + v3_keys = set(get_annotations(RunDescriberV3Dict)) + assert v0_keys.issubset(v3_keys) + + +def test_v3_dict_instantiation() -> None: + d: RunDescriberV3Dict = { + "version": 3, + "interdependencies": {"paramspecs": ()}, + "interdependencies_": { + "parameters": {}, + "dependencies": {}, + "inferences": {}, + "standalones": [], + }, + "shapes": {"x": (10,)}, + } + assert d["version"] == 3 + assert d["shapes"] == {"x": (10,)} + + +def test_v3_dict_shapes_none() -> None: + d: RunDescriberV3Dict = { + "version": 3, + "interdependencies": {"paramspecs": ()}, + "interdependencies_": { + "parameters": {}, + "dependencies": {}, + "inferences": {}, + "standalones": [], + }, + "shapes": None, + } + assert d["shapes"] is None + + +# --------------- RunDescriberDicts union --------------- + + +def test_rundescriber_dicts_includes_all_versions() -> None: + args = typing.get_args(RunDescriberDicts) + expected = { + RunDescriberV0Dict, + RunDescriberV1Dict, + RunDescriberV2Dict, + RunDescriberV3Dict, + } + assert set(args) == expected diff --git a/tests/dataset/test_snapshot_utils.py b/tests/dataset/test_snapshot_utils.py new file mode 100644 index 000000000000..5f4a8853806a --- /dev/null +++ b/tests/dataset/test_snapshot_utils.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from qcodes.dataset.snapshot_utils import diff_param_snapshots +from qcodes.utils import ParameterDiff + + +def _make_mock_dataset(run_id: int, snapshot: dict | None) -> MagicMock: + ds = MagicMock() + ds.run_id = run_id + ds.snapshot = snapshot + return ds + + +def test_diff_param_snapshots_both_have_snapshots() -> None: + left_snapshot = { + "station": { + "parameters": { + "p1": {"value": 1.0}, + "p2": {"value": 2.0}, + } + } + } + right_snapshot = { + "station": { + "parameters": { + "p1": {"value": 1.0}, + "p3": {"value": 3.0}, + } + } + } + left = _make_mock_dataset(1, left_snapshot) + right = _make_mock_dataset(2, right_snapshot) + + result = diff_param_snapshots(left, right) + + assert isinstance(result, ParameterDiff) + assert result.left_only == {"p2": 2.0} + assert result.right_only == {"p3": 3.0} + assert result.changed == {} + + +def test_diff_param_snapshots_identical_snapshots() -> None: + snapshot = { + "station": { + "parameters": { + "p1": {"value": 1.0}, + } + } + } + left = _make_mock_dataset(1, snapshot) + right = _make_mock_dataset(2, snapshot) + + result = diff_param_snapshots(left, right) + + assert result.left_only == {} + assert result.right_only == {} + assert result.changed == {} + + +def test_diff_param_snapshots_changed_values() -> None: + left_snapshot = { + "station": { + "parameters": { + "p1": {"value": 1.0}, + } + } + } + right_snapshot = { + "station": { + "parameters": { + "p1": {"value": 99.0}, + } + } + } + left = _make_mock_dataset(1, left_snapshot) + right = _make_mock_dataset(2, right_snapshot) + + result = diff_param_snapshots(left, right) + + assert result.changed == {"p1": (1.0, 99.0)} + + +def test_diff_param_snapshots_raises_when_left_snapshot_is_none() -> None: + left = _make_mock_dataset(run_id=5, snapshot=None) + right = _make_mock_dataset( + run_id=6, + snapshot={"station": {"parameters": {"p1": {"value": 1.0}}}}, + ) + + with pytest.raises(RuntimeError, match="5"): + diff_param_snapshots(left, right) + + +def test_diff_param_snapshots_raises_when_right_snapshot_is_none() -> None: + left = _make_mock_dataset( + run_id=7, + snapshot={"station": {"parameters": {"p1": {"value": 1.0}}}}, + ) + right = _make_mock_dataset(run_id=8, snapshot=None) + + with pytest.raises(RuntimeError, match="8"): + diff_param_snapshots(left, right) + + +def test_diff_param_snapshots_raises_when_both_snapshots_are_none() -> None: + left = _make_mock_dataset(run_id=10, snapshot=None) + right = _make_mock_dataset(run_id=11, snapshot=None) + + # When both are None, the left dataset is identified as the empty one + with pytest.raises(RuntimeError, match="10"): + diff_param_snapshots(left, right) + + +def test_diff_param_snapshots_error_message_includes_run_id() -> None: + left = _make_mock_dataset(run_id=42, snapshot=None) + right = _make_mock_dataset( + run_id=99, + snapshot={"station": {"parameters": {}}}, + ) + + with pytest.raises(RuntimeError) as exc_info: + diff_param_snapshots(left, right) + + assert "42" in str(exc_info.value) + assert "snapshot" in str(exc_info.value).lower() diff --git a/tests/dataset/test_sqlite_settings_extended.py b/tests/dataset/test_sqlite_settings_extended.py new file mode 100644 index 000000000000..ff10aed658ec --- /dev/null +++ b/tests/dataset/test_sqlite_settings_extended.py @@ -0,0 +1,66 @@ +""" +Extended tests for qcodes.dataset.sqlite.settings beyond the minimal +checks in test_sqlitesettings.py. +""" + +from __future__ import annotations + +from qcodes.dataset.sqlite.settings import SQLiteSettings, _read_settings + +# --------------- _read_settings returns a 2-tuple of dicts --------------- + + +def test_read_settings_returns_two_dicts() -> None: + result = _read_settings() + assert isinstance(result, tuple) + assert len(result) == 2 + limits, settings = result + assert isinstance(limits, dict) + assert isinstance(settings, dict) + + +# --------------- settings dict --------------- + + +def test_settings_contains_version_key() -> None: + assert "VERSION" in SQLiteSettings.settings + + +def test_settings_version_is_string() -> None: + assert isinstance(SQLiteSettings.settings["VERSION"], str) + + +def test_settings_dict_is_non_empty() -> None: + assert len(SQLiteSettings.settings) >= 1 + + +# --------------- limits dict --------------- + + +EXPECTED_LIMIT_KEYS = { + "MAX_ATTACHED", + "MAX_COLUMN", + "MAX_COMPOUND_SELECT", + "MAX_EXPR_DEPTH", + "MAX_FUNCTION_ARG", + "MAX_LENGTH", + "MAX_LIKE_PATTERN_LENGTH", + "MAX_PAGE_COUNT", + "MAX_SQL_LENGTH", + "MAX_VARIABLE_NUMBER", +} + + +def test_limits_contains_expected_keys() -> None: + assert EXPECTED_LIMIT_KEYS.issubset(SQLiteSettings.limits.keys()) + + +def test_each_limit_value_is_int_or_str() -> None: + for key, value in SQLiteSettings.limits.items(): + assert isinstance(value, (int, str)), ( + f"Limit {key!r} should be int or str, got {type(value)}" + ) + + +def test_limits_has_ten_entries() -> None: + assert len(SQLiteSettings.limits) == 10 diff --git a/tests/test_metadatable_base.py b/tests/test_metadatable_base.py new file mode 100644 index 000000000000..304aa5c8a841 --- /dev/null +++ b/tests/test_metadatable_base.py @@ -0,0 +1,145 @@ +""" +Tests for qcodes.metadatable.metadatable_base covering branches +not exercised by test_metadata.py. +""" + +from __future__ import annotations + +from typing import Any + +from qcodes.metadatable import Metadatable +from qcodes.metadatable.metadatable_base import MetadatableWithName, Snapshot + +# --------------- helpers --------------- + + +class ConcreteMetadatableWithName(MetadatableWithName): + """Minimal concrete implementation for testing.""" + + def __init__( + self, + name: str, + full: str | None = None, + metadata: dict[str, Any] | None = None, + ): + self._short_name = name + self._full_name = full or name + super().__init__(metadata=metadata) + + @property + def short_name(self) -> str: + return self._short_name + + @property + def full_name(self) -> str: + return self._full_name + + +# --------------- Snapshot type alias --------------- + + +def test_snapshot_type_alias_is_dict_str_any() -> None: + assert Snapshot == dict[str, Any] + + +# --------------- Metadatable.__init__ --------------- + + +def test_init_with_none_metadata() -> None: + m = Metadatable(metadata=None) + assert m.metadata == {} + + +def test_init_with_no_arguments() -> None: + m = Metadatable() + assert m.metadata == {} + + +def test_init_with_metadata() -> None: + m = Metadatable(metadata={"key": "value"}) + assert m.metadata == {"key": "value"} + + +# --------------- snapshot without / with metadata --------------- + + +def test_snapshot_without_metadata_returns_base_only() -> None: + m = Metadatable() + snap = m.snapshot() + assert snap == {} + assert "metadata" not in snap + + +def test_snapshot_with_metadata_includes_metadata_key() -> None: + m = Metadatable(metadata={"x": 1}) + snap = m.snapshot() + assert "metadata" in snap + assert snap["metadata"] == {"x": 1} + + +def test_snapshot_metadata_removed_after_clear() -> None: + m = Metadatable(metadata={"a": 1}) + assert "metadata" in m.snapshot() + m.metadata.clear() + assert "metadata" not in m.snapshot() + + +# --------------- snapshot_base default --------------- + + +def test_snapshot_base_default_returns_empty_dict() -> None: + m = Metadatable() + assert m.snapshot_base() == {} + assert m.snapshot_base(update=True) == {} + assert m.snapshot_base(params_to_skip_update=["p1"]) == {} + + +# --------------- load_metadata deep_update --------------- + + +def test_load_metadata_deep_updates_nested_dicts() -> None: + m = Metadatable(metadata={"outer": {"a": 1, "b": 2}}) + m.load_metadata({"outer": {"b": 99, "c": 3}}) + assert m.metadata == {"outer": {"a": 1, "b": 99, "c": 3}} + + +def test_load_metadata_adds_new_top_level_keys() -> None: + m = Metadatable(metadata={"first": 1}) + m.load_metadata({"second": 2}) + assert m.metadata == {"first": 1, "second": 2} + + +# --------------- MetadatableWithName --------------- + + +def test_metadatable_with_name_has_abstract_methods() -> None: + # MetadatableWithName uses @abstractmethod for static analysis; + # verify the property descriptors are marked abstract. + for attr_name in ("short_name", "full_name"): + descriptor = getattr(MetadatableWithName, attr_name) + assert isinstance(descriptor, property) + assert getattr(descriptor.fget, "__isabstractmethod__", False) + + +def test_concrete_metadatable_with_name_short_name() -> None: + obj = ConcreteMetadatableWithName("sensor") + assert obj.short_name == "sensor" + + +def test_concrete_metadatable_with_name_full_name() -> None: + obj = ConcreteMetadatableWithName("sensor", full="instrument_sensor") + assert obj.full_name == "instrument_sensor" + + +def test_concrete_metadatable_with_name_inherits_metadata() -> None: + obj = ConcreteMetadatableWithName("s", metadata={"cal": True}) + assert obj.metadata == {"cal": True} + snap = obj.snapshot() + assert snap["metadata"] == {"cal": True} + + +def test_concrete_metadatable_with_name_snapshot_no_metadata() -> None: + obj = ConcreteMetadatableWithName("s") + snap = obj.snapshot() + assert snap == {} + assert "metadata" not in snap diff --git a/tests/utils/test_abstractmethod.py b/tests/utils/test_abstractmethod.py new file mode 100644 index 000000000000..8d6954c75eda --- /dev/null +++ b/tests/utils/test_abstractmethod.py @@ -0,0 +1,77 @@ +""" +Tests for qcodes.utils.abstractmethod - custom abstract method decorator. +""" + +from qcodes.utils.abstractmethod import qcodes_abstractmethod + + +def test_decorator_sets_attribute() -> None: + """Test that the decorator sets __qcodes_is_abstract_method__ to True.""" + + @qcodes_abstractmethod + def my_func() -> None: + pass + + assert hasattr(my_func, "__qcodes_is_abstract_method__") + assert my_func.__qcodes_is_abstract_method__ is True # type: ignore[attr-defined] + + +def test_decorator_returns_same_function() -> None: + """Test that the decorator returns the original function object.""" + + def my_func() -> None: + pass + + result = qcodes_abstractmethod(my_func) + assert result is my_func + + +def test_decorated_function_is_still_callable() -> None: + """Test that the decorated function can still be called.""" + + @qcodes_abstractmethod + def my_func(x: int) -> int: + return x * 2 + + assert my_func(5) == 10 + + +def test_class_with_qcodes_abstractmethod_can_be_instantiated() -> None: + """Test that unlike abc.abstractmethod, classes can still be instantiated.""" + + class MyClass: + @qcodes_abstractmethod + def my_method(self) -> str: + return "base" + + instance = MyClass() + assert instance.my_method() == "base" + + +def test_undecorated_function_lacks_attribute() -> None: + """Test that undecorated functions don't have the attribute.""" + + def regular_func() -> None: + pass + + assert not hasattr(regular_func, "__qcodes_is_abstract_method__") + + +def test_multiple_methods_decorated() -> None: + """Test that multiple methods in a class can be decorated independently.""" + + class MyClass: + @qcodes_abstractmethod + def method_a(self) -> None: + pass + + @qcodes_abstractmethod + def method_b(self) -> None: + pass + + def method_c(self) -> None: + pass + + assert hasattr(MyClass.method_a, "__qcodes_is_abstract_method__") + assert hasattr(MyClass.method_b, "__qcodes_is_abstract_method__") + assert not hasattr(MyClass.method_c, "__qcodes_is_abstract_method__") diff --git a/tests/utils/test_deep_update_utils.py b/tests/utils/test_deep_update_utils.py new file mode 100644 index 000000000000..e8d1f76f40a0 --- /dev/null +++ b/tests/utils/test_deep_update_utils.py @@ -0,0 +1,115 @@ +""" +Tests for qcodes.utils.deep_update_utils - recursive dict merging. +""" + +from qcodes.utils.deep_update_utils import deep_update + + +def test_simple_key_value_update() -> None: + """Test updating simple key-value pairs.""" + dest = {"a": 1, "b": 2} + update = {"b": 3} + result = deep_update(dest, update) + assert result["a"] == 1 + assert result["b"] == 3 + + +def test_nested_dict_merging() -> None: + """Test that nested dicts are merged recursively.""" + dest = {"a": {"x": 1, "y": 2}} + update = {"a": {"y": 3, "z": 4}} + result = deep_update(dest, update) + assert result["a"] == {"x": 1, "y": 3, "z": 4} + + +def test_deeply_nested_dict_merging() -> None: + """Test recursive merging multiple levels deep.""" + dest = {"a": {"b": {"c": 1, "d": 2}}} + update = {"a": {"b": {"d": 3, "e": 4}}} + result = deep_update(dest, update) + assert result["a"]["b"] == {"c": 1, "d": 3, "e": 4} + + +def test_lists_replaced_entirely() -> None: + """Test that lists are replaced completely, not merged.""" + dest = {"a": [1, 2, 3]} + update = {"a": [4, 5]} + result = deep_update(dest, update) + assert result["a"] == [4, 5] + + +def test_new_keys_added() -> None: + """Test that new keys from update are added to dest.""" + dest = {"a": 1} + update = {"b": 2, "c": 3} + result = deep_update(dest, update) + assert result == {"a": 1, "b": 2, "c": 3} + + +def test_non_dict_replaces_dict() -> None: + """Test that a non-dict value replaces a dict value.""" + dest = {"a": {"x": 1}} + update = {"a": "string_value"} + result = deep_update(dest, update) + assert result["a"] == "string_value" + + +def test_dict_replaces_non_dict() -> None: + """Test that a dict value replaces a non-dict value.""" + dest = {"a": "string_value"} + update = {"a": {"x": 1}} + result = deep_update(dest, update) + assert result["a"] == {"x": 1} + + +def test_returns_dest_dict() -> None: + """Test that deep_update returns the dest dict (mutated in place).""" + dest = {"a": 1} + update = {"b": 2} + result = deep_update(dest, update) + assert result is dest + + +def test_deep_copy_of_update_values() -> None: + """Test that mutations to the update dict don't affect dest.""" + inner_list = [1, 2, 3] + dest: dict = {} + update = {"a": inner_list} + deep_update(dest, update) + + inner_list.append(4) + assert dest["a"] == [1, 2, 3] + + +def test_deep_copy_of_nested_update_values() -> None: + """Test that deep copies are made for nested structures.""" + inner_dict = {"x": [1, 2]} + dest: dict = {"a": 1} + update = {"b": inner_dict} + deep_update(dest, update) + + inner_dict["x"].append(3) + assert dest["b"]["x"] == [1, 2] + + +def test_empty_update() -> None: + """Test that an empty update leaves dest unchanged.""" + dest = {"a": 1, "b": 2} + result = deep_update(dest, {}) + assert result == {"a": 1, "b": 2} + + +def test_empty_dest() -> None: + """Test updating an empty dest with values.""" + dest: dict = {} + update = {"a": 1, "b": {"c": 2}} + result = deep_update(dest, update) + assert result == {"a": 1, "b": {"c": 2}} + + +def test_none_values() -> None: + """Test that None values are handled correctly.""" + dest = {"a": 1} + update = {"a": None} + result = deep_update(dest, update) + assert result["a"] is None diff --git a/tests/utils/test_deprecate.py b/tests/utils/test_deprecate.py new file mode 100644 index 000000000000..8f12e796a106 --- /dev/null +++ b/tests/utils/test_deprecate.py @@ -0,0 +1,62 @@ +""" +Tests for qcodes.utils.deprecate - QCoDeSDeprecationWarning. +""" + +import warnings + +import pytest + +from qcodes.utils.deprecate import QCoDeSDeprecationWarning + + +def test_is_subclass_of_runtime_warning() -> None: + """Test that QCoDeSDeprecationWarning is a subclass of RuntimeWarning.""" + assert issubclass(QCoDeSDeprecationWarning, RuntimeWarning) + + +def test_is_not_subclass_of_deprecation_warning() -> None: + """Test that it is not a subclass of the standard DeprecationWarning.""" + assert not issubclass(QCoDeSDeprecationWarning, DeprecationWarning) + + +def test_can_be_raised_and_caught() -> None: + """Test that QCoDeSDeprecationWarning can be raised and caught.""" + with pytest.raises(QCoDeSDeprecationWarning, match="test message"): + raise QCoDeSDeprecationWarning("test message") + + +def test_can_be_caught_as_runtime_warning() -> None: + """Test that it can be caught as a RuntimeWarning.""" + with pytest.raises(RuntimeWarning): + raise QCoDeSDeprecationWarning("test message") + + +def test_can_be_used_with_warnings_warn() -> None: + """Test that it can be used with warnings.warn.""" + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + warnings.warn("deprecation message", QCoDeSDeprecationWarning, stacklevel=1) + + assert len(caught) == 1 + assert issubclass(caught[0].category, QCoDeSDeprecationWarning) + assert "deprecation message" in str(caught[0].message) + + +def test_instance_attributes() -> None: + """Test that the warning carries its message.""" + warning = QCoDeSDeprecationWarning("my message") + assert str(warning) == "my message" + assert isinstance(warning, RuntimeWarning) + + +def test_not_suppressed_by_default_warning_filters() -> None: + """Test that QCoDeSDeprecationWarning is visible with default filters. + + Standard DeprecationWarning is suppressed by default, but since this + inherits from RuntimeWarning it should not be. + """ + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("default") + warnings.warn("should be visible", QCoDeSDeprecationWarning, stacklevel=1) + + assert len(caught) == 1 diff --git a/tests/utils/test_numpy_utils.py b/tests/utils/test_numpy_utils.py new file mode 100644 index 000000000000..367460d55659 --- /dev/null +++ b/tests/utils/test_numpy_utils.py @@ -0,0 +1,85 @@ +""" +Tests for qcodes.utils.numpy_utils - numpy array conversion utilities. +""" + +import numpy as np + +from qcodes.utils.numpy_utils import list_of_data_to_maybe_ragged_nd_array + + +def test_regular_list_converts_to_array() -> None: + """Test that a simple list converts to a 1D numpy array.""" + data = [1, 2, 3] + result = list_of_data_to_maybe_ragged_nd_array(data) + np.testing.assert_array_equal(result, np.array([1, 2, 3])) + assert result.ndim == 1 + + +def test_nested_lists_same_length_create_2d_array() -> None: + """Test that nested lists of equal length create a 2D array.""" + data = [[1, 2], [3, 4], [5, 6]] + result = list_of_data_to_maybe_ragged_nd_array(data) + expected = np.array([[1, 2], [3, 4], [5, 6]]) + np.testing.assert_array_equal(result, expected) + assert result.shape == (3, 2) + + +def test_ragged_nested_lists_return_object_array() -> None: + """Test that ragged nested lists produce an object-dtype array.""" + data = [[1, 2], [3, 4, 5], [6]] + result = list_of_data_to_maybe_ragged_nd_array(data) + assert result.dtype == object + assert len(result) == 3 + + +def test_dtype_parameter_is_respected() -> None: + """Test that the dtype parameter is used for the output array.""" + data = [1, 2, 3] + result = list_of_data_to_maybe_ragged_nd_array(data, dtype=float) + assert result.dtype == np.float64 + np.testing.assert_array_equal(result, np.array([1.0, 2.0, 3.0])) + + +def test_empty_list() -> None: + """Test that an empty list converts to an empty array.""" + data: list = [] + result = list_of_data_to_maybe_ragged_nd_array(data) + assert len(result) == 0 + + +def test_single_element_list() -> None: + """Test that a single element list converts correctly.""" + data = [42] + result = list_of_data_to_maybe_ragged_nd_array(data) + np.testing.assert_array_equal(result, np.array([42])) + assert result.shape == (1,) + + +def test_list_of_floats() -> None: + """Test that a list of floats converts correctly.""" + data = [1.1, 2.2, 3.3] + result = list_of_data_to_maybe_ragged_nd_array(data) + np.testing.assert_array_almost_equal(result, np.array([1.1, 2.2, 3.3])) + + +def test_list_of_strings() -> None: + """Test that a list of strings converts to a string array.""" + data = ["a", "b", "c"] + result = list_of_data_to_maybe_ragged_nd_array(data) + np.testing.assert_array_equal(result, np.array(["a", "b", "c"])) + + +def test_3d_uniform_data() -> None: + """Test that uniformly nested 3D data creates a 3D array.""" + data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + result = list_of_data_to_maybe_ragged_nd_array(data) + assert result.shape == (2, 2, 2) + + +def test_ragged_array_preserves_inner_lists() -> None: + """Test that ragged array elements are preserved correctly.""" + data = [[1, 2, 3], [4, 5]] + result = list_of_data_to_maybe_ragged_nd_array(data) + assert result.dtype == object + assert list(result[0]) == [1, 2, 3] + assert list(result[1]) == [4, 5] diff --git a/tests/utils/test_path_helpers.py b/tests/utils/test_path_helpers.py new file mode 100644 index 000000000000..5254ffeef8e0 --- /dev/null +++ b/tests/utils/test_path_helpers.py @@ -0,0 +1,92 @@ +""" +Tests for qcodes.utils.path_helpers - path utility functions. +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pytest + +from qcodes.utils.path_helpers import ( + QCODES_USER_PATH_ENV, + get_qcodes_path, + get_qcodes_user_path, +) + + +def test_get_qcodes_path_returns_string() -> None: + """Test that get_qcodes_path returns a string.""" + result = get_qcodes_path() + assert isinstance(result, str) + + +def test_get_qcodes_path_ends_with_separator() -> None: + """Test that get_qcodes_path returns a path ending with os.sep.""" + result = get_qcodes_path() + assert result.endswith(os.sep) + + +def test_get_qcodes_path_contains_qcodes() -> None: + """Test that the returned path contains 'qcodes'.""" + result = get_qcodes_path() + assert "qcodes" in result.lower() + + +def test_get_qcodes_path_with_subfolder() -> None: + """Test that get_qcodes_path appends a subfolder.""" + result = get_qcodes_path("subdir") + assert result.endswith("subdir" + os.sep) + + +def test_get_qcodes_path_with_nested_subfolders() -> None: + """Test that get_qcodes_path appends multiple subfolder parts.""" + result = get_qcodes_path("subdir", "nested") + assert "subdir" in result + assert result.endswith("nested" + os.sep) + + +def test_get_qcodes_user_path_default(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that get_qcodes_user_path returns ~/.qcodes by default.""" + monkeypatch.delenv(QCODES_USER_PATH_ENV, raising=False) + result = get_qcodes_user_path() + expected = os.path.join(str(Path.home()), ".qcodes") + assert result == expected + + +def test_get_qcodes_user_path_respects_env_var( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test that get_qcodes_user_path uses QCODES_USER_PATH env var.""" + custom_path = str(tmp_path / "custom_qcodes") + monkeypatch.setenv(QCODES_USER_PATH_ENV, custom_path) + result = get_qcodes_user_path() + assert result == custom_path + + +def test_get_qcodes_user_path_appends_file_parts( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test that get_qcodes_user_path appends file parts.""" + custom_path = str(tmp_path / "custom_qcodes") + monkeypatch.setenv(QCODES_USER_PATH_ENV, custom_path) + result = get_qcodes_user_path("config.json") + assert result == os.path.join(custom_path, "config.json") + + +def test_get_qcodes_user_path_appends_nested_parts( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test that get_qcodes_user_path appends multiple nested parts.""" + custom_path = str(tmp_path / "custom_qcodes") + monkeypatch.setenv(QCODES_USER_PATH_ENV, custom_path) + result = get_qcodes_user_path("subdir", "file.txt") + assert result == os.path.join(custom_path, "subdir", "file.txt") + + +def test_qcodes_user_path_env_constant() -> None: + """Test that the env variable constant has the expected value.""" + assert QCODES_USER_PATH_ENV == "QCODES_USER_PATH" diff --git a/tests/utils/test_types.py b/tests/utils/test_types.py new file mode 100644 index 000000000000..e4f6b6044d39 --- /dev/null +++ b/tests/utils/test_types.py @@ -0,0 +1,197 @@ +""" +Tests for qcodes.utils.types - numpy type tuples and aliases. +""" + +import numpy as np + +from qcodes.utils.types import ( + complex_types, + concrete_complex_types, + numpy_c_complex, + numpy_c_floats, + numpy_c_ints, + numpy_complex, + numpy_concrete_complex, + numpy_concrete_floats, + numpy_concrete_ints, + numpy_floats, + numpy_ints, + numpy_non_concrete_ints_instantiable, +) + + +def test_numpy_concrete_ints_contents() -> None: + """Test that numpy_concrete_ints contains the expected fixed-size int types.""" + expected = ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ) + assert numpy_concrete_ints == expected + + +def test_numpy_concrete_ints_length() -> None: + """Test that numpy_concrete_ints has 8 types.""" + assert len(numpy_concrete_ints) == 8 + + +def test_numpy_c_ints_contents() -> None: + """Test that numpy_c_ints contains the expected C-compatible int types.""" + expected = ( + np.uintp, + np.uintc, + np.intp, + np.intc, + np.short, + np.byte, + np.ushort, + np.ubyte, + np.longlong, + np.ulonglong, + ) + assert numpy_c_ints == expected + + +def test_numpy_c_ints_length() -> None: + """Test that numpy_c_ints has 10 types.""" + assert len(numpy_c_ints) == 10 + + +def test_numpy_non_concrete_ints_instantiable_contents() -> None: + """Test that numpy_non_concrete_ints_instantiable contains default int types.""" + expected = (np.int_, np.uint) + assert numpy_non_concrete_ints_instantiable == expected + + +def test_numpy_ints_is_combination() -> None: + """Test that numpy_ints is the concatenation of all int sub-tuples.""" + expected = numpy_concrete_ints + numpy_c_ints + numpy_non_concrete_ints_instantiable + assert numpy_ints == expected + + +def test_numpy_ints_length() -> None: + """Test that numpy_ints has the combined length of all int sub-tuples.""" + expected_len = ( + len(numpy_concrete_ints) + + len(numpy_c_ints) + + len(numpy_non_concrete_ints_instantiable) + ) + assert len(numpy_ints) == expected_len + + +def test_numpy_concrete_floats_contents() -> None: + """Test that numpy_concrete_floats contains fixed-size float types.""" + expected = (np.float16, np.float32, np.float64) + assert numpy_concrete_floats == expected + + +def test_numpy_c_floats_contents() -> None: + """Test that numpy_c_floats contains C-compatible float types.""" + expected = (np.half, np.single, np.double) + assert numpy_c_floats == expected + + +def test_numpy_floats_is_combination() -> None: + """Test that numpy_floats is the concatenation of float sub-tuples.""" + assert numpy_floats == numpy_concrete_floats + numpy_c_floats + + +def test_numpy_floats_length() -> None: + """Test that numpy_floats has the combined length of float sub-tuples.""" + assert len(numpy_floats) == len(numpy_concrete_floats) + len(numpy_c_floats) + + +def test_numpy_concrete_complex_contents() -> None: + """Test that numpy_concrete_complex contains fixed-size complex types.""" + expected = (np.complex64, np.complex128) + assert numpy_concrete_complex == expected + + +def test_numpy_c_complex_contents() -> None: + """Test that numpy_c_complex contains C-compatible complex types.""" + expected = (np.csingle, np.cdouble) + assert numpy_c_complex == expected + + +def test_numpy_complex_is_combination() -> None: + """Test that numpy_complex is the concatenation of complex sub-tuples.""" + assert numpy_complex == numpy_concrete_complex + numpy_c_complex + + +def test_concrete_complex_types_includes_python_complex() -> None: + """Test that concrete_complex_types includes numpy and Python complex.""" + assert complex in concrete_complex_types + for t in numpy_concrete_complex: + assert t in concrete_complex_types + + +def test_complex_types_includes_python_complex() -> None: + """Test that complex_types includes numpy and Python complex.""" + assert complex in complex_types + for t in numpy_concrete_complex: + assert t in complex_types + + +def test_all_int_types_are_numpy_integer_subclass() -> None: + """Test that all types in numpy_ints are subclasses of np.integer.""" + for t in numpy_ints: + assert issubclass(t, np.integer), f"{t} is not a subclass of np.integer" + + +def test_all_float_types_are_numpy_floating_subclass() -> None: + """Test that all types in numpy_floats are subclasses of np.floating.""" + for t in numpy_floats: + assert issubclass(t, np.floating), f"{t} is not a subclass of np.floating" + + +def test_all_complex_types_are_numpy_complexfloating_subclass() -> None: + """Test that all types in numpy_complex are subclasses of np.complexfloating.""" + for t in numpy_complex: + assert issubclass(t, np.complexfloating), ( + f"{t} is not a subclass of np.complexfloating" + ) + + +def test_concrete_int_instances() -> None: + """Test that instances of concrete int types can be created.""" + for t in numpy_concrete_ints: + val = t(42) + assert isinstance(val, np.integer) + + +def test_concrete_float_instances() -> None: + """Test that instances of concrete float types can be created.""" + for t in numpy_concrete_floats: + val = t(3.14) + assert isinstance(val, np.floating) + + +def test_concrete_complex_instances() -> None: + """Test that instances of concrete complex types can be created.""" + for t in numpy_concrete_complex: + val = t(1 + 2j) + assert isinstance(val, np.complexfloating) + + +def test_all_tuples_contain_types() -> None: + """Test that every element in every tuple is a type (class).""" + all_tuples = [ + numpy_concrete_ints, + numpy_c_ints, + numpy_non_concrete_ints_instantiable, + numpy_ints, + numpy_concrete_floats, + numpy_c_floats, + numpy_floats, + numpy_concrete_complex, + numpy_c_complex, + numpy_complex, + ] + for tup in all_tuples: + for t in tup: + assert isinstance(t, type), f"{t} is not a type" From 6f6729161907caf575fdbbf8638423930bdfa2a5 Mon Sep 17 00:00:00 2001 From: "Jens H. Nielsen" Date: Tue, 24 Mar 2026 07:06:03 +0100 Subject: [PATCH 2/6] Bump typing_extensions for test --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a9745ddacc65..b73e5b9c73a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "pyvisa>=1.11.0, <1.17.0", "ruamel.yaml>=0.16.0,!=0.16.6", "tabulate>=0.9.0", - "typing_extensions>=4.6.0", + "typing_extensions>=4.13.0", "tqdm>=4.59.0", "uncertainties>=3.2.0", "versioningit>=2.2.1", From 61b3bbb842f1ee5cd7f33ba079b99119e341c71b Mon Sep 17 00:00:00 2001 From: "Jens H. Nielsen" Date: Mon, 23 Mar 2026 20:30:43 +0100 Subject: [PATCH 3/6] Add extended tests for logger, configuration, monitor, and validators Add 207 new tests across 4 test files targeting modules with coverage between 43-56% (excluding instrument_drivers). New test files: - tests/test_logger_extended.py (28 tests): formatters, level name/code conversion, log file generation, LogCapture class, log_to_dataframe, logfile_to_dataframe, time_difference with comma separators - tests/test_config_extended.py (40 tests): DotDict dotted key access, nested set/get/contains/del, Config describe/add/getitem/save/load, update() recursive merge, schema validation edge cases - tests/test_monitor_extended.py (12 tests): _get_metadata with/without timestamps, unbound parameters, use_root_instrument flag, Monitor.show - tests/validators/test_validators_extended.py (127 tests): validate_all, range_str branches, base Validator, Nothing/Bool/Strings/Enum/OnOff/ Multiples/PermissiveMultiples/MultiType/MultiTypeOr/MultiTypeAnd/ Arrays/Lists/Sequence/Callable/Dict edge cases and error paths Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/test_config_extended.py | 303 ++++++++ tests/test_logger_extended.py | 371 ++++++++++ tests/test_monitor_extended.py | 170 +++++ tests/validators/test_validators_extended.py | 716 +++++++++++++++++++ 4 files changed, 1560 insertions(+) create mode 100644 tests/test_config_extended.py create mode 100644 tests/test_logger_extended.py create mode 100644 tests/test_monitor_extended.py create mode 100644 tests/validators/test_validators_extended.py diff --git a/tests/test_config_extended.py b/tests/test_config_extended.py new file mode 100644 index 000000000000..6bdf6f5e6cad --- /dev/null +++ b/tests/test_config_extended.py @@ -0,0 +1,303 @@ +"""Extended tests for the QCoDeS configuration module. + +Targets uncovered lines in qcodes/configuration/config.py including: +DotDict operations, Config.describe, Config.__getitem__, Config.add, +Config.save_config/save_schema, Config.__repr__, Config.__getattr__, +and the module-level update() function. +""" + +from __future__ import annotations + +import copy +import json +import logging + +import pytest + +import qcodes +from qcodes.configuration import Config +from qcodes.configuration.config import MISS_DESC, DotDict, update + +# --------------------------------------------------------------------------- +# DotDict tests +# --------------------------------------------------------------------------- + + +class TestDotDictInit: + def test_init_none(self) -> None: + d = DotDict(None) + assert len(d) == 0 + + def test_init_flat(self) -> None: + d = DotDict({"a": 1, "b": 2}) + assert d["a"] == 1 + assert d["b"] == 2 + + def test_init_nested_dict_becomes_dotdict(self) -> None: + d = DotDict({"outer": {"inner": 42}}) + assert isinstance(d["outer"], DotDict) + assert d["outer"]["inner"] == 42 + + +class TestDotDictGetItem: + def test_simple_key(self) -> None: + d = DotDict({"x": 10}) + assert d["x"] == 10 + + def test_dotted_key(self) -> None: + d = DotDict({"a": {"b": {"c": 99}}}) + assert d["a.b.c"] == 99 + + def test_missing_key_raises(self) -> None: + d = DotDict({"a": 1}) + with pytest.raises(KeyError): + d["nonexistent"] + + def test_missing_nested_key_raises(self) -> None: + d = DotDict({"a": {"b": 1}}) + with pytest.raises(KeyError): + d["a.z"] + + +class TestDotDictSetItem: + def test_simple_set(self) -> None: + d = DotDict() + d["key"] = "value" + assert d["key"] == "value" + + def test_dotted_set_creates_intermediates(self) -> None: + d = DotDict() + d["a.b.c"] = 42 + assert d["a"]["b"]["c"] == 42 + assert isinstance(d["a"], DotDict) + assert isinstance(d["a"]["b"], DotDict) + + def test_set_plain_dict_converts_to_dotdict(self) -> None: + d = DotDict() + d["x"] = {"nested": 1} + assert isinstance(d["x"], DotDict) + assert d["x"]["nested"] == 1 + + def test_overwrite_value(self) -> None: + d = DotDict({"a": {"b": 1}}) + d["a.b"] = 2 + assert d["a.b"] == 2 + + +class TestDotDictContains: + def test_simple_contains(self) -> None: + d = DotDict({"a": 1}) + assert "a" in d + assert "missing" not in d + + def test_dotted_contains(self) -> None: + d = DotDict({"a": {"b": {"c": 1}}}) + assert "a.b.c" in d + assert "a.b" in d + assert "a.b.z" not in d + + def test_non_string_key_returns_false(self) -> None: + d = DotDict({"a": 1}) + assert (123 in d) is False + + +class TestDotDictDeepCopy: + def test_deepcopy_returns_dotdict(self) -> None: + d = DotDict({"a": {"b": [1, 2, 3]}}) + d2 = copy.deepcopy(d) + assert isinstance(d2, DotDict) + assert d2["a"]["b"] == [1, 2, 3] + # mutating copy does not affect original + d2["a"]["b"].append(4) + assert d["a"]["b"] == [1, 2, 3] + + +class TestDotDictAttrAccess: + def test_getattr(self) -> None: + d = DotDict({"hello": "world"}) + assert d.hello == "world" + + def test_setattr(self) -> None: + d = DotDict() + d.foo = "bar" + assert d["foo"] == "bar" + + +# --------------------------------------------------------------------------- +# update() function tests +# --------------------------------------------------------------------------- + + +class TestUpdateFunction: + def test_simple_update(self) -> None: + d: dict = {"a": 1, "b": 2} + u = {"b": 3, "c": 4} + result = update(d, u) + assert result == {"a": 1, "b": 3, "c": 4} + assert result is d # in-place + + def test_nested_recursive_merge(self) -> None: + d: dict = {"x": {"y": 1, "z": 2}} + u = {"x": {"z": 99, "w": 100}} + result = update(d, u) + assert result["x"]["y"] == 1 + assert result["x"]["z"] == 99 + assert result["x"]["w"] == 100 + + def test_non_mapping_replaces(self) -> None: + d: dict = {"a": {"nested": True}} + u = {"a": "flat_now"} + result = update(d, u) + assert result["a"] == "flat_now" + + def test_new_nested_key(self) -> None: + d: dict = {} + u = {"a": {"b": 1}} + result = update(d, u) + assert result["a"]["b"] == 1 + + +# --------------------------------------------------------------------------- +# Config class tests +# --------------------------------------------------------------------------- + + +class TestConfigGetItem: + def test_access_top_level_section(self) -> None: + cfg = qcodes.config + core = cfg["core"] + assert isinstance(core, DotDict) + + def test_access_nested_key(self) -> None: + cfg = qcodes.config + val = cfg["core.db_debug"] + assert isinstance(val, bool) + + def test_missing_key_raises(self) -> None: + cfg = qcodes.config + with pytest.raises(KeyError): + cfg["nonexistent_section_xyz"] + + +class TestConfigDescribe: + def test_describe_known_key(self) -> None: + cfg = qcodes.config + desc = cfg.describe("core.db_debug") + assert isinstance(desc, str) + assert "Current value:" in desc + assert "Type:" in desc + assert "Default:" in desc + + def test_describe_user_section(self) -> None: + cfg = qcodes.config + cfg.add("testdesc", "val", "string", "My description", "default_val") + desc = cfg.describe("user.testdesc") + assert "My description" in desc + assert "val" in desc + assert "string" in desc + + +class TestConfigAdd: + def test_add_without_type(self) -> None: + cfg = qcodes.config + cfg.add("simple_key", "simple_val") + assert cfg.current_config is not None + assert cfg.current_config["user"]["simple_key"] == "simple_val" + + def test_add_with_type_only(self) -> None: + cfg = qcodes.config + cfg.add("typed_key", 42, "integer") + assert cfg.current_config is not None + assert cfg.current_config["user"]["typed_key"] == 42 + + def test_add_with_type_and_description(self) -> None: + cfg = qcodes.config + cfg.add("full_key", "hello", "string", "A full key", "default_hello") + assert cfg.current_config is not None + assert cfg.current_config["user"]["full_key"] == "hello" + # Verify schema was updated + assert cfg.current_schema is not None + user_props = cfg.current_schema["properties"]["user"]["properties"] + assert "full_key" in user_props + assert user_props["full_key"]["type"] == "string" + assert user_props["full_key"]["description"] == "A full key" + assert user_props["full_key"]["default"] == "default_hello" + + def test_add_description_without_type_warns( + self, caplog: pytest.LogCaptureFixture + ) -> None: + cfg = qcodes.config + with caplog.at_level(logging.WARNING, logger="qcodes.configuration.config"): + cfg.add("warn_key", "val", description="ignored desc") + assert MISS_DESC.strip() in caplog.text.strip() + + +class TestConfigRepr: + def test_repr_contains_current_info(self) -> None: + cfg = qcodes.config + r = repr(cfg) + assert "Current values:" in r + assert "Current paths:" in r + + +class TestConfigGetattr: + def test_getattr_delegates(self) -> None: + cfg = qcodes.config + # Accessing an attribute that exists in current_config + user = cfg.user + assert isinstance(user, DotDict) + + +class TestConfigSave: + def test_save_config(self, tmp_path) -> None: + cfg = Config() + path = str(tmp_path / "saved_config.json") + cfg.save_config(path) + with open(path) as f: + data = json.load(f) + assert "core" in data + + def test_save_schema(self, tmp_path) -> None: + cfg = Config() + path = str(tmp_path / "saved_schema.json") + cfg.save_schema(path) + with open(path) as f: + data = json.load(f) + assert "properties" in data + + def test_save_and_reload(self, tmp_path) -> None: + cfg = Config() + config_path = str(tmp_path / "roundtrip_config.json") + cfg.save_config(config_path) + loaded = Config.load_config(config_path) + assert isinstance(loaded, DotDict) + assert loaded["core"]["db_debug"] == cfg.current_config["core"]["db_debug"] + + +class TestConfigLoadConfig: + def test_load_config_returns_dotdict(self) -> None: + loaded = Config.load_config(Config.default_file_name) + assert isinstance(loaded, DotDict) + + def test_load_config_missing_raises(self) -> None: + with pytest.raises(FileNotFoundError): + Config.load_config("no_such_file_anywhere.json") + + +class TestConfigValidate: + def test_validate_uses_current_if_no_args(self) -> None: + cfg = Config() + # Should not raise - validates current_config against current_schema + cfg.validate() + + def test_validate_no_schema_raises(self) -> None: + cfg = Config() + cfg.current_schema = None + with pytest.raises(RuntimeError, match="Cannot validate"): + cfg.validate() + + def test_validate_no_config_uses_current(self) -> None: + cfg = Config() + # Passing schema but no json_config should use current_config + assert cfg.current_schema is not None + cfg.validate(schema=cfg.current_schema) diff --git a/tests/test_logger_extended.py b/tests/test_logger_extended.py new file mode 100644 index 000000000000..f25ae13619a6 --- /dev/null +++ b/tests/test_logger_extended.py @@ -0,0 +1,371 @@ +""" +Extended tests for ``qcodes.logger`` to improve coverage of +logger.py, log_analysis.py, and instrument_logger.py. +""" + +from __future__ import annotations + +import logging +import os +from copy import copy +from typing import TYPE_CHECKING + +import pandas as pd +import pytest + +from qcodes import logger +from qcodes.logger import ( + LogCapture, + flush_telemetry_traces, + get_console_handler, + get_file_handler, + get_level_code, + get_level_name, + get_log_file_name, + start_command_history_logger, + start_logger, +) +from qcodes.logger.log_analysis import ( + log_to_dataframe, + logfile_to_dataframe, + time_difference, +) +from qcodes.logger.logger import ( + FORMAT_STRING_DICT, + LOGGING_SEPARATOR, + PYTHON_LOG_NAME, + console_level, + generate_log_file_name, + get_formatter, + get_formatter_for_telemetry, +) + +if TYPE_CHECKING: + from collections.abc import Generator + from pathlib import Path + + +@pytest.fixture(autouse=True) +def cleanup_started_logger() -> Generator[None, None, None]: + """Cleanup state left by a test calling start_logger.""" + root_logger = logging.getLogger() + existing_handlers = copy(root_logger.handlers) + yield + post_test_handlers = copy(root_logger.handlers) + for handler in post_test_handlers: + if handler not in existing_handlers: + handler.close() + root_logger.removeHandler(handler) + logger.logger.file_handler = None + logger.logger.console_handler = None + + +# --------------------------------------------------------------------------- +# Tests for get_formatter / get_formatter_for_telemetry +# --------------------------------------------------------------------------- + + +def test_get_formatter_returns_formatter() -> None: + fmt = get_formatter() + assert isinstance(fmt, logging.Formatter) + fmt_str = fmt._fmt + assert fmt_str is not None + for key in FORMAT_STRING_DICT: + assert key in fmt_str + assert LOGGING_SEPARATOR in fmt_str + + +def test_get_formatter_for_telemetry() -> None: + fmt = get_formatter_for_telemetry() + assert isinstance(fmt, logging.Formatter) + fmt_str = fmt._fmt + assert fmt_str is not None + for key in ("message", "name", "funcName"): + assert key in fmt_str + # telemetry formatter should NOT contain asctime + assert "asctime" not in fmt_str + + +# --------------------------------------------------------------------------- +# Tests for get_console_handler / get_file_handler +# --------------------------------------------------------------------------- + + +def test_get_console_handler_before_start() -> None: + assert get_console_handler() is None + + +def test_get_file_handler_before_start() -> None: + assert get_file_handler() is None + + +def test_get_handlers_after_start() -> None: + start_logger() + assert get_console_handler() is not None + assert isinstance(get_console_handler(), logging.Handler) + assert get_file_handler() is not None + assert isinstance(get_file_handler(), logging.Handler) + + +# --------------------------------------------------------------------------- +# Tests for get_level_name +# --------------------------------------------------------------------------- + + +def test_get_level_name_from_int() -> None: + assert get_level_name(logging.DEBUG) == "DEBUG" + assert get_level_name(logging.WARNING) == "WARNING" + assert get_level_name(logging.ERROR) == "ERROR" + + +def test_get_level_name_from_str() -> None: + assert get_level_name("DEBUG") == "DEBUG" + assert get_level_name("INFO") == "INFO" + + +def test_get_level_name_invalid_type() -> None: + with pytest.raises(RuntimeError, match="get_level_name"): + get_level_name(3.14) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Tests for get_level_code +# --------------------------------------------------------------------------- + + +def test_get_level_code_from_str() -> None: + assert get_level_code("DEBUG") == logging.DEBUG + assert get_level_code("WARNING") == logging.WARNING + + +def test_get_level_code_from_int() -> None: + assert get_level_code(logging.DEBUG) == logging.DEBUG + assert get_level_code(logging.INFO) == logging.INFO + + +def test_get_level_code_invalid_type() -> None: + with pytest.raises(RuntimeError, match="get_level_code"): + get_level_code(3.14) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Tests for generate_log_file_name / get_log_file_name +# --------------------------------------------------------------------------- + + +def test_generate_log_file_name() -> None: + name = generate_log_file_name() + pid = str(os.getpid()) + assert pid in name + assert name.endswith(PYTHON_LOG_NAME) + # format is YYMMDD-PID-qcodes.log + parts = name.split("-") + assert len(parts) >= 3 + + +def test_get_log_file_name() -> None: + path = get_log_file_name() + assert isinstance(path, str) + assert path.endswith(PYTHON_LOG_NAME) + assert "logs" in path + + +# --------------------------------------------------------------------------- +# Tests for flush_telemetry_traces +# --------------------------------------------------------------------------- + + +def test_flush_telemetry_traces_no_op() -> None: + # When telemetry is not set up this should be a no-op + flush_telemetry_traces() + + +# --------------------------------------------------------------------------- +# Tests for start_command_history_logger +# --------------------------------------------------------------------------- + + +def test_start_command_history_logger_outside_ipython() -> None: + # Outside IPython, get_ipython() returns None so this should just warn + # and return without error. + start_command_history_logger() + + +# --------------------------------------------------------------------------- +# Tests for console_level +# --------------------------------------------------------------------------- + + +def test_console_level_without_handler_raises() -> None: + with pytest.raises(RuntimeError, match="Console handler is None"): + with console_level(logging.DEBUG): + pass + + +def test_console_level_with_handler() -> None: + start_logger() + handler = get_console_handler() + assert handler is not None + original_level = handler.level + with console_level(logging.DEBUG): + assert handler.level == logging.DEBUG + assert handler.level == original_level + + +# --------------------------------------------------------------------------- +# Tests for LogCapture +# --------------------------------------------------------------------------- + + +def test_log_capture_basic() -> None: + test_logger = logging.getLogger("test_log_capture_basic") + test_logger.setLevel(logging.DEBUG) + + with LogCapture(logger=test_logger, level=logging.DEBUG) as logs: + test_logger.debug("hello from capture") + + assert "hello from capture" in logs.value + + +def test_log_capture_multiple_messages() -> None: + test_logger = logging.getLogger("test_log_capture_multi") + test_logger.setLevel(logging.DEBUG) + + with LogCapture(logger=test_logger, level=logging.DEBUG) as logs: + test_logger.info("first message") + test_logger.warning("second message") + + assert "first message" in logs.value + assert "second message" in logs.value + + +def test_log_capture_level_filtering() -> None: + test_logger = logging.getLogger("test_log_capture_filter") + test_logger.setLevel(logging.DEBUG) + + with LogCapture(logger=test_logger, level=logging.WARNING) as logs: + test_logger.debug("should not appear") + test_logger.warning("should appear") + + assert "should not appear" not in logs.value + assert "should appear" in logs.value + + +def test_log_capture_restores_handlers() -> None: + test_logger = logging.getLogger("test_log_capture_restore") + test_logger.setLevel(logging.DEBUG) + dummy_handler = logging.StreamHandler() + test_logger.addHandler(dummy_handler) + handler_count_before = len(test_logger.handlers) + + with LogCapture(logger=test_logger): + test_logger.info("inside") + + assert len(test_logger.handlers) == handler_count_before + test_logger.removeHandler(dummy_handler) + + +# --------------------------------------------------------------------------- +# Tests for log_to_dataframe +# --------------------------------------------------------------------------- + + +def test_log_to_dataframe() -> None: + sep = LOGGING_SEPARATOR + columns = list(FORMAT_STRING_DICT.keys()) + log_line = sep.join( + [ + "2024-01-15 10:00:00,000", + "qcodes.logger", + "INFO", + "logger", + "start_logger", + "42", + "Test message", + ] + ) + df = log_to_dataframe([log_line]) + assert isinstance(df, pd.DataFrame) + assert list(df.columns) == columns + assert len(df) == 1 + assert df["message"].iloc[0] == "Test message" + + +def test_log_to_dataframe_skips_tracebacks() -> None: + sep = LOGGING_SEPARATOR + valid = sep.join( + [ + "2024-01-15 10:00:00,000", + "mod", + "ERROR", + "mod", + "func", + "1", + "error msg", + ] + ) + traceback_line = "Traceback (most recent call last):" + df = log_to_dataframe([valid, traceback_line]) + assert len(df) == 1 + + +def test_log_to_dataframe_empty() -> None: + df = log_to_dataframe(["Traceback line only"]) + assert isinstance(df, pd.DataFrame) + assert len(df) == 0 + + +# --------------------------------------------------------------------------- +# Tests for logfile_to_dataframe +# --------------------------------------------------------------------------- + + +def test_logfile_to_dataframe(tmp_path: Path) -> None: + sep = LOGGING_SEPARATOR + line = sep.join( + [ + "2024-06-01 12:00:00,000", + "qcodes.logger", + "DEBUG", + "logger", + "test_func", + "10", + "file log message", + ] + ) + logfile = tmp_path / "test.log" + logfile.write_text(line + "\n") + df = logfile_to_dataframe(str(logfile)) + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + assert df["message"].iloc[0].strip() == "file log message" + + +# --------------------------------------------------------------------------- +# Tests for time_difference +# --------------------------------------------------------------------------- + + +def test_time_difference_basic() -> None: + first = pd.Series(["2024-01-01 00:00:00.000", "2024-01-01 00:00:01.000"]) + second = pd.Series(["2024-01-01 00:00:01.000", "2024-01-01 00:00:03.000"]) + result = time_difference(first, second, use_first_series_labels=True) + assert isinstance(result, pd.Series) + assert len(result) == 2 + assert result.iloc[0] == pytest.approx(1.0, abs=0.01) + assert result.iloc[1] == pytest.approx(2.0, abs=0.01) + + +def test_time_difference_use_second_labels() -> None: + first = pd.Series(["2024-01-01 00:00:00.000"], index=["a"]) + second = pd.Series(["2024-01-01 00:00:05.000"], index=["b"]) + result = time_difference(first, second, use_first_series_labels=False) + assert list(result.index) == ["b"] + assert result.iloc[0] == pytest.approx(5.0, abs=0.01) + + +def test_time_difference_comma_separator() -> None: + first = pd.Series(["2024-01-01 00:00:00,000"]) + second = pd.Series(["2024-01-01 00:00:02,000"]) + result = time_difference(first, second) + assert result.iloc[0] == pytest.approx(2.0, abs=0.01) diff --git a/tests/test_monitor_extended.py b/tests/test_monitor_extended.py new file mode 100644 index 000000000000..e41810fb7d46 --- /dev/null +++ b/tests/test_monitor_extended.py @@ -0,0 +1,170 @@ +""" +Extended test suite for qcodes.monitor.monitor covering _get_metadata, +Monitor.show, Monitor TypeError on invalid parameters, and +Monitor.update_all. +""" + +from __future__ import annotations + +from unittest.mock import PropertyMock, patch + +import pytest + +from qcodes.instrument_drivers.mock_instruments import ( + DummyChannelInstrument, + DummyInstrument, +) +from qcodes.monitor.monitor import Monitor, _get_metadata +from qcodes.parameters import Parameter + +# --------------------------------------------------------------------------- +# _get_metadata - pure-function tests (no Monitor instance needed) +# --------------------------------------------------------------------------- + + +class TestGetMetadata: + """Tests for the ``_get_metadata`` helper function.""" + + def test_basic_structure(self) -> None: + """Returned dict must contain 'ts' and 'parameters' keys.""" + param = Parameter("p1", set_cmd=None, get_cmd=None, initial_value=42) + result = _get_metadata(param) + assert "ts" in result + assert "parameters" in result + assert isinstance(result["ts"], float) + assert isinstance(result["parameters"], list) + + def test_parameter_with_timestamp(self) -> None: + """A parameter that has been set should report a non-None ts.""" + param = Parameter("voltage", set_cmd=None, get_cmd=None, unit="V") + param(3.14) + result = _get_metadata(param) + unbound = [ + g for g in result["parameters"] if g["instrument"] == "Unbound Parameter" + ] + assert len(unbound) == 1 + meta = unbound[0]["parameters"][0] + assert meta["value"] == str(3.14) + assert meta["ts"] is not None + assert isinstance(meta["ts"], float) + assert meta["unit"] == "V" + + def test_parameter_without_timestamp(self) -> None: + """When get_timestamp() returns None, meta['ts'] must be None (line 75).""" + param = Parameter("fresh", set_cmd=None, get_cmd=None, initial_value=0) + # Force the timestamp to None via the cache to exercise line 75 + with patch.object( + type(param.cache), "timestamp", new_callable=PropertyMock, return_value=None + ): + result = _get_metadata(param) + unbound = [ + g for g in result["parameters"] if g["instrument"] == "Unbound Parameter" + ] + assert len(unbound) == 1 + meta = unbound[0]["parameters"][0] + assert meta["ts"] is None + + def test_unbound_parameter_grouping(self) -> None: + """Parameters not attached to an instrument go under 'Unbound Parameter'.""" + p1 = Parameter("alpha", set_cmd=None, get_cmd=None, initial_value=1) + p2 = Parameter("beta", set_cmd=None, get_cmd=None, initial_value=2) + result = _get_metadata(p1, p2) + instruments = [g["instrument"] for g in result["parameters"]] + assert instruments == ["Unbound Parameter"] + assert len(result["parameters"][0]["parameters"]) == 2 + + def test_use_root_instrument_true(self) -> None: + """With use_root_instrument=True, channel params are grouped by root.""" + instr = DummyChannelInstrument("MetaRootTest") + try: + result = _get_metadata( + instr.A.temperature, + instr.B.temperature, + use_root_instrument=True, + ) + instruments = [g["instrument"] for g in result["parameters"]] + # Both should be grouped under the single root instrument + assert len(instruments) == 1 + assert len(result["parameters"][0]["parameters"]) == 2 + finally: + instr.close() + + def test_use_root_instrument_false(self) -> None: + """With use_root_instrument=False, channel params are grouped by channel.""" + instr = DummyChannelInstrument("MetaChanTest") + try: + result = _get_metadata( + instr.A.temperature, + instr.B.temperature, + use_root_instrument=False, + ) + instruments = [g["instrument"] for g in result["parameters"]] + # Each channel is a separate instrument grouping + assert len(instruments) == 2 + finally: + instr.close() + + def test_instrument_bound_parameter(self) -> None: + """Parameters attached to a DummyInstrument use the instrument name.""" + instr = DummyInstrument("MetaDummy", gates=["ch1"]) + try: + result = _get_metadata(instr.ch1) + instruments = [g["instrument"] for g in result["parameters"]] + assert str(instr) in instruments + finally: + instr.close() + + def test_mixed_bound_and_unbound(self) -> None: + """Bound and unbound parameters appear in separate groups.""" + instr = DummyInstrument("MixedDummy", gates=["g1"]) + free = Parameter("free_param", set_cmd=None, get_cmd=None, initial_value=0) + try: + result = _get_metadata(instr.g1, free) + instruments = sorted(g["instrument"] for g in result["parameters"]) + assert "Unbound Parameter" in instruments + assert str(instr) in instruments + finally: + instr.close() + + def test_label_used_as_name(self) -> None: + """meta['name'] should equal parameter.label (or .name if no label).""" + p_with_label = Parameter( + "x", label="My Label", set_cmd=None, get_cmd=None, initial_value=0 + ) + p_without_label = Parameter("y", set_cmd=None, get_cmd=None, initial_value=0) + result = _get_metadata(p_with_label, p_without_label) + params = result["parameters"][0]["parameters"] + names = [p["name"] for p in params] + assert "My Label" in names + assert "y" in names + + def test_empty_parameters(self) -> None: + """Calling _get_metadata with no parameters returns an empty list.""" + result = _get_metadata() + assert result["parameters"] == [] + assert "ts" in result + + +# --------------------------------------------------------------------------- +# Monitor.show - static method (mock webbrowser to avoid opening a browser) +# --------------------------------------------------------------------------- + + +class TestMonitorShow: + def test_show_opens_browser(self) -> None: + """Monitor.show() should call webbrowser.open with correct URL.""" + with patch("qcodes.monitor.monitor.webbrowser.open") as mock_open: + Monitor.show() + mock_open.assert_called_once_with("http://localhost:3000") + + +# --------------------------------------------------------------------------- +# Monitor.__init__ - TypeError for non-Parameter arguments (line 162) +# --------------------------------------------------------------------------- + + +class TestMonitorTypeError: + def test_non_parameter_raises_type_error(self) -> None: + """Passing a non-Parameter object must raise TypeError.""" + with pytest.raises(TypeError, match="We can only monitor"): + Monitor("not_a_parameter") # type: ignore[arg-type] diff --git a/tests/validators/test_validators_extended.py b/tests/validators/test_validators_extended.py new file mode 100644 index 000000000000..c772cac9efc5 --- /dev/null +++ b/tests/validators/test_validators_extended.py @@ -0,0 +1,716 @@ +""" +Extended tests for the validators module to improve coverage. + +Covers edge cases and helper functions not fully exercised by existing tests: +validate_all, range_str, Validator base class, Nothing, Bool, Strings, Enum, +OnOff, Multiples, PermissiveMultiples, MultiType, MultiTypeOr, MultiTypeAnd, +Arrays, Lists, Sequence, Callable, Dict. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from qcodes.validators import ( + Anything, + Arrays, + Bool, + Dict, + Enum, + Ints, + Lists, + Multiples, + MultiType, + MultiTypeAnd, + MultiTypeOr, + Nothing, + Numbers, + OnOff, + PermissiveMultiples, + Sequence, + Strings, + Validator, +) +from qcodes.validators import Callable as CallableValidator +from qcodes.validators.validators import range_str, validate_all + + +# --------------------------------------------------------------------------- +# validate_all +# --------------------------------------------------------------------------- +class TestValidateAll: + def test_multiple_valid(self) -> None: + validate_all( + (Numbers(), 1), + (Strings(), "hello"), + (Ints(), 42), + ) + + def test_invalid_value_raises(self) -> None: + with pytest.raises((TypeError, ValueError)): + validate_all( + (Numbers(), 1), + (Strings(), 123), # not a string + ) + + def test_context_appears_in_error(self) -> None: + with pytest.raises((TypeError, ValueError), match="my context"): + validate_all( + (Strings(), 999), + context="my context", + ) + + def test_context_with_argument_index(self) -> None: + with pytest.raises((TypeError, ValueError), match="argument 0"): + validate_all( + (Ints(), "not_int"), + context="ctx", + ) + + def test_empty_args(self) -> None: + validate_all() # no validators -> no error + + +# --------------------------------------------------------------------------- +# range_str +# --------------------------------------------------------------------------- +class TestRangeStr: + def test_both_set_equal(self) -> None: + assert range_str(5, 5, "x") == " x=5" + + def test_both_set_different(self) -> None: + assert range_str(1, 10, "v") == " 1<=v<=10" + + def test_only_max(self) -> None: + assert range_str(None, 10, "v") == " v<=10" + + def test_only_min(self) -> None: + assert range_str(5, None, "v") == " v>=5" + + def test_neither(self) -> None: + assert range_str(None, None, "v") == "" + + +# --------------------------------------------------------------------------- +# Validator base class +# --------------------------------------------------------------------------- +class TestValidatorBase: + def test_validate_raises_not_implemented(self) -> None: + v = Validator() + with pytest.raises(NotImplementedError): + v.validate(42) + + def test_valid_values_empty_by_default(self) -> None: + v = Validator() + assert v.valid_values == () + + def test_is_numeric_false_by_default(self) -> None: + assert Validator.is_numeric is False + + +# --------------------------------------------------------------------------- +# Nothing +# --------------------------------------------------------------------------- +class TestNothing: + def test_validate_raises_runtime_error(self) -> None: + n = Nothing("disabled") + with pytest.raises(RuntimeError, match="disabled"): + n.validate(42) + + def test_validate_includes_context(self) -> None: + n = Nothing("broken") + with pytest.raises(RuntimeError, match="my_ctx"): + n.validate(0, context="my_ctx") + + def test_repr(self) -> None: + n = Nothing("some reason") + assert repr(n) == "" + + def test_reason_getter(self) -> None: + n = Nothing("test") + assert n.reason == "test" + + def test_reason_setter(self) -> None: + n = Nothing("old") + n.reason = "new" + assert n.reason == "new" + + def test_empty_reason_defaults(self) -> None: + n = Nothing("") + assert n.reason == "Nothing Validator" + + +# --------------------------------------------------------------------------- +# Bool +# --------------------------------------------------------------------------- +class TestBoolExtended: + def test_np_bool_accepted(self) -> None: + b = Bool() + b.validate(np.bool_(True)) + b.validate(np.bool_(False)) + + def test_non_bool_raises(self) -> None: + b = Bool() + with pytest.raises(TypeError, match="not Boolean"): + b.validate(1) + + def test_repr(self) -> None: + assert repr(Bool()) == "" + + def test_valid_values(self) -> None: + assert Bool().valid_values == (True, False) + + +# --------------------------------------------------------------------------- +# Strings +# --------------------------------------------------------------------------- +class TestStringsExtended: + def test_min_length_boundary(self) -> None: + s = Strings(min_length=3) + s.validate("abc") # exactly min_length + with pytest.raises(ValueError): + s.validate("ab") + + def test_max_length_boundary(self) -> None: + s = Strings(max_length=5) + s.validate("abcde") # exactly max_length + with pytest.raises(ValueError): + s.validate("abcdef") + + def test_invalid_min_length_type(self) -> None: + with pytest.raises( + TypeError, match="min_length must be a non-negative integer" + ): + Strings(min_length=-1) + + def test_invalid_min_length_float(self) -> None: + with pytest.raises( + TypeError, match="min_length must be a non-negative integer" + ): + Strings(min_length=1.5) # type: ignore[arg-type] + + def test_invalid_max_length(self) -> None: + with pytest.raises(TypeError, match="max_length must be a positive integer"): + Strings(max_length=0) + + def test_max_less_than_min(self) -> None: + with pytest.raises(TypeError, match="max_length must be a positive integer"): + Strings(min_length=5, max_length=3) + + def test_non_string_raises(self) -> None: + s = Strings() + with pytest.raises(TypeError, match="not a string"): + s.validate(42) + + def test_repr_with_constraints(self) -> None: + s = Strings(min_length=2, max_length=10) + r = repr(s) + assert "len" in r + assert "Strings" in r + + def test_repr_without_constraints(self) -> None: + s = Strings() + assert repr(s) == "" + + def test_properties(self) -> None: + s = Strings(min_length=3, max_length=50) + assert s.min_length == 3 + assert s.max_length == 50 + + def test_valid_values_with_min_length(self) -> None: + s = Strings(min_length=3) + # valid_values should be a string of length min_length + assert len(s.valid_values[0]) == 3 + + +# --------------------------------------------------------------------------- +# Enum +# --------------------------------------------------------------------------- +class TestEnumExtended: + def test_unhashable_raises_type_error(self) -> None: + e = Enum("a", "b") + with pytest.raises(TypeError): + e.validate([1, 2]) # list is unhashable + + def test_unhashable_error_includes_context(self) -> None: + e = Enum("a", "b") + with pytest.raises(TypeError, match="test_ctx"): + e.validate([1, 2], context="test_ctx") + + def test_values_returns_copy(self) -> None: + e = Enum("x", "y") + vals = e.values + vals.add("z") + assert "z" not in e.values + + def test_repr_format(self) -> None: + e = Enum("a") + r = repr(e) + assert r.startswith("") + + def test_no_values_raises(self) -> None: + with pytest.raises(TypeError, match="at least one value"): + Enum() + + +# --------------------------------------------------------------------------- +# OnOff +# --------------------------------------------------------------------------- +class TestOnOff: + def test_on_valid(self) -> None: + OnOff().validate("on") + + def test_off_valid(self) -> None: + OnOff().validate("off") + + def test_other_string_rejected(self) -> None: + with pytest.raises(ValueError): + OnOff().validate("yes") + + def test_non_string_rejected(self) -> None: + with pytest.raises((TypeError, ValueError)): + OnOff().validate(1) + + def test_valid_values(self) -> None: + assert set(OnOff().valid_values) == {"on", "off"} + + +# --------------------------------------------------------------------------- +# Multiples +# --------------------------------------------------------------------------- +class TestMultiplesExtended: + def test_zero_divisor_raises(self) -> None: + with pytest.raises(TypeError, match="positive integer"): + Multiples(divisor=0) + + def test_negative_divisor_raises(self) -> None: + with pytest.raises(TypeError, match="positive integer"): + Multiples(divisor=-3) + + def test_float_divisor_raises(self) -> None: + with pytest.raises(TypeError, match="positive integer"): + Multiples(divisor=2.5) # type: ignore[arg-type] + + def test_non_multiple_raises(self) -> None: + m = Multiples(divisor=3) + with pytest.raises(ValueError, match="not a multiple"): + m.validate(7) + + def test_valid_multiple(self) -> None: + m = Multiples(divisor=5) + m.validate(15) + m.validate(0) + + def test_repr_includes_divisor(self) -> None: + m = Multiples(divisor=4) + assert "Multiples of 4" in repr(m) + + def test_divisor_property(self) -> None: + m = Multiples(divisor=7) + assert m.divisor == 7 + + +# --------------------------------------------------------------------------- +# PermissiveMultiples +# --------------------------------------------------------------------------- +class TestPermissiveMultiplesExtended: + def test_int_divisor_int_value(self) -> None: + """With int divisor and int value, uses Multiples path.""" + pm = PermissiveMultiples(divisor=3) + pm.validate(9) + pm.validate(-6) + + def test_float_divisor_float_value(self) -> None: + pm = PermissiveMultiples(divisor=0.1) + pm.validate(0.3) + + def test_almost_multiple_within_precision(self) -> None: + pm = PermissiveMultiples(divisor=0.1, precision=1e-6) + pm.validate(0.30000000001) # within precision + + def test_not_multiple_raises(self) -> None: + pm = PermissiveMultiples(divisor=0.1, precision=1e-9) + with pytest.raises(ValueError, match="not a multiple"): + pm.validate(0.15) + + def test_zero_divisor_raises(self) -> None: + with pytest.raises(ValueError, match="zero"): + PermissiveMultiples(divisor=0) + + def test_zero_value_always_passes(self) -> None: + pm = PermissiveMultiples(divisor=7) + pm.validate(0) + + def test_repr(self) -> None: + pm = PermissiveMultiples(divisor=5) + r = repr(pm) + assert "PermissiveMultiples" in r + assert "5" in r + + def test_divisor_property(self) -> None: + pm = PermissiveMultiples(divisor=3) + assert pm.divisor == 3 + + def test_precision_property(self) -> None: + pm = PermissiveMultiples(divisor=2, precision=1e-6) + assert pm.precision == 1e-6 + + def test_precision_setter(self) -> None: + pm = PermissiveMultiples(divisor=2) + pm.precision = 1e-3 + assert pm.precision == 1e-3 + + def test_divisor_setter(self) -> None: + pm = PermissiveMultiples(divisor=2) + pm.divisor = 5 + assert pm.divisor == 5 + + def test_divisor_setter_zero_raises(self) -> None: + pm = PermissiveMultiples(divisor=2) + with pytest.raises(ValueError, match="zero"): + pm.divisor = 0 + + def test_is_numeric(self) -> None: + assert PermissiveMultiples(divisor=2).is_numeric is True + + def test_int_divisor_non_multiple_int(self) -> None: + """Int divisor + int value that is not a multiple -> error from Multiples.""" + pm = PermissiveMultiples(divisor=3) + with pytest.raises(ValueError): + pm.validate(7) + + def test_float_divisor_sets_mulval_none(self) -> None: + """Float divisor should not create a Multiples sub-validator.""" + pm = PermissiveMultiples(divisor=0.5) + assert pm._mulval is None + + +# --------------------------------------------------------------------------- +# MultiType +# --------------------------------------------------------------------------- +class TestMultiTypeExtended: + def test_or_all_pass(self) -> None: + mt = MultiType(Numbers(), Strings()) + mt.validate(42) + mt.validate("hello") + + def test_or_none_pass(self) -> None: + mt = MultiType(Numbers(), Strings()) + with pytest.raises(ValueError): + mt.validate([1, 2]) + + def test_and_all_pass(self) -> None: + mt = MultiType(Numbers(min_value=0), Numbers(max_value=10), combiner="AND") + mt.validate(5) + + def test_and_one_fails(self) -> None: + mt = MultiType(Numbers(min_value=0), Numbers(max_value=10), combiner="AND") + with pytest.raises(ValueError): + mt.validate(20) + + def test_no_validators_raises(self) -> None: + with pytest.raises(TypeError, match="at least one Validator"): + MultiType() + + def test_non_validator_arg_raises(self) -> None: + with pytest.raises(TypeError, match="each argument must be a Validator"): + MultiType("not_a_validator") # type: ignore[arg-type] + + def test_invalid_combiner_raises(self) -> None: + with pytest.raises(TypeError, match="combiner"): + MultiType(Numbers(), combiner="XOR") # type: ignore[arg-type] + + def test_is_numeric_with_numeric_sub(self) -> None: + mt = MultiType(Numbers(), Strings()) + assert mt.is_numeric is True + + def test_is_numeric_without_numeric_sub(self) -> None: + mt = MultiType(Strings(), Bool()) + assert not mt.is_numeric + + def test_repr_format(self) -> None: + mt = MultiType(Numbers(), Strings()) + r = repr(mt) + assert r.startswith("") + + def test_combiner_property(self) -> None: + mt = MultiType(Numbers(), combiner="AND") + assert mt.combiner == "AND" + + def test_validators_property(self) -> None: + n = Numbers() + s = Strings() + mt = MultiType(n, s) + assert mt.validators == (n, s) + + def test_valid_values_combined(self) -> None: + mt = MultiType(Numbers(min_value=0, max_value=10), Ints(min_value=0)) + assert len(mt.valid_values) > 0 + + +# --------------------------------------------------------------------------- +# MultiTypeOr / MultiTypeAnd +# --------------------------------------------------------------------------- +class TestMultiTypeOrAnd: + def test_or_repr(self) -> None: + mt = MultiTypeOr(Numbers(), Strings()) + r = repr(mt) + assert r.startswith("") + + def test_and_repr(self) -> None: + mt = MultiTypeAnd(Numbers(min_value=0), Numbers(max_value=100)) + r = repr(mt) + assert r.startswith("") + + def test_and_valid_values_empty(self) -> None: + mt = MultiTypeAnd(Numbers(), Ints()) + assert mt.valid_values == () + + def test_or_validates_first_match(self) -> None: + mt = MultiTypeOr(Numbers(), Strings()) + mt.validate(42) + mt.validate("hello") + + def test_and_validates_all(self) -> None: + mt = MultiTypeAnd(Numbers(min_value=0), Numbers(max_value=10)) + mt.validate(5) + with pytest.raises(ValueError): + mt.validate(-1) + + +# --------------------------------------------------------------------------- +# Arrays +# --------------------------------------------------------------------------- +class TestArraysExtended: + def test_unsupported_valid_type_raises(self) -> None: + with pytest.raises(TypeError, match="not supported"): + Arrays(valid_types=[str]) # type: ignore[list-item] + + def test_complex_with_min_raises(self) -> None: + with pytest.raises(TypeError, match="complex"): + Arrays(valid_types=[np.complexfloating], min_value=0) + + def test_complex_with_max_raises(self) -> None: + with pytest.raises(TypeError, match="complex"): + Arrays(valid_types=[np.complexfloating], max_value=10) + + def test_min_greater_than_max_raises(self) -> None: + with pytest.raises(TypeError, match="max_value must be bigger"): + Arrays(min_value=10, max_value=1) + + def test_callable_shape(self) -> None: + a = Arrays(shape=[lambda: 3, lambda: 2]) + arr = np.ones((3, 2)) + a.validate(arr) + + def test_shape_mismatch_raises(self) -> None: + a = Arrays(shape=[2, 3]) + with pytest.raises(ValueError, match="shape"): + a.validate(np.ones((4, 5))) + + def test_min_max_value_properties(self) -> None: + a = Arrays(min_value=0, max_value=10) + assert a.min_value == 0.0 + assert a.max_value == 10.0 + + def test_min_max_value_none(self) -> None: + a = Arrays() + assert a.min_value is None + assert a.max_value is None + + def test_validate_non_array_raises(self) -> None: + a = Arrays() + with pytest.raises(TypeError, match="not a numpy array"): + a.validate([1, 2, 3]) # type: ignore[arg-type] + + def test_validate_wrong_dtype_raises(self) -> None: + a = Arrays(valid_types=[np.integer]) + with pytest.raises(TypeError, match="is not any of"): + a.validate(np.array([1.0, 2.0])) + + def test_max_value_exceeded(self) -> None: + a = Arrays(min_value=0, max_value=5) + with pytest.raises(ValueError, match="all values must be between"): + a.validate(np.array([1, 2, 10])) + + def test_min_value_violated(self) -> None: + a = Arrays(min_value=0, max_value=10) + with pytest.raises(ValueError, match="all values must be between"): + a.validate(np.array([-5, 2, 3])) + + def test_repr(self) -> None: + a = Arrays(min_value=0, max_value=10, shape=[2, 3]) + r = repr(a) + assert "Arrays" in r + assert "shape" in r + + def test_shape_unevaluated_with_callable(self) -> None: + fn = lambda: 5 # noqa: E731 + a = Arrays(shape=[fn, 3]) + raw = a.shape_unevaluated + assert raw is not None + assert raw[0] is fn + assert raw[1] == 3 + + def test_is_numeric(self) -> None: + assert Arrays().is_numeric is True + + +# --------------------------------------------------------------------------- +# Lists +# --------------------------------------------------------------------------- +class TestListsExtended: + def test_non_list_raises(self) -> None: + lst = Lists() + with pytest.raises(TypeError, match="not a list"): + lst.validate((1, 2)) # type: ignore[arg-type] + + def test_with_element_validator(self) -> None: + lst = Lists(elt_validator=Ints()) + lst.validate([1, 2, 3]) + with pytest.raises(TypeError): + lst.validate(["a", "b"]) + + def test_repr_format(self) -> None: + lst = Lists(elt_validator=Ints()) + r = repr(lst) + assert "Lists" in r + assert "Ints" in r + + def test_elt_validator_property(self) -> None: + iv = Ints() + lst = Lists(elt_validator=iv) + assert lst.elt_validator is iv + + def test_default_elt_validator_is_anything(self) -> None: + lst = Lists() + assert isinstance(lst.elt_validator, Anything) + + def test_empty_list_valid(self) -> None: + lst = Lists(elt_validator=Ints()) + lst.validate([]) + + +# --------------------------------------------------------------------------- +# Sequence +# --------------------------------------------------------------------------- +class TestSequenceExtended: + def test_wrong_length_raises(self) -> None: + s = Sequence(length=3) + with pytest.raises(ValueError, match="length"): + s.validate([1, 2]) + + def test_correct_length_passes(self) -> None: + s = Sequence(length=2) + s.validate([1, 2]) + + def test_unsorted_when_require_sorted_raises(self) -> None: + s = Sequence(require_sorted=True) + with pytest.raises(ValueError, match="sorted"): + s.validate([3, 1, 2]) + + def test_sorted_passes(self) -> None: + s = Sequence(require_sorted=True) + s.validate([1, 2, 3]) + + def test_repr(self) -> None: + s = Sequence(length=5, require_sorted=True) + r = repr(s) + assert "Sequence" in r + assert "len: 5" in r + assert "sorted: True" in r + + def test_properties(self) -> None: + iv = Ints() + s = Sequence(elt_validator=iv, length=4, require_sorted=True) + assert s.elt_validator is iv + assert s.length == 4 + assert s.require_sorted is True + + def test_non_sequence_raises(self) -> None: + s = Sequence() + with pytest.raises(TypeError, match="not a sequence"): + s.validate(42) # type: ignore[arg-type] + + def test_tuple_accepted(self) -> None: + s = Sequence() + s.validate((1, 2, 3)) + + def test_with_element_validator(self) -> None: + s = Sequence(elt_validator=Ints()) + s.validate([1, 2, 3]) + with pytest.raises(TypeError): + s.validate(["a", "b"]) + + +# --------------------------------------------------------------------------- +# Callable +# --------------------------------------------------------------------------- +class TestCallableExtended: + def test_callable_passes(self) -> None: + c = CallableValidator() + c.validate(lambda: None) + c.validate(len) + + def test_non_callable_raises(self) -> None: + c = CallableValidator() + with pytest.raises(TypeError, match="not a callable"): + c.validate(42) + + def test_repr(self) -> None: + assert repr(CallableValidator()) == "" + + def test_valid_values_is_callable(self) -> None: + c = CallableValidator() + assert callable(c.valid_values[0]) + + +# --------------------------------------------------------------------------- +# Dict +# --------------------------------------------------------------------------- +class TestDictExtended: + def test_non_dict_raises(self) -> None: + d = Dict() + with pytest.raises(TypeError, match="not a dictionary"): + d.validate([1, 2]) # type: ignore[arg-type] + + def test_forbidden_key_raises_syntax_error(self) -> None: + d = Dict(allowed_keys=["a", "b"]) + with pytest.raises(SyntaxError, match="not in allowed keys"): + d.validate({"a": 1, "c": 2}) + + def test_allowed_keys_property(self) -> None: + d = Dict(allowed_keys=["x", "y"]) + assert d.allowed_keys == ["x", "y"] + + def test_allowed_keys_setter(self) -> None: + d = Dict() + assert d.allowed_keys is None + d.allowed_keys = ["a"] + assert d.allowed_keys == ["a"] + + def test_repr_with_keys(self) -> None: + d = Dict(allowed_keys=["a", "b"]) + r = repr(d) + assert "Dict" in r + assert "a" in r + + def test_repr_without_keys(self) -> None: + assert repr(Dict()) == "" + + def test_valid_dict_with_allowed_keys(self) -> None: + d = Dict(allowed_keys=["a", "b"]) + d.validate({"a": 1, "b": 2}) + + def test_any_dict_without_keys(self) -> None: + d = Dict() + d.validate({"anything": "goes", 42: True}) + + def test_valid_values(self) -> None: + d = Dict() + assert d.valid_values == ({0: 1},) From 9e2bab9cd37e93863ff969dada75eda0d3b655b7 Mon Sep 17 00:00:00 2001 From: "Jens H. Nielsen" Date: Mon, 23 Mar 2026 20:54:32 +0100 Subject: [PATCH 4/6] Fix typechecking of tests --- tests/test_config_extended.py | 1 + tests/validators/test_validators_extended.py | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_config_extended.py b/tests/test_config_extended.py index 6bdf6f5e6cad..3855aba63d51 100644 --- a/tests/test_config_extended.py +++ b/tests/test_config_extended.py @@ -271,6 +271,7 @@ def test_save_and_reload(self, tmp_path) -> None: cfg.save_config(config_path) loaded = Config.load_config(config_path) assert isinstance(loaded, DotDict) + assert cfg.current_config is not None assert loaded["core"]["db_debug"] == cfg.current_config["core"]["db_debug"] diff --git a/tests/validators/test_validators_extended.py b/tests/validators/test_validators_extended.py index c772cac9efc5..387fc63a77ef 100644 --- a/tests/validators/test_validators_extended.py +++ b/tests/validators/test_validators_extended.py @@ -153,7 +153,7 @@ def test_np_bool_accepted(self) -> None: def test_non_bool_raises(self) -> None: b = Bool() with pytest.raises(TypeError, match="not Boolean"): - b.validate(1) + b.validate(1) # pyright: ignore[reportArgumentType] def test_repr(self) -> None: assert repr(Bool()) == "" @@ -201,7 +201,7 @@ def test_max_less_than_min(self) -> None: def test_non_string_raises(self) -> None: s = Strings() with pytest.raises(TypeError, match="not a string"): - s.validate(42) + s.validate(42) # pyright: ignore[reportArgumentType] def test_repr_with_constraints(self) -> None: s = Strings(min_length=2, max_length=10) @@ -231,12 +231,12 @@ class TestEnumExtended: def test_unhashable_raises_type_error(self) -> None: e = Enum("a", "b") with pytest.raises(TypeError): - e.validate([1, 2]) # list is unhashable + e.validate([1, 2]) # pyright: ignore[reportArgumentType] # list is unhashable def test_unhashable_error_includes_context(self) -> None: e = Enum("a", "b") with pytest.raises(TypeError, match="test_ctx"): - e.validate([1, 2], context="test_ctx") + e.validate([1, 2], context="test_ctx") # pyright: ignore[reportArgumentType] # list is unhashable def test_values_returns_copy(self) -> None: e = Enum("x", "y") @@ -271,7 +271,7 @@ def test_other_string_rejected(self) -> None: def test_non_string_rejected(self) -> None: with pytest.raises((TypeError, ValueError)): - OnOff().validate(1) + OnOff().validate(1) # pyright: ignore[reportArgumentType] def test_valid_values(self) -> None: assert set(OnOff().valid_values) == {"on", "off"} @@ -574,7 +574,7 @@ def test_with_element_validator(self) -> None: lst = Lists(elt_validator=Ints()) lst.validate([1, 2, 3]) with pytest.raises(TypeError): - lst.validate(["a", "b"]) + lst.validate(["a", "b"]) # pyright: ignore[reportArgumentType] def test_repr_format(self) -> None: lst = Lists(elt_validator=Ints()) @@ -660,7 +660,7 @@ def test_callable_passes(self) -> None: def test_non_callable_raises(self) -> None: c = CallableValidator() with pytest.raises(TypeError, match="not a callable"): - c.validate(42) + c.validate(42) # pyright: ignore[reportArgumentType] def test_repr(self) -> None: assert repr(CallableValidator()) == "" From 851e7b5baadb17ec902c052c56fc81119428d080 Mon Sep 17 00:00:00 2001 From: "Jens H. Nielsen" Date: Tue, 24 Mar 2026 09:49:00 +0100 Subject: [PATCH 5/6] Add tests for parameters, instrument core, channels, and station Add 217 new tests across 7 test files targeting the largest remaining coverage gaps in non-driver modules. Parameters tests (71 tests): - tests/parameter/test_command_extended.py (17 tests): all call_by_str and call_cmd method combinations, arg count validation, NoCommandError - tests/parameter/test_function_extended.py (16 tests): validation, name properties, parsers, callable/string cmd, get_attrs - tests/parameter/test_grouped_parameter_extended.py (22 tests): DelegateGroup set/get with custom setter/getter, source_parameters, GroupedParameter repr and properties - tests/parameter/test_combined_parameter_extended.py (16 tests): combine(), aggregator, iter/len, snapshot_base, units deprecation Instrument core tests (146 tests): - tests/test_instrument_extended.py: write_raw/ask_raw NotImplementedError, close_all, find_instrument, exist/is_valid, repr, label, add_function, add_submodule, get_component, print_readable_snapshot, invalidate_cache, parent/ancestors/root_instrument, find_or_create_instrument - tests/test_channel_extended.py: InstrumentModule proxy methods, ChannelTuple operations (reversed/contains/add/index/count/get_by_name), ChannelList mutations and lock behavior, ChannelTupleValidator - tests/test_station_extended.py: snapshot_base, add/remove/get component, close_all_registered_instruments, Station.default handling Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../test_combined_parameter_extended.py | 133 ++++ tests/parameter/test_command_extended.py | 246 ++++++ tests/parameter/test_function_extended.py | 183 +++++ .../test_grouped_parameter_extended.py | 221 ++++++ tests/test_channel_extended.py | 742 ++++++++++++++++++ tests/test_instrument_extended.py | 673 ++++++++++++++++ tests/test_station_extended.py | 285 +++++++ 7 files changed, 2483 insertions(+) create mode 100644 tests/parameter/test_combined_parameter_extended.py create mode 100644 tests/parameter/test_command_extended.py create mode 100644 tests/parameter/test_function_extended.py create mode 100644 tests/parameter/test_grouped_parameter_extended.py create mode 100644 tests/test_channel_extended.py create mode 100644 tests/test_instrument_extended.py create mode 100644 tests/test_station_extended.py diff --git a/tests/parameter/test_combined_parameter_extended.py b/tests/parameter/test_combined_parameter_extended.py new file mode 100644 index 000000000000..e93ba24760be --- /dev/null +++ b/tests/parameter/test_combined_parameter_extended.py @@ -0,0 +1,133 @@ +"""Extended tests for qcodes.parameters.combined_parameter module.""" + +from __future__ import annotations + +import logging + +import numpy as np +import pytest + +from qcodes.parameters import Parameter +from qcodes.parameters.combined_parameter import CombinedParameter, combine + + +@pytest.fixture() +def two_params() -> list[Parameter]: + return [ + Parameter("x", set_cmd=None, get_cmd=None), + Parameter("y", set_cmd=None, get_cmd=None), + ] + + +class TestCombineFunction: + def test_combine_creates_combined_parameter( + self, two_params: list[Parameter] + ) -> None: + """combine() convenience function returns CombinedParameter.""" + cp = combine(*two_params, name="xy") + assert isinstance(cp, CombinedParameter) + assert cp.dimensionality == 2 + + def test_combine_with_label_and_unit(self, two_params: list[Parameter]) -> None: + """combine() passes label and unit through.""" + cp = combine(*two_params, name="xy", label="X and Y", unit="V") + assert cp.parameter.label == "X and Y" + assert cp.parameter.unit == "V" + + def test_combine_with_aggregator(self, two_params: list[Parameter]) -> None: + """combine() passes aggregator through.""" + cp = combine(*two_params, name="xy", aggregator=sum) + assert hasattr(cp, "aggregate") + + +class TestCombinedParameter: + def test_set_calls_parameter_sets(self, two_params: list[Parameter]) -> None: + """set() sets each parameter in order.""" + cp = CombinedParameter(two_params, name="xy") + swept = cp.sweep(np.array([[1.0, 2.0], [3.0, 4.0]])) + swept.set(0) + assert two_params[0]() == 1.0 + assert two_params[1]() == 2.0 + swept.set(1) + assert two_params[0]() == 3.0 + assert two_params[1]() == 4.0 + + def test_aggregate_with_aggregator(self, two_params: list[Parameter]) -> None: + """_aggregate calls the aggregator function.""" + + def my_agg(*vals: int) -> int: + return sum(vals) + + cp = CombinedParameter(two_params, name="xy", aggregator=my_agg) + result = cp._aggregate(1, 2, 3) + assert result == 6 + + def test_aggregate_without_aggregator(self, two_params: list[Parameter]) -> None: + """Without aggregator, _aggregate is not set as 'aggregate' attr.""" + cp = CombinedParameter(two_params, name="xy") + assert not hasattr(cp, "aggregate") + + def test_iter(self, two_params: list[Parameter]) -> None: + """__iter__ iterates over setpoint indices.""" + cp = CombinedParameter(two_params, name="xy") + swept = cp.sweep(np.array([[1, 2], [3, 4], [5, 6]])) + indices = list(swept) + assert indices == [0, 1, 2] + + def test_len(self, two_params: list[Parameter]) -> None: + """__len__ returns number of setpoints.""" + cp = CombinedParameter(two_params, name="xy") + swept = cp.sweep(np.array([[1, 2], [3, 4]])) + assert len(swept) == 2 + + def test_len_no_setpoints(self, two_params: list[Parameter]) -> None: + """__len__ returns 0 when no setpoints.""" + cp = CombinedParameter(two_params, name="xy") + assert len(cp) == 0 + + def test_snapshot_base(self, two_params: list[Parameter]) -> None: + """snapshot_base returns dict with expected keys.""" + cp = CombinedParameter(two_params, name="xy", label="combined", unit="mV") + snap = cp.snapshot_base() + assert snap["label"] == "combined" + assert snap["unit"] == "mV" + assert snap["full_name"] == "xy" + assert "__class__" in snap + assert "aggregator" in snap + + def test_snapshot_base_with_aggregator(self, two_params: list[Parameter]) -> None: + """snapshot_base includes aggregator repr.""" + cp = CombinedParameter(two_params, name="xy", aggregator=sum) + snap = cp.snapshot_base() + assert "sum" in snap["aggregator"] + + def test_units_deprecated( + self, two_params: list[Parameter], caplog: pytest.LogCaptureFixture + ) -> None: + """Passing units= triggers a deprecation warning log.""" + with caplog.at_level(logging.WARNING): + cp = CombinedParameter(two_params, name="xy", units="mV") + assert any("`units` is deprecated" in msg for msg in caplog.messages) + assert cp.parameter.unit == "mV" + + def test_units_deprecated_unit_takes_precedence( + self, two_params: list[Parameter], caplog: pytest.LogCaptureFixture + ) -> None: + """When both unit and units are given, unit takes precedence.""" + with caplog.at_level(logging.WARNING): + cp = CombinedParameter(two_params, name="xy", unit="V", units="mV") + assert cp.parameter.unit == "V" + + def test_invalid_name_raises(self, two_params: list[Parameter]) -> None: + """Invalid parameter name raises ValueError.""" + with pytest.raises(ValueError, match="valid identifier"): + CombinedParameter(two_params, name="invalid name") + + def test_sweep_multiple_arrays(self, two_params: list[Parameter]) -> None: + """sweep() with multiple 1D arrays.""" + cp = CombinedParameter(two_params, name="xy") + swept = cp.sweep(np.array([1, 2, 3]), np.array([4, 5, 6])) + assert len(swept) == 3 + swept.set(0) + assert two_params[0]() == 1 + assert two_params[1]() == 4 diff --git a/tests/parameter/test_command_extended.py b/tests/parameter/test_command_extended.py new file mode 100644 index 000000000000..2d1edb448bbe --- /dev/null +++ b/tests/parameter/test_command_extended.py @@ -0,0 +1,246 @@ +"""Extended tests for qcodes.parameters.command.Command covering all call_by_* methods.""" + +from __future__ import annotations + +import pytest + +from qcodes.parameters.command import Command, NoCommandError + + +def test_call_by_str_no_parsers() -> None: + """String cmd + exec_str, no parsers -> call_by_str.""" + results: list[str] = [] + + def exec_fn(cmd_str: str) -> str: + results.append(cmd_str) + return cmd_str + + cmd = Command(arg_count=1, cmd="SET {}", exec_str=exec_fn) + result = cmd(42) + assert result == "SET 42" + assert results == ["SET 42"] + + +def test_call_by_str_zero_args() -> None: + """String cmd with 0 args -> call_by_str with no formatting.""" + + def exec_fn(cmd_str: str) -> str: + return cmd_str + + cmd = Command(arg_count=0, cmd="*RST", exec_str=exec_fn) + assert cmd() == "*RST" + + +def test_call_by_str_parsed_out() -> None: + """String cmd + output_parser -> call_by_str_parsed_out.""" + + def exec_fn(cmd_str: str) -> str: + return cmd_str + + cmd = Command( + arg_count=1, + cmd="READ {}", + exec_str=exec_fn, + output_parser=lambda x: x.upper(), + ) + result = cmd("ch1") + assert result == "READ CH1" + + +def test_call_by_str_parsed_in() -> None: + """String cmd + single input_parser -> call_by_str_parsed_in.""" + + def exec_fn(cmd_str: str) -> str: + return cmd_str + + cmd = Command( + arg_count=1, + cmd="SET {}", + exec_str=exec_fn, + input_parser=lambda x: x * 2, + ) + result = cmd(5) + assert result == "SET 10" + + +def test_call_by_str_parsed_in_out() -> None: + """String cmd + input_parser + output_parser -> call_by_str_parsed_in_out.""" + + def exec_fn(cmd_str: str) -> str: + return cmd_str + + cmd = Command( + arg_count=1, + cmd="MEAS {}", + exec_str=exec_fn, + input_parser=lambda x: x + 1, + output_parser=lambda x: f"result:{x}", + ) + result = cmd(9) + assert result == "result:MEAS 10" + + +def test_call_by_str_parsed_in2() -> None: + """String cmd + multi-arg input_parser -> call_by_str_parsed_in2.""" + + def exec_fn(cmd_str: str) -> str: + return cmd_str + + def multi_parser(a: int, b: int) -> tuple[int, int]: + return (a * 10, b * 10) + + cmd = Command( + arg_count=2, + cmd="SET {} {}", + exec_str=exec_fn, + input_parser=multi_parser, + ) + result = cmd(3, 4) + assert result == "SET 30 40" + + +def test_call_by_str_parsed_in2_out() -> None: + """String cmd + multi-arg input_parser + output_parser.""" + + def exec_fn(cmd_str: str) -> str: + return cmd_str + + def multi_parser(a: int, b: int) -> tuple[int, int]: + return (a + 1, b + 1) + + cmd = Command( + arg_count=2, + cmd="CMD {} {}", + exec_str=exec_fn, + input_parser=multi_parser, + output_parser=lambda x: x.replace("CMD", "OUT"), + ) + result = cmd(0, 1) + assert result == "OUT 1 2" + + +def test_call_cmd_no_parsers() -> None: + """Callable cmd, no parsers -> direct call.""" + + def my_func(a: int) -> int: + return a * 3 + + cmd = Command(arg_count=1, cmd=my_func) + assert cmd(7) == 21 + + +def test_call_cmd_parsed_out() -> None: + """Callable cmd + output_parser -> call_cmd_parsed_out.""" + + def my_func(a: int) -> int: + return a + 1 + + cmd = Command( + arg_count=1, + cmd=my_func, + output_parser=lambda x: x * 100, + ) + assert cmd(5) == 600 + + +def test_call_cmd_parsed_in() -> None: + """Callable cmd + single input_parser -> call_cmd_parsed_in.""" + + def my_func(a: int) -> int: + return a + + cmd = Command( + arg_count=1, + cmd=my_func, + input_parser=lambda x: x + 10, + ) + assert cmd(5) == 15 + + +def test_call_cmd_parsed_in_out() -> None: + """Callable cmd + input_parser + output_parser -> call_cmd_parsed_in_out.""" + + def my_func(a: int) -> int: + return a * 2 + + cmd = Command( + arg_count=1, + cmd=my_func, + input_parser=lambda x: x + 1, + output_parser=lambda x: x + 100, + ) + # input_parser(3) = 4, my_func(4) = 8, output_parser(8) = 108 + assert cmd(3) == 108 + + +def test_call_cmd_parsed_in2() -> None: + """Callable cmd + multi-arg input_parser -> call_cmd_parsed_in2.""" + + def my_func(a: int, b: int) -> int: + return a + b + + def multi_parser(a: int, b: int) -> tuple[int, int]: + return (a * 10, b * 10) + + cmd = Command( + arg_count=2, + cmd=my_func, + input_parser=multi_parser, + ) + assert cmd(3, 4) == 70 + + +def test_call_cmd_parsed_in2_out() -> None: + """Callable cmd + multi-arg input_parser + output_parser.""" + + def my_func(a: int, b: int) -> int: + return a + b + + def multi_parser(a: int, b: int) -> tuple[int, int]: + return (a * 2, b * 3) + + cmd = Command( + arg_count=2, + cmd=my_func, + input_parser=multi_parser, + output_parser=lambda x: x * -1, + ) + # multi_parser(5, 10) = (10, 30), my_func(10, 30) = 40, output_parser(40) = -40 + assert cmd(5, 10) == -40 + + +def test_wrong_arg_count_raises_type_error() -> None: + """Calling with wrong number of args raises TypeError.""" + + def my_func(a: int) -> int: + return a + + cmd = Command(arg_count=1, cmd=my_func) + + with pytest.raises(TypeError, match="command takes exactly 1 args"): + cmd() + + with pytest.raises(TypeError, match="command takes exactly 1 args"): + cmd(1, 2) + + +def test_no_command_error_when_no_cmd() -> None: + """NoCommandError raised when no cmd and no no_cmd_function.""" + with pytest.raises(NoCommandError, match="no ``cmd`` provided"): + Command(arg_count=0, cmd=None) + + +def test_no_cmd_with_no_cmd_function() -> None: + """no_cmd_function is used as fallback when cmd is None.""" + + def fallback() -> str: + return "fallback_called" + + cmd = Command(arg_count=0, cmd=None, no_cmd_function=fallback) + assert cmd() == "fallback_called" + + +def test_str_cmd_without_exec_str_raises() -> None: + """String cmd with no exec_str raises TypeError.""" + with pytest.raises(TypeError, match="exec_str cannot be None"): + Command(arg_count=0, cmd="*RST", exec_str=None) diff --git a/tests/parameter/test_function_extended.py b/tests/parameter/test_function_extended.py new file mode 100644 index 000000000000..f894e28acc96 --- /dev/null +++ b/tests/parameter/test_function_extended.py @@ -0,0 +1,183 @@ +"""Extended tests for qcodes.parameters.function.Function.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from qcodes.parameters import Function +from qcodes.validators import Numbers, Strings + + +def test_function_with_callable_cmd_no_args() -> None: + """Function with a callable cmd and no args.""" + call_count = 0 + + def my_cmd() -> str: + nonlocal call_count + call_count += 1 + return "ok" + + func = Function("reset", call_cmd=my_cmd) + result = func() + assert result == "ok" + assert call_count == 1 + + +def test_function_with_callable_cmd_and_args() -> None: + """Function with args validation, callable cmd.""" + + def my_cmd(x: float) -> float: + return x * 2 + + func = Function("double", call_cmd=my_cmd, args=[Numbers(0, 100)]) + assert func(5) == 10 + + +def test_function_validate_wrong_arg_count() -> None: + """validate() raises TypeError when wrong number of args.""" + func = Function("noop", call_cmd=lambda x: x, args=[Numbers()]) + + with pytest.raises(TypeError, match="called with 0 args but requires 1"): + func.validate() + + with pytest.raises(TypeError, match="called with 2 args but requires 1"): + func.validate(1, 2) + + +def test_function_validate_wrong_type() -> None: + """validate() raises error when arg fails validation.""" + func = Function("typed", call_cmd=lambda x: x, args=[Numbers(0, 10)]) + + with pytest.raises(Exception): + func.validate(100) + + +def test_function_validate_passes() -> None: + """validate() succeeds with valid args.""" + func = Function("typed", call_cmd=lambda x: x, args=[Numbers(0, 10)]) + func.validate(5) + + +def test_function_args_must_be_validators() -> None: + """_set_args raises TypeError for non-Validator objects.""" + with pytest.raises(TypeError, match="all args must be Validator objects"): + Function("bad", call_cmd=lambda x: x, args=["not_a_validator"]) # type: ignore[list-item] + + +def test_function_short_name() -> None: + """short_name returns the function name.""" + func = Function("my_func", call_cmd=lambda: None) + assert func.short_name == "my_func" + + +def test_function_name_parts_no_instrument() -> None: + """name_parts returns [name] when no instrument.""" + func = Function("my_func", call_cmd=lambda: None) + assert func.name_parts == ["my_func"] + + +def test_function_name_parts_with_instrument_like_object() -> None: + """name_parts uses instrument.name_parts if available.""" + mock_instr = MagicMock() + mock_instr.name_parts = ["instr", "sub"] + mock_instr.write = MagicMock() + mock_instr.ask = MagicMock() + + func = Function("my_func", instrument=mock_instr, call_cmd=lambda: None) + assert func.name_parts == ["instr", "sub", "my_func"] + + +def test_function_name_parts_instrument_no_name_parts() -> None: + """name_parts falls back to instrument.name when name_parts is empty.""" + mock_instr = MagicMock() + mock_instr.name_parts = [] + mock_instr.name = "fallback_instr" + mock_instr.write = MagicMock() + + func = Function("my_func", instrument=mock_instr, call_cmd=lambda: None) + assert func.name_parts == ["fallback_instr", "my_func"] + + +def test_function_full_name() -> None: + """full_name joins name_parts with underscore.""" + mock_instr = MagicMock() + mock_instr.name_parts = ["dev", "ch1"] + mock_instr.write = MagicMock() + + func = Function("read", instrument=mock_instr, call_cmd=lambda: None) + assert func.full_name == "dev_ch1_read" + + +def test_function_get_attrs() -> None: + """get_attrs returns the expected attribute list.""" + func = Function("my_func", call_cmd=lambda: None) + assert func.get_attrs() == ["__doc__", "_args", "_arg_count"] + + +def test_function_call_method() -> None: + """call() wraps __call__.""" + + def my_cmd(x: int) -> int: + return x + 1 + + func = Function("inc", call_cmd=my_cmd, args=[Numbers()]) + assert func.call(5) == 6 + + +def test_function_docstring() -> None: + """Custom docstring is set on function.""" + func = Function("my_func", call_cmd=lambda: None, docstring="Custom doc") + assert func.__doc__ == "Custom doc" + + +def test_function_with_arg_parser_and_return_parser() -> None: + """Function with arg_parser and return_parser via callable cmd.""" + # When using a callable cmd, parsers are not applied by Function itself + # (they go through Command). We use a string cmd to test parsers fully. + mock_instr = MagicMock() + mock_instr.ask = MagicMock(return_value="42") + mock_instr.write = MagicMock() + + func = Function( + "measure", + instrument=mock_instr, + call_cmd="MEAS {}", + args=[Numbers()], + arg_parser=int, + return_parser=int, + ) + result = func(3.14) + assert result == 42 + mock_instr.ask.assert_called_once() + + +def test_function_multiple_args_validation() -> None: + """Function with multiple args validates each.""" + + def my_cmd(x: Any, y: Any) -> str: + return f"{x},{y}" + + func = Function( + "dual", + call_cmd=my_cmd, + args=[Numbers(0, 10), Strings()], + ) + result = func(5, "hello") + assert result == "5,hello" + + with pytest.raises(Exception): + func(5, 123) + + +def test_function_instrument_property() -> None: + """Instrument property returns the bound instrument.""" + func = Function("my_func", call_cmd=lambda: None) + assert func.instrument is None + + mock_instr = MagicMock() + mock_instr.write = MagicMock() + func2 = Function("my_func2", instrument=mock_instr, call_cmd=lambda: None) + assert func2.instrument is mock_instr diff --git a/tests/parameter/test_grouped_parameter_extended.py b/tests/parameter/test_grouped_parameter_extended.py new file mode 100644 index 000000000000..59999446f0b8 --- /dev/null +++ b/tests/parameter/test_grouped_parameter_extended.py @@ -0,0 +1,221 @@ +"""Extended tests for qcodes.parameters.grouped_parameter module.""" + +from __future__ import annotations + +import pytest + +from qcodes.parameters import Parameter +from qcodes.parameters.grouped_parameter import ( + DelegateGroup, + DelegateGroupParameter, + GroupedParameter, +) + + +def _make_source_and_delegate( + name: str, initial: float = 0.0 +) -> tuple[Parameter, DelegateGroupParameter]: + """Helper to create a source parameter and a DelegateGroupParameter.""" + source = Parameter(name=f"{name}_source", set_cmd=None, get_cmd=None) + source.set(initial) + delegate = DelegateGroupParameter(name=name, source=source) + return source, delegate + + +class TestDelegateGroupParameter: + def test_basic_creation(self) -> None: + """DelegateGroupParameter wraps a source parameter.""" + source = Parameter(name="src", set_cmd=None, get_cmd=None) + source.set(3.14) + dgp = DelegateGroupParameter(name="wrapped", source=source) + assert dgp.name == "wrapped" + assert dgp() == 3.14 + + def test_set_propagates_to_source(self) -> None: + """Setting DelegateGroupParameter updates the source.""" + source = Parameter(name="src", set_cmd=None, get_cmd=None) + dgp = DelegateGroupParameter(name="wrapped", source=source) + dgp.set(99.0) + assert source() == 99.0 + + +class TestDelegateGroup: + def test_get_without_custom_getter(self) -> None: + """get() returns namedtuple of parameter values by default.""" + _, d1 = _make_source_and_delegate("alpha", 1.0) + _, d2 = _make_source_and_delegate("beta", 2.0) + + group = DelegateGroup("my_group", parameters=[d1, d2]) + result = group.get() + assert result.alpha == 1.0 + assert result.beta == 2.0 + + def test_get_single_parameter_returns_scalar(self) -> None: + """get() with single parameter returns scalar, not namedtuple.""" + _, d1 = _make_source_and_delegate("only", 5.0) + group = DelegateGroup("single_group", parameters=[d1]) + result = group.get() + assert result == 5.0 + + def test_get_with_custom_getter(self) -> None: + """get() uses custom getter when provided.""" + _, d1 = _make_source_and_delegate("x", 1.0) + + def custom_getter() -> str: + return "custom_value" + + group = DelegateGroup("cg", parameters=[d1], getter=custom_getter) + assert group.get() == "custom_value" + + def test_set_with_dict(self) -> None: + """set() with a dict sets each parameter by name.""" + src_a, d_a = _make_source_and_delegate("a", 0.0) + src_b, d_b = _make_source_and_delegate("b", 0.0) + + group = DelegateGroup("dict_group", parameters=[d_a, d_b]) + group.set({"a": 10.0, "b": 20.0}) + assert src_a() == 10.0 + assert src_b() == 20.0 + + def test_set_single_value_without_setter(self) -> None: + """set() with a single value broadcasts to all parameters.""" + src_a, d_a = _make_source_and_delegate("a", 0.0) + src_b, d_b = _make_source_and_delegate("b", 0.0) + + group = DelegateGroup("broadcast_group", parameters=[d_a, d_b]) + group.set(42.0) + assert src_a() == 42.0 + assert src_b() == 42.0 + + def test_set_with_custom_setter(self) -> None: + """set() uses custom setter when provided.""" + captured: list[object] = [] + _, d1 = _make_source_and_delegate("x", 0.0) + + def custom_setter(value: object) -> None: + captured.append(value) + + group = DelegateGroup("cs", parameters=[d1], setter=custom_setter) + group.set("hello") + assert captured == ["hello"] + + def test_get_parameters(self) -> None: + """get_parameters() returns formatted result.""" + _, d1 = _make_source_and_delegate("p1", 3.0) + _, d2 = _make_source_and_delegate("p2", 7.0) + + group = DelegateGroup("gp", parameters=[d1, d2]) + result = group.get_parameters() + assert result.p1 == 3.0 + assert result.p2 == 7.0 + + def test_source_parameters(self) -> None: + """source_parameters returns tuple of source Parameter objects.""" + src_a, d_a = _make_source_and_delegate("a", 0.0) + src_b, d_b = _make_source_and_delegate("b", 0.0) + + group = DelegateGroup("sp_group", parameters=[d_a, d_b]) + sources = group.source_parameters + assert sources == (src_a, src_b) + + def test_custom_formatter(self) -> None: + """Custom formatter transforms get_parameters output.""" + _, d1 = _make_source_and_delegate("x", 2.0) + _, d2 = _make_source_and_delegate("y", 3.0) + + def my_fmt(x: float, y: float) -> float: + return x + y + + group = DelegateGroup("fmt", parameters=[d1, d2], formatter=my_fmt) + assert group.get() == 5.0 + + def test_custom_parameter_names(self) -> None: + """parameter_names overrides the default names from parameters.""" + _, d1 = _make_source_and_delegate("orig_name", 1.0) + + group = DelegateGroup( + "named_group", parameters=[d1], parameter_names=["custom_name"] + ) + assert "custom_name" in group.parameters + + def test_set_from_dict(self) -> None: + """_set_from_dict sets parameters by name from a dict.""" + src_a, d_a = _make_source_and_delegate("a", 0.0) + src_b, d_b = _make_source_and_delegate("b", 0.0) + + group = DelegateGroup("sfd", parameters=[d_a, d_b]) + group._set_from_dict({"a": 100.0, "b": 200.0}) + assert src_a() == 100.0 + assert src_b() == 200.0 + + +class TestGroupedParameter: + def test_basic_creation(self) -> None: + """GroupedParameter wraps a DelegateGroup.""" + _, d1 = _make_source_and_delegate("ch1", 5.0) + group = DelegateGroup("g", parameters=[d1]) + gp = GroupedParameter("grouped", group=group) + assert gp.name == "grouped" + assert gp.group is group + + def test_repr(self) -> None: + """__repr__ includes name and source parameters.""" + _src, d1 = _make_source_and_delegate("ch1", 0.0) + group = DelegateGroup("g", parameters=[d1]) + gp = GroupedParameter("my_grouped", group=group) + r = repr(gp) + assert "GroupedParameter" in r + assert "my_grouped" in r + assert "source_parameters" in r + + def test_parameters_property(self) -> None: + """Parameters property returns delegate parameters dict.""" + _, d1 = _make_source_and_delegate("a", 0.0) + _, d2 = _make_source_and_delegate("b", 0.0) + group = DelegateGroup("g", parameters=[d1, d2]) + gp = GroupedParameter("gp", group=group) + assert "a" in gp.parameters + assert "b" in gp.parameters + + def test_source_parameters_property(self) -> None: + """source_parameters property delegates to group.""" + src_a, d_a = _make_source_and_delegate("a", 0.0) + group = DelegateGroup("g", parameters=[d_a]) + gp = GroupedParameter("gp", group=group) + assert gp.source_parameters == (src_a,) + + def test_get_raw(self) -> None: + """get_raw returns formatted parameter values.""" + _, d1 = _make_source_and_delegate("v", 42.0) + group = DelegateGroup("g", parameters=[d1]) + gp = GroupedParameter("gp", group=group) + assert gp.get_raw() == 42.0 + + def test_set_raw(self) -> None: + """set_raw delegates to group.set.""" + src, d1 = _make_source_and_delegate("v", 0.0) + group = DelegateGroup("g", parameters=[d1]) + gp = GroupedParameter("gp", group=group) + gp.set_raw(99.0) + assert src() == 99.0 + + def test_missing_group_raises(self) -> None: + """GroupedParameter requires group kwarg.""" + with pytest.raises(TypeError, match="missing required keyword argument"): + GroupedParameter("bad") # type: ignore[call-arg] + + def test_label_and_unit_defaults(self) -> None: + """Default label is name, default unit is empty.""" + _, d1 = _make_source_and_delegate("v", 0.0) + group = DelegateGroup("g", parameters=[d1]) + gp = GroupedParameter("my_param", group=group) + assert gp.label == "my_param" + assert gp.unit == "" + + def test_custom_label_and_unit(self) -> None: + """Custom label and unit are set.""" + _, d1 = _make_source_and_delegate("v", 0.0) + group = DelegateGroup("g", parameters=[d1]) + gp = GroupedParameter("p", group=group, label="Voltage", unit="V") + assert gp.label == "Voltage" + assert gp.unit == "V" diff --git a/tests/test_channel_extended.py b/tests/test_channel_extended.py new file mode 100644 index 000000000000..b52ff57ba427 --- /dev/null +++ b/tests/test_channel_extended.py @@ -0,0 +1,742 @@ +""" +Extended tests for channel.py to improve coverage. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from qcodes.instrument import ( + ChannelList, + ChannelTuple, + InstrumentChannel, + InstrumentModule, +) +from qcodes.instrument.channel import ChannelTupleValidator +from qcodes.instrument_drivers.mock_instruments import ( + DummyChannel, + DummyChannelInstrument, +) + +if TYPE_CHECKING: + from collections.abc import Iterator + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(name="ch_instr", scope="function") +def _ch_instr() -> Iterator[DummyChannelInstrument]: + inst = DummyChannelInstrument(name="ch_ext") + try: + yield inst + finally: + inst.close() + + +@pytest.fixture(name="chan_tuple", scope="function") +def _chan_tuple(ch_instr: DummyChannelInstrument) -> ChannelTuple: + return ch_instr.channels + + +@pytest.fixture(name="mutable_list", scope="function") +def _mutable_list(ch_instr: DummyChannelInstrument) -> ChannelList: + """Create an unlocked ChannelList with some channels.""" + cl = ChannelList(ch_instr, "TestList", DummyChannel) + # Append two channels + chan_x = DummyChannel(ch_instr, "ChanX", "X") + chan_y = DummyChannel(ch_instr, "ChanY", "Y") + cl.append(chan_x) + cl.append(chan_y) + return cl + + +# --------------------------------------------------------------------------- +# InstrumentModule — __repr__, write/ask proxy, parent/root/name_parts +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_instrument_module_repr(ch_instr: DummyChannelInstrument) -> None: + chan_a = ch_instr.submodules["A"] + r = repr(chan_a) + assert "DummyChannel" in r + assert "ch_ext" in r + assert "ChanA" in r + + +@pytest.mark.serial +def test_instrument_module_write_proxy(ch_instr: DummyChannelInstrument) -> None: + """write() on a module should proxy to parent, which raises for DummyBase.""" + chan_a = ch_instr.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + # DummyBase.get_idn exists but write_raw is not implemented on Instrument + # The chain: module.write -> parent.write -> parent.write_raw -> NotImplementedError + with pytest.raises(NotImplementedError): + chan_a.write("test_cmd") + + +@pytest.mark.serial +def test_instrument_module_ask_proxy(ch_instr: DummyChannelInstrument) -> None: + """ask() on a module should proxy to parent.""" + chan_a = ch_instr.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + with pytest.raises(NotImplementedError): + chan_a.ask("test_cmd") + + +@pytest.mark.serial +def test_instrument_module_write_raw_proxy( + ch_instr: DummyChannelInstrument, +) -> None: + chan_a = ch_instr.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + with pytest.raises(NotImplementedError): + chan_a.write_raw("raw_cmd") + + +@pytest.mark.serial +def test_instrument_module_ask_raw_proxy( + ch_instr: DummyChannelInstrument, +) -> None: + chan_a = ch_instr.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + with pytest.raises(NotImplementedError): + chan_a.ask_raw("raw_cmd") + + +@pytest.mark.serial +def test_instrument_module_parent(ch_instr: DummyChannelInstrument) -> None: + chan_a = ch_instr.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + assert chan_a.parent is ch_instr + + +@pytest.mark.serial +def test_instrument_module_root_instrument( + ch_instr: DummyChannelInstrument, +) -> None: + chan_a = ch_instr.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + assert chan_a.root_instrument is ch_instr + + +@pytest.mark.serial +def test_instrument_module_name_parts(ch_instr: DummyChannelInstrument) -> None: + chan_a = ch_instr.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + assert chan_a.name_parts == ["ch_ext", "ChanA"] + + +# --------------------------------------------------------------------------- +# ChannelTuple — __reversed__, __contains__, __add__, index, count +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_tuple_reversed(chan_tuple: ChannelTuple) -> None: + channels = list(chan_tuple) + reversed_channels = list(reversed(chan_tuple)) + assert reversed_channels == list(reversed(channels)) + + +@pytest.mark.serial +def test_channel_tuple_contains( + ch_instr: DummyChannelInstrument, chan_tuple: ChannelTuple +) -> None: + chan_a = ch_instr.submodules["A"] + assert chan_a in chan_tuple + # Create a channel not in the tuple + chan_new = DummyChannel(ch_instr, "ChanNew", "N") + assert chan_new not in chan_tuple + + +@pytest.mark.serial +def test_channel_tuple_add(ch_instr: DummyChannelInstrument) -> None: + """Adding two ChannelTuples should produce a combined tuple.""" + channels = ch_instr.channels + # Split into two halves + first = channels[0:3] + second = channels[3:6] + combined = first + second + assert len(combined) == 6 + + +@pytest.mark.serial +def test_channel_tuple_add_type_mismatch( + ch_instr: DummyChannelInstrument, +) -> None: + """Adding ChannelTuples of different types should raise.""" + + class OtherChannel(InstrumentChannel): + pass + + ct1 = ch_instr.channels + ct2 = ChannelTuple(ch_instr, "other", OtherChannel) + with pytest.raises(TypeError, match="same type"): + ct1 + ct2 + + +@pytest.mark.serial +def test_channel_tuple_add_different_parent() -> None: + """Adding ChannelTuples with different parents should raise.""" + instr1 = DummyChannelInstrument(name="parent1") + instr2 = DummyChannelInstrument(name="parent2") + try: + ct1 = instr1.channels[0:1] + ct2 = instr2.channels[0:1] + with pytest.raises(ValueError, match="same parent"): + ct1 + ct2 + finally: + instr1.close() + instr2.close() + + +@pytest.mark.serial +def test_channel_tuple_index( + ch_instr: DummyChannelInstrument, chan_tuple: ChannelTuple +) -> None: + chan_a = ch_instr.submodules["A"] + idx = chan_tuple.index(chan_a) # type: ignore[arg-type] + assert idx == 0 + + +@pytest.mark.serial +def test_channel_tuple_count( + ch_instr: DummyChannelInstrument, chan_tuple: ChannelTuple +) -> None: + chan_a = ch_instr.submodules["A"] + c = chan_tuple.count(chan_a) # type: ignore[arg-type] + assert c == 1 + + +# --------------------------------------------------------------------------- +# ChannelTuple — get_channels_by_name, get_channel_by_name +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_get_channels_by_name( + ch_instr: DummyChannelInstrument, chan_tuple: ChannelTuple +) -> None: + subset = chan_tuple.get_channels_by_name("ChanA", "ChanB") + assert len(subset) == 2 + + +@pytest.mark.serial +def test_get_channels_by_name_empty_raises(chan_tuple: ChannelTuple) -> None: + with pytest.raises(TypeError, match="one or more names"): + chan_tuple.get_channels_by_name() + + +@pytest.mark.serial +def test_get_channel_by_name( + ch_instr: DummyChannelInstrument, chan_tuple: ChannelTuple +) -> None: + chan = chan_tuple.get_channel_by_name("ChanA") + assert chan is ch_instr.submodules["A"] + + +# --------------------------------------------------------------------------- +# ChannelTuple — get_validator +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_tuple_get_validator( + ch_instr: DummyChannelInstrument, chan_tuple: ChannelTuple +) -> None: + validator = chan_tuple.get_validator() + assert isinstance(validator, ChannelTupleValidator) + # Validate a channel that is in the tuple + chan_a = ch_instr.submodules["A"] + assert isinstance(chan_a, InstrumentChannel) + validator.validate(chan_a) + + +@pytest.mark.serial +def test_channel_tuple_validator_rejects( + ch_instr: DummyChannelInstrument, chan_tuple: ChannelTuple +) -> None: + validator = chan_tuple.get_validator() + chan_new = DummyChannel(ch_instr, "ChanNew", "N") + with pytest.raises(ValueError, match="is not part of the expected channel list"): + validator.validate(chan_new) + + +@pytest.mark.serial +def test_channel_tuple_validator_requires_channel_tuple() -> None: + with pytest.raises(ValueError, match="must be a ChannelTuple"): + ChannelTupleValidator("not a channel tuple") # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# ChannelTuple — repr, snapshot, name_parts, full_name, short_name +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_tuple_repr(chan_tuple: ChannelTuple) -> None: + r = repr(chan_tuple) + assert "ChannelTuple" in r + assert "DummyChannel" in r + + +@pytest.mark.serial +def test_channel_tuple_name_parts( + ch_instr: DummyChannelInstrument, chan_tuple: ChannelTuple +) -> None: + assert chan_tuple.short_name == "TempSensors" + assert chan_tuple.name_parts == ["ch_ext", "TempSensors"] + assert chan_tuple.full_name == "ch_ext_TempSensors" + + +@pytest.mark.serial +def test_channel_tuple_snapshot(chan_tuple: ChannelTuple) -> None: + # The DummyChannelInstrument creates channels with snapshotable=False + snap = chan_tuple.snapshot_base(update=False) + assert "snapshotable" in snap + assert "__class__" in snap + + +@pytest.mark.serial +def test_channel_tuple_snapshotable() -> None: + """ChannelTuple with snapshotable=True should include channels in snapshot.""" + instr = DummyChannelInstrument(name="snap_ch") + try: + cl = ChannelList(instr, "SnapList", DummyChannel, snapshotable=True) + chan = DummyChannel(instr, "ChanSnap", "S") + cl.append(chan) + ct = cl.to_channel_tuple() + snap = ct.snapshot_base(update=False) + assert "channels" in snap + finally: + instr.close() + + +# --------------------------------------------------------------------------- +# ChannelTuple — print_readable_snapshot, invalidate_cache +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_tuple_print_readable_snapshot_not_snapshotable( + ch_instr: DummyChannelInstrument, + capsys: pytest.CaptureFixture[str], +) -> None: + """print_readable_snapshot on non-snapshotable tuple should not print channels.""" + ch_instr.channels.print_readable_snapshot(update=False) + captured = capsys.readouterr() + # snapshotable=False means nothing should be printed + assert captured.out == "" + + +@pytest.mark.serial +def test_channel_tuple_print_readable_snapshot_snapshotable( + capsys: pytest.CaptureFixture[str], +) -> None: + instr = DummyChannelInstrument(name="prs_ch") + try: + cl = ChannelList(instr, "SnapList", DummyChannel, snapshotable=True) + chan = DummyChannel(instr, "ChanPRS", "P") + cl.append(chan) + ct = cl.to_channel_tuple() + ct.print_readable_snapshot(update=False) + captured = capsys.readouterr() + assert "ChanPRS" in captured.out + finally: + instr.close() + + +@pytest.mark.serial +def test_channel_tuple_invalidate_cache( + ch_instr: DummyChannelInstrument, +) -> None: + chan_a = ch_instr.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + chan_a.parameters["temperature"].set(100) + chan_a.parameters["temperature"].get() + assert chan_a.parameters["temperature"].cache.valid + ch_instr.channels.invalidate_cache() + assert not chan_a.parameters["temperature"].cache.valid + + +# --------------------------------------------------------------------------- +# ChannelTuple — __getitem__ with slice and tuple index +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_tuple_getitem_slice(chan_tuple: ChannelTuple) -> None: + sliced = chan_tuple[1:3] + assert isinstance(sliced, ChannelTuple) + assert len(sliced) == 2 + + +@pytest.mark.serial +def test_channel_tuple_getitem_tuple_index(chan_tuple: ChannelTuple) -> None: + selected = chan_tuple[(0, 2, 4)] + assert isinstance(selected, ChannelTuple) + assert len(selected) == 3 + + +@pytest.mark.serial +def test_channel_tuple_getitem_int(chan_tuple: ChannelTuple) -> None: + single = chan_tuple[0] + assert isinstance(single, InstrumentModule) + + +# --------------------------------------------------------------------------- +# ChannelList — mutation operations +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_list_append( + ch_instr: DummyChannelInstrument, mutable_list: ChannelList +) -> None: + chan_z = DummyChannel(ch_instr, "ChanZ", "Z") + mutable_list.append(chan_z) + assert len(mutable_list) == 3 + assert chan_z in mutable_list + + +@pytest.mark.serial +def test_channel_list_append_wrong_type( + ch_instr: DummyChannelInstrument, mutable_list: ChannelList +) -> None: + mod = InstrumentModule(ch_instr, "notchan") + with pytest.raises(TypeError, match="same type"): + mutable_list.append(mod) # type: ignore[arg-type] + + +@pytest.mark.serial +def test_channel_list_extend( + ch_instr: DummyChannelInstrument, mutable_list: ChannelList +) -> None: + new_chans = [ + DummyChannel(ch_instr, "ChanE1", "E1"), + DummyChannel(ch_instr, "ChanE2", "E2"), + ] + mutable_list.extend(new_chans) + assert len(mutable_list) == 4 + + +@pytest.mark.serial +def test_channel_list_extend_wrong_type( + ch_instr: DummyChannelInstrument, mutable_list: ChannelList +) -> None: + with pytest.raises(TypeError, match="same type"): + mutable_list.extend([InstrumentModule(ch_instr, "bad")]) # type: ignore[list-item] + + +@pytest.mark.serial +def test_channel_list_insert( + ch_instr: DummyChannelInstrument, mutable_list: ChannelList +) -> None: + chan_ins = DummyChannel(ch_instr, "ChanIns", "I") + mutable_list.insert(0, chan_ins) + assert mutable_list[0] is chan_ins + assert len(mutable_list) == 3 + + +@pytest.mark.serial +def test_channel_list_insert_wrong_type( + ch_instr: DummyChannelInstrument, mutable_list: ChannelList +) -> None: + with pytest.raises(TypeError, match="same type"): + mutable_list.insert(0, InstrumentModule(ch_instr, "bad")) # type: ignore[arg-type] + + +@pytest.mark.serial +def test_channel_list_remove( + ch_instr: DummyChannelInstrument, mutable_list: ChannelList +) -> None: + first = mutable_list[0] + mutable_list.remove(first) + assert len(mutable_list) == 1 + assert first not in mutable_list + + +@pytest.mark.serial +def test_channel_list_clear(mutable_list: ChannelList) -> None: + mutable_list.clear() + assert len(mutable_list) == 0 + + +@pytest.mark.serial +def test_channel_list_delitem(mutable_list: ChannelList) -> None: + original_len = len(mutable_list) + del mutable_list[0] + assert len(mutable_list) == original_len - 1 + + +# --------------------------------------------------------------------------- +# ChannelList — locked operations raise AttributeError +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_locked_list_append(ch_instr: DummyChannelInstrument) -> None: + cl = ChannelList(ch_instr, "Locked", DummyChannel) + chan = DummyChannel(ch_instr, "ChanL", "L") + cl.append(chan) + cl.lock() + with pytest.raises(AttributeError, match="locked"): + cl.append(DummyChannel(ch_instr, "ChanL2", "L2")) + + +@pytest.mark.serial +def test_locked_list_extend(ch_instr: DummyChannelInstrument) -> None: + cl = ChannelList(ch_instr, "LockedE", DummyChannel) + chan = DummyChannel(ch_instr, "ChanLE", "LE") + cl.append(chan) + cl.lock() + with pytest.raises(AttributeError, match="locked"): + cl.extend([DummyChannel(ch_instr, "ChanLE2", "LE2")]) + + +@pytest.mark.serial +def test_locked_list_insert(ch_instr: DummyChannelInstrument) -> None: + cl = ChannelList(ch_instr, "LockedI", DummyChannel) + chan = DummyChannel(ch_instr, "ChanLI", "LI") + cl.append(chan) + cl.lock() + with pytest.raises(AttributeError, match="locked"): + cl.insert(0, DummyChannel(ch_instr, "ChanLI2", "LI2")) + + +@pytest.mark.serial +def test_locked_list_remove(ch_instr: DummyChannelInstrument) -> None: + cl = ChannelList(ch_instr, "LockedR", DummyChannel) + chan = DummyChannel(ch_instr, "ChanLR", "LR") + cl.append(chan) + cl.lock() + with pytest.raises(AttributeError, match="locked"): + cl.remove(chan) + + +@pytest.mark.serial +def test_locked_list_clear(ch_instr: DummyChannelInstrument) -> None: + cl = ChannelList(ch_instr, "LockedC", DummyChannel) + chan = DummyChannel(ch_instr, "ChanLC", "LC") + cl.append(chan) + cl.lock() + with pytest.raises(AttributeError, match="locked"): + cl.clear() + + +@pytest.mark.serial +def test_locked_list_delitem(ch_instr: DummyChannelInstrument) -> None: + cl = ChannelList(ch_instr, "LockedD", DummyChannel) + chan = DummyChannel(ch_instr, "ChanLD", "LD") + cl.append(chan) + cl.lock() + with pytest.raises(AttributeError, match="locked"): + del cl[0] + + +@pytest.mark.serial +def test_locked_list_setitem(ch_instr: DummyChannelInstrument) -> None: + cl = ChannelList(ch_instr, "LockedS", DummyChannel) + chan = DummyChannel(ch_instr, "ChanLS", "LS") + cl.append(chan) + cl.lock() + with pytest.raises(AttributeError, match="locked"): + cl[0] = DummyChannel(ch_instr, "ChanLS2", "LS2") + + +# --------------------------------------------------------------------------- +# ChannelList — get_validator on unlocked list raises +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_unlocked_list_get_validator(mutable_list: ChannelList) -> None: + with pytest.raises(AttributeError, match="Cannot create a validator"): + mutable_list.get_validator() + + +# --------------------------------------------------------------------------- +# ChannelList — lock / to_channel_tuple +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_list_lock(mutable_list: ChannelList) -> None: + mutable_list.lock() + with pytest.raises(AttributeError, match="locked"): + mutable_list.append( + DummyChannel( + mutable_list._parent, # type: ignore[arg-type] + "ChanLk", + "Lk", + ) + ) + + +@pytest.mark.serial +def test_channel_list_lock_idempotent(mutable_list: ChannelList) -> None: + """Locking an already-locked list should be a no-op.""" + mutable_list.lock() + mutable_list.lock() # should not raise + + +@pytest.mark.serial +def test_channel_list_to_channel_tuple(mutable_list: ChannelList) -> None: + ct = mutable_list.to_channel_tuple() + assert isinstance(ct, ChannelTuple) + assert not isinstance(ct, ChannelList) + assert len(ct) == len(mutable_list) + + +# --------------------------------------------------------------------------- +# ChannelList — repr +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_list_repr(mutable_list: ChannelList) -> None: + r = repr(mutable_list) + assert "ChannelList" in r + assert "DummyChannel" in r + + +# --------------------------------------------------------------------------- +# ChannelList — __setitem__ +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_list_setitem( + ch_instr: DummyChannelInstrument, mutable_list: ChannelList +) -> None: + new_chan = DummyChannel(ch_instr, "ChanSet", "S") + mutable_list[0] = new_chan + assert mutable_list[0] is new_chan + + +# --------------------------------------------------------------------------- +# ChannelList — constructed with existing channels (auto-locked) +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_list_init_with_channels( + ch_instr: DummyChannelInstrument, +) -> None: + """ChannelList created with a non-empty chan_list should auto-lock.""" + channels = [DummyChannel(ch_instr, f"Ch{i}", str(i)) for i in range(3)] + cl = ChannelList(ch_instr, "AutoLocked", DummyChannel, chan_list=channels) + assert cl._locked is True + with pytest.raises(AttributeError, match="locked"): + cl.append(DummyChannel(ch_instr, "ChNew", "N")) + + +# --------------------------------------------------------------------------- +# ChannelTupleValidator with unlocked ChannelList +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_tuple_validator_unlocked_list( + ch_instr: DummyChannelInstrument, +) -> None: + cl = ChannelList(ch_instr, "UnlockedV", DummyChannel) + chan = DummyChannel(ch_instr, "ChanUV", "UV") + cl.append(chan) + with pytest.raises(AttributeError, match="must be locked"): + ChannelTupleValidator(cl) + + +# --------------------------------------------------------------------------- +# ChannelTuple — multi_parameter / multi_function / __getattr__ +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_tuple_multi_parameter(chan_tuple: ChannelTuple) -> None: + """multi_parameter should return a MultiChannelInstrumentParameter.""" + mp = chan_tuple.multi_parameter("temperature") + assert mp is not None + + +@pytest.mark.serial +def test_channel_tuple_multi_parameter_nonexistent( + chan_tuple: ChannelTuple, +) -> None: + with pytest.raises(AttributeError, match="no parameter"): + chan_tuple.multi_parameter("nonexistent_param") + + +@pytest.mark.serial +def test_channel_tuple_multi_function(chan_tuple: ChannelTuple) -> None: + """multi_function should return a callable for functions on channels.""" + mf = chan_tuple.multi_function("log_my_name") + assert callable(mf) + + +@pytest.mark.serial +def test_channel_tuple_multi_function_callable(chan_tuple: ChannelTuple) -> None: + """multi_function should detect callables (methods) on channels.""" + mf = chan_tuple.multi_function("turn_on") + assert callable(mf) + mf() # should not raise + + +@pytest.mark.serial +def test_channel_tuple_multi_function_nonexistent( + chan_tuple: ChannelTuple, +) -> None: + with pytest.raises(AttributeError, match="no callable or function"): + chan_tuple.multi_function("nonexistent_func") + + +@pytest.mark.serial +def test_channel_tuple_multi_function_empty() -> None: + """multi_function on empty tuple raises AttributeError.""" + instr = DummyChannelInstrument(name="empty_mf") + try: + empty_ct = ChannelTuple(instr, "empty", DummyChannel) + with pytest.raises(AttributeError, match="no callable or function"): + empty_ct.multi_function("anything") + finally: + instr.close() + + +@pytest.mark.serial +def test_channel_tuple_getattr_parameter(chan_tuple: ChannelTuple) -> None: + """__getattr__ should return a multi-parameter for known parameters.""" + temp = chan_tuple.temperature # type: ignore[attr-defined] + assert temp is not None + + +@pytest.mark.serial +def test_channel_tuple_getattr_channel_by_name( + ch_instr: DummyChannelInstrument, chan_tuple: ChannelTuple +) -> None: + """__getattr__ should return a channel by short_name.""" + chan = chan_tuple.ChanA # type: ignore[attr-defined] + assert chan is ch_instr.submodules["A"] + + +@pytest.mark.serial +def test_channel_tuple_getattr_nonexistent(chan_tuple: ChannelTuple) -> None: + with pytest.raises(AttributeError, match="has no attribute"): + chan_tuple.nonexistent_attr_xyz # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# ChannelTuple — __dir__ +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_channel_tuple_dir(chan_tuple: ChannelTuple) -> None: + d = dir(chan_tuple) + assert "temperature" in d + assert "ChanA" in d diff --git a/tests/test_instrument_extended.py b/tests/test_instrument_extended.py new file mode 100644 index 000000000000..a72da7238435 --- /dev/null +++ b/tests/test_instrument_extended.py @@ -0,0 +1,673 @@ +""" +Extended tests for Instrument and InstrumentBase to improve coverage. +""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import pytest + +from qcodes.instrument import ( + Instrument, + InstrumentBase, + InstrumentModule, + find_or_create_instrument, +) +from qcodes.instrument_drivers.mock_instruments import ( + DummyChannelInstrument, + DummyInstrument, + MockMetaParabola, + MockParabola, +) +from qcodes.parameters import Function + +if TYPE_CHECKING: + from collections.abc import Iterator + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(name="dummy", scope="function") +def _dummy() -> Iterator[DummyInstrument]: + inst = DummyInstrument(name="ext_dummy", gates=["dac1", "dac2", "dac3"]) + try: + yield inst + finally: + inst.close() + + +@pytest.fixture(name="dummy_ch", scope="function") +def _dummy_ch() -> Iterator[DummyChannelInstrument]: + inst = DummyChannelInstrument(name="ext_dummy_ch") + try: + yield inst + finally: + inst.close() + + +# --------------------------------------------------------------------------- +# Instrument.write_raw / ask_raw raise NotImplementedError +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_write_raw_raises(dummy: DummyInstrument) -> None: + """write_raw on bare Instrument should raise NotImplementedError.""" + # DummyBase overrides get_idn but not write_raw/ask_raw, so + # the base Instrument.write_raw implementation applies. + bare = Instrument("bare_instr") + try: + with pytest.raises(NotImplementedError, match="has not defined a write method"): + bare.write_raw("cmd") + finally: + bare.close() + + +@pytest.mark.serial +def test_ask_raw_raises() -> None: + """ask_raw on bare Instrument should raise NotImplementedError.""" + bare = Instrument("bare_ask") + try: + with pytest.raises(NotImplementedError, match="has not defined an ask method"): + bare.ask_raw("cmd") + finally: + bare.close() + + +# --------------------------------------------------------------------------- +# Instrument.write and ask wrap errors +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_write_wraps_exception() -> None: + """write() should wrap underlying exceptions with context.""" + bare = Instrument("bare_write_wrap") + try: + with pytest.raises(NotImplementedError) as exc_info: + bare.write("SOMECMD") + assert "writing 'SOMECMD'" in str(exc_info.value.args) + finally: + bare.close() + + +@pytest.mark.serial +def test_ask_wraps_exception() -> None: + """ask() should wrap underlying exceptions with context.""" + bare = Instrument("bare_ask_wrap") + try: + with pytest.raises(NotImplementedError) as exc_info: + bare.ask("SOMECMD") + assert "asking 'SOMECMD'" in str(exc_info.value.args) + finally: + bare.close() + + +# --------------------------------------------------------------------------- +# Instrument.close_all +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_close_all() -> None: + """close_all should remove all registered instruments.""" + DummyInstrument(name="closeall1", gates=["g1"]) + DummyInstrument(name="closeall2", gates=["g2"]) + + assert Instrument.exist("closeall1") + assert Instrument.exist("closeall2") + + Instrument.close_all() + + assert not Instrument.exist("closeall1") + assert not Instrument.exist("closeall2") + + +# --------------------------------------------------------------------------- +# Instrument.find_instrument +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_find_instrument_by_name(dummy: DummyInstrument) -> None: + """find_instrument should return the instrument by name.""" + found = Instrument.find_instrument("ext_dummy") + assert found is dummy + + +@pytest.mark.serial +def test_find_instrument_not_found() -> None: + """find_instrument should raise KeyError for non-existent names.""" + with pytest.raises(KeyError, match="does not exist"): + Instrument.find_instrument("nonexistent_instrument_xyz") + + +@pytest.mark.serial +def test_find_instrument_wrong_class(dummy: DummyInstrument) -> None: + """find_instrument should raise TypeError when class doesn't match.""" + with pytest.raises(TypeError, match="was requested"): + Instrument.find_instrument("ext_dummy", instrument_class=MockParabola) + + +@pytest.mark.serial +def test_find_instrument_with_class(dummy: DummyInstrument) -> None: + """find_instrument with matching class should succeed.""" + found = Instrument.find_instrument("ext_dummy", instrument_class=DummyInstrument) + assert found is dummy + + +# --------------------------------------------------------------------------- +# Instrument.exist and is_valid +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_exist_true(dummy: DummyInstrument) -> None: + assert Instrument.exist("ext_dummy") is True + + +@pytest.mark.serial +def test_exist_false() -> None: + assert Instrument.exist("does_not_exist_xyz") is False + + +@pytest.mark.serial +def test_exist_with_class(dummy: DummyInstrument) -> None: + assert Instrument.exist("ext_dummy", instrument_class=DummyInstrument) is True + # exist() with wrong class will raise TypeError (not return False) + with pytest.raises(TypeError): + Instrument.exist("ext_dummy", instrument_class=MockParabola) + + +@pytest.mark.serial +def test_is_valid_open(dummy: DummyInstrument) -> None: + assert Instrument.is_valid(dummy) is True + + +@pytest.mark.serial +def test_is_valid_after_close() -> None: + inst = DummyInstrument(name="valid_test", gates=["g"]) + inst.close() + assert Instrument.is_valid(inst) is False + + +# --------------------------------------------------------------------------- +# Instrument.__repr__ and __del__ +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_repr(dummy: DummyInstrument) -> None: + r = repr(dummy) + assert "DummyInstrument" in r + assert "ext_dummy" in r + + +@pytest.mark.serial +def test_del_closes_instrument() -> None: + """__del__ should close the instrument without raising.""" + inst = DummyInstrument(name="del_test", gates=["g"]) + assert Instrument.exist("del_test") + inst.__del__() + assert not Instrument.exist("del_test") + + +# --------------------------------------------------------------------------- +# Instrument.close with connection attribute +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_close_with_connection() -> None: + """close() should call connection.close() if present.""" + + class FakeConnection: + closed = False + + def close(self) -> None: + self.closed = True + + inst = DummyInstrument(name="conn_test", gates=["g"]) + conn = FakeConnection() + inst.connection = conn # type: ignore[attr-defined] + inst.close() + assert conn.closed + + +# --------------------------------------------------------------------------- +# Instrument.instances / record_instance / remove_instance +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_instances_returns_list(dummy: DummyInstrument) -> None: + """instances() should return a list containing the instrument.""" + instances = DummyInstrument.instances() + assert dummy in instances + + +@pytest.mark.serial +def test_instances_empty_after_close() -> None: + inst = DummyInstrument(name="inst_empty_test", gates=["g"]) + inst.close() + assert inst not in DummyInstrument.instances() + + +@pytest.mark.serial +def test_record_instance_duplicate_name(dummy: DummyInstrument) -> None: + """Recording an instance with a duplicate name should raise.""" + with pytest.raises(KeyError, match="Another instrument has the name"): + DummyInstrument(name="ext_dummy", gates=["g"]) + + +# --------------------------------------------------------------------------- +# InstrumentBase.label +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_label_default(dummy: DummyInstrument) -> None: + """Default label should be the instrument name.""" + assert dummy.label == "ext_dummy" + + +@pytest.mark.serial +def test_label_set_get(dummy: DummyInstrument) -> None: + """Label property should be settable and gettable.""" + dummy.label = "My Custom Label" + assert dummy.label == "My Custom Label" + + +@pytest.mark.serial +def test_label_via_constructor() -> None: + """Label kwarg should be respected.""" + inst = DummyInstrument(name="label_test", gates=["g"], label="Custom") + try: + assert inst.label == "Custom" + finally: + inst.close() + + +# --------------------------------------------------------------------------- +# InstrumentBase.add_function +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_add_function(dummy: DummyInstrument) -> None: + dummy.add_function("rst", call_cmd="*RST") + assert "rst" in dummy.functions + assert isinstance(dummy.functions["rst"], Function) + + +@pytest.mark.serial +def test_add_function_duplicate(dummy: DummyInstrument) -> None: + dummy.add_function("rst2", call_cmd="*RST") + with pytest.raises(KeyError, match="Duplicate function name"): + dummy.add_function("rst2", call_cmd="*RST") + + +# --------------------------------------------------------------------------- +# InstrumentBase.add_submodule +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_add_submodule(dummy: DummyInstrument) -> None: + mod = InstrumentModule(dummy, "mymod") + result = dummy.add_submodule("mymod", mod) + assert result is mod + assert "mymod" in dummy.submodules + assert "mymod" in dummy.instrument_modules + + +@pytest.mark.serial +def test_add_submodule_duplicate(dummy: DummyInstrument) -> None: + mod1 = InstrumentModule(dummy, "dupmod") + dummy.add_submodule("dupmod", mod1) + mod2 = InstrumentModule(dummy, "dupmod2") + with pytest.raises(KeyError, match="Duplicate submodule name"): + dummy.add_submodule("dupmod", mod2) + + +@pytest.mark.serial +def test_add_submodule_non_metadatable(dummy: DummyInstrument) -> None: + with pytest.raises(TypeError, match="Submodules must be metadatable"): + dummy.add_submodule("bad", "not_a_submodule") # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# InstrumentBase.get_component / _get_component_by_name +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_get_component_parameter(dummy: DummyInstrument) -> None: + comp = dummy.get_component("dac1") + assert comp is dummy.parameters["dac1"] + + +@pytest.mark.serial +def test_get_component_submodule(dummy_ch: DummyChannelInstrument) -> None: + comp = dummy_ch.get_component("A") + assert comp is dummy_ch.submodules["A"] + + +@pytest.mark.serial +def test_get_component_nested(dummy_ch: DummyChannelInstrument) -> None: + """Get a parameter within a submodule.""" + comp = dummy_ch.get_component("A_temperature") + assert comp is dummy_ch.submodules["A"].parameters["temperature"] # type: ignore[union-attr] + + +@pytest.mark.serial +def test_get_component_not_found(dummy: DummyInstrument) -> None: + with pytest.raises(KeyError): + dummy.get_component("nonexistent_thing") + + +# --------------------------------------------------------------------------- +# InstrumentBase.print_readable_snapshot +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_print_readable_snapshot( + dummy: DummyInstrument, capsys: pytest.CaptureFixture[str] +) -> None: + dummy.print_readable_snapshot(update=False) + captured = capsys.readouterr() + assert "ext_dummy:" in captured.out + assert "dac1" in captured.out + + +@pytest.mark.serial +def test_print_readable_snapshot_truncation( + dummy: DummyInstrument, capsys: pytest.CaptureFixture[str] +) -> None: + """Test that long lines are truncated to max_chars.""" + dummy.print_readable_snapshot(update=False, max_chars=40) + captured = capsys.readouterr() + for line in captured.out.split("\n"): + if line.startswith("-"): + continue + if line.strip() == "": + continue + # header lines and parameter lines + # Lines should be at most 40 chars (or contain "...") + if len(line) > 40: + assert line.endswith("...") + + +# --------------------------------------------------------------------------- +# InstrumentBase.invalidate_cache +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_invalidate_cache(dummy: DummyInstrument) -> None: + """invalidate_cache should mark parameters as stale.""" + dummy.dac1.set(42) + dummy.dac1.get() + assert dummy.dac1.cache.valid + dummy.invalidate_cache() + assert not dummy.dac1.cache.valid + + +@pytest.mark.serial +def test_invalidate_cache_with_submodules( + dummy_ch: DummyChannelInstrument, +) -> None: + """invalidate_cache should recurse into submodules.""" + chan_a = dummy_ch.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + chan_a.parameters["temperature"].set(100) + chan_a.parameters["temperature"].get() + assert chan_a.parameters["temperature"].cache.valid + dummy_ch.invalidate_cache() + assert not chan_a.parameters["temperature"].cache.valid + + +# --------------------------------------------------------------------------- +# InstrumentBase.parent / ancestors / root_instrument +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_parent_of_instrument(dummy: DummyInstrument) -> None: + """Top-level instruments should have parent=None.""" + assert dummy.parent is None + + +@pytest.mark.serial +def test_parent_of_module(dummy_ch: DummyChannelInstrument) -> None: + chan_a = dummy_ch.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + assert chan_a.parent is dummy_ch + + +@pytest.mark.serial +def test_ancestors_of_instrument(dummy: DummyInstrument) -> None: + assert dummy.ancestors == (dummy,) + + +@pytest.mark.serial +def test_ancestors_of_module(dummy_ch: DummyChannelInstrument) -> None: + chan_a = dummy_ch.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + assert chan_a.ancestors == (chan_a, dummy_ch) + + +@pytest.mark.serial +def test_root_instrument(dummy: DummyInstrument) -> None: + assert dummy.root_instrument is dummy + + +@pytest.mark.serial +def test_root_instrument_of_module(dummy_ch: DummyChannelInstrument) -> None: + chan_a = dummy_ch.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + assert chan_a.root_instrument is dummy_ch + + +# --------------------------------------------------------------------------- +# InstrumentBase.name_parts / full_name / short_name +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_name_parts_instrument(dummy: DummyInstrument) -> None: + assert dummy.name_parts == ["ext_dummy"] + + +@pytest.mark.serial +def test_name_parts_module(dummy_ch: DummyChannelInstrument) -> None: + chan_a = dummy_ch.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + assert chan_a.name_parts == ["ext_dummy_ch", "ChanA"] + + +@pytest.mark.serial +def test_full_name_module(dummy_ch: DummyChannelInstrument) -> None: + chan_a = dummy_ch.submodules["A"] + assert isinstance(chan_a, InstrumentModule) + assert chan_a.full_name == "ext_dummy_ch_ChanA" + + +@pytest.mark.serial +def test_short_name(dummy: DummyInstrument) -> None: + assert dummy.short_name == "ext_dummy" + + +@pytest.mark.serial +def test_name_equals_full_name(dummy: DummyInstrument) -> None: + assert dummy.name == dummy.full_name + + +# --------------------------------------------------------------------------- +# InstrumentBase.snapshot_base +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_snapshot_base(dummy: DummyInstrument) -> None: + snap = dummy.snapshot_base(update=False) + assert "parameters" in snap + assert "functions" in snap + assert "submodules" in snap + assert "__class__" in snap + assert "name" in snap + assert "label" in snap + assert "dac1" in snap["parameters"] + + +# --------------------------------------------------------------------------- +# InstrumentBase.validate_status +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_validate_status( + dummy: DummyInstrument, capsys: pytest.CaptureFixture[str] +) -> None: + """validate_status should not raise for valid parameters.""" + dummy.dac1.set(10) + dummy.validate_status(verbose=True) + captured = capsys.readouterr() + assert "dac1" in captured.out + + +# --------------------------------------------------------------------------- +# InstrumentBase._replace_hyphen and _is_valid_identifier +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_replace_hyphen() -> None: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + inst = DummyInstrument(name="my-inst", gates=["g"]) + try: + assert inst.name == "my_inst" + assert len(w) >= 1 + hyphen_warnings = [x for x in w if "Changed my-inst" in str(x.message)] + assert len(hyphen_warnings) >= 1 + finally: + inst.close() + + +@pytest.mark.serial +def test_invalid_identifier() -> None: + with pytest.raises(ValueError, match="invalid instrument identifier"): + DummyInstrument(name="123invalid", gates=["g"]) + + +# --------------------------------------------------------------------------- +# InstrumentBase deprecated __getitem__, set, get, call +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_deprecated_getitem(dummy: DummyInstrument) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PendingDeprecationWarning) + val = dummy["dac1"] # type: ignore[index] + assert val is dummy.parameters["dac1"] + + +@pytest.mark.serial +def test_deprecated_set_get(dummy: DummyInstrument) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PendingDeprecationWarning) + dummy.set("dac1", 42) # type: ignore[call-overload] + result = dummy.get("dac1") # type: ignore[call-overload] + assert result == 42 + + +@pytest.mark.serial +def test_deprecated_call(dummy: DummyInstrument) -> None: + dummy.add_function("noop", call_cmd="*OPC") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PendingDeprecationWarning) + # call() on a function backed by a string cmd will try to write, + # which raises NotImplementedError on DummyBase. That's fine — + # we just test that the deprecated `call` path is exercised. + try: + dummy.call("noop") # type: ignore[call-overload] + except NotImplementedError: + pass + + +# --------------------------------------------------------------------------- +# InstrumentBase.__getstate__ prevents pickling +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_getstate_raises(dummy: DummyInstrument) -> None: + with pytest.raises(RuntimeError, match="can not be pickled"): + dummy.__getstate__() + + +# --------------------------------------------------------------------------- +# find_or_create_instrument +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_find_or_create_new() -> None: + """find_or_create_instrument should create a new instrument.""" + inst = find_or_create_instrument(DummyInstrument, "foc_new", gates=["g1"]) + try: + assert isinstance(inst, DummyInstrument) + assert inst.name == "foc_new" + finally: + inst.close() + + +@pytest.mark.serial +def test_find_or_create_existing() -> None: + """find_or_create_instrument should find an existing instrument.""" + inst = DummyInstrument(name="foc_exist", gates=["g1"]) + try: + found = find_or_create_instrument(DummyInstrument, "foc_exist", gates=["g1"]) + assert found is inst + finally: + inst.close() + + +@pytest.mark.serial +def test_find_or_create_recreate() -> None: + """find_or_create_instrument with recreate=True should recreate.""" + inst = DummyInstrument(name="foc_recreate", gates=["g1"]) + new_inst = find_or_create_instrument( + DummyInstrument, "foc_recreate", gates=["g1"], recreate=True + ) + try: + assert new_inst is not inst + assert new_inst.name == "foc_recreate" + finally: + new_inst.close() + + +# --------------------------------------------------------------------------- +# MockMetaParabola (InstrumentBase, not Instrument) +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_meta_instrument_not_tracked() -> None: + """MockMetaParabola is InstrumentBase but not Instrument, not tracked.""" + p = MockParabola("meta_parabola_parent") + try: + m = MockMetaParabola("meta_test", p) + assert isinstance(m, InstrumentBase) + assert not isinstance(m, Instrument) + assert not Instrument.exist("meta_test") + finally: + p.close() diff --git a/tests/test_station_extended.py b/tests/test_station_extended.py new file mode 100644 index 000000000000..6995832a0991 --- /dev/null +++ b/tests/test_station_extended.py @@ -0,0 +1,285 @@ +""" +Extended tests for Station to improve coverage. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from qcodes.instrument import Instrument, InstrumentModule +from qcodes.instrument_drivers.mock_instruments import DummyInstrument +from qcodes.parameters import Parameter +from qcodes.station import Station + +if TYPE_CHECKING: + from collections.abc import Iterator + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(name="station", scope="function") +def _station() -> Iterator[Station]: + st = Station(default=True) + try: + yield st + finally: + Station.default = None + + +@pytest.fixture(name="dummy_instr", scope="function") +def _dummy_instr() -> Iterator[DummyInstrument]: + inst = DummyInstrument(name="st_dummy", gates=["dac1", "dac2"]) + try: + yield inst + finally: + inst.close() + + +# --------------------------------------------------------------------------- +# Station.default class attribute +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_station_default_set_on_init() -> None: + """New Station with default=True should set Station.default.""" + st = Station(default=True) + assert Station.default is st + Station.default = None + + +@pytest.mark.serial +def test_station_default_not_set() -> None: + """Station with default=False should not overwrite Station.default.""" + Station.default = None + _ = Station(default=False) + assert Station.default is None + + +# --------------------------------------------------------------------------- +# Station.add_component / remove_component +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_add_component(station: Station, dummy_instr: DummyInstrument) -> None: + name = station.add_component(dummy_instr) + assert name == "st_dummy" + assert "st_dummy" in station.components + + +@pytest.mark.serial +def test_add_component_custom_name( + station: Station, dummy_instr: DummyInstrument +) -> None: + name = station.add_component(dummy_instr, name="custom_name") + assert name == "custom_name" + assert "custom_name" in station.components + + +@pytest.mark.serial +def test_add_component_duplicate_raises( + station: Station, dummy_instr: DummyInstrument +) -> None: + station.add_component(dummy_instr) + with pytest.raises(RuntimeError, match="already registered"): + station.add_component(dummy_instr) + + +@pytest.mark.serial +def test_remove_component(station: Station, dummy_instr: DummyInstrument) -> None: + station.add_component(dummy_instr) + removed = station.remove_component("st_dummy") + assert removed is dummy_instr + assert "st_dummy" not in station.components + + +@pytest.mark.serial +def test_remove_component_not_found(station: Station) -> None: + with pytest.raises(KeyError, match="is not part of the station"): + station.remove_component("nonexistent_xyz") + + +# --------------------------------------------------------------------------- +# Station.__getitem__ +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_getitem(station: Station, dummy_instr: DummyInstrument) -> None: + station.add_component(dummy_instr) + assert station["st_dummy"] is dummy_instr + + +@pytest.mark.serial +def test_getitem_missing(station: Station) -> None: + with pytest.raises(KeyError): + station["nonexistent"] + + +# --------------------------------------------------------------------------- +# Station.snapshot_base +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_snapshot_base_empty(station: Station) -> None: + snap = station.snapshot_base(update=False) + assert "instruments" in snap + assert "parameters" in snap + assert "components" in snap + assert "config" in snap + + +@pytest.mark.serial +def test_snapshot_base_with_instrument( + station: Station, dummy_instr: DummyInstrument +) -> None: + station.add_component(dummy_instr) + snap = station.snapshot_base(update=False) + assert "st_dummy" in snap["instruments"] + + +@pytest.mark.serial +def test_snapshot_base_with_parameter(station: Station) -> None: + param = Parameter("standalone_param", set_cmd=None, get_cmd=None, initial_value=5) + station.add_component(param, name="my_param") + snap = station.snapshot_base(update=False) + assert "my_param" in snap["parameters"] + + +@pytest.mark.serial +def test_snapshot_base_with_other_component(station: Station) -> None: + """Non-instrument, non-parameter components go into 'components'.""" + # InstrumentModule is Metadatable but not Instrument or Parameter + instr = DummyInstrument(name="snap_parent_st", gates=["g"]) + try: + mod = InstrumentModule(instr, "modcomp") + station.add_component(mod, name="modcomp") + snap = station.snapshot_base(update=False) + assert "modcomp" in snap["components"] + finally: + instr.close() + + +@pytest.mark.serial +def test_snapshot_base_removes_closed_instrument(station: Station) -> None: + """Closed instruments should be removed from station during snapshot.""" + inst = DummyInstrument(name="closing_inst", gates=["g"]) + station.add_component(inst) + inst.close() + snap = station.snapshot_base(update=False) + assert "closing_inst" not in snap["instruments"] + assert "closing_inst" not in station.components + + +# --------------------------------------------------------------------------- +# Station.close_all_registered_instruments +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_close_all_registered_instruments(station: Station) -> None: + d1 = DummyInstrument(name="close_reg_1", gates=["g"]) + d2 = DummyInstrument(name="close_reg_2", gates=["g"]) + station.add_component(d1) + station.add_component(d2) + + station.close_all_registered_instruments() + + assert "close_reg_1" not in station.components + assert "close_reg_2" not in station.components + assert not Instrument.exist("close_reg_1") + assert not Instrument.exist("close_reg_2") + + +# --------------------------------------------------------------------------- +# Station.get_component +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_get_component_top_level( + station: Station, dummy_instr: DummyInstrument +) -> None: + station.add_component(dummy_instr) + comp = station.get_component("st_dummy") + assert comp is dummy_instr + + +@pytest.mark.serial +def test_get_component_parameter( + station: Station, dummy_instr: DummyInstrument +) -> None: + """get_component should resolve sub-components like parameters.""" + station.add_component(dummy_instr) + comp = station.get_component("st_dummy_dac1") + assert comp is dummy_instr.parameters["dac1"] + + +@pytest.mark.serial +def test_get_component_not_found(station: Station) -> None: + with pytest.raises(KeyError, match="is not part of the station"): + station.get_component("nonexistent_component") + + +@pytest.mark.serial +def test_get_component_non_instrumentbase(station: Station) -> None: + """get_component with remaining parts on a non-InstrumentBase should raise.""" + param = Parameter("toppar", set_cmd=None, get_cmd=None, initial_value=0) + station.add_component(param, name="toppar") + with pytest.raises(KeyError, match="no sub-component"): + station.get_component("toppar_something") + + +# --------------------------------------------------------------------------- +# Station with components in constructor +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_station_init_with_components() -> None: + d = DummyInstrument(name="init_comp_st", gates=["g"]) + try: + st = Station(d) + assert "init_comp_st" in st.components + finally: + d.close() + Station.default = None + + +# --------------------------------------------------------------------------- +# Station.delegate_attr_dicts — attribute access to components +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_station_delegate_attr(station: Station, dummy_instr: DummyInstrument) -> None: + """Station should delegate attribute access to components dict.""" + station.add_component(dummy_instr) + assert station.st_dummy is dummy_instr # type: ignore[attr-defined] + + +# --------------------------------------------------------------------------- +# Station.add_component with snapshot_exclude Parameter +# --------------------------------------------------------------------------- + + +@pytest.mark.serial +def test_add_component_snapshot_exclude_param(station: Station) -> None: + """snapshot_exclude parameters should not be in snapshot.""" + param = Parameter( + "hidden_param", + set_cmd=None, + get_cmd=None, + initial_value=42, + snapshot_exclude=True, + ) + station.add_component(param, name="hidden_param") + snap = station.snapshot_base(update=False) + assert "hidden_param" not in snap["parameters"] From 76f78cea8c56a62f89c490581992862e554981fd Mon Sep 17 00:00:00 2001 From: "Jens H. Nielsen" Date: Tue, 24 Mar 2026 09:58:53 +0100 Subject: [PATCH 6/6] Fix typechecking of tests --- tests/parameter/test_combined_parameter_extended.py | 11 +++++++---- tests/test_channel_extended.py | 4 ++-- tests/test_instrument_extended.py | 4 ++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/parameter/test_combined_parameter_extended.py b/tests/parameter/test_combined_parameter_extended.py index e93ba24760be..ed74dc6406f3 100644 --- a/tests/parameter/test_combined_parameter_extended.py +++ b/tests/parameter/test_combined_parameter_extended.py @@ -31,8 +31,9 @@ def test_combine_creates_combined_parameter( def test_combine_with_label_and_unit(self, two_params: list[Parameter]) -> None: """combine() passes label and unit through.""" cp = combine(*two_params, name="xy", label="X and Y", unit="V") - assert cp.parameter.label == "X and Y" - assert cp.parameter.unit == "V" + # cp.parameter is a parameter like object but these attributes are dynamically added + assert cp.parameter.label == "X and Y" # pyright: ignore[reportFunctionMemberAccess] + assert cp.parameter.unit == "V" # pyright: ignore[reportFunctionMemberAccess] def test_combine_with_aggregator(self, two_params: list[Parameter]) -> None: """combine() passes aggregator through.""" @@ -108,7 +109,8 @@ def test_units_deprecated( with caplog.at_level(logging.WARNING): cp = CombinedParameter(two_params, name="xy", units="mV") assert any("`units` is deprecated" in msg for msg in caplog.messages) - assert cp.parameter.unit == "mV" + # cp.parameter is a parameter like object but these attributes are dynamically added + assert cp.parameter.unit == "mV" # pyright: ignore[reportFunctionMemberAccess] def test_units_deprecated_unit_takes_precedence( self, two_params: list[Parameter], caplog: pytest.LogCaptureFixture @@ -116,7 +118,8 @@ def test_units_deprecated_unit_takes_precedence( """When both unit and units are given, unit takes precedence.""" with caplog.at_level(logging.WARNING): cp = CombinedParameter(two_params, name="xy", unit="V", units="mV") - assert cp.parameter.unit == "V" + # cp.parameter is a parameter like object but these attributes are dynamically added + assert cp.parameter.unit == "V" # pyright: ignore[reportFunctionMemberAccess] def test_invalid_name_raises(self, two_params: list[Parameter]) -> None: """Invalid parameter name raises ValueError.""" diff --git a/tests/test_channel_extended.py b/tests/test_channel_extended.py index b52ff57ba427..2b86e5c98297 100644 --- a/tests/test_channel_extended.py +++ b/tests/test_channel_extended.py @@ -178,7 +178,7 @@ class OtherChannel(InstrumentChannel): ct1 = ch_instr.channels ct2 = ChannelTuple(ch_instr, "other", OtherChannel) with pytest.raises(TypeError, match="same type"): - ct1 + ct2 + _ = ct1 + ct2 @pytest.mark.serial @@ -190,7 +190,7 @@ def test_channel_tuple_add_different_parent() -> None: ct1 = instr1.channels[0:1] ct2 = instr2.channels[0:1] with pytest.raises(ValueError, match="same parent"): - ct1 + ct2 + _ = ct1 + ct2 finally: instr1.close() instr2.close() diff --git a/tests/test_instrument_extended.py b/tests/test_instrument_extended.py index a72da7238435..4324a7e21f5a 100644 --- a/tests/test_instrument_extended.py +++ b/tests/test_instrument_extended.py @@ -116,8 +116,8 @@ def test_ask_wraps_exception() -> None: @pytest.mark.serial def test_close_all() -> None: """close_all should remove all registered instruments.""" - DummyInstrument(name="closeall1", gates=["g1"]) - DummyInstrument(name="closeall2", gates=["g2"]) + _ = DummyInstrument(name="closeall1", gates=["g1"]) + _ = DummyInstrument(name="closeall2", gates=["g2"]) assert Instrument.exist("closeall1") assert Instrument.exist("closeall2")