This is the repository for the paper "Hierarchical Refinement: Optimal Transport to Infinity and Beyond," which scales optimal transport linearly in space and log-linearly in time by using a hierarchical strategy that constructs multiscale partitions from low-rank optimal transport.
Figure 1: Hierarchical Refinement algorithm: low-rank optimal transport is used to progressively refine partitions at the previous scale, with the coarsest scale partitions denoted
Hierarchical Refinement (HiRef) only requires two n×d dimensional point clouds X
and Y
(torch tensors) as input.
Before running HiRef, call the rank-annealing scheduler to find a sequence of ranks that minimizes the number of calls to the low-rank optimal transport subroutine while remaining under a machine-specific maximal rank.
n
: The size of the datasethierarchy_depth (κ)
: The depth of the hierarchy of levels used in the refinement strategymax_Q
: The maximal terminal rank at the base casemax_rank
: The maximal rank of the intermediate sub-problems
Import the rank annealing module and compute the rank schedule:
import rank_annealing
rank_schedule = rank_annealing.optimal_rank_schedule(
n=n, hierarchy_depth=hierarchy_depth, max_Q=max_Q, max_rank=max_rank
)
Import HR_OT and initialize the class using only the point clouds (you can additionally input the cost C
if desired) along with any relevant parameters (e.g., sq_Euclidean) for your problem.
import HR_OT
hrot = HR_OT.HierarchicalRefinementOT.init_from_point_clouds(
X, Y, rank_schedule, base_rank=1, device=device
)
Run and return paired tuples from X
and Y
(the bijective Monge map between the datasets):
Gamma_hrot = hrot.run(return_as_coupling=False)
To print the Optimal Transport (OT) cost, simply call:
cost_hrot = hrot.compute_OT_cost()
print(f"Refinement Cost: {cost_hr_ot.item()}")
For questions, discussions, or collaboration inquiries, feel free to reach out at [email protected] or [email protected].