-
Notifications
You must be signed in to change notification settings - Fork 78
Clean up getContigMergeOfInnerSize #5936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
!test |
|
Review updated until commit cb8be65 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
| ||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Breaking API Change
|
78310cd to
ba2f27d
Compare
|
!test |
f3a9608 to
f184027
Compare
|
!test |
f184027 to
b7345f7
Compare
|
!test |
aaa42fd to
12cc04e
Compare
|
!test |
| getProjectedExtent(id), commonOrConstExtent(ca_map_, id)); | ||
| } | ||
|
|
||
| void ContiguousInnerDimensionsMapper::addProjectedExtent( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // Ordering of dimensions is important in this analysis, if an ordering is | ||
| // contiguous in the reference, but not the target tensor views, then we | ||
| // cannot consider that a contiguous merge dimension for vectorization. | ||
| auto projected_logical = projectId(filtered_ids, logical_domain); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
projected_logical gives me the wrong impression that the whole logical domain is projected. In fact, it's still as filtered as filtered_ids.
Greptile OverviewGreptile SummaryThis PR refactors The main functional change is that the contig-inner-size computation now attempts to multiply projected extents of allocation IDs after matching against the mapper’s logical IDs, relying on Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant S as Scheduler/VectorizeHeuristic
participant M as ContiguousInnerDimensionsMapper
participant TV as TensorView (reference)
participant TV2 as TensorView (target)
S->>M: map(reference_tv, logical_ids)
activate M
M->>M: recording_=true
M->>M: addProjectedExtent(logical_id, commonOrConstExtent)
M->>M: projectId(filtered_ids, logical_domain, logical_domain)
M->>M: projectId(filtered_ids, root_domain, allocation_domain)
M->>M: recording_=false
M->>M: traverse spanning tree
deactivate M
S->>M: getTvToContigMergeOfInnerSizeMap()
activate M
loop for each tv in tv_infos_
M->>M: getContigMergeOfInnerSize(tv)
M->>TV2: alloc = getMaybeAllocationDomain()
M->>TV2: contiguity = getContiguity()
M->>M: projected_dims = mappedLogicalIds(tv)
M->>M: iterate alloc & contiguity (reverse)
M->>M: logical_id = ir_utils::getReachableIds(logical_domain, {alloc_id})
M->>M: if logical_id matches next projected_dim
M->>M: product *= getProjectedExtent(alloc_id)
end
deactivate M
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
| for (auto [alloc_id, cont] : | ||
| zip(alloc | std::views::reverse, contiguity | std::views::reverse)) { | ||
| auto is_treated_as_size_one = [](IterDomain* id) { | ||
| return id->isReduction() || id->isBroadcast() || id->isParallelized() || | ||
| id->extent()->isOneInt(); | ||
| }; | ||
| if (is_treated_as_size_one(alloc_id)) { | ||
| continue; | ||
| } | ||
|
|
||
| auto contiguity_i = contiguity.at(alloc_ii); | ||
| if (!contiguity_i.has_value()) { | ||
| NVF_THROW("contiguity flag at alloc_ii can't be null"); | ||
| } else { | ||
| // Not contiguous | ||
| if (!contiguity_i.value()) { | ||
| break; | ||
| } | ||
| NVF_ERROR(cont.has_value()); | ||
| if (!cont.value()) { | ||
| break; | ||
| } | ||
|
|
||
| // Get the logical ID corresponding to the allocation ID. | ||
| auto exprs = DependencyCheck::getAllExprsBetween( | ||
| {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}, | ||
| {alloc_iid}); | ||
| IterDomain* logical_id = alloc_iid; | ||
| Val* num_devices = tv->container()->oneVal(); | ||
| bool only_valid_device_split = true; | ||
| for (Expr* expr : exprs | std::views::reverse) { | ||
| if (!isValidDeviceSplit(expr)) { | ||
| only_valid_device_split = false; | ||
| break; | ||
| } | ||
| auto* split = expr->as<Split>(); | ||
| logical_id = split->in(); | ||
| num_devices = SimplifyingIrBuilder::mulExpr(num_devices, split->factor()); | ||
| while (projected_dim != projected_dims.rend() && | ||
| is_treated_as_size_one(*projected_dim)) { | ||
| projected_dim++; | ||
| } | ||
|
|
||
| // Non device split could lead to padding, which prevents vectorization | ||
| if (!only_valid_device_split) { | ||
| break; | ||
| } | ||
| IterDomain* logical_id = [&]() { | ||
| std::vector<IterDomain*> reachable_ids = | ||
| ir_utils::getReachableIds(tv->getLogicalDomain(), {alloc_id}); | ||
| NVF_ERROR_EQ(reachable_ids.size(), 1); | ||
| return reachable_ids.front(); | ||
| }(); | ||
|
|
||
| // Mapping order isn't correct, cannot expand vectorization dimension. | ||
| if (projected_dims[--projected_dims_i] != logical_id) { | ||
| if (projected_dim == projected_dims.rend() || | ||
| *projected_dim != logical_id) { | ||
| break; | ||
| } | ||
|
|
||
| Val* sharded_extent; | ||
| if (logical_id->isDeviceDim()) { | ||
| sharded_extent = tv->container()->oneVal(); | ||
| } else { | ||
| sharded_extent = SimplifyingIrBuilder::divExpr( | ||
| getProjectedExtent(logical_id), num_devices); | ||
| } | ||
| product_of_inner_extents = | ||
| SimplifyingIrBuilder::mulExpr(product_of_inner_extents, sharded_extent); | ||
| // This assumes projected_dim can be matched only once. This assumption is | ||
| // OK for now but when we get to non-outermost sharding such as | ||
| // ``` | ||
| // [iS0] | ||
| // / \. | ||
| // iS1 iS2 | ||
| // / \. | ||
| // iDIDx3 iS4 | ||
| // ``` | ||
| // We may want to allow multiple contiguous allocation IDs to match | ||
| // projected_dim. | ||
| projected_dim++; | ||
|
|
||
| product_of_inner_extents = SimplifyingIrBuilder::mulExpr( | ||
| product_of_inner_extents, getProjectedExtent(alloc_id)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allocation extent mismatched
getContigMergeOfInnerSize now multiplies getProjectedExtent(alloc_id) (i.e., allocation ID) after matching logical_id against mappedLogicalIds(tv) (logical IDs). projected_extent_ values originate from recording logical/root projections and are not guaranteed to include allocation IDs; in those cases getProjectedExtent(alloc_id) will throw "Not projected" at runtime. Even when present, using allocation-ID extents here changes semantics vs the previous logical-ID-based computation and can incorrectly size the contig inner extent for TVs with an allocation permutation. Consider multiplying the projected extent for the matched logical_id (or ensure allocation IDs are always recorded consistently before using them here).
| const std::vector<IterDomain*>& alloc = tv->getMaybeAllocationDomain(); | ||
| const std::vector<std::optional<bool>>& contiguity = tv->getContiguity(); | ||
|
|
||
| NVF_ERROR(hasMappedDims(tv)); | ||
|
|
||
| const std::vector<IterDomain*>& projected_dims = mappedLogicalIds(tv); | ||
| auto alloc_no_reductions = TensorDomain::noReductions(alloc); | ||
|
|
||
| std::vector<std::optional<bool>> contiguity = tv->domain()->contiguity(); | ||
| NVF_ERROR_EQ(contiguity.size(), alloc.size()); | ||
| // Appears after reductions the reduction domain often has a contiguity entry. | ||
| // This only matters if the result of the reduction is an output | ||
| if (contiguity.size() != alloc_no_reductions.size()) { | ||
| std::vector<std::optional<bool>> new_contiguity; | ||
| for (auto i : arange(alloc.size())) { | ||
| if (!alloc[i]->isReduction()) { | ||
| new_contiguity.push_back(contiguity.at(i)); | ||
| } | ||
| } | ||
| contiguity = new_contiguity; | ||
| } | ||
|
|
||
| auto alloc_no_reductions_size = alloc_no_reductions.size(); | ||
|
|
||
| NVF_ERROR_EQ(alloc_no_reductions_size, contiguity.size()); | ||
|
|
||
| Val* product_of_inner_extents = tv->container()->oneVal(); | ||
| // Order is important, need to make sure dimensions match up correctly with | ||
| // what was propogated through the mapper. The mapper's dimensions is | ||
| // propogated in the order of the reference, if that order doesn't match the | ||
| // tensor we're mapping too then a transpose interfered with expanded the | ||
| // vectorize dimension. | ||
| size_t projected_dims_i = projected_dims.size(); | ||
|
|
||
| for (auto i : arange(alloc_no_reductions_size)) { | ||
| if (projected_dims_i == 0) { | ||
| break; | ||
| } | ||
| auto alloc_ii = alloc_no_reductions_size - i - 1; | ||
| auto alloc_iid = alloc_no_reductions.at(alloc_ii); | ||
|
|
||
| if (alloc_iid->extent()->isOneInt() || alloc_iid->isBroadcast()) { | ||
| if (projected_dims[projected_dims_i - 1] == alloc_iid) { | ||
| --projected_dims_i; | ||
| } | ||
| auto projected_dim = projected_dims.rbegin(); | ||
| // Wish I could `zip(alloc, contiguity) | std::views::reverse` here. It | ||
| // doesn't compile. | ||
| for (auto [alloc_id, cont] : | ||
| zip(alloc | std::views::reverse, contiguity | std::views::reverse)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Contiguity/alloc size assumption
This loop zips tv->getMaybeAllocationDomain() with tv->getContiguity() and iterates them in lockstep. If getContiguity() is defined in terms of the logical/root domain (as it historically was via tv->domain()->contiguity()), TVs with a distinct allocation domain can have a different rank/order, making the zip silently drop trailing elements and compute an incorrect inner-extent product. At minimum this should assert alloc.size() == contiguity.size() before zipping (or fetch contiguity for the allocation domain explicitly).
Makes the code less error prone, and removes the reliance on isValidDeviceSplit to support non-outermost sharding in the future.
Should be an NFC.