Utils

Utility functionalities or recurrent operations for the machine learning pipeline.

Tensor operations


source

get_displacements

 get_displacements (x)

Returns the displacements of trajectory x [dim, length].


source

split_tensor

 split_tensor (t, indices)

Splits input tensor t according to indices in the first dimension.


source

lengths_from_cps

 lengths_from_cps (cps, length=200)

Returns segment lengths determined by cps and a total length length.

Segmentation post-processing


source

fit_segments

 fit_segments (pred, pen=1.0, return_cps=False, kernel='linear',
               min_size=2, jump=1, params=None)

Fit piecewise constant segments to input signal pred.

fit_segments is mainly intended to process predictions of continuous values. However, the following functinos are mainly intended to post-process discrete predictions. Notably, post_process_prediction takes a prediction of discrete categories over a trajectory and extracts the most likely changepoints and segments, minimizing the impact of spurious mistakes along the predicted segment (see the example below).


source

post_process_prediction

 post_process_prediction (pred, n_change_points=1)

Segmentation prediction post-processing to find change points and classes.


source

abundance

 abundance (val, t)

Abundance of value val in tensor t.


source

majority_vote

 majority_vote (t)

Returns majoritary value from t.


source

get_split_classes

 get_split_classes (splits)

Returns majority class of each split.


source

change_points_from_splits

 change_points_from_splits (splits)

Returns change point position from split tensor.


source

get_splits

 get_splits (t)

Splits tensor t into chunks with the same value.


source

find_change_points

 find_change_points (t)

Finds points in tensor t where the value changes.

prediction = tensor([0, 0, 0, 0, 1, 1, 0, 0, 2, 2, 2, 2, 1, 2, 2])
cps, classes, splits = post_process_prediction(prediction)
cps, classes, splits
(tensor([8]),
 [tensor([0]), tensor([2])],
 [tensor([0, 0, 0, 0, 1, 1, 0, 0]), tensor([2, 2, 2, 2, 1, 2, 2])])

Model evaluation


source

mean_relative_error

 mean_relative_error (pred, true, base=10)

Mean relative error assuming pred and true in log_base.


source

mean_absolute_error

 mean_absolute_error (pred, true)

Mean absolute error between pred and true.


source

jaccard_index

 jaccard_index (true_positive, false_positive, false_negative)

Computes the Jaccard index a.k.a. Tanimoto index.


source

assign_changepoints

 assign_changepoints (true, pred)

Matches predicted and true changepoints solving a linear sum assignment problem.


source

evaluate_cp_prediction

 evaluate_cp_prediction (true, pred, changepoint_threshold=5)

Evaluates the change point prediction.

Since the changepoint detection algorithm can provide an arbitrary number of change points, we solve a linear sum assignment problem to perform the matching between the ground truth and the predicted changepoints.

Then, we consider a valid prediction, i.e., a true positive (TP), those changepoints that lie within a trheshold of their corresponding ground truth. This way, all the predicted change points that are not TP are false positives (FP). Finally, the ground truth change points that do not have a predicted counterpart within the threshold are false negatives (FN).

To evaluate the change point detection, we use the Jaccard index, which is a function of the TP, FP and FN: \[JI = \frac{TP}{TP + FP + FN}\,.\]


source

validate_andi_3_alpha

 validate_andi_3_alpha (m, dim=1, task=3, bs=128, pct=1, path=None)

Validates model on the AnDi test set for task 3 (segmentation) predicting anomalous exponents.


source

validate_andi_3_models

 validate_andi_3_models (m, dim=1, task=3, bs=128, pct=1, path=None)

Validates model on the AnDi test set for task 3 (segmentation) predicting diffusion models.


source

validate_andi_1

 validate_andi_1 (m, dim=1, bs=1, pct=1, task=1, path=None)

Validates model on the AnDi test set for task 1 (anomalous exponent).


source

eval_andi_metrics

 eval_andi_metrics (dls, model)

Evaluates model in validation set in order to obtain AnDi challenge metrics.

Figures

Here, we define colors and colormaps for our plots.

color_order = ['blue', 'orange', 'yellow', 'purple', 'green']
color_dict = {
    'blue':   {'dark': (0.2745098, 0.4, 0.6),
               'medium': (0.39607843, 0.5254902, 0.71764706),
               'light': (0.65098039, 0.79215686, 0.94117647)},
    'orange': {'dark': (0.71764706, 0.36470588, 0.24313725),
               'medium': (0.88627451, 0.4627451, 0.34901961),
               'light': (1.0, 0.63921569, 0.44705882)},
    'yellow': {'dark': (0.85882353, 0.58431373, 0.18039216),
               'medium': (0.89803922, 0.68235294, 0.39607843),
               'light': (0.96470588, 0.84705882, 0.52941176)},
    'purple': {'dark': (0.6627451, 0.16078431, 0.30980392),
               'medium': (0.7372549, 0.39607843, 0.55294118),
               'light': (0.89019608, 0.38823529, 0.52941176)},
    'green':  {'dark': (0.22352941, 0.46666667, 0.4549019607843137),
               'medium': (0.29803922, 0.60784314, 0.58431373),
               'light': (0.50980392, 0.76862745, 0.76470588)}
}

colors = [color_dict[k]['medium'] for k in color_order]
colors_light = [color_dict[k]['light'] for k in color_order]
colors_dark = [color_dict[k]['dark'] for k in color_order]

cmap_hist1 = clr.LinearSegmentedColormap.from_list(
    'custom cm', ['w', 
                  color_dict['blue']['light'],
                  color_dict['blue']['dark']],
                  N=256
)
cmap_hist2 = clr.LinearSegmentedColormap.from_list(
    'custom cm', ['w', 
                  color_dict['orange']['light'],
                  color_dict['orange']['dark']],
                  N=256
)
cmap_points = clr.LinearSegmentedColormap.from_list(
    'custom cm', [color_dict['yellow']['light'], 
                  color_dict['purple']['light'],
                  color_dict['blue']['medium']],
                  N=256
)

fig_size = 4
linewidth = 2
alpha_grid = 0.2
scatter_size = 12

D_units = "($\mu$m$^2$/s)"