Skip to content

9. utils

torchfsm.utils.clean_up_memory ¤

clean_up_memory()

Clean up the memory by calling garbage collector and emptying the cache.

Source code in torchfsm/_utils.py
43
44
45
46
47
48
def clean_up_memory():
    """
    Clean up the memory by calling garbage collector and emptying the cache.
    """
    gc.collect()
    torch.cuda.empty_cache()

torchfsm.utils.print_gpu_memory ¤

print_gpu_memory(prefix='', device='cuda:1')
Source code in torchfsm/_utils.py
50
51
52
53
def print_gpu_memory(prefix="",device="cuda:1"):
    allocated = torch.cuda.memory_allocated(device)
    reserved = torch.cuda.memory_reserved(device)
    print(f"{prefix}Allocated: {allocated / 1024**2:.2f} MB, Reserved: {reserved / 1024**2:.2f} MB")

torchfsm.utils.statistics_traj ¤

statistics_traj(
    traj: ValueList[Union[Tensor, ndarray]],
) -> Tuple[float, float, float, float]

Compute the mean, std, min, and max of a trajectory.

Parameters:

Name Type Description Default
traj ValueList[Union[Tensor, ndarray]]

The trajectory to compute statistics for. The trajectory can be a list of tensors or numpy arrays, or a single tensor or numpy array.

required

Returns:

Name Type Description
tuple Tuple[float, float, float, float]

A tuple containing the mean, std, min, and max of the trajectory. The mean, std, min, and max are computed along the first dimension of the trajectory.

Source code in torchfsm/utils/traj_manipulate.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def statistics_traj(traj: ValueList[Union[torch.Tensor, np.ndarray]])-> Tuple[float, float, float, float]:
    """
    Compute the mean, std, min, and max of a trajectory.

    Args:
        traj (ValueList[Union[torch.Tensor, np.ndarray]]): The trajectory to compute statistics for.
            The trajectory can be a list of tensors or numpy arrays, or a single tensor or numpy array.

    Returns:
        tuple: A tuple containing the mean, std, min, and max of the trajectory.
            The mean, std, min, and max are computed along the first dimension of the trajectory.
    """
    # [B, T, C, H, ...]
    if not isinstance(traj, list):
        traj = [traj]
    traj = [
        traj_i if isinstance(traj_i, torch.Tensor) else torch.from_numpy(traj_i)
        for traj_i in traj
    ]
    new_shape = tuple([-1] + list(traj[0].shape[2:]))
    traj_all = torch.cat([t.reshape(new_shape) for t in traj], dim=0)
    means = [traj_all[:, i].mean().item() for i in range(traj_all.shape[1])]
    stds = [traj_all[:, i].std().item() for i in range(traj_all.shape[1])]
    mins = [traj_all[:, i].min().item() for i in range(traj_all.shape[1])]
    maxs = [traj_all[:, i].max().item() for i in range(traj_all.shape[1])]
    return means, stds, mins, maxs

torchfsm.utils.randomly_clip_traj ¤

randomly_clip_traj(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
        FourierTensor["B T C H ..."],
        FourierArray["B T C H ..."],
    ],
    length: int,
) -> Union[
    SpatialTensor["B length C H ..."],
    SpatialArray["B length C H ..."],
    FourierTensor["B length C H ..."],
    FourierArray["B length C H ..."],
]

Randomly clip a trajectory to a specified length.

Parameters:

Name Type Description Default
traj Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]

The trajectory to clip. The trajectory can be a tensor or numpy array with shape [B, T, C, H, ...].

required
length int

The length to clip the trajectory to. The length should be less than the original length of the trajectory.

required

Returns:

Type Description
Union[SpatialTensor['B length C H ...'], SpatialArray['B length C H ...'], FourierTensor['B length C H ...'], FourierArray['B length C H ...']]

Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]: The clipped trajectory. The clipped trajectory has shape [B, length, C, H, ...].

Source code in torchfsm/utils/traj_manipulate.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def randomly_clip_traj(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
        FourierTensor["B T C H ..."],
        FourierArray["B T C H ..."],
    ],
    length: int,
)-> Union[
        SpatialTensor["B length C H ..."],
        SpatialArray["B length C H ..."],
        FourierTensor["B length C H ..."],
        FourierArray["B length C H ..."],
    ]:

    """
    Randomly clip a trajectory to a specified length.

    Args:
        traj (Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]): The trajectory to clip.
            The trajectory can be a tensor or numpy array with shape [B, T, C, H, ...].
        length (int): The length to clip the trajectory to.
            The length should be less than the original length of the trajectory.

    Returns:
        Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]: The clipped trajectory.
            The clipped trajectory has shape [B, length, C, H, ...].
    """
    is_nparray = isinstance(traj, np.ndarray)
    traj = torch.from_numpy(traj) if is_nparray else traj
    new_traj = []
    ori_len_time = traj.shape[1]
    start = torch.randint(0, ori_len_time - length, (traj.shape[0],))
    end = start + length
    for i in range(traj.shape[0]):
        new_traj.append(traj[i, start[i] : end[i]])
    new_traj = torch.stack(new_traj, dim=0)
    new_traj = new_traj.numpy() if is_nparray else new_traj
    return new_traj

torchfsm.utils.randomly_select_frames ¤

randomly_select_frames(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
        FourierTensor["B T C H ..."],
        FourierArray["B T C H ..."],
    ],
    n_frames: int,
    return_frame_indices: bool = False,
) -> Union[
    SpatialTensor["B n_frames C H ..."],
    SpatialArray["B n_frames C H ..."],
    FourierTensor["B n_frames C H ..."],
    FourierArray["B n_frames C H ..."],
]

Randomly select a specified number of frames from a trajectory.

Parameters:

Name Type Description Default
traj Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]

The trajectory to select frames from. The trajectory can be a tensor or numpy array with shape [B, T, C, H, ...].

required
n_frames int

The number of frames to select. The number of frames should be less than the original length of the trajectory.

required

Returns:

Type Description
Union[SpatialTensor['B n_frames C H ...'], SpatialArray['B n_frames C H ...'], FourierTensor['B n_frames C H ...'], FourierArray['B n_frames C H ...']]

Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]: The selected frames. The selected frames have shape [B, n_frames, C, H, ...].

Source code in torchfsm/utils/traj_manipulate.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def randomly_select_frames(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
        FourierTensor["B T C H ..."],
        FourierArray["B T C H ..."],
    ],
    n_frames: int,
    return_frame_indices: bool = False
)-> Union[
        SpatialTensor["B n_frames C H ..."],
        SpatialArray["B n_frames C H ..."],
        FourierTensor["B n_frames C H ..."],
        FourierArray["B n_frames C H ..."],
    ]:
    """
    Randomly select a specified number of frames from a trajectory.

    Args:
        traj (Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]): The trajectory to select frames from.
            The trajectory can be a tensor or numpy array with shape [B, T, C, H, ...].
        n_frames (int): The number of frames to select.
            The number of frames should be less than the original length of the trajectory.

    Returns:
        Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]: The selected frames.
            The selected frames have shape [B, n_frames, C, H, ...].
    """
    is_nparray = isinstance(traj, np.ndarray)
    traj = torch.from_numpy(traj) if is_nparray else traj
    ori_len_time = traj.shape[1]
    selected_frames = torch.randint(0, ori_len_time, (n_frames,))
    new_traj = traj[:, selected_frames]
    new_traj = new_traj.numpy() if is_nparray else new_traj
    if return_frame_indices:
        return new_traj, selected_frames
    return new_traj

torchfsm.utils.uniformly_select_frames ¤

uniformly_select_frames(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
        FourierTensor["B T C H ..."],
        FourierArray["B T C H ..."],
    ],
    n_frames: int,
    return_frame_indices: bool = False,
) -> Union[
    SpatialTensor["B n_frames C H ..."],
    SpatialArray["B n_frames C H ..."],
    FourierTensor["B n_frames C H ..."],
    FourierArray["B n_frames C H ..."],
]

Uniformly select a specified number of frames from a trajectory.

Parameters:

Name Type Description Default
traj Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]

The trajectory to select frames from. The trajectory can be a tensor or numpy array with shape [B, T, C, H, ...].

required
n_frames int

The number of frames to select. The number of frames should be less than the original length of the trajectory.

required

Returns:

Type Description
Union[SpatialTensor['B n_frames C H ...'], SpatialArray['B n_frames C H ...'], FourierTensor['B n_frames C H ...'], FourierArray['B n_frames C H ...']]

Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]: The selected frames. The selected frames have shape [B, n_frames, C, H, ...].

Source code in torchfsm/utils/traj_manipulate.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def uniformly_select_frames(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
        FourierTensor["B T C H ..."],
        FourierArray["B T C H ..."],
    ],
    n_frames: int,
    return_frame_indices: bool = False
)-> Union[
        SpatialTensor["B n_frames C H ..."],
        SpatialArray["B n_frames C H ..."],
        FourierTensor["B n_frames C H ..."],
        FourierArray["B n_frames C H ..."],
    ]:
    """
    Uniformly select a specified number of frames from a trajectory.

    Args:
        traj (Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]): The trajectory to select frames from.
            The trajectory can be a tensor or numpy array with shape [B, T, C, H, ...].
        n_frames (int): The number of frames to select.
            The number of frames should be less than the original length of the trajectory.

    Returns:
        Union[SpatialTensor, SpatialArray, FourierTensor, FourierArray]: The selected frames.
            The selected frames have shape [B, n_frames, C, H, ...].
    """
    is_nparray = isinstance(traj, np.ndarray)
    traj = torch.from_numpy(traj) if is_nparray else traj
    ori_len_time = traj.shape[1]
    selected_frames = torch.linspace(0, ori_len_time - 1, n_frames).long()
    new_traj = traj[:, selected_frames]
    new_traj = new_traj.numpy() if is_nparray else new_traj
    if return_frame_indices:
        return new_traj, selected_frames
    return new_traj  

torchfsm.utils.default ¤

default(value, default)

Return the default value if the value is None.

Parameters:

Name Type Description Default
value

The value to check.

required
default

The default value to return if value is None.

required

Returns:

Type Description

The value if it is not None, otherwise the default value.

Source code in torchfsm/_utils.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
def default(value, default):
    """
    Return the default value if the value is None.

    Args:
        value: The value to check.
        default: The default value to return if value is None.

    Returns:
        The value if it is not None, otherwise the default value.
    """
    return value if value is not None else default

torchfsm.utils.format_device_dtype ¤

format_device_dtype(
    device: Optional[Union[device, str]] = None,
    dtype: Optional[dtype] = None,
) -> Tuple[torch.device, torch.dtype]

Format the device and dtype for PyTorch.

Parameters:

Name Type Description Default
device Optional[Union[device, str]]

The device to use. If None, defaults to CPU.

None
dtype Optional[dtype]

The data type to use. If None, defaults to float32.

None

Returns:

Type Description
Tuple[device, dtype]

tuple[torch.device, torch.dtype]: The formatted device and dtype.

Source code in torchfsm/_utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def format_device_dtype(
    device: Optional[Union[torch.device, str]] = None,
    dtype: Optional[torch.dtype] = None,
)-> Tuple[torch.device, torch.dtype]:
    """
    Format the device and dtype for PyTorch.

    Args:
        device (Optional[Union[torch.device, str]]): The device to use. If None, defaults to CPU.
        dtype (Optional[torch.dtype]): The data type to use. If None, defaults to float32.

    Returns:
        tuple[torch.device, torch.dtype]: The formatted device and dtype.
    """
    if device is None:
        device = torch.device("cpu")
    elif isinstance(device, str):
        device = torch.device(device)
        if device.index is None and device.type != "cpu":
            device = torch.device(device.type, 0)
    dtype = default(dtype, torch.float32)
    return device, dtype

torchfsm.utils.traj_slices ¤

traj_slices(
    traj: Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
    ],
    slice_control: Sequence[Optional[Union[int, float]]],
) -> Sequence[
    Union[
        SpatialTensor["B T C H ..."],
        SpatialArray["B T C H ..."],
    ]
]

Slice a trajectory along specified dimensions.

Parameters:

Name Type Description Default
traj Union[SpatialTensor['B T C H ...'], SpatialArray['B T C H ...']]

The trajectory to slice.

required
slice_control Sequence[Optional[Union[int, float]]]

A sequence of slice values for each dimension. If a value is None, that dimension will not be sliced. If a value is negative, it will slice from the end of that dimension. If a value is positive, it will slice from the start of that dimension.

required

Returns:

Type Description
Sequence[Union[SpatialTensor['B T C H ...'], SpatialArray['B T C H ...']]]

Union[Sequence[SpatialTensor["B T C H ..."], SpatialArray["B T C H ..."]]]: A sequence of sliced trajectories. Each element corresponds to a slice along one dimension.

Source code in torchfsm/utils/slice.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def traj_slices(
    traj: Union[SpatialTensor["B T C H ..."], SpatialArray["B T C H ..."]],
    slice_control: Sequence[Optional[Union[int, float]]],
) -> Sequence[Union[SpatialTensor["B T C H ..."], SpatialArray["B T C H ..."]]]:
    """
    Slice a trajectory along specified dimensions.

    Args:
        traj (Union[SpatialTensor["B T C H ..."], SpatialArray["B T C H ..."]]): The trajectory to slice.
        slice_control (Sequence[Optional[Union[int,float]]]): A sequence of slice values for each dimension.
            If a value is None, that dimension will not be sliced.
            If a value is negative, it will slice from the end of that dimension.
            If a value is positive, it will slice from the start of that dimension.

    Returns:
        Union[Sequence[SpatialTensor["B T C H ..."], SpatialArray["B T C H ..."]]]:
            A sequence of sliced trajectories. Each element corresponds to a slice along one dimension.
    """
    n_dim = len(traj.shape) - 3
    if n_dim != len(slice_control):
        raise ValueError(
            f"The number of slice control values {len(slice_control)} should be equal to the number of dimensions {n_dim} in the input trajectory."
        )
    if n_dim == 1:
        raise ValueError("Cannot slice 1D trajectory.")
    re = []
    if len(slice_control) != n_dim:
        raise ValueError(
            f"The number of slice control values {len(slice_control)} should be equal to the number of dimensions {n_dim} in the input trajectory."
        )
    for i, slice_value in enumerate(slice_control):
        if slice_value is None:
            continue
        if slice_value < 1:
            slice_value = int(traj.shape[i + 3] * slice_value)
        else:
            slice_value = int(slice_value)
        if i == 0:
            re.append(traj[:, :, :, slice_value, ...])
        elif i == 1:
            re.append(traj[:, :, :, :, slice_value, ...])
        elif i == 2:
            re.append(traj[:, :, :, :, :, slice_value])
    if len(re) == 0:
        raise ValueError("No slice value provided.")
    return re

torchfsm.utils.field_slices ¤

field_slices(
    field: Union[
        SpatialTensor["B C H ..."],
        SpatialArray["B C H ..."],
    ],
    slice_control: Sequence[Optional[Union[int, float]]],
) -> Sequence[
    Union[
        SpatialTensor["B C H ..."],
        SpatialArray["B C H ..."],
    ]
]

Slice a field along specified dimensions. Args: field (Union[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]]): The field to slice. slice_control (Sequence[Optional[Union[int,float]]]): A sequence of slice values for each dimension. If a value is None, that dimension will not be sliced. If a value is negative, it will slice from the end of that dimension. If a value is positive, it will slice from the start of that dimension. Returns: Union[Sequence[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]]]: A sequence of sliced fields. Each element corresponds to a slice along one dimension.

Source code in torchfsm/utils/slice.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def field_slices(
    field: Union[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]],
    slice_control: Sequence[Optional[Union[int, float]]],
) -> Sequence[Union[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]]]:
    """
    Slice a field along specified dimensions.
    Args:
        field (Union[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]]): The field to slice.
        slice_control (Sequence[Optional[Union[int,float]]]): A sequence of slice values for each dimension.
            If a value is None, that dimension will not be sliced.
            If a value is negative, it will slice from the end of that dimension.
            If a value is positive, it will slice from the start of that dimension.
    Returns:
        Union[Sequence[SpatialTensor["B C H ..."], SpatialArray["B C H ..."]]]:
            A sequence of sliced fields. Each element corresponds to a slice along one dimension.
    """

    if isinstance(traj, torch.Tensor):
        traj = field.unsqueeze(1)
    if isinstance(traj, np.ndarray):
        traj = np.expand_dims(field, axis=1)
    return traj_slices(traj=traj, slice_control=slice_control)

torchfsm.utils.test_sim_dt ¤

test_sim_dt(
    operator: Operator,
    u_0: Tensor,
    max_sim_dt: float,
    min_sim_dt: float,
    mesh: Optional[
        Union[
            Sequence[tuple[float, float, int]],
            MeshGrid,
            FourierMesh,
        ]
    ] = None,
    stop_criteria: float = None,
    initial_step: int = 1,
    dt_decrement: float = 0.5,
    **kwargs
) -> tuple[dict, dict, dict]

Test the simulation with varying time steps.

Parameters:

Name Type Description Default
operator Operator

The operator to be tested.

required
u_0 Tensor

Initial condition tensor.

required
max_sim_dt float

Maximum simulation time step.

required
min_sim_dt float

Minimum simulation time step.

required
mesh Optional[Union[Sequence[tuple[float, float, int]], MeshGrid, FourierMesh]]

Mesh information or mesh object.

None
stop_criteria float

Criteria to stop the simulation if reached.

None
initial_step int

Initial step size for the simulation.

1
dt_decrement float

Factor by which to reduce the time step on failure.

0.5
**kwargs

Additional keyword arguments for the operator's integrate method.

{}

Returns:

Type Description
tuple[dict, dict, dict]

tuple[dict, dict, dict]: A tuple containing: - errors: Dictionary of errors for each time step. - mean_differences: Mean differences between frames for each time step. - std_differences: Standard deviations of differences for each time step.

Source code in torchfsm/utils/test.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def test_sim_dt(
    operator: Operator,
    u_0: torch.Tensor,
    max_sim_dt: float,
    min_sim_dt: float,
    mesh: Optional[
            Union[Sequence[tuple[float, float, int]], MeshGrid, FourierMesh]
        ] = None,
    stop_criteria: float = None,
    initial_step: int = 1,
    dt_decrement: float = 0.5,
    **kwargs
) -> tuple[dict, dict, dict]:
    """Test the simulation with varying time steps.

    Args:
        operator (Operator): The operator to be tested.
        u_0 (torch.Tensor): Initial condition tensor.
        max_sim_dt (float): Maximum simulation time step.
        min_sim_dt (float): Minimum simulation time step.
        mesh (Optional[Union[Sequence[tuple[float, float, int]], MeshGrid, FourierMesh]]): Mesh information or mesh object.
        stop_criteria (float, optional): Criteria to stop the simulation if reached.
        initial_step (int, optional): Initial step size for the simulation.
        dt_decrement (float, optional): Factor by which to reduce the time step on failure.
        **kwargs: Additional keyword arguments for the operator's integrate method.

    Returns:
        tuple[dict, dict, dict]: A tuple containing:
            - errors: Dictionary of errors for each time step.
            - mean_differences: Mean differences between frames for each time step.
            - std_differences: Standard deviations of differences for each time step.
    """


    current_dt = max_sim_dt
    recorder=_DtTestRecorder()
    current_frame = None
    previous_frame = None
    errors={}
    mean_differences={}
    std_differences={}
    while current_dt > min_sim_dt:
        print(f"Testing with dt={current_dt}")
        try:
            operator.integrate(
                u_0= u_0,
                trajectory_recorder=recorder,
                mesh=mesh,
                dt=current_dt,
                step=int(max_sim_dt // current_dt * initial_step),
                nan_check=True,
                progressive= True,
                **kwargs
            )
        except NanSimulationError:
            print(f"Simulation failed with dt={current_dt}, reducing dt.")
            current_dt *= dt_decrement
            recorder.teardown()
            continue
        mean_dif=np.mean(recorder.differences)
        std_dif=np.std(recorder.differences)
        mean_differences[current_dt] = mean_dif
        std_differences[current_dt] = std_dif
        current_frame = recorder.trajectory
        current_error = None
        if previous_frame is not None:
            current_error = (current_frame - previous_frame).abs().mean().item()
            errors[current_dt] = current_error
            if stop_criteria is not None:
                if current_error < stop_criteria:
                    print(f"Simulation converged with dt={current_dt}, error={current_error:.3e}")
                    break
        previous_frame = current_frame
        current_dt /= 2
        recorder.teardown()
        msg = ""
        if current_error is not None:
            msg += f"Error of the last frame: {current_error:.3e}; "
        msg += f"Difference between frames: {mean_dif:.3e}±{std_dif:.3e}"
        print(msg)
    print("Reach minimum dt:", current_dt)
    return errors, mean_differences, std_differences

torchfsm.utils.collect_energy_spectrum ¤

collect_energy_spectrum(
    u: SpatialTensor["1 C H ..."],
    mesh: Optional[
        Union[
            Sequence[tuple[float, float, int]],
            MeshGrid,
            FourierMesh,
        ]
    ] = None,
    progressive: bool = False,
    exact=False,
    n_bins: Optional[int] = None,
)

Collect the energy spectrum from a spatial tensor with batch size 1.

Parameters:

Name Type Description Default
u SpatialTensor['1 C H ...']

The input spatial tensor with batch size 1.

required
mesh Optional[Union[Sequence[tuple[float, float, int]], MeshGrid, FourierMesh]]

The mesh grid or Fourier mesh to use for FFT. If None, a default mesh is created based on the shape of u. Default is None.

None
progressive bool

Whether to show a progress bar during computation. Default is False.

False
exact bool

Whether to compute the exact energy spectrum without binning. Default is False.

False
n_bins Optional[int]

Number of bins to smooth the spectrum if exact is False. If None, it defaults to min(50, u.shape[-1] // 2). Default is None.

None

Returns:

Type Description

Tuple[List[float], List[float]]: Two lists containing the wave numbers and their corresponding energy spectrum values.

Source code in torchfsm/utils/spectrum.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def collect_energy_spectrum(
    u: SpatialTensor["1 C H ..."],
    mesh: Optional[
        Union[Sequence[tuple[float, float, int]], MeshGrid, FourierMesh]
    ] = None,
    progressive: bool = False,
    exact=False,
    n_bins: Optional[int] = None,
):
    """
    Collect the energy spectrum from a spatial tensor with batch size 1.

    Args:
        u (SpatialTensor["1 C H ..."]): The input spatial tensor with batch size 1.
        mesh (Optional[Union[Sequence[tuple[float, float, int]], MeshGrid, FourierMesh]]): The mesh grid or Fourier mesh to use for FFT. If None, a default mesh is created based on the shape of u. Default is None.
        progressive (bool): Whether to show a progress bar during computation.  Default is False.
        exact (bool): Whether to compute the exact energy spectrum without binning. Default is False.
        n_bins (Optional[int]): Number of bins to smooth the spectrum if exact is False. If None, it defaults to min(50, u.shape[-1] // 2). Default is None.

    Returns:
        Tuple[List[float], List[float]]: Two lists containing the wave numbers and their corresponding energy spectrum values.
    """
    if u.shape[0] != 1:
        raise ValueError("Batch size of u must be 1 for collecting energy spectrum.")
    if mesh is None:
        mesh = [(0, 1, size_i) for size_i in u.shape[2:]]
    if not isinstance(mesh, FourierMesh):
        f_mesh = FourierMesh(mesh, device=u.device)
    else:
        f_mesh = mesh
    u_fft = f_mesh.fft(u)
    k_vec_norm = torch.norm(f_mesh.bf_vector * 2 * torch.pi, dim=1, keepdim=True)
    energy_fft = (
        0.5
        * torch.sum(torch.abs(u_fft) ** 2, dim=1, keepdim=True)
        * (4 * torch.pi * k_vec_norm**2)
    )
    if not exact:
        k_max = torch.max(k_vec_norm)
        if n_bins is None:
            n_bins = min(50, u.shape[-1] // 2)
        k_bins = torch.linspace(0, k_max, n_bins + 1)
        k_centers = (k_bins[1:] + k_bins[:-1]) / 2
        radial = torch.zeros(len(k_centers), device=k_vec_norm.device)
        for i in range(len(k_centers)):
            mask = (k_vec_norm >= k_bins[i]) & (k_vec_norm < k_bins[i + 1])
            if torch.any(mask):
                radial[i] = torch.mean(energy_fft[mask])
        return k_centers.cpu().numpy().tolist(), radial.cpu().numpy().tolist()
    else:
        re = defaultdict(list)
        if progressive:
            iterator = tqdm(
                zip(k_vec_norm.view(-1), energy_fft.view(-1)), total=k_vec_norm.numel()
            )
        else:
            iterator = zip(k_vec_norm.view(-1), energy_fft.view(-1))
        for k, e in iterator:
            re[k.item()].append(e.item())
        for k, e in re.items():
            if k == 0:
                re[k] = 0.0  # Avoid division by zero for k=0
            else:
                re[k] = sum(e) / len(e)
        sorted_k = sorted(re.keys())
        sorted_e = [re[k] for k in sorted_k]
        sorted_k = sorted_k[1:]
        sorted_e = sorted_e[1:]
    return sorted_k, sorted_e