Skip to content

6. plot

torchfsm.plot.plot_1D_field ¤

plot_1D_field(
    ax: Axes,
    data: Union[ndarray, Tensor],
    x_label: Optional[str] = None,
    y_label: Optional[str] = None,
    title: Optional[str] = None,
    title_loc="center",
    show_ticks=True,
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    extend_value_range: bool = True,
    grid=True,
    **kwargs
)

Plot a 1D field.

Parameters:

Name Type Description Default
ax Axes

The axes to plot on.

required
data Union[ndarray, Tensor]

The data to plot.

required
x_label Optional[str]

The label for the x-axis. Defaults to None.

None
y_label Optional[str]

The label for the y-axis. Defaults to None.

None
title Optional[str]

The title of the plot. Defaults to None.

None
title_loc str

The location of the title. Defaults to "center".

'center'
show_ticks bool

Whether to show ticks. Defaults to True.

True
ticks_x Tuple[Sequence[float], Sequence[str]]

Custom ticks for the x-axis. Defaults to None.

None
ticks_y Tuple[Sequence[float], Sequence[str]]

Custom ticks for the y-axis. Defaults to None.

None
vmin Optional[float]

The minimum value for the color scale. Defaults to None.

None
vmax Optional[float]

The maximum value for the color scale. Defaults to None.

None
extend_value_range bool

Whether to extend the value range. Defaults to True.

True
grid bool

Whether to show grid lines. Defaults to True.

True
**kwargs

Additional keyword arguments for the plot.

{}
Source code in torchfsm/plot/core/field.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def plot_1D_field(
    ax: plt.Axes,
    data: Union[np.ndarray, torch.Tensor],
    x_label: Optional[str] = None,
    y_label: Optional[str] = None,
    title: Optional[str] = None,
    title_loc="center",
    show_ticks=True,
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    extend_value_range: bool = True,
    grid=True,
    **kwargs,
):
    """
    Plot a 1D field.

    Args:
        ax (plt.Axes): The axes to plot on.
        data (Union[np.ndarray, torch.Tensor]): The data to plot.
        x_label (Optional[str], optional): The label for the x-axis. Defaults to None.
        y_label (Optional[str], optional): The label for the y-axis. Defaults to None.
        title (Optional[str], optional): The title of the plot. Defaults to None.
        title_loc (str, optional): The location of the title. Defaults to "center".
        show_ticks (bool, optional): Whether to show ticks. Defaults to True.
        ticks_x (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the x-axis. Defaults to None.
        ticks_y (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the y-axis. Defaults to None.
        vmin (Optional[float], optional): The minimum value for the color scale. Defaults to None.
        vmax (Optional[float], optional): The maximum value for the color scale. Defaults to None.
        extend_value_range (bool, optional): Whether to extend the value range. Defaults to True.
        grid (bool, optional): Whether to show grid lines. Defaults to True.
        **kwargs: Additional keyword arguments for the plot.

    """
    if isinstance(data, torch.Tensor):
        data = data.detach().cpu().numpy()
    elif not isinstance(data, np.ndarray):
        data = np.asarray(data)
    if len(data.shape) != 1:
        raise ValueError("Only support 1D data.")
    ax.plot(data, **kwargs)
    if not show_ticks:
        ax.set_xticks([])
        ax.set_yticks([])
    else:
        if ticks_x is not None:
            ax.set_xticks(ticks_x[0], labels=ticks_x[1])
        if ticks_y is not None:
            ax.set_yticks(ticks_y[0], labels=ticks_y[1])
    if x_label is not None:
        ax.set_xlabel(x_label)
    if y_label is not None:
        ax.set_ylabel(y_label)
    if title is not None:
        ax.set_title(title, loc=title_loc)
    if vmin is not None and vmax is not None:
        if extend_value_range:
            ax.set_ylim(vmin * 1.05, vmax * 1.05)
        else:
            ax.set_ylim(vmin, vmax)
    if grid:
        ax.grid()

torchfsm.plot.plot_2D_field ¤

plot_2D_field(
    ax: Axes,
    data: Union[ndarray, Tensor],
    x_label: Optional[str] = None,
    y_label: Optional[str] = None,
    title: Optional[str] = None,
    title_loc="center",
    interpolation="none",
    aspect="auto",
    cmap: Union[str, Colormap] = "twilight",
    show_ticks=True,
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    rasterized: bool = True,
    **kwargs
)

Plot a 2D field.

Parameters:

Name Type Description Default
ax Axes

The axes to plot on.

required
data Union[ndarray, Tensor]

The data to plot.

required
x_label Optional[str]

The label for the x-axis. Defaults to None.

None
y_label Optional[str]

The label for the y-axis. Defaults to None.

None
title Optional[str]

The title of the plot. Defaults to None.

None
title_loc str

The location of the title. Defaults to "center".

'center'
interpolation str

The interpolation method. Defaults to "none".

'none'
aspect str

The aspect ratio. Defaults to "auto".

'auto'
cmap Union[str, Colormap]

The colormap to use. Defaults to "twilight".

'twilight'
show_ticks bool

Whether to show ticks. Defaults to True.

True
ticks_x Tuple[Sequence[float], Sequence[str]]

Custom ticks for the x-axis. Defaults to None.

None
ticks_y Tuple[Sequence[float], Sequence[str]]

Custom ticks for the y-axis. Defaults to None.

None
rasterized bool

Whether to rasterize the image. Defaults to True.

True
**kwargs

Additional keyword arguments for the plot.

{}
Source code in torchfsm/plot/core/field.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def plot_2D_field(
    ax: plt.Axes,
    data: Union[np.ndarray, torch.Tensor],
    x_label: Optional[str] = None,
    y_label: Optional[str] = None,
    title: Optional[str] = None,
    title_loc="center",
    interpolation="none",
    aspect="auto",
    cmap: Union[str, Colormap] = "twilight",
    show_ticks=True,
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    rasterized: bool = True,
    **kwargs,
):
    """
    Plot a 2D field.

    Args:
        ax (plt.Axes): The axes to plot on.
        data (Union[np.ndarray, torch.Tensor]): The data to plot.
        x_label (Optional[str], optional): The label for the x-axis. Defaults to None.
        y_label (Optional[str], optional): The label for the y-axis. Defaults to None.
        title (Optional[str], optional): The title of the plot. Defaults to None.
        title_loc (str, optional): The location of the title. Defaults to "center".
        interpolation (str, optional): The interpolation method. Defaults to "none".
        aspect (str, optional): The aspect ratio. Defaults to "auto".
        cmap (Union[str, Colormap], optional): The colormap to use. Defaults to "twilight".
        show_ticks (bool, optional): Whether to show ticks. Defaults to True.
        ticks_x (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the x-axis. Defaults to None.
        ticks_y (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the y-axis. Defaults to None.
        rasterized (bool, optional): Whether to rasterize the image. Defaults to True.
        **kwargs: Additional keyword arguments for the plot.
    """
    if isinstance(data, torch.Tensor):
        data = data.detach().cpu().numpy()
    elif not isinstance(data, np.ndarray):
        data = np.asarray(data)
    if len(data.shape) != 2:
        raise ValueError("Only support 2D data.")
    im = ax.imshow(
        data.T,
        interpolation=interpolation,
        cmap=cmap,
        origin="lower",
        aspect=aspect,
        rasterized=rasterized,
        **kwargs,
    )
    if not show_ticks:
        ax.set_xticks([])
        ax.set_yticks([])
    if x_label is not None:
        ax.set_xlabel(x_label)
    if y_label is not None:
        ax.set_ylabel(y_label)
    if title is not None:
        ax.set_title(title, loc=title_loc)
    if ticks_x is not None:
        ax.set_xticks(ticks_x[0], labels=ticks_x[1])
    if ticks_y is not None:
        ax.set_yticks(ticks_y[0], labels=ticks_y[1])
    return im

torchfsm.plot.plot_3D_field ¤

plot_3D_field(
    ax: Axes,
    data: Union[ndarray, Tensor],
    bottom_label: Optional[str] = None,
    left_label: Optional[str] = None,
    title: Optional[str] = None,
    title_loc="center",
    aspect="auto",
    cmap: Union[str, Colormap] = "twilight",
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    return_cmap: bool = False,
    distance_scale: float = 10,
    background=(0, 0, 0, 0),
    width=512,
    height=512,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    gamma_correction: float = 2.4,
    show_3d_coordinates: bool = True,
    coordinates_size: float = 0.1,
    coordinates_x_loc: float = -0.05,
    coordinates_y_loc: float = -0.05,
    arrow_length: float = 0.6,
    x_arrow_label: str = "x",
    y_arrow_label: str = "y",
    z_arrow_label: str = "z",
    x_arrow_color: str = "r",
    y_arrow_color: str = "g",
    z_arrow_color: str = "b",
    arrow_length_ratio: float = 0.25,
    arrow_linewidth: int = 1,
    **kwargs
)

Plot a 3D field. Powered by https://github.com/KeKsBoTer/vape4d

Parameters:

Name Type Description Default
ax Axes

The axes to plot on.

required
data Union[ndarray, Tensor]

The data to plot.

required
bottom_label Optional[str]

The label for the bottom axis. Defaults to None.

None
left_label Optional[str]

The label for the left axis. Defaults to None.

None
title Optional[str]

The title of the plot. Defaults to None.

None
title_loc str

The location of the title. Defaults to "center".

'center'
aspect str

The aspect ratio. Defaults to "auto".

'auto'
cmap Union[str, Colormap]

The colormap to use. Defaults to "twilight".

'twilight'
vmin Optional[float]

The minimum value for the color scale. Defaults to None.

None
vmax Optional[float]

The maximum value for the color scale. Defaults to None.

None
return_cmap bool

Whether to return the colormap. Defaults to False.

False
distance_scale float

The distance scale for rendering. Defaults to 10.

10
background tuple

The background color. Defaults to (0, 0, 0, 0).

(0, 0, 0, 0)
width int

The width of the rendered image. Defaults to 512.

512
height int

The height of the rendered image. Defaults to 512.

512
alpha_func Union[Literal['zigzag', 'central_peak', 'central_valley', 'linear_increase', 'linear_decrease', 'luminance'], AlphaFunction]

The alpha function. Defaults to "zigzag".

'zigzag'
gamma_correction float

The gamma correction factor. Defaults to 2.4.

2.4
show_3d_coordinates bool

Whether to show 3D coordinates. Defaults to False. ax2d (mlp.axes.Axes): The 2D axes to which the 3D coordinate system will be added

True
coordinates_size float

Size of the 3D axes relative to the 2D axes. Defaults to 0.1.

0.1
coordinates_x_loc float

X location offset for the 3D axes. Defaults to -0.05.

-0.05
coordinates_y_loc float

Y location offset for the 3D axes. Defaults to -0.05.

-0.05
arrow_length float

Length of the arrows. Defaults to 0.6.

0.6
x_arrow_label str

Label for the X axis. Defaults to 'x'.

'x'
y_arrow_label str

Label for the Y axis. Defaults to 'y'.

'y'
z_arrow_label str

Label for the Z axis. Defaults to 'z'.

'z'
x_arrow_color str

Color for the X axis arrow. Defaults to 'r'.

'r'
y_arrow_color str

Color for the Y axis arrow. Defaults to 'g'.

'g'
z_arrow_color str

Color for the Z axis arrow. Defaults to 'b'.

'b'
arrow_length_ratio float

Ratio of the arrow head length to the total arrow length. Defaults to 0.25.

0.25
arrow_linewidth int

Line width of the arrows. Defaults to 1.

1
**kwargs

Additional keyword arguments for the plot.

{}
Source code in torchfsm/plot/core/field.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def plot_3D_field(
    ax: plt.Axes,
    data: Union[np.ndarray, torch.Tensor],
    bottom_label: Optional[str] = None,
    left_label: Optional[str] = None,
    title: Optional[str] = None,
    title_loc="center",
    aspect="auto",
    cmap: Union[str, Colormap] = "twilight",
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    return_cmap: bool = False,
    distance_scale: float = 10,
    background=(0, 0, 0, 0),
    width=512,
    height=512,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    gamma_correction: float = 2.4,
    show_3d_coordinates: bool = True,
    coordinates_size: float = 0.1,
    coordinates_x_loc: float = -0.05,
    coordinates_y_loc: float = -0.05,
    arrow_length: float = 0.6,
    x_arrow_label:str='x',
    y_arrow_label:str='y',
    z_arrow_label:str='z',
    x_arrow_color:str='r',
    y_arrow_color:str='g',
    z_arrow_color:str='b',
    arrow_length_ratio:float=0.25,
    arrow_linewidth:int=1,
    **kwargs,
):
    """
    Plot a 3D field.
    Powered by https://github.com/KeKsBoTer/vape4d

    Args:
        ax (plt.Axes): The axes to plot on.
        data (Union[np.ndarray, torch.Tensor]): The data to plot.
        bottom_label (Optional[str], optional): The label for the bottom axis. Defaults to None.
        left_label (Optional[str], optional): The label for the left axis. Defaults to None.
        title (Optional[str], optional): The title of the plot. Defaults to None.
        title_loc (str, optional): The location of the title. Defaults to "center".
        aspect (str, optional): The aspect ratio. Defaults to "auto".
        cmap (Union[str, Colormap], optional): The colormap to use. Defaults to "twilight".
        vmin (Optional[float], optional): The minimum value for the color scale. Defaults to None.
        vmax (Optional[float], optional): The maximum value for the color scale. Defaults to None.
        return_cmap (bool, optional): Whether to return the colormap. Defaults to False.
        distance_scale (float, optional): The distance scale for rendering. Defaults to 10.
        background (tuple, optional): The background color. Defaults to (0, 0, 0, 0).
        width (int, optional): The width of the rendered image. Defaults to 512.
        height (int, optional): The height of the rendered image. Defaults to 512.
        alpha_func (Union[Literal["zigzag","central_peak","central_valley","linear_increase","linear_decrease", "luminance",],AlphaFunction,], optional): The alpha function. Defaults to "zigzag".
        gamma_correction (float, optional): The gamma correction factor. Defaults to 2.4.
        show_3d_coordinates (bool, optional): Whether to show 3D coordinates. Defaults to False.
                ax2d (mlp.axes.Axes): The 2D axes to which the 3D coordinate system will be added
        coordinates_size (float, optional): Size of the 3D axes relative to the 2D axes. Defaults to 0.1.
        coordinates_x_loc (float, optional): X location offset for the 3D axes. Defaults to -0.05.
        coordinates_y_loc (float, optional): Y location offset for the 3D axes. Defaults to -0.05.
        arrow_length (float, optional): Length of the arrows. Defaults to 0.6.
        x_arrow_label (str, optional): Label for the X axis. Defaults to 'x'.
        y_arrow_label (str, optional): Label for the Y axis. Defaults to 'y'.
        z_arrow_label (str, optional): Label for the Z axis. Defaults to 'z'.
        x_arrow_color (str, optional): Color for the X axis arrow. Defaults to 'r'.
        y_arrow_color (str, optional): Color for the Y axis arrow. Defaults to 'g'.
        z_arrow_color (str, optional): Color for the Z axis arrow. Defaults to 'b'.
        arrow_length_ratio (float, optional): Ratio of the arrow head length to the total arrow length. Defaults to 0.25.
        arrow_linewidth (int, optional): Line width of the arrows. Defaults to 1.
        **kwargs: Additional keyword arguments for the plot.
    """
    if isinstance(data, torch.Tensor):
        data = data.detach().cpu().numpy()
    elif not isinstance(data, np.ndarray):
        data = np.asarray(data)
    if len(data.shape) == 3:
        data = np.expand_dims(data, 0)
    elif not (len(data.shape) == 4 and data.shape[0] == 1):
        raise ValueError("Only support 3D data with shape of [X,Y,Z] or [1,X,Y,Z].")
    img = render_3d_field(
        data,
        cmap,
        vmin,
        vmax,
        distance_scale,
        background,
        width,
        height,
        alpha_func,
        gamma_correction,
        **kwargs,
    )
    im = _plot_3D_field(
        ax,
        img,
        bottom_label=bottom_label,
        left_label=left_label,
        title=title,
        title_loc=title_loc,
        aspect=aspect,
        show_3d_coordinates=show_3d_coordinates,
        coordinates_size=coordinates_size,
        coordinates_x_loc=coordinates_x_loc,
        coordinates_y_loc=coordinates_y_loc,
        arrow_length=arrow_length,
        x_arrow_label=x_arrow_label,
        y_arrow_label=y_arrow_label,
        z_arrow_label=z_arrow_label,
        x_arrow_color=x_arrow_color,
        y_arrow_color=y_arrow_color,
        z_arrow_color=z_arrow_color,
        arrow_length_ratio=arrow_length_ratio,
        arrow_linewidth=arrow_linewidth,
    )
    if return_cmap:
        return im, cmap
    return im

torchfsm.plot.plot_traj ¤

plot_traj(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
    ],
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_t: Optional[str] = "t",
    ticks_t: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    show_time_index: bool = True,
    animation: bool = True,
    fps=30,
    show_in_notebook: bool = True,
    animation_engine: Literal["jshtml", "html5"] = "html5",
    save_name: Optional[str] = None,
    show_3d_coordinates: bool = True,
    **kwargs
) -> Optional[FuncAnimation]

Plot a trajectory. The dimension of the trajectory can be 1D, 2D, or 3D.

Parameters:

Name Type Description Default
traj Union[SpatialTensor['B T C H ...'], SpatialArray['B T C H ...']]

The trajectory to plot.

required
channel_names Optional[Sequence[str]]

The names of the channels. Defaults to None.

None
batch_names Optional[Sequence[str]]

The names of the batches. Defaults to None.

None
title Optional[str]

The title of the plot. Defaults to None.

None
vmin Optional[Union[float, Sequence[float]]]

The minimum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.

None
vmax Optional[Union[float, Sequence[float]]]

The maximum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.

None
cmap Union[str, Colormap]

The colormap to use. Defaults to "twilight".

'twilight'
use_sym_colormap bool

Whether to use a symmetric colormap. Defaults to False.

False
alpha_func Union[Literal['zigzag', 'central_peak', 'central_valley', 'linear_increase', 'linear_decrease'], AlphaFunction]

The alpha function for the colormap when plot 3D data. Defaults to "zigzag".

'zigzag'
num_colorbar_value int

The number of values for the colorbar. Defaults to 4.

4
c_bar_labels Optional[Sequence[str]]

The labels for the colorbar. Defaults to None. If provided, it should have the same length as the number of channels. If not provided, the colorbar will not have labels.

None
cbar_pad Optional[float]

The padding for the colorbar. Defaults to 0.1.

0.1
ctick_format Optional[str]

The format for the colorbar ticks. Defaults to "%.1f".

'%.1f'
subfig_size float

The size of the subfigures. Defaults to 2.5.

2.5
real_size_ratio bool

Whether to use the real size ratio for the subfigures. Defaults to False.

False
width_correction float

The correction factor for the width of the subfigures. Defaults to 1.0.

1.0
height_correction float

The correction factor for the height of the subfigures. Defaults to 1.0.

1.0
space_x Optional[float]

The space between subfigures in the x direction. Defaults to 0.7.

0.7
space_y Optional[float]

The space between subfigures in the y direction. Defaults to 0.1.

0.1
label_x Optional[str]

The label for the x-axis. Defaults to "x".

'x'
label_y Optional[str]

The label for the y-axis. Defaults to "y".

'y'
label_t Optional[str]

The label for the time index. Defaults to "t".

't'
ticks_t Tuple[Sequence[float], Sequence[str]]

Custom ticks for the time index. Defaults to None.

None
ticks_x Tuple[Sequence[float], Sequence[str]]

Custom ticks for the x-axis. Defaults to None.

None
ticks_y Tuple[Sequence[float], Sequence[str]]

Custom ticks for the y-axis. Defaults to None.

None
show_ticks Union[Literal['auto'], bool]

Whether to show ticks. Defaults to "auto".

'auto'
show_time_index bool

Whether to show the time index in the plot. Defaults to True.

True
animation bool

Whether to create an animation. Defaults to True.

True
fps int

The frames per second for the animation. Defaults to 30.

30
show_in_notebook bool

Whether to show the plot in a Jupyter notebook. Defaults to True.

True
animation_engine Literal['jshtml', 'html5']

The engine to use for the animation. Defaults to "html5".

'html5'
save_name Optional[str]

The name of the file to save the plot. Defaults to None.

None
show_3d_coordinates bool

Whether to show 3D coordinate axes when plotting 3D data. Defaults to True.

True
**kwargs

Additional keyword arguments for the plot.

{}

Returns:

Type Description
Optional[FuncAnimation]

Optional[FuncAnimation]: If animation is True and not show_in_notebook, returns a FuncAnimation object.

Source code in torchfsm/plot/app/traj_field.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def plot_traj(
    traj: Union[SpatialTensor["B T C H ..."], SpatialArray["B T C H ..."]],
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_t: Optional[str] = "t",
    ticks_t: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    show_time_index: bool = True,
    animation: bool = True,
    fps=30,
    show_in_notebook: bool = True,
    animation_engine: Literal["jshtml", "html5"] = "html5",
    save_name: Optional[str] = None,
    show_3d_coordinates: bool = True,
    **kwargs,
) -> Optional[FuncAnimation]:
    """
    Plot a trajectory. The dimension of the trajectory can be 1D, 2D, or 3D.

    Args:
        traj (Union[SpatialTensor["B T C H ..."], SpatialArray["B T C H ..."]]): The trajectory to plot.
        channel_names (Optional[Sequence[str]], optional): The names of the channels. Defaults to None.
        batch_names (Optional[Sequence[str]], optional): The names of the batches. Defaults to None.
        title (Optional[str], optional): The title of the plot. Defaults to None.
        vmin (Optional[Union[float, Sequence[float]]], optional): The minimum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.
        vmax (Optional[Union[float, Sequence[float]]], optional): The maximum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.
        cmap (Union[str, Colormap], optional): The colormap to use. Defaults to "twilight".
        use_sym_colormap (bool, optional): Whether to use a symmetric colormap. Defaults to False.
        alpha_func (Union[Literal["zigzag","central_peak","central_valley","linear_increase","linear_decrease",],AlphaFunction,], optional): The alpha function for the colormap when plot 3D data. Defaults to "zigzag".
        num_colorbar_value (int, optional): The number of values for the colorbar. Defaults to 4.
        c_bar_labels (Optional[Sequence[str]], optional): The labels for the colorbar. Defaults to None.
            If provided, it should have the same length as the number of channels.
            If not provided, the colorbar will not have labels.
        cbar_pad (Optional[float], optional): The padding for the colorbar. Defaults to 0.1.
        ctick_format (Optional[str], optional): The format for the colorbar ticks. Defaults to "%.1f".
        subfig_size (float, optional): The size of the subfigures. Defaults to 2.5.
        real_size_ratio (bool, optional): Whether to use the real size ratio for the subfigures. Defaults to False.
        width_correction (float, optional): The correction factor for the width of the subfigures. Defaults to 1.0.
        height_correction (float, optional): The correction factor for the height of the subfigures. Defaults to 1.0.
        space_x (Optional[float], optional): The space between subfigures in the x direction. Defaults to 0.7.
        space_y (Optional[float], optional): The space between subfigures in the y direction. Defaults to 0.1.
        label_x (Optional[str], optional): The label for the x-axis. Defaults to "x".
        label_y (Optional[str], optional): The label for the y-axis. Defaults to "y".
        label_t (Optional[str], optional): The label for the time index. Defaults to "t".
        ticks_t (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the time index. Defaults to None.
        ticks_x (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the x-axis. Defaults to None.
        ticks_y (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the y-axis. Defaults to None.
        show_ticks (Union[Literal["auto"], bool], optional): Whether to show ticks. Defaults to "auto".
        show_time_index (bool, optional): Whether to show the time index in the plot. Defaults to True.
        animation (bool, optional): Whether to create an animation. Defaults to True.
        fps (int, optional): The frames per second for the animation. Defaults to 30.
        show_in_notebook (bool, optional): Whether to show the plot in a Jupyter notebook. Defaults to True.
        animation_engine (Literal["jshtml", "html5"], optional): The engine to use for the animation. Defaults to "html5".
        save_name (Optional[str], optional): The name of the file to save the plot. Defaults to None.
        show_3d_coordinates (bool, optional): Whether to show 3D coordinate axes when plotting 3D data. Defaults to True.
        **kwargs: Additional keyword arguments for the plot.

    Returns:
        Optional[FuncAnimation]: If `animation` is True and not show_in_notebook, returns a `FuncAnimation` object.
    """
    return ChannelWisedPlotter().plot(
        traj=traj,
        title=title,
        channel_names=channel_names,
        batch_names=batch_names,
        vmin=vmin,
        vmax=vmax,
        subfig_size=subfig_size,
        space_x=space_x,
        space_y=space_y,
        cbar_pad=cbar_pad,
        c_bar_labels=c_bar_labels,
        real_size_ratio=real_size_ratio,
        num_colorbar_value=num_colorbar_value,
        ctick_format=ctick_format,
        show_ticks=show_ticks,
        show_time_index=show_time_index,
        use_sym_colormap=use_sym_colormap,
        cmap=cmap,
        label_x=label_x,
        label_y=label_y,
        label_t=label_t,
        ticks_t=ticks_t,
        ticks_x=ticks_x,
        ticks_y=ticks_y,
        animation=animation,
        fps=fps,
        show_in_notebook=show_in_notebook,
        animation_engine=animation_engine,
        alpha_func=alpha_func,
        save_name=save_name,
        width_correction=width_correction,
        height_correction=height_correction,
        show_3d_coordinates=show_3d_coordinates,
        **kwargs,
    )

torchfsm.plot.plot_field ¤

plot_field(
    field: Union[
        SpatialTensor["B C H ..."],
        SpatialArray["B C H ..."],
    ],
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_t: Optional[str] = "t",
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    save_name: Optional[str] = None,
    show_3d_coordinates: bool = True,
    **kwargs
)

Plot a field. The dimension of the field can be 1D, 2D, or 3D.

Parameters:

Name Type Description Default
field Union[SpatialTensor['B C H ...'], SpatialArray['B C H ...']]

The field to plot.

required
channel_names Optional[Sequence[str]]

The names of the channels. Defaults to None.

None
batch_names Optional[Sequence[str]]

The names of the batches. Defaults to None.

None
title Optional[str]

The title of the plot. Defaults to None.

None
vmin Optional[Union[float, Sequence[float]]]

The minimum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.

None
vmax Optional[Union[float, Sequence[float]]]

The maximum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.

None
cmap Union[str, Colormap]

The colormap to use. Defaults to "twilight".

'twilight'
use_sym_colormap bool

Whether to use a symmetric colormap. Defaults to False.

False
alpha_func Union[Literal['zigzag', 'central_peak', 'central_valley', 'linear_increase', 'linear_decrease'], AlphaFunction]

The alpha function for the colormap when plot 3D data. Defaults to "zigzag".

'zigzag'
num_colorbar_value int

The number of values for the colorbar. Defaults to 4.

4
c_bar_labels Optional[Sequence[str]]

The labels for the colorbar. Defaults to None. If provided, it should have the same length as the number of channels. If not provided, the colorbar will not have labels.

None
cbar_pad Optional[float]

The padding for the colorbar. Defaults to 0.1.

0.1
ctick_format Optional[str]

The format for the colorbar ticks. Defaults to "%.1f".

'%.1f'
subfig_size float

The size of the subfigures. Defaults to 2.5.

2.5
real_size_ratio bool

Whether to use the real size ratio for the subfigures. Defaults to False.

False
width_correction float

The correction factor for the width of the subfigures. Defaults to 1.0.

1.0
height_correction float

The correction factor for the height of the subfigures. Defaults to 1.0.

1.0
space_x Optional[float]

The space between subfigures in the x direction. Defaults to 0.7.

0.7
space_y Optional[float]

The space between subfigures in the y direction. Defaults to 0.1.

0.1
label_x Optional[str]

The label for the x-axis. Defaults to "x".

'x'
label_y Optional[str]

The label for the y-axis. Defaults to "y".

'y'
label_t Optional[str]

The label for the time index. Defaults to "t".

't'
ticks_x Tuple[Sequence[float], Sequence[str]]

Custom ticks for the x-axis. Defaults to None.

None
ticks_y Tuple[Sequence[float], Sequence[str]]

Custom ticks for the y-axis. Defaults to None.

None
show_ticks Union[Literal['auto'], bool]

Whether to show ticks. Defaults to "auto".

'auto'
save_name Optional[str]

The name of the file to save the plot. Defaults to None.

None
show_3d_coordinates bool

Whether to show 3D coordinate axes when plotting 3D data. Defaults to True.

True
Source code in torchfsm/plot/app/traj_field.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def plot_field(
    field: Union[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]],
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_t: Optional[str] = "t",
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    save_name: Optional[str] = None,
    show_3d_coordinates: bool = True,
    **kwargs,
):
    """
    Plot a field. The dimension of the field can be 1D, 2D, or 3D.

    Args:
        field (Union[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]]): The field to plot.
        channel_names (Optional[Sequence[str]], optional): The names of the channels. Defaults to None.
        batch_names (Optional[Sequence[str]], optional): The names of the batches. Defaults to None.
        title (Optional[str], optional): The title of the plot. Defaults to None.
        vmin (Optional[Union[float, Sequence[float]]], optional): The minimum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.
        vmax (Optional[Union[float, Sequence[float]]], optional): The maximum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.
        cmap (Union[str, Colormap], optional): The colormap to use. Defaults to "twilight".
        use_sym_colormap (bool, optional): Whether to use a symmetric colormap. Defaults to False.
        alpha_func (Union[Literal["zigzag","central_peak","central_valley","linear_increase","linear_decrease",],AlphaFunction,], optional): The alpha function for the colormap when plot 3D data. Defaults to "zigzag".
        num_colorbar_value (int, optional): The number of values for the colorbar. Defaults to 4.
        c_bar_labels (Optional[Sequence[str]], optional): The labels for the colorbar. Defaults to None.
            If provided, it should have the same length as the number of channels.
            If not provided, the colorbar will not have labels.
        cbar_pad (Optional[float], optional): The padding for the colorbar. Defaults to 0.1.
        ctick_format (Optional[str], optional): The format for the colorbar ticks. Defaults to "%.1f".
        subfig_size (float, optional): The size of the subfigures. Defaults to 2.5.
        real_size_ratio (bool, optional): Whether to use the real size ratio for the subfigures. Defaults to False.
        width_correction (float, optional): The correction factor for the width of the subfigures. Defaults to 1.0.
        height_correction (float, optional): The correction factor for the height of the subfigures. Defaults to 1.0.
        space_x (Optional[float], optional): The space between subfigures in the x direction. Defaults to 0.7.
        space_y (Optional[float], optional): The space between subfigures in the y direction. Defaults to 0.1.
        label_x (Optional[str], optional): The label for the x-axis. Defaults to "x".
        label_y (Optional[str], optional): The label for the y-axis. Defaults to "y".
        label_t (Optional[str], optional): The label for the time index. Defaults to "t".
        ticks_x (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the x-axis. Defaults to None.
        ticks_y (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the y-axis. Defaults to None.
        show_ticks (Union[Literal["auto"], bool], optional): Whether to show ticks. Defaults to "auto".
        save_name (Optional[str], optional): The name of the file to save the plot. Defaults to None.
        show_3d_coordinates (bool, optional): Whether to show 3D coordinate axes when plotting 3D data. Defaults to True.
    """

    if isinstance(field, torch.Tensor):
        field = field.cpu().detach().numpy()
    field = np.expand_dims(field, 1)
    return ChannelWisedPlotter().plot(
        traj=field,
        title=title,
        channel_names=channel_names,
        batch_names=batch_names,
        vmin=vmin,
        vmax=vmax,
        subfig_size=subfig_size,
        space_x=space_x,
        space_y=space_y,
        cbar_pad=cbar_pad,
        c_bar_labels=c_bar_labels,
        real_size_ratio=real_size_ratio,
        num_colorbar_value=num_colorbar_value,
        ctick_format=ctick_format,
        show_ticks=show_ticks,
        use_sym_colormap=use_sym_colormap,
        cmap=cmap,
        ticks_x=ticks_x,
        ticks_y=ticks_y,
        save_name=save_name,
        alpha_func=alpha_func,
        animation=True,
        show_time_index=False,
        width_correction=width_correction,
        height_correction=height_correction,
        label_x=label_x,
        label_y=label_y,
        label_t=label_t,
        show_3d_coordinates=show_3d_coordinates,
        **kwargs,
    )

torchfsm.plot.plot_traj_frame ¤

plot_traj_frame(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
    ],
    n_frames: int = 5,
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_t: Optional[str] = "t",
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    save_name: Optional[str] = None,
    compare_mode: Literal[
        "t_wised",
        "channel_wised",
        "channel_wised_universal",
    ] = "channel_wised_universal",
    frame_start_index: int = 0,
    show_3d_coordinates: bool = True,
    **kwargs
)

Plot frames of a single trajectory. The dimension of the trajectory can be 1D, 2D, or 3D.

Parameters:

Name Type Description Default
traj Union[SpatialTensor['B C H ...'], SpatialArray['B C H ...']]

The trajectory to plot.

required
n_frames int

The number of frames to plot.

5
channel_names Optional[Sequence[str]]

The names of the channels. Defaults to None.

None
batch_names Optional[Sequence[str]]

The names of the batches. Defaults to None.

None
title Optional[str]

The title of the plot. Defaults to None.

None
vmin Optional[Union[float, Sequence[float]]]

The minimum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.

None
vmax Optional[Union[float, Sequence[float]]]

The maximum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.

None
cmap Union[str, Colormap]

The colormap to use. Defaults to "twilight".

'twilight'
use_sym_colormap bool

Whether to use a symmetric colormap. Defaults to False.

False
alpha_func Union[Literal['zigzag', 'central_peak', 'central_valley', 'linear_increase', 'linear_decrease'], AlphaFunction]

The alpha function for the colormap when plot 3D data. Defaults to "zigzag".

'zigzag'
num_colorbar_value int

The number of values for the colorbar. Defaults to 4.

4
c_bar_labels Optional[Sequence[str]]

The labels for the colorbar. Defaults to None. If provided, it should have the same length as the number of channels. If not provided, the colorbar will not have labels.

None
cbar_pad Optional[float]

The padding for the colorbar. Defaults to 0.1.

0.1
ctick_format Optional[str]

The format for the colorbar ticks. Defaults to "%.1f".

'%.1f'
subfig_size float

The size of the subfigures. Defaults to 2.5.

2.5
real_size_ratio bool

Whether to use the real size ratio for the subfigures. Defaults to False.

False
width_correction float

The correction factor for the width of the subfigures. Defaults to 1.0.

1.0
height_correction float

The correction factor for the height of the subfigures. Defaults to 1.0.

1.0
space_x Optional[float]

The space between subfigures in the x direction. Defaults to 0.7.

0.7
space_y Optional[float]

The space between subfigures in the y direction. Defaults to 0.1.

0.1
label_x Optional[str]

The label for the x-axis. Defaults to "x".

'x'
label_y Optional[str]

The label for the y-axis. Defaults to "y".

'y'
label_t Optional[str]

The label for the time index. Defaults to "t".

't'
ticks_x Tuple[Sequence[float], Sequence[str]]

Custom ticks for the x-axis. Defaults to None.

None
ticks_y Tuple[Sequence[float], Sequence[str]]

Custom ticks for the y-axis. Defaults to None.

None
show_ticks Union[Literal['auto'], bool]

Whether to show ticks. Defaults to "auto".

'auto'
frame_start_index int

(int): The starting index for the frame numbers. Defaults to 0.

0
show_3d_coordinates bool

Whether to show 3D coordinates for 3D plots. Defaults to True.

True
**kwargs

Additional keyword arguments for the plot.

{}
Source code in torchfsm/plot/app/frame.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def plot_traj_frame(
    traj: Union[SpatialTensor["B T C H ..."], SpatialArray["B T C H ..."]],
    n_frames: int = 5,
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_t: Optional[str] = "t",
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    save_name: Optional[str] = None,
    compare_mode: Literal[
        "t_wised", "channel_wised", "channel_wised_universal"
    ] = "channel_wised_universal",
    frame_start_index: int = 0,
    show_3d_coordinates: bool = True,
    **kwargs,
):
    """
    Plot frames of a single trajectory. The dimension of the trajectory can be 1D, 2D, or 3D.

    Args:
        traj (Union[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]]): The trajectory to plot.
        n_frames (int): The number of frames to plot.
        channel_names (Optional[Sequence[str]], optional): The names of the channels. Defaults to None.
        batch_names (Optional[Sequence[str]], optional): The names of the batches. Defaults to None.
        title (Optional[str], optional): The title of the plot. Defaults to None.
        vmin (Optional[Union[float, Sequence[float]]], optional): The minimum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.
        vmax (Optional[Union[float, Sequence[float]]], optional): The maximum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.
        cmap (Union[str, Colormap], optional): The colormap to use. Defaults to "twilight".
        use_sym_colormap (bool, optional): Whether to use a symmetric colormap. Defaults to False.
        alpha_func (Union[Literal["zigzag","central_peak","central_valley","linear_increase","linear_decrease",],AlphaFunction,], optional): The alpha function for the colormap when plot 3D data. Defaults to "zigzag".
        num_colorbar_value (int, optional): The number of values for the colorbar. Defaults to 4.
        c_bar_labels (Optional[Sequence[str]], optional): The labels for the colorbar. Defaults to None.
            If provided, it should have the same length as the number of channels.
            If not provided, the colorbar will not have labels.
        cbar_pad (Optional[float], optional): The padding for the colorbar. Defaults to 0.1.
        ctick_format (Optional[str], optional): The format for the colorbar ticks. Defaults to "%.1f".
        subfig_size (float, optional): The size of the subfigures. Defaults to 2.5.
        real_size_ratio (bool, optional): Whether to use the real size ratio for the subfigures. Defaults to False.
        width_correction (float, optional): The correction factor for the width of the subfigures. Defaults to 1.0.
        height_correction (float, optional): The correction factor for the height of the subfigures. Defaults to 1.0.
        space_x (Optional[float], optional): The space between subfigures in the x direction. Defaults to 0.7.
        space_y (Optional[float], optional): The space between subfigures in the y direction. Defaults to 0.1.
        label_x (Optional[str], optional): The label for the x-axis. Defaults to "x".
        label_y (Optional[str], optional): The label for the y-axis. Defaults to "y".
        label_t (Optional[str], optional): The label for the time index. Defaults to "t".
        ticks_x (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the x-axis. Defaults to None.
        ticks_y (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the y-axis. Defaults to None.
        show_ticks (Union[Literal["auto"], bool], optional): Whether to show ticks. Defaults to "auto".
        frame_start_index: (int): The starting index for the frame numbers. Defaults to 0.
        show_3d_coordinates (bool, optional): Whether to show 3D coordinates for 3D plots. Defaults to True.
        **kwargs: Additional keyword arguments for the plot.
    """
    if isinstance(traj, torch.Tensor):
        traj = traj.cpu().detach().numpy()
    if traj.ndim < 4:
        raise ValueError("Trajectory must have at least 4 dimensions (B, T, C, H, ...)")
    if traj.shape[1] < n_frames:
        raise ValueError(
            f"Trajectory has only {traj.shape[1]} frames, but {n_frames} are requested."
        )
    frames, frame_indices = uniformly_select_frames(traj, n_frames, True)
    channel_wised_data = []
    for c_i in range(frames.shape[2]):
        temp = []
        for b_i in range(frames.shape[0]):
            temp.append(frames[b_i, :, c_i, ...])  # T, H, W, ...
        channel_wised_data.append(np.stack(temp, axis=0))  # [B, T, H, W, ...]
    if compare_mode == "t_wised":
        ploter = ChannelWisedPlotter()
        space_x = default(space_x, 0.7)
        space_y = default(space_y, 0.1)
        universal_minmax = False
    elif compare_mode == "channel_wised" or compare_mode == "channel_wised_universal":
        ploter = BatchWisedPlotter()
        space_x = default(space_x, 0.2)
        space_y = default(space_y, 0.2)
        if compare_mode == "channel_wised_universal":
            universal_minmax = True
    else:
        raise ValueError(
            f"Unknown compare_mode: {compare_mode}. Must be 't_wised', 'channel_wised' or 'channel_wised_universal'."
        )
    time_names = [f"{label_t}={i + frame_start_index}" for i in frame_indices]
    channel_names = default(
        channel_names, [f"channel {i}" for i in range(len(channel_wised_data))]
    )
    batch_names = default(
        batch_names, [f"batch {i} " for i in range(channel_wised_data[0].shape[0])]
    )
    batch_names = [
        [f"{batch_names[j]}, {channel_names[i]}" for j in range(len(batch_names))]
        for i in range(len(channel_names))
    ]
    return concate_fields_plot(
        ploters=ploter,
        fields=channel_wised_data,
        channel_names=time_names,
        batch_names=batch_names,
        vmin=vmin,
        vmax=vmax,
        universal_minmax=universal_minmax,
        subfig_size=subfig_size,
        space_x=space_x,
        space_y=space_y,
        cbar_pad=cbar_pad,
        c_bar_labels=c_bar_labels,
        real_size_ratio=real_size_ratio,
        num_colorbar_value=num_colorbar_value,
        ctick_format=ctick_format,
        show_ticks=show_ticks,
        cmap=cmap,
        use_sym_colormap=use_sym_colormap,
        ticks_x=ticks_x,
        ticks_y=ticks_y,
        save_name=save_name,
        alpha_func=alpha_func,
        show_time_index=False,
        title=title,
        width_correction=width_correction,
        height_correction=height_correction,
        label_x=label_x,
        label_y=label_y,
        label_t=label_t,
        show_3d_coordinates=show_3d_coordinates,
        **kwargs,
    )

torchfsm.plot.plot_traj_slice ¤

plot_traj_slice(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
    ],
    slice_control: Sequence[
        Optional[Union[int, float]]
    ] = None,
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_z: Optional[str] = "z",
    label_t: Optional[str] = "t",
    ticks_t: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_z: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    show_time_index: bool = True,
    animation: bool = True,
    fps=30,
    show_in_notebook: bool = True,
    animation_engine: Literal["jshtml", "html5"] = "html5",
    save_name: Optional[str] = None,
    **kwargs
) -> Optional[FuncAnimation]

Plot the trajectory slices.

Parameters:

Name Type Description Default
traj Union[SpatialTensor['B C H ...'], SpatialArray['B C H ...']]

The trajectory to plot.

required
slice_control Sequence[Optional[Union[int, float]]]

The control points for slicing the trajectory. Defaults to None. If None, it will slice at the middle of each dimension.

None
channel_names Optional[Sequence[str]]

The names of the channels. Defaults to None.

None
batch_names Optional[Sequence[str]]

The names of the batches. Defaults to None.

None
title Optional[str]

The title of the plot. Defaults to None.

None
vmin Optional[Union[float, Sequence[float]]]

The minimum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.

None
vmax Optional[Union[float, Sequence[float]]]

The maximum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.

None
cmap Union[str, Colormap]

The colormap to use. Defaults to "twilight".

'twilight'
use_sym_colormap bool

Whether to use a symmetric colormap. Defaults to False.

False
alpha_func Union[Literal['zigzag', 'central_peak', 'central_valley', 'linear_increase', 'linear_decrease'], AlphaFunction]

The alpha function for the colormap when plot 3D data. Defaults to "zigzag".

'zigzag'
num_colorbar_value int

The number of values for the colorbar. Defaults to 4.

4
c_bar_labels Optional[Sequence[str]]

The labels for the colorbar. Defaults to None. If provided, it should have the same length as the number of channels. If not provided, the colorbar will not have labels.

None
cbar_pad Optional[float]

The padding for the colorbar. Defaults to 0.1.

0.1
ctick_format Optional[str]

The format for the colorbar ticks. Defaults to "%.1f".

'%.1f'
subfig_size float

The size of the subfigures. Defaults to 2.5.

2.5
real_size_ratio bool

Whether to use the real size ratio for the subfigures. Defaults to False.

False
width_correction float

The correction factor for the width of the subfigures. Defaults to 1.0.

1.0
height_correction float

The correction factor for the height of the subfigures. Defaults to 1.0.

1.0
space_x Optional[float]

The space between subfigures in the x direction. Defaults to 0.7.

0.7
space_y Optional[float]

The space between subfigures in the y direction. Defaults to 0.1.

0.1
label_x Optional[str]

The label for the x-axis. Defaults to "x".

'x'
label_y Optional[str]

The label for the y-axis. Defaults to "y".

'y'
label_t Optional[str]

The label for the time index. Defaults to "t".

't'
ticks_t Tuple[Sequence[float], Sequence[str]]

Custom ticks for the time index. Defaults to None.

None
ticks_x Tuple[Sequence[float], Sequence[str]]

Custom ticks for the x-axis. Defaults to None.

None
ticks_y Tuple[Sequence[float], Sequence[str]]

Custom ticks for the y-axis. Defaults to None.

None
show_ticks Union[Literal['auto'], bool]

Whether to show ticks. Defaults to "auto".

'auto'
show_time_index bool

Whether to show the time index in the plot. Defaults to True.

True
animation bool

Whether to create an animation. Defaults to True.

True
fps int

The frames per second for the animation. Defaults to 30.

30
show_in_notebook bool

Whether to show the plot in a Jupyter notebook. Defaults to True.

True
animation_engine Literal['jshtml', 'html5']

The engine to use for the animation. Defaults to "html5".

'html5'
save_name Optional[str]

The name of the file to save the plot. Defaults to None.

None

Returns: FuncAnimation: If animation is True and not show_in_notebook, returns a FuncAnimation object.

Source code in torchfsm/plot/app/slice.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def plot_traj_slice(
    traj: Union[SpatialTensor["B T C H ..."], SpatialArray["B T C H ..."]],
    slice_control: Sequence[Optional[Union[int, float]]] = None,
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_z: Optional[str] = "z",
    label_t: Optional[str] = "t",
    ticks_t: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_z: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    show_time_index: bool = True,
    animation: bool = True,
    fps=30,
    show_in_notebook: bool = True,
    animation_engine: Literal["jshtml", "html5"] = "html5",
    save_name: Optional[str] = None,
    **kwargs,
) ->Optional[FuncAnimation]:
    """
    Plot the trajectory slices.

    Args:
        traj (Union[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]]): The trajectory to plot.
        slice_control (Sequence[Optional[Union[int, float]]], optional): The control points for slicing the trajectory. Defaults to None.
            If None, it will slice at the middle of each dimension.
        channel_names (Optional[Sequence[str]], optional): The names of the channels. Defaults to None.
        batch_names (Optional[Sequence[str]], optional): The names of the batches. Defaults to None.
        title (Optional[str], optional): The title of the plot. Defaults to None.
        vmin (Optional[Union[float, Sequence[float]]], optional): The minimum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.
        vmax (Optional[Union[float, Sequence[float]]], optional): The maximum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.
        cmap (Union[str, Colormap], optional): The colormap to use. Defaults to "twilight".
        use_sym_colormap (bool, optional): Whether to use a symmetric colormap. Defaults to False.
        alpha_func (Union[Literal["zigzag","central_peak","central_valley","linear_increase","linear_decrease",],AlphaFunction,], optional): The alpha function for the colormap when plot 3D data. Defaults to "zigzag".
        num_colorbar_value (int, optional): The number of values for the colorbar. Defaults to 4.
        c_bar_labels (Optional[Sequence[str]], optional): The labels for the colorbar. Defaults to None.
            If provided, it should have the same length as the number of channels.
            If not provided, the colorbar will not have labels.
        cbar_pad (Optional[float], optional): The padding for the colorbar. Defaults to 0.1.
        ctick_format (Optional[str], optional): The format for the colorbar ticks. Defaults to "%.1f".
        subfig_size (float, optional): The size of the subfigures. Defaults to 2.5.
        real_size_ratio (bool, optional): Whether to use the real size ratio for the subfigures. Defaults to False.
        width_correction (float, optional): The correction factor for the width of the subfigures. Defaults to 1.0.
        height_correction (float, optional): The correction factor for the height of the subfigures. Defaults to 1.0.
        space_x (Optional[float], optional): The space between subfigures in the x direction. Defaults to 0.7.
        space_y (Optional[float], optional): The space between subfigures in the y direction. Defaults to 0.1.
        label_x (Optional[str], optional): The label for the x-axis. Defaults to "x".
        label_y (Optional[str], optional): The label for the y-axis. Defaults to "y".
        label_t (Optional[str], optional): The label for the time index. Defaults to "t".
        ticks_t (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the time index. Defaults to None.
        ticks_x (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the x-axis. Defaults to None.
        ticks_y (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the y-axis. Defaults to None.
        show_ticks (Union[Literal["auto"], bool], optional): Whether to show ticks. Defaults to "auto".
        show_time_index (bool, optional): Whether to show the time index in the plot. Defaults to True.
        animation (bool, optional): Whether to create an animation. Defaults to True.
        fps (int, optional): The frames per second for the animation. Defaults to 30.
        show_in_notebook (bool, optional): Whether to show the plot in a Jupyter notebook. Defaults to True.
        animation_engine (Literal["jshtml", "html5"], optional): The engine to use for the animation. Defaults to "html5".
        save_name (Optional[str], optional): The name of the file to save the plot. Defaults to None.
    Returns:
        FuncAnimation: If `animation` is True and not show_in_notebook, returns a `FuncAnimation` object.
    """
    n_dim = len(traj.shape) - 3
    if n_dim != 2 and n_dim != 3:
        raise ValueError(
            f"Trajectory must have 2 or 3 spatial dimensions, but got {n_dim}."
        )
    if slice_control is None:
        slice_control = [0.5] * n_dim
    slices = traj_slices(traj, slice_control)
    slice_names, label_xs, label_ys, ticks_xs, ticks_ys = _get_slice_names(
        slice_control=slice_control,
        label_x=label_x,
        label_y=label_y,
        label_z=label_z,
        ticks_x=ticks_x,
        ticks_y=ticks_y,
        ticks_z=ticks_z,
    )
    batch_names = default(batch_names, [f"batch {i}" for i in range(traj.shape[0])])
    batch_names = [
        [f"{batch_names[j]}, {slice_names[i]}" for j in range(len(batch_names))]
        for i in range(len(slice_names))
    ]
    return concate_traj_plots(
        ploters=ChannelWisedPlotter(),
        trajs=slices,
        channel_names=channel_names,
        batch_names=batch_names,
        vmin=vmin,
        vmax=vmax,
        subfig_size=subfig_size,
        space_x=space_x,
        space_y=space_y,
        cbar_pad=cbar_pad,
        c_bar_labels=c_bar_labels,
        real_size_ratio=real_size_ratio,
        num_colorbar_value=num_colorbar_value,
        ctick_format=ctick_format,
        show_ticks=show_ticks,
        use_sym_colormap=use_sym_colormap,
        cmap=cmap,
        ticks_x=ticks_xs,
        ticks_y=ticks_ys,
        ticks_t=ticks_t,
        save_name=save_name,
        alpha_func=alpha_func,
        show_time_index=show_time_index,
        animation=animation,
        fps=fps,
        show_in_notebook=show_in_notebook,
        animation_engine=animation_engine,
        title=title,
        width_correction=width_correction,
        height_correction=height_correction,
        label_x=label_xs,
        label_y=label_ys,
        label_t=label_t,
        **kwargs,
    )

torchfsm.plot.plot_traj_frame_slice ¤

plot_traj_frame_slice(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
    ],
    slice_control: Sequence[
        Optional[Union[int, float]]
    ] = None,
    n_frames: int = 5,
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_z: Optional[str] = "z",
    label_t: Optional[str] = "t",
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_z: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    save_name: Optional[str] = None,
    compare_mode: Literal[
        "t_wised",
        "channel_wised",
        "channel_wised_universal",
    ] = "channel_wised_universal",
    frame_start_index: int = 0,
    **kwargs
)

Plot the trajectory slices.

Parameters:

Name Type Description Default
traj Union[SpatialTensor['B C H ...'], SpatialArray['B C H ...']]

The trajectory to plot.

required
slice_control Sequence[Optional[Union[int, float]]]

The control points for slicing the trajectory. Defaults to None. If None, it will slice at the middle of each dimension.

None
n_frames int

The number of frames to plot.

5
channel_names Optional[Sequence[str]]

The names of the channels. Defaults to None.

None
batch_names Optional[Sequence[str]]

The names of the batches. Defaults to None.

None
title Optional[str]

The title of the plot. Defaults to None.

None
vmin Optional[Union[float, Sequence[float]]]

The minimum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.

None
vmax Optional[Union[float, Sequence[float]]]

The maximum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.

None
cmap Union[str, Colormap]

The colormap to use. Defaults to "twilight".

'twilight'
use_sym_colormap bool

Whether to use a symmetric colormap. Defaults to False.

False
alpha_func Union[Literal['zigzag', 'central_peak', 'central_valley', 'linear_increase', 'linear_decrease'], AlphaFunction]

The alpha function for the colormap when plot 3D data. Defaults to "zigzag".

'zigzag'
num_colorbar_value int

The number of values for the colorbar. Defaults to 4.

4
c_bar_labels Optional[Sequence[str]]

The labels for the colorbar. Defaults to None. If provided, it should have the same length as the number of channels. If not provided, the colorbar will not have labels.

None
cbar_pad Optional[float]

The padding for the colorbar. Defaults to 0.1.

0.1
ctick_format Optional[str]

The format for the colorbar ticks. Defaults to "%.1f".

'%.1f'
subfig_size float

The size of the subfigures. Defaults to 2.5.

2.5
real_size_ratio bool

Whether to use the real size ratio for the subfigures. Defaults to False.

False
width_correction float

The correction factor for the width of the subfigures. Defaults to 1.0.

1.0
height_correction float

The correction factor for the height of the subfigures. Defaults to 1.0.

1.0
space_x Optional[float]

The space between subfigures in the x direction. Defaults to 0.7.

0.7
space_y Optional[float]

The space between subfigures in the y direction. Defaults to 0.1.

0.1
label_x Optional[str]

The label for the x-axis. Defaults to "x".

'x'
label_y Optional[str]

The label for the y-axis. Defaults to "y".

'y'
label_t Optional[str]

The label for the time index. Defaults to "t".

't'
ticks_t Tuple[Sequence[float], Sequence[str]]

Custom ticks for the time index. Defaults to None.

required
ticks_x Tuple[Sequence[float], Sequence[str]]

Custom ticks for the x-axis. Defaults to None.

None
ticks_y Tuple[Sequence[float], Sequence[str]]

Custom ticks for the y-axis. Defaults to None.

None
show_ticks Union[Literal['auto'], bool]

Whether to show ticks. Defaults to "auto".

'auto'
save_name Optional[str]

The name of the file to save the plot. Defaults to None.

None
compare_mode Literal['t_wised', 'channel_wised', 'channel_wised_universal']

The mode to compare the data. Defaults to "channel_wised_universal". - "t_wised": Compare the data across time. - "channel_wised": Compare the data across channels. - "channel_wised_universal": Compare the data across channels with universal min-max scaling

'channel_wised_universal'
frame_start_index int

The starting index for the frame numbers. Defaults to 0.

0

Returns: FuncAnimation: If animation is True and not show_in_notebook, returns a FuncAnimation object.

Source code in torchfsm/plot/app/frame_slice.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def plot_traj_frame_slice(
    traj: Union[SpatialTensor["B T C H ..."], SpatialArray["B T C H ..."]],
    slice_control: Sequence[Optional[Union[int, float]]] = None,
    n_frames: int = 5,
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_z: Optional[str] = "z",
    label_t: Optional[str] = "t",
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_z: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    save_name: Optional[str] = None,
    compare_mode: Literal[
        "t_wised", "channel_wised", "channel_wised_universal"
    ] = "channel_wised_universal",
    frame_start_index: int = 0,
    **kwargs,
):
    """
    Plot the trajectory slices.

    Args:
        traj (Union[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]]): The trajectory to plot.
        slice_control (Sequence[Optional[Union[int, float]]], optional): The control points for slicing the trajectory. Defaults to None.
            If None, it will slice at the middle of each dimension.
        n_frames (int): The number of frames to plot.
        channel_names (Optional[Sequence[str]], optional): The names of the channels. Defaults to None.
        batch_names (Optional[Sequence[str]], optional): The names of the batches. Defaults to None.
        title (Optional[str], optional): The title of the plot. Defaults to None.
        vmin (Optional[Union[float, Sequence[float]]], optional): The minimum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.
        vmax (Optional[Union[float, Sequence[float]]], optional): The maximum value for the color scale. Defaults to None. If a sequence is provided, it should have the same length as the number of channels.
        cmap (Union[str, Colormap], optional): The colormap to use. Defaults to "twilight".
        use_sym_colormap (bool, optional): Whether to use a symmetric colormap. Defaults to False.
        alpha_func (Union[Literal["zigzag","central_peak","central_valley","linear_increase","linear_decrease",],AlphaFunction,], optional): The alpha function for the colormap when plot 3D data. Defaults to "zigzag".
        num_colorbar_value (int, optional): The number of values for the colorbar. Defaults to 4.
        c_bar_labels (Optional[Sequence[str]], optional): The labels for the colorbar. Defaults to None.
            If provided, it should have the same length as the number of channels.
            If not provided, the colorbar will not have labels.
        cbar_pad (Optional[float], optional): The padding for the colorbar. Defaults to 0.1.
        ctick_format (Optional[str], optional): The format for the colorbar ticks. Defaults to "%.1f".
        subfig_size (float, optional): The size of the subfigures. Defaults to 2.5.
        real_size_ratio (bool, optional): Whether to use the real size ratio for the subfigures. Defaults to False.
        width_correction (float, optional): The correction factor for the width of the subfigures. Defaults to 1.0.
        height_correction (float, optional): The correction factor for the height of the subfigures. Defaults to 1.0.
        space_x (Optional[float], optional): The space between subfigures in the x direction. Defaults to 0.7.
        space_y (Optional[float], optional): The space between subfigures in the y direction. Defaults to 0.1.
        label_x (Optional[str], optional): The label for the x-axis. Defaults to "x".
        label_y (Optional[str], optional): The label for the y-axis. Defaults to "y".
        label_t (Optional[str], optional): The label for the time index. Defaults to "t".
        ticks_t (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the time index. Defaults to None.
        ticks_x (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the x-axis. Defaults to None.
        ticks_y (Tuple[Sequence[float], Sequence[str]], optional): Custom ticks for the y-axis. Defaults to None.
        show_ticks (Union[Literal["auto"], bool], optional): Whether to show ticks. Defaults to "auto".
        save_name (Optional[str], optional): The name of the file to save the plot. Defaults to None.
        compare_mode (Literal["t_wised", "channel_wised", "channel_wised_universal"], optional): The mode to compare the data. Defaults to "channel_wised_universal".
            - "t_wised": Compare the data across time.
            - "channel_wised": Compare the data across channels.
            - "channel_wised_universal": Compare the data across channels with universal min-max scaling
        frame_start_index (int): The starting index for the frame numbers. Defaults to 0.
    Returns:
        FuncAnimation: If `animation` is True and not show_in_notebook, returns a `FuncAnimation` object.
    """
    n_dim = len(traj.shape) - 3
    if n_dim != 2 and n_dim != 3:
        raise ValueError(
            f"Trajectory must have 2 or 3 spatial dimensions, but got {n_dim}."
        )
    if traj.shape[1] < n_frames:
        raise ValueError(
            f"Trajectory has only {traj.shape[1]} frames, but {n_frames} are requested."
        )
    # make_slices:
    if slice_control is None:
        slice_control = [0.5] * n_dim
    slices = traj_slices(traj, slice_control)
    slice_names, label_xs, label_ys, ticks_xs, ticks_ys = _get_slice_names(
        slice_control=slice_control,
        label_x=label_x,
        label_y=label_y,
        label_z=label_z,
        ticks_x=ticks_x,
        ticks_y=ticks_y,
        ticks_z=ticks_z,
    )
    batch_names = default(batch_names, [f"batch {i}" for i in range(traj.shape[0])])
    batch_names = [
        [f"{batch_names[j]}, {slice_names[i]}" for j in range(len(batch_names))]
        for i in range(len(slice_names))
    ]
    return _plot_traj_frame_group(
        trajs=slices,
        n_frames=n_frames,
        channel_names=channel_names,
        batch_names=batch_names,
        title=title,
        vmin=vmin,
        vmax=vmax,
        cmap=cmap,
        use_sym_colormap=use_sym_colormap,
        alpha_func=alpha_func,
        num_colorbar_value=num_colorbar_value,
        c_bar_labels=c_bar_labels,
        cbar_pad=cbar_pad,
        ctick_format=ctick_format,
        subfig_size=subfig_size,
        real_size_ratio=real_size_ratio,
        width_correction=width_correction,
        height_correction=height_correction,
        space_x=space_x,
        space_y=space_y,
        label_x=label_xs,
        label_y=label_ys,
        label_t=label_t,
        ticks_x=ticks_xs,
        ticks_y=ticks_ys,
        show_ticks=show_ticks,
        save_name=save_name,
        compare_mode=compare_mode,
        frame_start_index=frame_start_index,
        **kwargs,
    )

torchfsm.plot.plot_field_slice ¤

plot_field_slice(
    field: Union[
        SpatialTensor["B C H ..."],
        SpatialArray["B C H ..."],
    ],
    slice_control: Sequence[
        Optional[Union[int, float]]
    ] = None,
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_z: Optional[str] = "z",
    label_t: Optional[str] = "t",
    ticks_t: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_z: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    save_name: Optional[str] = None,
    **kwargs
)
Source code in torchfsm/plot/app/slice.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def plot_field_slice(
    field: Union[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]],
    slice_control: Sequence[Optional[Union[int, float]]] = None,
    channel_names: Optional[Sequence[str]] = None,
    batch_names: Optional[Sequence[str]] = None,
    title: Optional[str] = None,
    vmin: Optional[Union[float, Sequence[float]]] = None,
    vmax: Optional[Union[float, Sequence[float]]] = None,
    cmap: Union[str, Colormap] = "twilight",
    use_sym_colormap: bool = False,
    alpha_func: Union[
        Literal[
            "zigzag",
            "central_peak",
            "central_valley",
            "linear_increase",
            "linear_decrease",
        ],
        AlphaFunction,
    ] = "zigzag",
    num_colorbar_value: int = 4,
    c_bar_labels: Optional[Sequence[str]] = None,
    cbar_pad: Optional[float] = 0.1,
    ctick_format: Optional[str] = "%.1f",
    subfig_size: float = 2.5,
    real_size_ratio: bool = False,
    width_correction: float = 1.0,
    height_correction: float = 1.0,
    space_x: Optional[float] = 0.7,
    space_y: Optional[float] = 0.1,
    label_x: Optional[str] = "x",
    label_y: Optional[str] = "y",
    label_z: Optional[str] = "z",
    label_t: Optional[str] = "t",
    ticks_t: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_x: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_y: Tuple[Sequence[float], Sequence[str]] = None,
    ticks_z: Tuple[Sequence[float], Sequence[str]] = None,
    show_ticks: Union[Literal["auto"], bool] = "auto",
    save_name: Optional[str] = None,
    **kwargs,
):
    if isinstance(field, torch.Tensor):
        field = field.cpu().detach().numpy()
    field = np.expand_dims(field, 1)
    return plot_traj_slice(
        traj=field,
        slice_control=slice_control,
        channel_names=channel_names,
        batch_names=batch_names,
        title=title,
        vmin=vmin,
        vmax=vmax,
        cmap=cmap,
        use_sym_colormap=use_sym_colormap,
        alpha_func=alpha_func,
        num_colorbar_value=num_colorbar_value,
        c_bar_labels=c_bar_labels,
        cbar_pad=cbar_pad,
        ctick_format=ctick_format,
        subfig_size=subfig_size,
        real_size_ratio=real_size_ratio,
        width_correction=width_correction,
        height_correction=height_correction,
        space_x=space_x,
        space_y=space_y,
        label_x=label_x,
        label_y=label_y,
        label_z=label_z,
        label_t=label_t,
        ticks_t=ticks_t,
        ticks_x=ticks_x,
        ticks_y=ticks_y,
        ticks_z=ticks_z,
        show_ticks=show_ticks,
        save_name=save_name,
        show_time_index=False,
        **kwargs,
    )