diff --git a/encodings/fastlanes/src/bitpacking/compute/filter.rs b/encodings/fastlanes/src/bitpacking/compute/filter.rs index e8d1385e4d8..567bce2e78b 100644 --- a/encodings/fastlanes/src/bitpacking/compute/filter.rs +++ b/encodings/fastlanes/src/bitpacking/compute/filter.rs @@ -97,19 +97,26 @@ fn filter_primitive_without_patches( array: &BitPackedArray, selection: &Arc, ) -> VortexResult<(Buffer, Validity)> { - let values = filter_with_indices(array, selection.indices()); + let selection_buffer = selection.bit_buffer(); + + let values = filter_with_indices( + array, + selection_buffer.set_indices(), + selection_buffer.true_count(), + ); let validity = array.validity()?.filter(&Mask::Values(selection.clone()))?; Ok((values.freeze(), validity)) } -fn filter_with_indices( +fn filter_with_indices>( array: &BitPackedArray, - indices: &[usize], + indices: I, + indices_len: usize, ) -> BufferMut { let offset = array.offset() as usize; let bit_width = array.bit_width() as usize; - let mut values = BufferMut::with_capacity(indices.len()); + let mut values = BufferMut::with_capacity(indices_len); // Some re-usable memory to store per-chunk indices. let mut unpacked = [const { MaybeUninit::::uninit() }; 1024]; @@ -118,43 +125,39 @@ fn filter_with_indices( // Group the indices by the FastLanes chunk they belong to. let chunk_size = 128 * bit_width / size_of::(); - chunked_indices( - indices.iter().copied(), - offset, - |chunk_idx, indices_within_chunk| { - let packed = &packed_bytes[chunk_idx * chunk_size..][..chunk_size]; - - if indices_within_chunk.len() == 1024 { - // Unpack the entire chunk. - unsafe { - let values_len = values.len(); - values.set_len(values_len + 1024); - BitPacking::unchecked_unpack( - bit_width, - packed, - &mut values.as_mut_slice()[values_len..], - ); - } - } else if indices_within_chunk.len() > UNPACK_CHUNK_THRESHOLD { - // Unpack into a temporary chunk and then copy the values. - unsafe { - let dst: &mut [MaybeUninit] = &mut unpacked; - let dst: &mut [T] = std::mem::transmute(dst); - BitPacking::unchecked_unpack(bit_width, packed, dst); - } - values.extend_trusted( - indices_within_chunk - .iter() - .map(|&idx| unsafe { unpacked.get_unchecked(idx).assume_init() }), + chunked_indices(indices, offset, |chunk_idx, indices_within_chunk| { + let packed = &packed_bytes[chunk_idx * chunk_size..][..chunk_size]; + + if indices_within_chunk.len() == 1024 { + // Unpack the entire chunk. + unsafe { + let values_len = values.len(); + values.set_len(values_len + 1024); + BitPacking::unchecked_unpack( + bit_width, + packed, + &mut values.as_mut_slice()[values_len..], ); - } else { - // Otherwise, unpack each element individually. - values.extend_trusted(indices_within_chunk.iter().map(|&idx| unsafe { - BitPacking::unchecked_unpack_single(bit_width, packed, idx) - })); } - }, - ); + } else if indices_within_chunk.len() > UNPACK_CHUNK_THRESHOLD { + // Unpack into a temporary chunk and then copy the values. + unsafe { + let dst: &mut [MaybeUninit] = &mut unpacked; + let dst: &mut [T] = std::mem::transmute(dst); + BitPacking::unchecked_unpack(bit_width, packed, dst); + } + values.extend_trusted( + indices_within_chunk + .iter() + .map(|&idx| unsafe { unpacked.get_unchecked(idx).assume_init() }), + ); + } else { + // Otherwise, unpack each element individually. + values.extend_trusted(indices_within_chunk.iter().map(|&idx| unsafe { + BitPacking::unchecked_unpack_single(bit_width, packed, idx) + })); + } + }); values } diff --git a/encodings/runend/src/compute/filter.rs b/encodings/runend/src/compute/filter.rs index 269d64b97f4..8af90e69301 100644 --- a/encodings/runend/src/compute/filter.rs +++ b/encodings/runend/src/compute/filter.rs @@ -41,7 +41,7 @@ impl FilterKernel for RunEndVTable { if runs_ratio < FILTER_TAKE_THRESHOLD || mask_values.true_count() < 25 { Ok(Some(take_indices_unchecked( array, - mask_values.indices(), + mask_values.bit_buffer().set_indices(), &Validity::NonNullable, )?)) } else { diff --git a/encodings/runend/src/compute/take.rs b/encodings/runend/src/compute/take.rs index d8bfdd3fc1b..f638e6261ec 100644 --- a/encodings/runend/src/compute/take.rs +++ b/encodings/runend/src/compute/take.rs @@ -49,14 +49,19 @@ impl TakeExecute for RunEndVTable { .collect::>>()? }); - take_indices_unchecked(array, &checked_indices, primitive_indices.validity()).map(Some) + take_indices_unchecked( + array, + checked_indices.into_iter(), + primitive_indices.validity(), + ) + .map(Some) } } /// Perform a take operation on a RunEndArray by binary searching for each of the indices. -pub fn take_indices_unchecked>( +pub fn take_indices_unchecked, I: Iterator>( array: &RunEndArray, - indices: &[T], + indices: I, validity: &Validity, ) -> VortexResult { let ends = array.ends().to_primitive(); @@ -66,7 +71,6 @@ pub fn take_indices_unchecked>( let physical_indices = match_each_integer_ptype!(ends.ptype(), |I| { let end_slices = ends.as_slice::(); let physical_indices_vec: Vec = indices - .iter() .map(|idx| idx.as_() + array.offset()) .map(|idx| { match ::from(idx) { diff --git a/encodings/sequence/src/compute/filter.rs b/encodings/sequence/src/compute/filter.rs index b7064a388bf..a6684477a6d 100644 --- a/encodings/sequence/src/compute/filter.rs +++ b/encodings/sequence/src/compute/filter.rs @@ -37,7 +37,7 @@ fn filter_impl(mul: T, base: T, mask: &Mask, validity: Validity) .values() .vortex_expect("FilterKernel precondition: mask is Mask::Values"); let mut buffer = BufferMut::::with_capacity(mask_values.true_count()); - buffer.extend(mask_values.indices().iter().map(|&idx| { + buffer.extend(mask_values.bit_buffer().set_indices().map(|idx| { let i = T::from_usize(idx).vortex_expect("all valid indices fit"); base + i * mul })); diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index 749dfcf7a84..63f56ea705e 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -44,7 +44,6 @@ use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; -use vortex_mask::AllOr; use vortex_mask::Mask; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -311,23 +310,16 @@ impl SparseArray { } else if mask.false_count() as f64 > (0.9 * mask.len() as f64) { // Array is dominated by NULL but has non-NULL values let non_null_values = filter(array, &mask)?; - let non_null_indices = match mask.indices() { - AllOr::All => { - // We already know that the mask is 90%+ false - unreachable!("Mask is mostly null") - } - AllOr::None => { - // we know there are some non-NULL values - unreachable!("Mask is mostly null but not all null") - } - AllOr::Some(values) => { - let buffer: Buffer = values - .iter() - .map(|&v| v.try_into().vortex_expect("indices must fit in u32")) - .collect(); - - buffer.into_array() - } + let non_null_indices = if let Some(mask_values) = mask.values() { + let buffer: Buffer = mask_values + .bit_buffer() + .set_indices() + .map(|v| v.try_into().vortex_expect("indices must fit in u32")) + .collect(); + + buffer.into_array() + } else { + unreachable!() }; return Ok(SparseArray::try_new( @@ -370,7 +362,11 @@ impl SparseArray { // All values are equal to the top value return Ok(fill_array); } - Mask::Values(values) => values.indices().iter().map(|v| *v as u64).collect(), + Mask::Values(values) => values + .bit_buffer() + .set_indices() + .map(|v| v as u64) + .collect(), }; SparseArray::try_new(indices.into_array(), non_top_values, array.len(), fill) diff --git a/encodings/zstd/src/array.rs b/encodings/zstd/src/array.rs index 8c665c8882e..d712ba406da 100644 --- a/encodings/zstd/src/array.rs +++ b/encodings/zstd/src/array.rs @@ -251,27 +251,25 @@ fn collect_valid_primitive(parray: &PrimitiveArray) -> VortexResult VortexResult<(ByteBuffer, Vec)> { let mask = vbv.validity_mask()?; - let buffer_and_value_byte_indices = match mask.bit_buffer() { - AllOr::None => (Buffer::empty(), Vec::new()), - _ => { - let mut buffer = BufferMut::with_capacity( - usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer") - + mask.true_count() * size_of::(), - ); - let mut value_byte_indices = Vec::new(); - vbv.with_iterator(|iterator| { - // by flattening, we should omit nulls - for value in iterator.flatten() { - value_byte_indices.push(buffer.len()); - // here's where we write the string lengths - buffer - .extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter()); - buffer.extend_from_slice(value); - } - Ok::<_, VortexError>(()) - })?; - (buffer.freeze(), value_byte_indices) - } + let buffer_and_value_byte_indices = if mask.all_false() { + (Buffer::empty(), Vec::new()) + } else { + let mut buffer = BufferMut::with_capacity( + usize::try_from(vbv.nbytes()).vortex_expect("must fit into buffer") + + mask.true_count() * size_of::(), + ); + let mut value_byte_indices = Vec::new(); + vbv.with_iterator(|iterator| { + // by flattening, we should omit nulls + for value in iterator.flatten() { + value_byte_indices.push(buffer.len()); + // here's where we write the string lengths + buffer.extend_trusted(ViewLen::try_from(value.len())?.to_le_bytes().into_iter()); + buffer.extend_from_slice(value); + } + Ok::<_, VortexError>(()) + })?; + (buffer.freeze(), value_byte_indices) }; Ok(buffer_and_value_byte_indices) } @@ -719,7 +717,9 @@ impl ZstdArray { Ok(primitive.into_array()) } DType::Binary(_) | DType::Utf8(_) => { - match slice_validity.to_mask(slice_n_rows).indices() { + let mask = slice_validity.to_mask(slice_n_rows); + + match mask.bit_buffer() { AllOr::All => { // the decompressed buffer is a bunch of interleaved u32 lengths // and strings of those lengths, we need to reconstruct the @@ -745,7 +745,7 @@ impl ZstdArray { slice_n_rows, ) .into_array()), - AllOr::Some(valid_indices) => { + AllOr::Some(mask_bits) => { // the decompressed buffer is a bunch of interleaved u32 lengths // and strings of those lengths, we need to reconstruct the // views into those strings by passing through the buffer. @@ -755,8 +755,9 @@ impl ZstdArray { ); let mut views = BufferMut::::zeroed(slice_n_rows); - for (view, index) in valid_views.into_iter().zip_eq(valid_indices) { - views[*index] = view + for (view, index) in valid_views.into_iter().zip_eq(mask_bits.set_indices()) + { + views[index] = view } // SAFETY: we properly construct the views inside `reconstruct_views` diff --git a/vortex-array/src/arrays/bool/compute/filter.rs b/vortex-array/src/arrays/bool/compute/filter.rs index 82887d8799d..84ad36e109c 100644 --- a/vortex-array/src/arrays/bool/compute/filter.rs +++ b/vortex-array/src/arrays/bool/compute/filter.rs @@ -7,7 +7,6 @@ use vortex_buffer::get_bit; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_mask::Mask; -use vortex_mask::MaskIter; use crate::ArrayRef; use crate::ExecutionCtx; @@ -32,17 +31,18 @@ impl FilterKernel for BoolVTable { .values() .vortex_expect("AllTrue and AllFalse are handled by filter fn"); - let buffer = match mask_values.threshold_iter(FILTER_SLICES_DENSITY_THRESHOLD) { - MaskIter::Indices(indices) => filter_indices( + let buffer = if mask_values.density() >= FILTER_SLICES_DENSITY_THRESHOLD { + filter_slices( &array.to_bit_buffer(), mask.true_count(), - indices.iter().copied(), - ), - MaskIter::Slices(slices) => filter_slices( + mask_values.bit_buffer().set_slices(), + ) + } else { + filter_indices( &array.to_bit_buffer(), mask.true_count(), - slices.iter().copied(), - ), + mask_values.bit_buffer().set_indices(), + ) }; Ok(Some(BoolArray::new(buffer, validity).into_array())) diff --git a/vortex-array/src/arrays/chunked/compute/filter.rs b/vortex-array/src/arrays/chunked/compute/filter.rs index 2e31adcc0aa..b1685779455 100644 --- a/vortex-array/src/arrays/chunked/compute/filter.rs +++ b/vortex-array/src/arrays/chunked/compute/filter.rs @@ -5,7 +5,6 @@ use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_mask::Mask; -use vortex_mask::MaskIter; use crate::Array; use crate::ArrayRef; @@ -32,12 +31,11 @@ impl FilterKernel for ChunkedVTable { .values() .vortex_expect("AllTrue and AllFalse are handled by filter fn"); - // Based on filter selectivity, we take the values between a range of slices, or - // we take individual indices. - let chunks = match mask_values.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { - MaskIter::Indices(indices) => filter_indices(array, indices.iter().copied()), - MaskIter::Slices(slices) => filter_slices(array, slices.iter().copied()), - }?; + let chunks = if mask_values.density() >= FILTER_SLICES_SELECTIVITY_THRESHOLD { + filter_slices(array, mask_values.bit_buffer().set_slices())? + } else { + filter_indices(array, mask_values.bit_buffer().set_indices())? + }; // SAFETY: Filter operation preserves the dtype of each chunk. // All filtered chunks maintain the same dtype as the original array. diff --git a/vortex-array/src/arrays/chunked/compute/mask.rs b/vortex-array/src/arrays/chunked/compute/mask.rs index 7c6aa7b87ea..40f3a279570 100644 --- a/vortex-array/src/arrays/chunked/compute/mask.rs +++ b/vortex-array/src/arrays/chunked/compute/mask.rs @@ -5,10 +5,9 @@ use itertools::Itertools as _; use vortex_buffer::BitBuffer; use vortex_buffer::BitBufferMut; use vortex_dtype::DType; +use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_mask::AllOr; use vortex_mask::Mask; -use vortex_mask::MaskIter; use vortex_scalar::Scalar; use super::filter::ChunkFilter; @@ -31,13 +30,12 @@ use crate::validity::Validity; impl MaskKernel for ChunkedVTable { fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult { let new_dtype = array.dtype().as_nullable(); - let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { - AllOr::All => unreachable!("handled in top-level mask"), - AllOr::None => unreachable!("handled in top-level mask"), - AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype), - AllOr::Some(MaskIter::Slices(slices)) => { - mask_slices(array, slices.iter().cloned(), &new_dtype) - } + let mask_values = mask.values().vortex_expect("handled in top-level mask"); + + let new_chunks = if mask_values.density() >= FILTER_SLICES_SELECTIVITY_THRESHOLD { + mask_indices(array, mask_values.bit_buffer().set_indices(), &new_dtype) + } else { + mask_slices(array, mask_values.bit_buffer().set_slices(), &new_dtype) }?; debug_assert_eq!(new_chunks.len(), array.nchunks()); debug_assert_eq!( @@ -52,7 +50,7 @@ register_kernel!(MaskKernelAdapter(ChunkedVTable).lift()); fn mask_indices( array: &ChunkedArray, - indices: &[usize], + indices: impl Iterator, new_dtype: &DType, ) -> VortexResult> { let mut new_chunks = Vec::with_capacity(array.nchunks()); @@ -61,7 +59,7 @@ fn mask_indices( let chunk_offsets = array.chunk_offsets(); - for &set_index in indices { + for set_index in indices { let (chunk_id, index) = find_chunk_idx(set_index, &chunk_offsets)?; if chunk_id != current_chunk_id { let chunk = array.chunk(current_chunk_id).clone(); diff --git a/vortex-array/src/arrays/constant/compute/take.rs b/vortex-array/src/arrays/constant/compute/take.rs index 209fc5feb5c..b4bac8979d7 100644 --- a/vortex-array/src/arrays/constant/compute/take.rs +++ b/vortex-array/src/arrays/constant/compute/take.rs @@ -62,7 +62,6 @@ mod tests { use rstest::rstest; use vortex_buffer::buffer; use vortex_dtype::Nullability; - use vortex_mask::AllOr; use vortex_scalar::Scalar; use crate::Array; @@ -86,7 +85,6 @@ mod tests { .into_array(), ) .unwrap(); - let valid_indices: &[usize] = &[1usize]; assert_eq!( &array.dtype().with_nullability(Nullability::Nullable), taken.dtype() @@ -98,10 +96,14 @@ mod tests { Validity::from_iter([false, true, false]) ) ); - assert_eq!( - taken.validity_mask().unwrap().indices(), - AllOr::Some(valid_indices) - ); + let mask = taken.validity_mask().unwrap(); + let indices: Vec = mask + .values() + .expect("Expected values from mask") + .bit_buffer() + .set_indices() + .collect(); + assert_eq!(indices, [1]); } #[test] @@ -118,7 +120,7 @@ mod tests { taken.to_primitive(), PrimitiveArray::new(buffer![42i32, 42, 42], Validity::AllValid) ); - assert_eq!(taken.validity_mask().unwrap().indices(), AllOr::All); + assert!(taken.validity_mask().unwrap().all_true()); } #[rstest] diff --git a/vortex-array/src/arrays/dict/array.rs b/vortex-array/src/arrays/dict/array.rs index de616417ce1..3b576f93bd2 100644 --- a/vortex-array/src/arrays/dict/array.rs +++ b/vortex-array/src/arrays/dict/array.rs @@ -243,8 +243,6 @@ mod test { use vortex_dtype::UnsignedPType; use vortex_error::VortexExpect; use vortex_error::VortexResult; - use vortex_error::vortex_panic; - use vortex_mask::AllOr; use crate::Array; use crate::ArrayRef; @@ -271,9 +269,12 @@ mod test { ) .unwrap(); let mask = dict.validity_mask().unwrap(); - let AllOr::Some(indices) = mask.indices() else { - vortex_panic!("Expected indices from mask") - }; + let indices: Vec = mask + .values() + .expect("Expected values from mask") + .bit_buffer() + .set_indices() + .collect(); assert_eq!(indices, [0, 2, 4]); } @@ -289,9 +290,12 @@ mod test { ) .unwrap(); let mask = dict.validity_mask().unwrap(); - let AllOr::Some(indices) = mask.indices() else { - vortex_panic!("Expected indices from mask") - }; + let indices: Vec = mask + .values() + .expect("Expected values from mask") + .bit_buffer() + .set_indices() + .collect(); assert_eq!(indices, [0]); } @@ -311,9 +315,12 @@ mod test { ) .unwrap(); let mask = dict.validity_mask().unwrap(); - let AllOr::Some(indices) = mask.indices() else { - vortex_panic!("Expected indices from mask") - }; + let indices: Vec = mask + .values() + .expect("Expected values from mask") + .bit_buffer() + .set_indices() + .collect(); assert_eq!(indices, [2, 4]); } @@ -329,9 +336,12 @@ mod test { ) .unwrap(); let mask = dict.validity_mask().unwrap(); - let AllOr::Some(indices) = mask.indices() else { - vortex_panic!("Expected indices from mask") - }; + let indices: Vec = mask + .values() + .expect("Expected values from mask") + .bit_buffer() + .set_indices() + .collect(); assert_eq!(indices, [0, 2, 4]); } diff --git a/vortex-array/src/arrays/filter/execute/fixed_size_list.rs b/vortex-array/src/arrays/filter/execute/fixed_size_list.rs index f892f90921f..39247025cbd 100644 --- a/vortex-array/src/arrays/filter/execute/fixed_size_list.rs +++ b/vortex-array/src/arrays/filter/execute/fixed_size_list.rs @@ -5,7 +5,6 @@ use std::sync::Arc; use vortex_error::VortexExpect; use vortex_mask::Mask; -use vortex_mask::MaskIter; use vortex_mask::MaskValues; use crate::arrays::FixedSizeListArray; @@ -83,31 +82,22 @@ pub fn filter_fixed_size_list( fn compute_mask_for_fsl_elements(selection_mask: &MaskValues, list_size: usize) -> Mask { let expanded_len = selection_mask.len() * list_size; - // Use threshold_iter to choose the optimal representation based on density. - let expanded_slices = match selection_mask.threshold_iter(MASK_EXPANSION_DENSITY_THRESHOLD) { - MaskIter::Slices(slices) => { - // Expand a dense mask (represented as slices) by scaling each slice by `list_size`. - slices - .iter() - .map(|&(start, end)| (start * list_size, end * list_size)) - .collect() - } - MaskIter::Indices(indices) => { - // Expand a sparse mask (represented as indices) by duplicating each index `list_size` - // times. - // - // Note that in the worst case, it is possible that we create only a few slices with a - // small range (for example, when list_size <= 2). This could be further optimized, - // but we choose simplicity for now. - indices - .iter() - .map(|&idx| { - let start = idx * list_size; - let end = (idx + 1) * list_size; - (start, end) - }) - .collect() - } + let expanded_slices = if selection_mask.density() >= MASK_EXPANSION_DENSITY_THRESHOLD { + selection_mask + .bit_buffer() + .set_slices() + .map(|(start, end)| (start * list_size, end * list_size)) + .collect() + } else { + selection_mask + .bit_buffer() + .set_indices() + .map(|idx| { + let start = idx * list_size; + let end = (idx + 1) * list_size; + (start, end) + }) + .collect() }; Mask::from_slices(expanded_len, expanded_slices) diff --git a/vortex-array/src/arrays/list/compute/filter.rs b/vortex-array/src/arrays/list/compute/filter.rs index 0393c65a817..fe49ea3bb24 100644 --- a/vortex-array/src/arrays/list/compute/filter.rs +++ b/vortex-array/src/arrays/list/compute/filter.rs @@ -11,7 +11,6 @@ use vortex_dtype::IntegerPType; use vortex_dtype::match_each_integer_ptype; use vortex_error::VortexResult; use vortex_mask::Mask; -use vortex_mask::MaskIter; use vortex_mask::MaskValues; use crate::ArrayRef; @@ -43,27 +42,24 @@ pub fn element_mask_from_offsets( let mut mask_builder = BitBufferMut::with_capacity(len); - match selection.threshold_iter(MASK_EXPANSION_DENSITY_THRESHOLD) { - MaskIter::Slices(slices) => { - // Dense iteration: process ranges of consecutive selected lists. - for &(start, end) in slices { - // Optimization: for dense ranges, we can process the elements mask more efficiently. - let elems_start = offsets[start].as_() - first_offset; - let elems_end = offsets[end].as_() - first_offset; + if selection.density() >= MASK_EXPANSION_DENSITY_THRESHOLD { + // Dense iteration: process ranges of consecutive selected lists. + for (start, end) in selection.bit_buffer().set_slices() { + // Optimization: for dense ranges, we can process the elements mask more efficiently. + let elems_start = offsets[start].as_() - first_offset; + let elems_end = offsets[end].as_() - first_offset; - // Process the entire range of elements at once. - process_element_range(elems_start, elems_end, &mut mask_builder); - } + // Process the entire range of elements at once. + process_element_range(elems_start, elems_end, &mut mask_builder); } - MaskIter::Indices(indices) => { - // Sparse iteration: process individual selected lists. - for &idx in indices { - let list_start = offsets[idx].as_() - first_offset; - let list_end = offsets[idx + 1].as_() - first_offset; - - // Process the elements for this list. - process_element_range(list_start, list_end, &mut mask_builder); - } + } else { + // Sparse iteration: process individual selected lists. + for idx in selection.bit_buffer().set_indices() { + let list_start = offsets[idx].as_() - first_offset; + let list_end = offsets[idx + 1].as_() - first_offset; + + // Process the elements for this list. + process_element_range(list_start, list_end, &mut mask_builder); } } @@ -127,8 +123,8 @@ impl FilterKernel for ListVTable { let mut offset = O::zero(); unsafe { new_offsets.push_unchecked(offset) }; - for idx in selection.indices() { - let size = offsets[idx + 1] - offsets[*idx]; + for idx in selection.bit_buffer().set_indices() { + let size = offsets[idx + 1] - offsets[idx]; offset += size; unsafe { new_offsets.push_unchecked(offset) }; } diff --git a/vortex-array/src/arrays/primitive/array/top_value.rs b/vortex-array/src/arrays/primitive/array/top_value.rs index 0b67de41027..1e115babe47 100644 --- a/vortex-array/src/arrays/primitive/array/top_value.rs +++ b/vortex-array/src/arrays/primitive/array/top_value.rs @@ -8,7 +8,6 @@ use vortex_dtype::NativePType; use vortex_dtype::match_each_native_ptype; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_mask::AllOr; use vortex_mask::Mask; use vortex_scalar::PValue; use vortex_utils::aliases::hash_map::HashMap; @@ -41,20 +40,19 @@ where { let mut distinct_values: HashMap, usize, FxBuildHasher> = HashMap::with_hasher(FxBuildHasher); - match mask.indices() { - AllOr::All => { - for value in values.iter().copied() { - *distinct_values.entry(NativeValue(value)).or_insert(0) += 1; - } + + if let Some(mask_values) = mask.values() { + for i in mask_values.bit_buffer().set_indices() { + *distinct_values + .entry(NativeValue(unsafe { *values.get_unchecked(i) })) + .or_insert(0) += 1 } - AllOr::None => unreachable!("All invalid arrays should be handled earlier"), - AllOr::Some(idxs) => { - for &i in idxs { - *distinct_values - .entry(NativeValue(unsafe { *values.get_unchecked(i) })) - .or_insert(0) += 1 - } + } else if mask.all_true() { + for value in values.iter().copied() { + *distinct_values.entry(NativeValue(value)).or_insert(0) += 1; } + } else { + unreachable!("All invalid arrays should be handled earlier") } let (&top_value, &top_count) = distinct_values diff --git a/vortex-array/src/arrays/varbin/compute/filter.rs b/vortex-array/src/arrays/varbin/compute/filter.rs index 24d961e03c1..1c3e1eb15dd 100644 --- a/vortex-array/src/arrays/varbin/compute/filter.rs +++ b/vortex-array/src/arrays/varbin/compute/filter.rs @@ -12,7 +12,6 @@ use vortex_error::vortex_err; use vortex_error::vortex_panic; use vortex_mask::AllOr; use vortex_mask::Mask; -use vortex_mask::MaskIter; use crate::ArrayRef; use crate::ExecutionCtx; @@ -36,21 +35,28 @@ impl FilterKernel for VarBinVTable { } fn filter_select_var_bin(arr: &VarBinArray, mask: &Mask) -> VortexResult { - match mask + let mask_values = mask .values() - .vortex_expect("AllTrue and AllFalse are handled by filter fn") - .threshold_iter(0.5) - { - MaskIter::Indices(indices) => { - filter_select_var_bin_by_index(arr, indices, mask.true_count()) - } - MaskIter::Slices(slices) => filter_select_var_bin_by_slice(arr, slices, mask.true_count()), + .vortex_expect("AllTrue and AllFalse are handled by filter fn"); + + if mask_values.density() >= 0.5 { + filter_select_var_bin_by_slice( + arr, + mask_values.bit_buffer().set_slices(), + mask.true_count(), + ) + } else { + filter_select_var_bin_by_index( + arr, + mask_values.bit_buffer().set_indices(), + mask.true_count(), + ) } } fn filter_select_var_bin_by_slice( values: &VarBinArray, - mask_slices: &[(usize, usize)], + mask_slices: impl Iterator, selection_count: usize, ) -> VortexResult { let offsets = values.offsets().to_primitive(); @@ -70,7 +76,7 @@ fn filter_select_var_bin_by_slice_primitive_offset( dtype: DType, offsets: &[O], data: &[u8], - mask_slices: &[(usize, usize)], + mask_slices: impl Iterator, logical_validity: Mask, selection_count: usize, ) -> VortexResult @@ -81,15 +87,15 @@ where let mut builder = VarBinBuilder::::with_capacity(selection_count); match logical_validity.bit_buffer() { AllOr::All => { - mask_slices.iter().for_each(|(start, end)| { - update_non_nullable_slice(data, offsets, &mut builder, *start, *end) + mask_slices.for_each(|(start, end)| { + update_non_nullable_slice(data, offsets, &mut builder, start, end) }); } AllOr::None => { builder.append_n_nulls(selection_count); } AllOr::Some(validity) => { - for (start, end) in mask_slices.iter().copied() { + for (start, end) in mask_slices { let null_sl = validity.slice(start..end); if null_sl.true_count() == null_sl.len() { update_non_nullable_slice(data, offsets, &mut builder, start, end) @@ -148,7 +154,7 @@ fn update_non_nullable_slice( fn filter_select_var_bin_by_index( values: &VarBinArray, - mask_indices: &[usize], + mask_indices: impl Iterator, selection_count: usize, ) -> VortexResult { let offsets = values.offsets().to_primitive(); @@ -168,13 +174,13 @@ fn filter_select_var_bin_by_index_primitive_offset( dtype: DType, offsets: &[O], data: &[u8], - mask_indices: &[usize], + mask_indices: impl Iterator, // TODO(ngates): pass LogicalValidity instead validity: Validity, selection_count: usize, ) -> VortexResult { let mut builder = VarBinBuilder::::with_capacity(selection_count); - for idx in mask_indices.iter().copied() { + for idx in mask_indices { if validity.is_valid(idx)? { let (start, end) = ( offsets[idx].to_usize().ok_or_else(|| { @@ -219,7 +225,7 @@ mod test { ], DType::Utf8(NonNullable), ); - let buf = filter_select_var_bin_by_index(&arr, &[0, 2], 2).unwrap(); + let buf = filter_select_var_bin_by_index(&arr, [0, 2].into_iter(), 2).unwrap(); assert_arrays_eq!(buf, VarBinArray::from(vec!["hello", "filter"])); } @@ -237,7 +243,8 @@ mod test { DType::Utf8(NonNullable), ); - let buf = filter_select_var_bin_by_slice(&arr, &[(0, 1), (2, 3), (4, 5)], 3).unwrap(); + let buf = + filter_select_var_bin_by_slice(&arr, [(0, 1), (2, 3), (4, 5)].into_iter(), 3).unwrap(); assert_arrays_eq!(buf, VarBinArray::from(vec!["hello", "filter", "filter3"])); } @@ -262,7 +269,7 @@ mod test { ); let arr = VarBinArray::try_new(offsets, bytes, DType::Utf8(Nullable), validity).unwrap(); - let buf = filter_select_var_bin_by_slice(&arr, &[(0, 3), (4, 6)], 5).unwrap(); + let buf = filter_select_var_bin_by_slice(&arr, [(0, 3), (4, 6)].into_iter(), 5).unwrap(); assert_arrays_eq!( buf, @@ -287,7 +294,7 @@ mod test { let validity = Validity::Array(BoolArray::from_iter([false, true, true]).into_array()); let arr = VarBinArray::try_new(offsets, bytes, DType::Utf8(Nullable), validity).unwrap(); - let buf = filter_select_var_bin_by_slice(&arr, &[(0, 1), (2, 3)], 2).unwrap(); + let buf = filter_select_var_bin_by_slice(&arr, [(0, 1), (2, 3)].into_iter(), 2).unwrap(); assert_arrays_eq!(buf, VarBinArray::from(vec![None, Some("two")])); } @@ -304,7 +311,7 @@ mod test { ) .unwrap(); - let buf = filter_select_var_bin_by_slice(&arr, &[(0, 1), (2, 3)], 2).unwrap(); + let buf = filter_select_var_bin_by_slice(&arr, [(0, 1), (2, 3)].into_iter(), 2).unwrap(); assert_arrays_eq!(buf, VarBinArray::from(vec![None::<&str>, None])); } diff --git a/vortex-array/src/arrays/varbinview/compute/zip.rs b/vortex-array/src/arrays/varbinview/compute/zip.rs index 9718c996c38..3b8629a7c9a 100644 --- a/vortex-array/src/arrays/varbinview/compute/zip.rs +++ b/vortex-array/src/arrays/varbinview/compute/zip.rs @@ -57,57 +57,57 @@ impl ZipKernel for VarBinViewVTable { let true_validity = if_true.validity_mask()?; let false_validity = if_false.validity_mask()?; - match mask.slices() { - AllOr::All => push_range( + if let Some(values) = mask.values() { + let mut pos = 0; + for (start, end) in values.bit_buffer().set_slices() { + if pos < start { + push_range( + if_false, + &false_lookup, + &false_validity, + pos..start, + &mut views_builder, + &mut validity_builder, + ); + } + push_range( + if_true, + &true_lookup, + &true_validity, + start..end, + &mut views_builder, + &mut validity_builder, + ); + pos = end; + } + if pos < len { + push_range( + if_false, + &false_lookup, + &false_validity, + pos..len, + &mut views_builder, + &mut validity_builder, + ); + } + } else if mask.all_true() { + push_range( if_true, &true_lookup, &true_validity, 0..len, &mut views_builder, &mut validity_builder, - ), - AllOr::None => push_range( + ) + } else { + push_range( if_false, &false_lookup, &false_validity, 0..len, &mut views_builder, &mut validity_builder, - ), - AllOr::Some(slices) => { - let mut pos = 0; - for (start, end) in slices { - if pos < *start { - push_range( - if_false, - &false_lookup, - &false_validity, - pos..*start, - &mut views_builder, - &mut validity_builder, - ); - } - push_range( - if_true, - &true_lookup, - &true_validity, - *start..*end, - &mut views_builder, - &mut validity_builder, - ); - pos = *end; - } - if pos < len { - push_range( - if_false, - &false_lookup, - &false_validity, - pos..len, - &mut views_builder, - &mut validity_builder, - ); - } - } + ) } let validity = validity_builder.finish_with_nullability(dtype.nullability()); diff --git a/vortex-array/src/compute/zip.rs b/vortex-array/src/compute/zip.rs index 3b83e968bcf..e119d508ea6 100644 --- a/vortex-array/src/compute/zip.rs +++ b/vortex-array/src/compute/zip.rs @@ -9,7 +9,6 @@ use vortex_error::VortexError; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; -use vortex_mask::AllOr; use vortex_mask::Mask; use super::ComputeFnVTable; @@ -225,19 +224,19 @@ fn zip_impl_with_builder( mask: &Mask, mut builder: Box, ) -> VortexResult { - match mask.slices() { - AllOr::All => Ok(if_true.to_array()), - AllOr::None => Ok(if_false.to_array()), - AllOr::Some(slices) => { - for (start, end) in slices { - builder.extend_from_array(&if_false.slice(builder.len()..*start)?); - builder.extend_from_array(&if_true.slice(*start..*end)?); - } - if builder.len() < if_false.len() { - builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?); - } - Ok(builder.finish()) + if let Some(values) = mask.values() { + for (start, end) in values.bit_buffer().set_slices() { + builder.extend_from_array(&if_false.slice(builder.len()..start)?); + builder.extend_from_array(&if_true.slice(start..end)?); + } + if builder.len() < if_false.len() { + builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?); } + Ok(builder.finish()) + } else if mask.all_true() { + Ok(if_true.to_array()) + } else { + Ok(if_false.to_array()) } } diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index afa02def88c..61bf16b3c22 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -582,17 +582,17 @@ impl Patches { ); } - match mask.indices() { + match mask.bit_buffer() { AllOr::All => Ok(Some(self.clone())), AllOr::None => Ok(None), - AllOr::Some(mask_indices) => { + AllOr::Some(mask_bits) => { let flat_indices = self.indices().to_primitive(); match_each_unsigned_integer_ptype!(flat_indices.ptype(), |I| { filter_patches_with_mask( flat_indices.as_slice::(), self.offset(), self.values(), - mask_indices, + &mask_bits.set_indices().collect::>(), ) }) } diff --git a/vortex-compute/src/filter/bitbuffer.rs b/vortex-compute/src/filter/bitbuffer.rs index 22f294233d7..522a4c5d4bd 100644 --- a/vortex-compute/src/filter/bitbuffer.rs +++ b/vortex-compute/src/filter/bitbuffer.rs @@ -38,7 +38,16 @@ impl Filter for &BitBuffer { "Selection mask length must equal the mask length" ); - self.filter(mask_values.indices()) + let bools = self.inner().as_slice(); + let bit_offset = self.offset(); + + BitBufferMut::from_iter( + mask_values + .bit_buffer() + .set_indices() + .map(|idx| get_bit(bools, bit_offset + idx)), + ) + .freeze() } } @@ -125,7 +134,15 @@ impl Filter for &mut BitBufferMut { ); // BitBufferMut filtering always uses indices for simplicity. - self.filter(mask_values.indices()) + let bools = self.inner().as_slice(); + let bit_offset = self.offset(); + + *self = BitBufferMut::from_iter( + mask_values + .bit_buffer() + .set_indices() + .map(|idx| get_bit(bools, bit_offset + idx)), + ); } } diff --git a/vortex-compute/src/filter/buffer.rs b/vortex-compute/src/filter/buffer.rs index a70c9e22908..d6331131f5e 100644 --- a/vortex-compute/src/filter/buffer.rs +++ b/vortex-compute/src/filter/buffer.rs @@ -106,6 +106,7 @@ impl Filter> for &Buffer { impl Filter for &mut BufferMut where for<'a> &'a mut [T]: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/slice.rs b/vortex-compute/src/filter/slice.rs index c8f1e89a3bb..06c029511fc 100644 --- a/vortex-compute/src/filter/slice.rs +++ b/vortex-compute/src/filter/slice.rs @@ -16,7 +16,6 @@ use vortex_buffer::BitView; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_mask::Mask; -use vortex_mask::MaskIter; use vortex_mask::MaskValues; use crate::filter::Filter; @@ -46,9 +45,22 @@ impl Filter for &[T] { "Selection mask length must equal the buffer length" ); - match mask_values.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { - MaskIter::Indices(indices) => self.filter(indices), - MaskIter::Slices(slices) => self.filter(slices), + if mask_values.density() >= FILTER_SLICES_SELECTIVITY_THRESHOLD { + // High density: use slices (contiguous ranges) + + let mut out = BufferMut::::empty(); + for (start, end) in mask_values.bit_buffer().set_slices() { + out.extend_from_slice(&self[start..end]); + } + + out.freeze() + } else { + // Low density: stream indices directly from bitmap without allocatingExpand commentComment on line R52Resolved + let mut out = BufferMut::::with_capacity(mask_values.true_count()); + for idx in mask_values.bit_buffer().set_indices() { + out.push(self[idx]); + } + out.freeze() } } } diff --git a/vortex-compute/src/filter/slice_mut.rs b/vortex-compute/src/filter/slice_mut.rs index fae1b8f68f4..15b8cceeb80 100644 --- a/vortex-compute/src/filter/slice_mut.rs +++ b/vortex-compute/src/filter/slice_mut.rs @@ -43,7 +43,27 @@ impl Filter for &mut [T] { // We choose to _always_ use slices here because iterating over indices will have strictly // more loop iterations than slices (more branches), and the overhead over batched // `ptr::copy(len)` is not that high. - self.filter(mask_values.slices()) + let mut write_pos = 0; + + // For each range in the selection, copy all of the elements to the current write position. + for (start, end) in mask_values.bit_buffer().set_slices() { + // Note that we could add an if statement here that checks `if start != write_pos`, but + // it's probably better to just avoid the branch misprediction. + let len = end - start; + + // SAFETY: Slices should be within bounds. + unsafe { + ptr::copy( + self.as_ptr().add(start), + self.as_mut_ptr().add(write_pos), + len, + ) + }; + + write_pos += len; + } + + &mut self[..write_pos] } } diff --git a/vortex-compute/src/filter/vector/binaryview.rs b/vortex-compute/src/filter/vector/binaryview.rs index 39da497bb34..f4773aac6ac 100644 --- a/vortex-compute/src/filter/vector/binaryview.rs +++ b/vortex-compute/src/filter/vector/binaryview.rs @@ -18,6 +18,7 @@ impl Filter for &BinaryViewVector where for<'a> &'a Mask: Filter, for<'a> &'a Buffer: Filter>, + M: ?Sized, { type Output = BinaryViewVector; @@ -34,6 +35,7 @@ impl Filter for &mut BinaryViewVectorMut where for<'a> &'a mut MaskMut: Filter, for<'a> &'a mut BufferMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/bool.rs b/vortex-compute/src/filter/vector/bool.rs index deec0dbcd25..34dfc1f25a8 100644 --- a/vortex-compute/src/filter/vector/bool.rs +++ b/vortex-compute/src/filter/vector/bool.rs @@ -16,6 +16,7 @@ impl Filter for &BoolVector where for<'a> &'a BitBuffer: Filter, for<'a> &'a Mask: Filter, + M: ?Sized, { type Output = BoolVector; @@ -34,6 +35,7 @@ impl Filter for &mut BoolVectorMut where for<'a> &'a mut BitBufferMut: Filter, for<'a> &'a mut MaskMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/decimal.rs b/vortex-compute/src/filter/vector/decimal.rs index 12263c03189..1fe2a9cd00a 100644 --- a/vortex-compute/src/filter/vector/decimal.rs +++ b/vortex-compute/src/filter/vector/decimal.rs @@ -21,6 +21,7 @@ where for<'a> &'a DVector: Filter>, for<'a> &'a DVector: Filter>, for<'a> &'a DVector: Filter>, + M: ?Sized, { type Output = DecimalVector; @@ -37,6 +38,7 @@ where for<'a> &'a mut DVectorMut: Filter, for<'a> &'a mut DVectorMut: Filter, for<'a> &'a mut DVectorMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/dvector.rs b/vortex-compute/src/filter/vector/dvector.rs index f0371af0497..2425dcde00b 100644 --- a/vortex-compute/src/filter/vector/dvector.rs +++ b/vortex-compute/src/filter/vector/dvector.rs @@ -17,6 +17,7 @@ impl Filter for &DVector where for<'a> &'a Buffer: Filter>, for<'a> &'a Mask: Filter, + M: ?Sized, { type Output = DVector; @@ -32,6 +33,7 @@ impl Filter for &mut DVectorMut where for<'a> &'a mut BufferMut: Filter, for<'a> &'a mut MaskMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/fixed_size_list.rs b/vortex-compute/src/filter/vector/fixed_size_list.rs index c9fd449738a..f94a098e57a 100644 --- a/vortex-compute/src/filter/vector/fixed_size_list.rs +++ b/vortex-compute/src/filter/vector/fixed_size_list.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use vortex_mask::Mask; -use vortex_mask::MaskIter; use vortex_mask::MaskMut; use vortex_vector::Vector; use vortex_vector::VectorMut; @@ -27,7 +26,8 @@ const MASK_EXPANSION_DENSITY_THRESHOLD: f64 = 0.05; impl Filter for &FixedSizeListVector where for<'a> &'a Mask: Filter, - for<'a> &'a Vector: Filter, + for<'a> &'a Vector: Filter<[(usize, usize)], Output = Vector>, + M: ?Sized, { type Output = FixedSizeListVector; @@ -40,7 +40,7 @@ where let elements_mask = compute_fsl_elements_mask(&filtered_validity, list_size as usize); // Filter the child elements vector. - self.elements().as_ref().filter(&elements_mask) + self.elements().as_ref().filter(elements_mask.as_slice()) } else { debug_assert!( self.elements().is_empty(), @@ -68,7 +68,8 @@ where impl Filter for &mut FixedSizeListVectorMut where for<'a> &'a mut MaskMut: Filter, - for<'a> &'a mut VectorMut: Filter, + for<'a> &'a mut VectorMut: Filter<[(usize, usize)], Output = ()>, + M: ?Sized, { type Output = (); @@ -93,7 +94,7 @@ where // SAFETY: The expanded mask has the correct length (`validity.len() * list_size`), // which maintains the invariant after filtering. unsafe { - self.elements_mut().filter(&elements_mask); + self.elements_mut().filter(elements_mask.as_slice()); } debug_assert_eq!( @@ -137,41 +138,30 @@ where /// `list_size` times. /// /// The output [`Mask`] is guaranteed to have a length equal to `selection_mask.len() * list_size`. -fn compute_fsl_elements_mask(selection_mask: &Mask, list_size: usize) -> Mask { - let expanded_len = selection_mask.len() * list_size; +fn compute_fsl_elements_mask(selection_mask: &Mask, list_size: usize) -> Vec<(usize, usize)> { + // let expanded_len = selection_mask.len() * list_size; let values = match selection_mask { - Mask::AllTrue(_) => return Mask::AllTrue(expanded_len), - Mask::AllFalse(_) => return Mask::AllFalse(expanded_len), + Mask::AllTrue(_) => return vec![(0, selection_mask.len() * list_size)], + Mask::AllFalse(_) => return vec![], Mask::Values(values) => values, }; - // Use threshold_iter to choose the optimal representation based on density. - let expanded_slices = match values.threshold_iter(MASK_EXPANSION_DENSITY_THRESHOLD) { - MaskIter::Slices(slices) => { - // Expand a dense mask (represented as slices) by scaling each slice by `list_size`. - slices - .iter() - .map(|&(start, end)| (start * list_size, end * list_size)) - .collect() - } - MaskIter::Indices(indices) => { - // Expand a sparse mask (represented as indices) by duplicating each index `list_size` - // times. - // - // Note that in the worst case, it is possible that we create only a few slices with a - // small range (for example, when list_size <= 2). This could be further optimized, - // but we choose simplicity for now. - indices - .iter() - .map(|&idx| { - let start = idx * list_size; - let end = (idx + 1) * list_size; - (start, end) - }) - .collect() - } - }; - - Mask::from_slices(expanded_len, expanded_slices) + if values.density() >= MASK_EXPANSION_DENSITY_THRESHOLD { + values + .bit_buffer() + .set_slices() + .map(|(start, end)| (start * list_size, end * list_size)) + .collect() + } else { + values + .bit_buffer() + .set_indices() + .map(|idx| { + let start = idx * list_size; + let end = (idx + 1) * list_size; + (start, end) + }) + .collect() + } } diff --git a/vortex-compute/src/filter/vector/list.rs b/vortex-compute/src/filter/vector/list.rs index 3b9615c2f05..f87d43b2cd7 100644 --- a/vortex-compute/src/filter/vector/list.rs +++ b/vortex-compute/src/filter/vector/list.rs @@ -18,6 +18,7 @@ impl Filter for &ListViewVector where for<'a> &'a PrimitiveVector: Filter, for<'a> &'a Mask: Filter, + M: ?Sized, { type Output = ListViewVector; @@ -37,6 +38,7 @@ impl Filter for &mut ListViewVectorMut where for<'a> &'a mut PrimitiveVectorMut: Filter, for<'a> &'a mut MaskMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/mod.rs b/vortex-compute/src/filter/vector/mod.rs index 040b2e06b33..cf4ba9687c3 100644 --- a/vortex-compute/src/filter/vector/mod.rs +++ b/vortex-compute/src/filter/vector/mod.rs @@ -74,6 +74,14 @@ impl Filter for &mut VectorMut { } } +impl Filter<[(usize, usize)]> for &Vector { + type Output = Vector; + + fn filter(self, selection: &[(usize, usize)]) -> Self::Output { + match_each_vector!(self, |v| { v.filter(selection).into() }) + } +} + impl Filter> for &Vector { type Output = Vector; @@ -89,3 +97,11 @@ impl Filter> for &mut VectorMut { match_each_vector_mut!(self, |v| { v.filter(selection) }) } } + +impl Filter<[(usize, usize)]> for &mut VectorMut { + type Output = (); + + fn filter(self, selection: &[(usize, usize)]) -> Self::Output { + match_each_vector_mut!(self, |v| { v.filter(selection) }) + } +} diff --git a/vortex-compute/src/filter/vector/null.rs b/vortex-compute/src/filter/vector/null.rs index a4df6f4c519..e2a60879fc8 100644 --- a/vortex-compute/src/filter/vector/null.rs +++ b/vortex-compute/src/filter/vector/null.rs @@ -16,6 +16,32 @@ impl Filter for &NullVector { } } +impl Filter<[(usize, usize)]> for &NullVector { + type Output = NullVector; + + fn filter(self, selection: &[(usize, usize)]) -> Self::Output { + NullVector::new( + selection + .iter() + .map(|(start, end)| start + end) + .sum::(), + ) + } +} + +impl Filter<[(usize, usize)]> for &mut NullVectorMut { + type Output = (); + + fn filter(self, selection: &[(usize, usize)]) -> Self::Output { + *self = NullVectorMut::new( + selection + .iter() + .map(|(start, end)| start + end) + .sum::(), + ) + } +} + impl Filter> for &NullVector { type Output = NullVector; diff --git a/vortex-compute/src/filter/vector/primitive.rs b/vortex-compute/src/filter/vector/primitive.rs index 82b50c1adbd..799a85c4d01 100644 --- a/vortex-compute/src/filter/vector/primitive.rs +++ b/vortex-compute/src/filter/vector/primitive.rs @@ -26,6 +26,7 @@ where for<'a> &'a PVector: Filter>, for<'a> &'a PVector: Filter>, for<'a> &'a PVector: Filter>, + M: ?Sized, { type Output = PrimitiveVector; @@ -47,6 +48,7 @@ where for<'a> &'a mut PVectorMut: Filter, for<'a> &'a mut PVectorMut: Filter, for<'a> &'a mut PVectorMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/pvector.rs b/vortex-compute/src/filter/vector/pvector.rs index fa8e9b64635..1210271f5cb 100644 --- a/vortex-compute/src/filter/vector/pvector.rs +++ b/vortex-compute/src/filter/vector/pvector.rs @@ -17,6 +17,7 @@ impl Filter for &PVector where for<'a> &'a Buffer: Filter>, for<'a> &'a Mask: Filter, + M: ?Sized, { type Output = PVector; @@ -34,6 +35,7 @@ impl Filter for &mut PVectorMut where for<'a> &'a mut BufferMut: Filter, for<'a> &'a mut MaskMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-compute/src/filter/vector/struct_.rs b/vortex-compute/src/filter/vector/struct_.rs index 67a057bf691..b9584049748 100644 --- a/vortex-compute/src/filter/vector/struct_.rs +++ b/vortex-compute/src/filter/vector/struct_.rs @@ -18,6 +18,7 @@ impl Filter for &StructVector where for<'a> &'a Mask: Filter, for<'a> &'a Vector: Filter, + M: ?Sized, { type Output = StructVector; @@ -40,6 +41,7 @@ impl Filter for &mut StructVectorMut where for<'a> &'a mut MaskMut: Filter, for<'a> &'a mut VectorMut: Filter, + M: ?Sized, { type Output = (); diff --git a/vortex-cuda/src/kernel/encodings/zstd.rs b/vortex-cuda/src/kernel/encodings/zstd.rs index d4b68937047..06eecf79b1a 100644 --- a/vortex-cuda/src/kernel/encodings/zstd.rs +++ b/vortex-cuda/src/kernel/encodings/zstd.rs @@ -25,7 +25,7 @@ use vortex_dtype::DType; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_err; -use vortex_mask::AllOr; +use vortex_mask::Mask; use vortex_nvcomp::sys::nvcompStatus_t; use vortex_nvcomp::zstd as nvcomp_zstd; use vortex_zstd::ZstdArray; @@ -282,8 +282,8 @@ async fn decode_zstd(array: ZstdArray, ctx: &mut CudaExecutionCtx) -> VortexResu let sliced_validity = validity.slice(slice_start..slice_stop)?; - match sliced_validity.to_mask(slice_stop - slice_start).indices() { - AllOr::All => { + match sliced_validity.to_mask(slice_stop - slice_start) { + Mask::AllTrue(_) => { let all_views = vortex_zstd::reconstruct_views(&host_buffer); let sliced_views = all_views.slice(slice_value_idx_start..slice_value_idx_stop); diff --git a/vortex-mask/src/intersect_by_rank.rs b/vortex-mask/src/intersect_by_rank.rs index efce6f17dbd..0e8ead98a38 100644 --- a/vortex-mask/src/intersect_by_rank.rs +++ b/vortex-mask/src/intersect_by_rank.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use crate::AllOr; use crate::Mask; impl Mask { @@ -29,21 +28,25 @@ impl Mask { pub fn intersect_by_rank(&self, mask: &Mask) -> Mask { assert_eq!(self.true_count(), mask.len()); - match (self.indices(), mask.indices()) { - (AllOr::All, _) => mask.clone(), - (_, AllOr::All) => self.clone(), - (AllOr::None, _) | (_, AllOr::None) => Self::new_false(self.len()), + match (self, mask) { + (Mask::AllTrue(_), _) => mask.clone(), + (_, Mask::AllTrue(_)) => self.clone(), + (Mask::AllFalse(_), _) | (_, Mask::AllFalse(_)) => Self::new_false(self.len()), + (Mask::Values(self_values), Mask::Values(mask_values)) => { + let self_indices = self_values.bit_buffer().set_indices().collect::>(); - (AllOr::Some(self_indices), AllOr::Some(mask_indices)) => { Self::from_indices( self.len(), - mask_indices - .iter() - .map(|idx| + mask_values + .bit_buffer() + .set_indices() + .map(|idx| { + // SAFETY: // This is verified as safe because we know that the indices are less than the // mask.len() and we known mask.len() <= self.len(), // implied by `self.true_count() == mask.len()`. - unsafe{*self_indices.get_unchecked(*idx)}) + unsafe { *self_indices.get_unchecked(idx) } + }) .collect(), ) } @@ -135,10 +138,15 @@ mod test { ) { let result = base_mask.intersect_by_rank(&rank_mask); - match result.indices() { - crate::AllOr::All => assert_eq!(expected_indices.len(), result.len()), - crate::AllOr::None => assert!(expected_indices.is_empty()), - crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]), + match result { + Mask::AllTrue(_) => assert_eq!(expected_indices.len(), result.len()), + Mask::AllFalse(_) => assert!(expected_indices.is_empty()), + Mask::Values(mask_value) => { + assert_eq!( + mask_value.bit_buffer().set_indices().collect::>(), + &expected_indices[..] + ) + } } } @@ -188,12 +196,16 @@ mod test { ) { let base = Mask::from_indices(10, base_indices); let rank = Mask::from_iter(rank_pattern); - let result = base.intersect_by_rank(&rank); - match result.indices() { - crate::AllOr::Some(indices) => assert_eq!(indices, &expected_indices[..]), - crate::AllOr::None => assert!(expected_indices.is_empty()), - _ => panic!("Unexpected result"), + match base.intersect_by_rank(&rank) { + Mask::AllTrue(n) => assert_eq!(n, expected_indices.len()), + Mask::AllFalse(_) => assert!(expected_indices.is_empty()), + Mask::Values(mask_values) => { + assert_eq!( + mask_values.bit_buffer().set_indices().collect::>(), + &expected_indices[..] + ) + } } } } diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index 1f87eb70f46..548da47f286 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -21,13 +21,13 @@ use std::fmt::Formatter; use std::ops::Bound; use std::ops::RangeBounds; use std::sync::Arc; -use std::sync::OnceLock; use itertools::Itertools; pub use mask_mut::*; use vortex_buffer::BitBuffer; use vortex_buffer::BitBufferMut; use vortex_buffer::set_bit_unchecked; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_panic; @@ -128,13 +128,6 @@ impl Default for Mask { pub struct MaskValues { buffer: BitBuffer, - // We cached the indices and slices representations, since it can be faster than iterating - // the bit-mask over and over again. - #[cfg_attr(feature = "serde", serde(skip))] - indices: OnceLock>, - #[cfg_attr(feature = "serde", serde(skip))] - slices: OnceLock>, - // Pre-computed values. true_count: usize, // i.e., the fraction of values that are true @@ -177,8 +170,6 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer, - indices: Default::default(), - slices: Default::default(), true_count, density: true_count as f64 / len as f64, })) @@ -208,8 +199,6 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer: buf.freeze(), - indices: OnceLock::from(indices), - slices: Default::default(), true_count, density: true_count as f64 / len as f64, })) @@ -237,8 +226,6 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer: buf.freeze(), - indices: Default::default(), - slices: Default::default(), true_count, density: true_count as f64 / len as f64, })) @@ -271,8 +258,6 @@ impl Mask { Self::Values(Arc::new(MaskValues { buffer: buf.freeze(), - indices: Default::default(), - slices: OnceLock::from(slices), true_count, density: true_count as f64 / len as f64, })) @@ -411,15 +396,7 @@ impl Mask { match &self { Self::AllTrue(len) => (*len > 0).then_some(0), Self::AllFalse(_) => None, - Self::Values(values) => { - if let Some(indices) = values.indices.get() { - return indices.first().copied(); - } - if let Some(slices) = values.slices.get() { - return slices.first().map(|(start, _)| *start); - } - values.buffer.set_indices().next() - } + Self::Values(values) => values.buffer.set_indices().next(), } } @@ -435,7 +412,12 @@ impl Mask { Self::AllTrue(_) => n, Self::AllFalse(_) => unreachable!("no true values in all-false mask"), // TODO(joe): optimize this function - Self::Values(values) => values.indices()[n], + Self::Values(values) => values + .bit_buffer() + .set_indices() + .take(n + 1) + .last() + .vortex_expect("validated within range"), } } @@ -498,36 +480,6 @@ impl Mask { } } - /// Return the indices representation of the mask. - #[inline] - pub fn indices(&self) -> AllOr<&[usize]> { - match &self { - Self::AllTrue(_) => AllOr::All, - Self::AllFalse(_) => AllOr::None, - Self::Values(values) => AllOr::Some(values.indices()), - } - } - - /// Return the slices representation of the mask. - #[inline] - pub fn slices(&self) -> AllOr<&[(usize, usize)]> { - match &self { - Self::AllTrue(_) => AllOr::All, - Self::AllFalse(_) => AllOr::None, - Self::Values(values) => AllOr::Some(values.slices()), - } - } - - /// Return an iterator over either indices or slices of the mask based on a density threshold. - #[inline] - pub fn threshold_iter(&self, threshold: f64) -> AllOr> { - match &self { - Self::AllTrue(_) => AllOr::All, - Self::AllFalse(_) => AllOr::None, - Self::Values(values) => AllOr::Some(values.threshold_iter(threshold)), - } - } - /// Return [`MaskValues`] if the mask is not all true or all false. #[inline] pub fn values(&self) -> Option<&MaskValues> { @@ -673,55 +625,6 @@ impl MaskValues { self.buffer.value(index) } - /// Constructs an indices vector from one of the other representations. - pub fn indices(&self) -> &[usize] { - self.indices.get_or_init(|| { - if self.true_count == 0 { - return vec![]; - } - - if self.true_count == self.len() { - return (0..self.len()).collect(); - } - - if let Some(slices) = self.slices.get() { - let mut indices = Vec::with_capacity(self.true_count); - indices.extend(slices.iter().flat_map(|(start, end)| *start..*end)); - debug_assert!(indices.is_sorted()); - assert_eq!(indices.len(), self.true_count); - return indices; - } - - let mut indices = Vec::with_capacity(self.true_count); - indices.extend(self.buffer.set_indices()); - debug_assert!(indices.is_sorted()); - assert_eq!(indices.len(), self.true_count); - indices - }) - } - - /// Constructs a slices vector from one of the other representations. - #[inline] - pub fn slices(&self) -> &[(usize, usize)] { - self.slices.get_or_init(|| { - if self.true_count == self.len() { - return vec![(0, self.len())]; - } - - self.buffer.set_slices().collect() - }) - } - - /// Return an iterator over either indices or slices of the mask based on a density threshold. - #[inline] - pub fn threshold_iter(&self, threshold: f64) -> MaskIter<'_> { - if self.density >= threshold { - MaskIter::Slices(self.slices()) - } else { - MaskIter::Indices(self.indices()) - } - } - /// Extracts the internal [`BitBuffer`]. pub(crate) fn into_buffer(self) -> BitBuffer { self.buffer diff --git a/vortex-mask/src/tests.rs b/vortex-mask/src/tests.rs index a13827abbc3..9b757e37864 100644 --- a/vortex-mask/src/tests.rs +++ b/vortex-mask/src/tests.rs @@ -9,7 +9,6 @@ use vortex_buffer::BitBuffer; use crate::AllOr; use crate::Mask; -use crate::MaskIter; // Basic mask creation and properties tests #[test] @@ -18,8 +17,8 @@ fn mask_all_true() { assert_eq!(mask.len(), 5); assert_eq!(mask.true_count(), 5); assert_eq!(mask.density(), 1.0); - assert_eq!(mask.indices(), AllOr::All); - assert_eq!(mask.slices(), AllOr::All); + // assert_eq!(mask.indices(), AllOr::All); + // assert_eq!(mask.slices(), AllOr::All); assert_eq!(mask.bit_buffer(), AllOr::All,); } @@ -29,8 +28,8 @@ fn mask_all_false() { assert_eq!(mask.len(), 5); assert_eq!(mask.true_count(), 0); assert_eq!(mask.density(), 0.0); - assert_eq!(mask.indices(), AllOr::None); - assert_eq!(mask.slices(), AllOr::None); + // assert_eq!(mask.indices(), AllOr::None); + // assert_eq!(mask.slices(), AllOr::None); assert_eq!(mask.bit_buffer(), AllOr::None,); } @@ -46,8 +45,8 @@ fn mask_from() { assert_eq!(mask.len(), 5); assert_eq!(mask.true_count(), 3); assert_eq!(mask.density(), 0.6); - assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..])); - assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..])); + // assert_eq!(mask.indices(), AllOr::Some(&[0, 2, 3][..])); + // assert_eq!(mask.slices(), AllOr::Some(&[(0, 1), (2, 4)][..])); assert_eq!( mask.bit_buffer(), AllOr::Some(&BitBuffer::from_iter([true, false, true, true, false])) @@ -251,27 +250,27 @@ fn test_mask_values() { assert!(!values.value(1)); } -#[test] -fn test_mask_values_threshold_iter() { - let mask = Mask::from_buffer(BitBuffer::from_iter([true, false, true, true, false])); - let values = mask.values().unwrap(); +// #[test] +// fn test_mask_values_threshold_iter() { +// let mask = Mask::from_buffer(BitBuffer::from_iter([true, false, true, true, false])); +// let values = mask.values().unwrap(); - // With low threshold, should prefer indices - match values.threshold_iter(0.7) { - MaskIter::Indices(indices) => { - assert_eq!(indices, &[0, 2, 3]); - } - _ => panic!("Expected indices iterator"), - } +// // With low threshold, should prefer indices +// match values.threshold_iter(0.7) { +// MaskIter::Indices(indices) => { +// assert_eq!(indices, &[0, 2, 3]); +// } +// _ => panic!("Expected indices iterator"), +// } - // With high threshold, should prefer slices - match values.threshold_iter(0.5) { - MaskIter::Slices(slices) => { - assert_eq!(slices, &[(0, 1), (2, 4)]); - } - _ => panic!("Expected slices iterator"), - } -} +// // With high threshold, should prefer slices +// match values.threshold_iter(0.5) { +// MaskIter::Slices(slices) => { +// assert_eq!(slices, &[(0, 1), (2, 4)]); +// } +// _ => panic!("Expected slices iterator"), +// } +// } #[test] fn test_mask_values_is_empty() { @@ -476,65 +475,65 @@ fn test_mask_from_slices_overlapping() { Mask::from_slices(5, vec![(0, 3), (2, 4)]); // Overlapping ranges } -// Threshold iterator tests -#[test] -fn test_mask_threshold_iter() { - let all_true = Mask::new_true(5); - assert!(matches!(all_true.threshold_iter(0.5), AllOr::All)); +// // Threshold iterator tests +// #[test] +// fn test_mask_threshold_iter() { +// let all_true = Mask::new_true(5); +// assert!(matches!(all_true.threshold_iter(0.5), AllOr::All)); - let all_false = Mask::new_false(5); - assert!(matches!(all_false.threshold_iter(0.5), AllOr::None)); +// let all_false = Mask::new_false(5); +// assert!(matches!(all_false.threshold_iter(0.5), AllOr::None)); - let mask = Mask::from_buffer(BitBuffer::from_iter([true, false, true, true, false])); - if let AllOr::Some(MaskIter::Indices(indices)) = mask.threshold_iter(0.7) { - assert_eq!(indices, &[0, 2, 3]); - } else { - panic!("Expected indices iterator"); - } -} +// let mask = Mask::from_buffer(BitBuffer::from_iter([true, false, true, true, false])); +// if let AllOr::Some(MaskIter::Indices(indices)) = mask.threshold_iter(0.7) { +// assert_eq!(indices, &[0, 2, 3]); +// } else { +// panic!("Expected indices iterator"); +// } +// } // Caching tests -#[test] -fn test_mask_indices_caching() { - // Test that indices are properly cached - let mask = Mask::from_slices(10, vec![(0, 3), (5, 7), (9, 10)]); - - // First call should compute indices - let indices1 = mask.indices(); - // Second call should return cached value - let indices2 = mask.indices(); - - match (indices1, indices2) { - (AllOr::Some(i1), AllOr::Some(i2)) => { - assert_eq!(i1, i2); - assert_eq!(i1, &[0, 1, 2, 5, 6, 9]); - // Verify they're the same reference (cached) - assert!(std::ptr::eq(i1, i2)); - } - _ => panic!("Expected Some variant"), - } -} - -#[test] -fn test_mask_slices_caching() { - // Test that slices are properly cached - let mask = Mask::from_indices(10, vec![0, 1, 2, 5, 6, 9]); - - // First call should compute slices - let slices1 = mask.slices(); - // Second call should return cached value - let slices2 = mask.slices(); - - match (slices1, slices2) { - (AllOr::Some(s1), AllOr::Some(s2)) => { - assert_eq!(s1, s2); - assert_eq!(s1, &[(0, 3), (5, 7), (9, 10)]); - // Verify they're the same reference (cached) - assert!(std::ptr::eq(s1, s2)); - } - _ => panic!("Expected Some variant"), - } -} +// #[test] +// fn test_mask_indices_caching() { +// // Test that indices are properly cached +// let mask = Mask::from_slices(10, vec![(0, 3), (5, 7), (9, 10)]); + +// // First call should compute indices +// let indices1 = mask.indices(); +// // Second call should return cached value +// let indices2 = mask.indices(); + +// match (indices1, indices2) { +// (AllOr::Some(i1), AllOr::Some(i2)) => { +// assert_eq!(i1, i2); +// assert_eq!(i1, &[0, 1, 2, 5, 6, 9]); +// // Verify they're the same reference (cached) +// assert!(std::ptr::eq(i1, i2)); +// } +// _ => panic!("Expected Some variant"), +// } +// } + +// #[test] +// fn test_mask_slices_caching() { +// // Test that slices are properly cached +// let mask = Mask::from_indices(10, vec![0, 1, 2, 5, 6, 9]); + +// // First call should compute slices +// let slices1 = mask.slices(); +// // Second call should return cached value +// let slices2 = mask.slices(); + +// match (slices1, slices2) { +// (AllOr::Some(s1), AllOr::Some(s2)) => { +// assert_eq!(s1, s2); +// assert_eq!(s1, &[(0, 3), (5, 7), (9, 10)]); +// // Verify they're the same reference (cached) +// assert!(std::ptr::eq(s1, s2)); +// } +// _ => panic!("Expected Some variant"), +// } +// } // AllOr tests #[test] @@ -610,52 +609,52 @@ fn test_mask_properties( assert!((mask.density() - expected_density).abs() < 1e-10); } -#[rstest] -#[case::indices(vec![0, 2, 4], vec![(0, 1), (2, 3), (4, 5)])] -#[case::consecutive(vec![0, 1, 2], vec![(0, 3)])] -#[case::gap(vec![0, 1, 4, 5], vec![(0, 2), (4, 6)])] -#[case::single(vec![3], vec![(3, 4)])] -fn test_indices_to_slices_conversion( - #[case] indices: Vec, - #[case] expected_slices: Vec<(usize, usize)>, -) { - let mask = Mask::from_indices(10, indices.clone()); - - // Check indices - if let AllOr::Some(actual_indices) = mask.indices() { - assert_eq!(actual_indices, &indices[..]); - } else { - panic!("Expected Some variant for indices"); - } - - // Check slices - if let AllOr::Some(actual_slices) = mask.slices() { - assert_eq!(actual_slices, &expected_slices[..]); - } else { - panic!("Expected Some variant for slices"); - } -} - -#[rstest] -#[case::empty_intersection(vec![0, 2, 4], vec![1, 3, 5], vec![])] -#[case::full_intersection(vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5])] -#[case::partial_intersection(vec![0, 1, 2, 3], vec![2, 3, 4, 5], vec![2, 3])] -#[case::subset_left(vec![1, 2], vec![0, 1, 2, 3], vec![1, 2])] -#[case::subset_right(vec![0, 1, 2, 3], vec![1, 2], vec![1, 2])] -fn test_intersection_indices( - #[case] left: Vec, - #[case] right: Vec, - #[case] expected: Vec, -) { - let mask = Mask::from_intersection_indices(10, left.into_iter(), right.into_iter()); - - match mask.indices() { - AllOr::Some(indices) if expected.is_empty() => assert!(indices.is_empty()), - AllOr::Some(indices) => assert_eq!(indices, &expected[..]), - AllOr::None if expected.is_empty() => {} - AllOr::None | AllOr::All => panic!("Unexpected result for intersection"), - } -} +// #[rstest] +// #[case::indices(vec![0, 2, 4], vec![(0, 1), (2, 3), (4, 5)])] +// #[case::consecutive(vec![0, 1, 2], vec![(0, 3)])] +// #[case::gap(vec![0, 1, 4, 5], vec![(0, 2), (4, 6)])] +// #[case::single(vec![3], vec![(3, 4)])] +// fn test_indices_to_slices_conversion( +// #[case] indices: Vec, +// #[case] expected_slices: Vec<(usize, usize)>, +// ) { +// let mask = Mask::from_indices(10, indices.clone()); + +// // Check indices +// if let AllOr::Some(actual_indices) = mask.indices() { +// assert_eq!(actual_indices, &indices[..]); +// } else { +// panic!("Expected Some variant for indices"); +// } + +// // Check slices +// if let AllOr::Some(actual_slices) = mask.slices() { +// assert_eq!(actual_slices, &expected_slices[..]); +// } else { +// panic!("Expected Some variant for slices"); +// } +// } + +// #[rstest] +// #[case::empty_intersection(vec![0, 2, 4], vec![1, 3, 5], vec![])] +// #[case::full_intersection(vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5])] +// #[case::partial_intersection(vec![0, 1, 2, 3], vec![2, 3, 4, 5], vec![2, 3])] +// #[case::subset_left(vec![1, 2], vec![0, 1, 2, 3], vec![1, 2])] +// #[case::subset_right(vec![0, 1, 2, 3], vec![1, 2], vec![1, 2])] +// fn test_intersection_indices( +// #[case] left: Vec, +// #[case] right: Vec, +// #[case] expected: Vec, +// ) { +// let mask = Mask::from_intersection_indices(10, left.into_iter(), right.into_iter()); + +// match mask.indices() { +// AllOr::Some(indices) if expected.is_empty() => assert!(indices.is_empty()), +// AllOr::Some(indices) => assert_eq!(indices, &expected[..]), +// AllOr::None if expected.is_empty() => {} +// AllOr::None | AllOr::All => panic!("Unexpected result for intersection"), +// } +// } // Concat operation tests #[test] diff --git a/vortex-scan/src/selection.rs b/vortex-scan/src/selection.rs index d2fb46d4bb0..66e0dbd2472 100644 --- a/vortex-scan/src/selection.rs +++ b/vortex-scan/src/selection.rs @@ -166,13 +166,17 @@ fn indices_range(range: &Range, row_indices: &[u64]) -> Option mod tests { use vortex_buffer::Buffer; + fn collect_indices(mask: &vortex_mask::Mask) -> Vec { + mask.values().unwrap().bit_buffer().set_indices().collect() + } + #[test] fn test_row_mask_all() { let selection = super::Selection::IncludeByIndex(Buffer::from_iter(vec![1, 3, 5, 7])); let range = 1..8; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 2, 4, 6]); } #[test] @@ -181,7 +185,7 @@ mod tests { let range = 3..6; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 2]); } #[test] @@ -190,7 +194,7 @@ mod tests { let range = 3..5; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]); + assert_eq!(collect_indices(row_mask.mask()), &[0]); } #[test] @@ -217,7 +221,7 @@ mod tests { let range = 0..5; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0]); + assert_eq!(collect_indices(row_mask.mask()), &[0]); } #[cfg(feature = "roaring")] @@ -238,7 +242,7 @@ mod tests { let range = 1..8; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 2, 4, 6]); } #[test] @@ -253,7 +257,7 @@ mod tests { let range = 3..6; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 2]); } #[test] @@ -299,7 +303,7 @@ mod tests { let row_mask = selection.row_mask(&range); // Should exclude indices 1, 3, 5, so we get 0, 2, 4, 6 - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 2, 4, 6]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 2, 4, 6]); } #[test] @@ -344,7 +348,7 @@ mod tests { let row_mask = selection.row_mask(&range); // Should exclude 5, 6, 7 (mapped to 0, 1, 2), keep 8, 9 (mapped to 3, 4) - assert_eq!(row_mask.mask().values().unwrap().indices(), &[3, 4]); + assert_eq!(collect_indices(row_mask.mask()), &[3, 4]); } #[test] @@ -377,7 +381,7 @@ mod tests { let range = 0..100; let row_mask = selection.row_mask(&range); - assert_eq!(row_mask.mask().values().unwrap().indices(), &[0, 99]); + assert_eq!(collect_indices(row_mask.mask()), &[0, 99]); } #[test] @@ -393,7 +397,7 @@ mod tests { // Should include 15-19 (mapped to 0-4) and 30-34 (mapped to 15-19) let expected: Vec = (0..5).chain(15..20).collect(); - assert_eq!(row_mask.mask().values().unwrap().indices(), &expected); + assert_eq!(collect_indices(row_mask.mask()), expected); } #[test] @@ -443,8 +447,8 @@ mod tests { let roaring_mask = roaring_selection.row_mask(&range); assert_eq!( - buffer_mask.mask().values().unwrap().indices(), - roaring_mask.mask().values().unwrap().indices() + collect_indices(buffer_mask.mask()), + collect_indices(roaring_mask.mask()) ); } @@ -467,8 +471,8 @@ mod tests { let roaring_mask = roaring_selection.row_mask(&range); assert_eq!( - buffer_mask.mask().values().unwrap().indices(), - roaring_mask.mask().values().unwrap().indices() + collect_indices(buffer_mask.mask()), + collect_indices(roaring_mask.mask()) ); } }