diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 0d603d68..c39eb469 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -287,3 +287,56 @@ def test_searchsorted(data): except Exception as exc: ph.add_note(exc, repro_snippet) raise + + +@pytest.mark.min_version("2025.12") +@given(data=st.data()) +def test_searchsorted_with_scalars(data): + # TODO: deduplicate with test_searchorted above + + # 1. draw x1, sorter and side exactly the same as in test_searchsorted + x1_dtype = data.draw(st.sampled_from(dh.real_dtypes)) + _x1 = data.draw( + st.lists( + xps.from_dtype(x1_dtype, allow_nan=False, allow_infinity=False), + min_size=1, + unique=True + ), + label="_x1", + ) + x1 = xp.asarray(_x1, dtype=x1_dtype) + if data.draw(st.booleans(), label="use sorter?"): + sorter = xp.argsort(x1) + else: + sorter = None + x1 = xp.sort(x1) + + kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"]))) + + # 2. draw x2, a real-valued scalar + # - for a float-dtype x1 array, draw python ints or floats + # - for an integer-dtype x1 array, draw an in-range python int + dt_for_x2 = [x1.dtype] + if x1.dtype in dh.real_float_dtypes: + dt_for_x2 += [xp.int32] + + x2 = data.draw(hh.scalars(st.sampled_from(dt_for_x2), finite=True)) + + # 3. testing: similar to test_searchsorted, modulo `out.shape == ()` + repro_snippet = ph.format_snippet( + f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r}, **kw) with {kw = }" + ) + try: + out = xp.searchsorted(x1, x2, sorter=sorter, **kw) + + ph.assert_dtype( + "searchsorted", + in_dtype=[x1.dtype], #, x2.dtype + out_dtype=out.dtype, + expected=xp.__array_namespace_info__().default_dtypes()["indexing"], + ) + # TODO: values testing + ph.assert_shape("searchsorted", out_shape=out.shape, expected=()) + except Exception as exc: + ph.add_note(exc, repro_snippet) + raise