Point Cloud Optimization

import multipers as mp
from multipers.data import noisy_annulus, three_annulus
import gudhi as gd
import numpy as np
import matplotlib.pyplot as plt
import torch
# t.autograd.set_detect_anomaly(True)
from multipers.plots import plot_signed_measures, plot_signed_measure
from tqdm import tqdm
torch.manual_seed(1)
np.random.seed(1)

Spatially localized optimization

The goal of this notebook is to generate cycles on the modes of a fixed measure.

In this example, the measure is defined (in the cell below) as a sum of 3 gaussian measures.

## The density function of the measure
def custom_map(x, sigma=.17, threshold=None):
    if x.ndim == 1:
        x = x[None,:]
    assert x.ndim ==2
    basepoints = torch.tensor([[0.2,0.2], [0.8, 0.4], [0.4, 0.7]]).T
    out = -(torch.exp( - (((x[:,:,None]- basepoints[None,:,:]) / sigma).square() ).sum(dim=1) )).sum(dim=-1)
    return 1+out
x= np.linspace(0,1,100)
mesh = np.meshgrid(x,x)
coordinates = np.concatenate([stuff.flatten()[:,None] for stuff in mesh], axis=1)
coordinates = torch.from_numpy(coordinates)
plt.scatter(*coordinates.T,c=custom_map(coordinates), cmap="viridis_r")
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f7c733a8980>
../_images/c195384406436da3bf5bf0bd8406305bfda0b12c1ebbec6625fd2db901ee0197.png

We start from a uniform point cloud, that we will optimize

x = np.random.uniform(size=(500,2))
x = torch.tensor(x, requires_grad=True)
plt.scatter(*x.detach().numpy().T, c=custom_map(x).detach().numpy(), cmap="viridis_r")
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f7c702eae90>
../_images/22a2cb81c0d7576be09abf6bac7ca569afdb3e241e9750c5c77d6e7a3e3506e3.png

The usual filtration functions (e.g., RipsLowerstar, Cubical, DelaunayLowerstar) automatically detect the differentiability of the input, and propagates the gradient in that case. Nothing changes !

x = np.random.uniform(size=(300,2))
x = torch.tensor(x, requires_grad=True)
from multipers.filtrations import RipsLowerstar, DelaunayLowerstar
# st = DelaunayLowerstar(points=x,function=custom_map(x), flagify=True)
st = RipsLowerstar(points=x,function=custom_map(x))
st.collapse_edges(-1) # litte preprocessing
st.expansion(2)       # adding 2-dimensional simplices are necessary for H1
sm_diff, = mp.signed_measure(st, degree=1, plot=True)
print(sm_diff[0].requires_grad) # Should be true
True
../_images/73975ca9f104fd3fe6a703f90ccdfd055772f1b1377f58d7de1e05dd98d2b982.png

For this example we use the following loss. Given a signed measure \(\mu\), define

\[\mathrm{loss}(\mu) := \int\varphi(x) d\mu(x)\]

where \(x := (r,d)\in \mathbb R^2\) (\(r\) for radius, and \(d\) for codensity value of the original measure) $\(\varphi(x) = \varphi(r,d) = r\times(\mathrm{threshold}-d)\)$

This can be interpreted as follows :

  • we maximise the radius of the negative point (maximizing the radius of cycles)

  • we minimize the radius of positive points (the edges of the connected points creating the cycles). This create pretty cycles

  • we care more about cycles that are close to the mode (the threshold-d part). The threshold is meant to prevent the cycles that are not close enough the the cycles to progressively stop to create loops.

threshold = .65
def softplus(x):
    return torch.log(1+torch.exp(x))
# @torch.compile(dynamic=True)
def loss_function(x,sm):
    pts,weights = sm
    radius,density = pts.T
    density = density
    
    phi = lambda x,d : (
        x
        * (threshold-d)
    ).sum()
    loss = phi(radius[weights>0], density[weights>0]) - phi(radius[weights<0], density[weights<0])
    return loss

loss_function(x,sm_diff) #test that it work. It should make no error + have a gradient
tensor(0.2417, dtype=torch.float64, grad_fn=<SubBackward0>)

As Delaunay complexes are faster to use in low-dimensional Euclidean space, we switch from Rips to Delaunay, and optimize this loss_function.

Another optimization would be to copy the simplextree into a slicer, on which it’s possible to compute a minimal presentation. Computing the Hilbert function is a quite fast computation, so this will not lead to a significative improvement, but consider this optimization when computing harder invariants !

from multipers.filtrations import DelaunayLowerstar
xinit = np.random.uniform(size=(500,2)) # initial dataset
x = torch.tensor(xinit, requires_grad=True)
adam = torch.optim.Adam([x], lr=0.01) #optimizer
losses = []
plt.scatter(*x.detach().numpy().T, c=custom_map(x, threshold=np.inf).detach().numpy(), cmap="viridis_r")
plt.show()
for i in range(101): # gradient steps
    # Little caveat of Delaunay, they are hard to differentiate. 
    # We hence weaken the delaunay complex into a flag complex, which is easier to differentiate, 
    # with the `flagify=True` flag.
    st = DelaunayLowerstar(points=x, function=custom_map(x), flagify=True)
    sm_diff, = mp.signed_measure(st, degree=1)

    # Rips version
    # st = RipsLowerstar(points=x,function=custom_map(x))
    # st.collapse_edges(-1)
    # st.expansion(2)
    # sm_diff, = mp.signed_measure(st, degree=1)
    
    adam.zero_grad()
    loss = loss_function(x,sm_diff)
    loss.backward()
    adam.step()
    losses.append([loss.detach().numpy()])
    with torch.no_grad():
        if i %10 == 1: #plot part
            base=4
            ncols=3
            fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols*base, base))
            ax1.scatter(*x.detach().numpy().T, c=custom_map(x, threshold=np.inf).detach().numpy(), cmap="viridis_r", )
            plot_signed_measure(sm_diff, ax=ax2)
            ax3.plot(losses, label="loss")
            plt.show()
fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols*base, base))
ax1.scatter(*xinit.T, c=custom_map(torch.tensor(xinit)).detach().numpy(),cmap="viridis_r")
ax2.scatter(*x.detach().numpy().T, c=custom_map(x).detach().numpy(),cmap="viridis_r")
ax3.plot(losses)

plt.show()
../_images/70b705b8e29283f4c03825d9ae4642aeb34fabfb3853a75b50a5848178f5586d.png ../_images/42af42d9812bb5c1ad1401a15e52f70828799e396c15676327aee048741e3b51.png ../_images/480ba3d0abc76840c041acde8b4c340235a86d05a3058a29c7f5829ca7c1d344.png ../_images/bf009a98220879f964e4a8bd5232015355e30a309194123b5bec6b03a6f7330d.png ../_images/951ea730280813139cb986fc135a75f43e6833a7e92849813a9c7dfe26acab6e.png ../_images/80ebfab00326e16876e3427c11dd7cd5907301823e41461ea8248f034d6da6a3.png ../_images/b1ae315e0b0560627dcc35d456b5938c3ca1eb6dbb6575627568d32f46968281.png ../_images/5b5a040691ad7a411c0ecf607a953263e394f9077b8250845fde22fb3ccd8cd5.png ../_images/2abf44b47c9d4b864c78a780e23f1a154a90f91684bc972f59f7b6808fe55ef4.png ../_images/a74e7c25f528ff8cc22afb75a80173d4280817279553fd377c9ec5f9e0c57705.png ../_images/982b5082a1440bc5cfef4c392e2c2ff7f63a9b214bd2dcc82276e4f37e480733.png ../_images/f595a9ea72f18302d3f932b54af6abbd04b297d7b267098dd8a688300517138e.png

We now observe an onion-like structure around the poles of the background measure.

How to interpret this ? The density constraints (in the loss) and the radius constraints are fighting against each other, therefore, each cycle has to balance itself to a local optimal, which leads to these onion layers.

Density preserving optimization

Example taken from the paper Differentiability and Optimization of Multiparameter Persistent Homology, and is an extension of the point cloud experiement of the optimization Gudhi notebook.
One can check from the Gudhi’s notebook that a compacity regularization term is necessary in the one parameter persistence setting; this issue will not happen when optimizing a Rips-Density bi-filtration, as we can enforce cycles to naturally balance between scale and density, and hence not diverge.

from multipers.filtrations.density import KDE
X = np.block([
    [np.random.uniform(low=-0.1,high=.2,size=(100,2))],
    [mp.data.noisy_annulus(300,0, 0.85,1)]
])
bandwidth = .1
custom_map2 = lambda X :  -KDE(bandwidth=bandwidth, return_log=True).fit(X).score_samples(X)
codensity = custom_map2(X)
plt.scatter(*X.T, c=-codensity)
plt.gca().set_aspect(1)
../_images/e0f9f58a504a3eca2c15d81508d13581fb17c87b9071159bae7dbd29d06d758e.png
def norm_loss(sm_diff,norm=1.):
    pts,weights = sm_diff
    loss = (torch.norm(pts[weights>0], p=norm, dim=1)).sum() -  (torch.norm(pts[weights<0], p=norm, dim=1)).sum()
    return loss / pts.shape[0]
x = torch.from_numpy(X).clone().requires_grad_(True)
opt = torch.optim.Adam([x], lr=.01)
losses = []
for i in range(100):
    opt.zero_grad()
    st = DelaunayLowerstar(points=x, function=custom_map(x), flagify=True)
    sm_diff, = mp.signed_measure(st, degree=1)
    loss = norm_loss(sm_diff)
    loss.backward()
    losses.append(loss)
    opt.step()
    if i % 10 == 0:
    	with torch.no_grad():
            base=4
            ncols=3
            fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols*base, base))
            ax1.scatter(*x.detach().numpy().T, c=custom_map2(x).detach().numpy(), cmap="viridis_r", )
            plot_signed_measure(sm_diff, ax=ax2)
            ax3.plot(losses, label="loss")
            plt.show()

with torch.no_grad():
    fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols*base, base))
    ax1.scatter(*X.T, c=custom_map2(torch.tensor(X)).detach().numpy(),cmap="viridis_r")
    ax2.scatter(*x.detach().numpy().T, c=custom_map2(x).detach().numpy(),cmap="viridis_r")
    ax3.plot(losses)
    plt.show()
../_images/86119c68d8d4fc1c4198e2a89ba8c2f1db6ad37e749d69d522831fe42c5b8516.png ../_images/18828fd9e5546788ebbeeba6f4bea45d02888baaf93bdfda2a1d913369d24997.png ../_images/d2136a1d22d30b3376457a072a8fffb047ee589775cce5fcdf7c0f4f96cc683b.png ../_images/df1d9ca039bdb7a3700df4b6eaccd663d7dd377f589e6e2bc6c14050b6ba2981.png ../_images/280db1d2a5effd196dff61fd57ac6706e3ff4a1a4d46ca12639dc0f5ead2f62d.png ../_images/c60971a6f96c243016a9e1e8555ab11fc06133d3b3271c64539ada8441ec9c58.png ../_images/87fa33ebc3e6001ad95fdb24a1a428cc39a0a0b4094e95de50afea0b638e9fb4.png ../_images/c92f14f8abc08c867500568594e88408bdcf40edb69e84c7a1de1c938aad3293.png ../_images/e5cd74038ef88e1ceb93fc4fc82bcd41853d632064c16b4e513e5af4b0dc75c2.png ../_images/400c7d09414a18e6fe4da767f946dc73361d012043f9adf12fed64f2c09f7f1b.png ../_images/3b13836818e999239c9b6310f806023cf570af8d03b2ddd62a46716acd256457.png