Skip to content

8. traj_recorder

torchfsm.traj_recorder.IntervalController ¤

A class to control the recording of trajectories at specified intervals. This class can be used as an input for the control_func parameters of recorder objects.

Parameters:

Name Type Description Default
interval int

The interval at which to record the trajectory.

1
start int

The step at which to start recording the trajectory.

0
Source code in torchfsm/traj_recorder.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class IntervalController:
    """
    A class to control the recording of trajectories at specified intervals.
        This class can be used as an input for the `control_func` parameters of recorder objects.

    Args:
        interval (int): The interval at which to record the trajectory.
        start (int): The step at which to start recording the trajectory.
    """

    def __init__(
        self,
        interval: int = 1,
        start: int = 0,
    ) -> None:
        self.start = start
        self.interval = interval

    def __call__(self, step: int) -> bool:
        return step >= self.start and (step - self.start) % self.interval == 0
start instance-attribute ¤
start = start
interval instance-attribute ¤
interval = interval
__init__ ¤
__init__(interval: int = 1, start: int = 0) -> None
Source code in torchfsm/traj_recorder.py
18
19
20
21
22
23
24
def __init__(
    self,
    interval: int = 1,
    start: int = 0,
) -> None:
    self.start = start
    self.interval = interval
__call__ ¤
__call__(step: int) -> bool
Source code in torchfsm/traj_recorder.py
26
27
def __call__(self, step: int) -> bool:
    return step >= self.start and (step - self.start) % self.interval == 0

torchfsm.traj_recorder._TrajRecorder ¤

A base class for trajectory recorders. A recorder is an object that helps to control the recording of trajectories during a simulation

Parameters:

Name Type Description Default
control_func Optional[Callable[[int], bool]]

A function that takes a step as input and returns a boolean indicating whether to record the trajectory at that step.

None
include_initial_state bool

If True, the initial state will be included in the trajectory.

True
Source code in torchfsm/traj_recorder.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class _TrajRecorder:
    """
    A base class for trajectory recorders.
        A recorder is an object that helps to control the recording of trajectories during a simulation

    Args:
        control_func (Optional[Callable[[int],bool]]): A function that takes a step as input and returns a boolean indicating whether to record the trajectory at that step.
        include_initial_state (bool): If True, the initial state will be included in the trajectory.
    """

    def __init__(
        self,
        control_func: Optional[Callable[[int], bool]] = None,
        include_initial_state: bool = True,
    ):
        control_func = default(control_func, lambda step: True)
        if include_initial_state:
            self.control_func = lambda step: True if step == 0 else control_func(step)
        else:
            self.control_func = lambda step: False if step == 0 else control_func(step)
        self.return_in_fourier = False
        self._shutdown_flag = False

    def set_shutdown_flag(self):
        """
        Set the shutdown flag to True.
            This will prevent any further recording of trajectories.
        """
        self._shutdown_flag = True

    def record(self, step: int, frame: torch.tensor):
        """
        Record the trajectory at a given step.

        Args:
            step (int): The current step.
            frame (torch.tensor): The current frame to be recorded.
        """
        if self.control_func(step):
            self._record(step, frame)

    def _record(self, step: int, frame: torch.tensor):
        """
        Record the trajectory at a given step.
            This method should be implemented by subclasses.

        Args:
            step (int): The current step.
            frame (torch.tensor): The current frame to be recorded.
        """
        raise NotImplementedError

    def _traj_ifft(self, trajectory: torch.tensor):
        """
        Perform an inverse FFT on the trajectory.

        Args:
            trajectory (torch.tensor): The trajectory to be transformed.

        Returns:
            torch.tensor: The transformed trajectory.
        """
        fft_dim = tuple(-1 * (i + 1) for i in range(len(trajectory.shape) - 3))
        return torch.fft.ifftn(trajectory, dim=fft_dim)

    def _field_ifft(self, field: torch.tensor):
        """
        Perform an inverse FFT on the field.

        Args:
            field (torch.tensor): The field to be transformed.

        Returns:
            torch.tensor: The transformed field.
        """
        fft_dim = tuple(-1 * (i + 1) for i in range(len(field.shape) - 2))
        return torch.fft.ifftn(field, dim=fft_dim)

    @property
    def trajectory(self):
        """
        Get the recorded trajectory.
            This method should be implemented by subclasses.

        Args:
            return_in_fourier (bool): If True, return the trajectory in Fourier space. Default is False.

        Returns:
            torch.tensor: The recorded trajectory.
        """
        raise NotImplementedError
control_func instance-attribute ¤
control_func = lambda step: (
    True if step == 0 else control_func(step)
)
return_in_fourier instance-attribute ¤
return_in_fourier = False
trajectory property ¤
trajectory

Get the recorded trajectory. This method should be implemented by subclasses.

Parameters:

Name Type Description Default
return_in_fourier bool

If True, return the trajectory in Fourier space. Default is False.

required

Returns:

Type Description

torch.tensor: The recorded trajectory.

__init__ ¤
__init__(
    control_func: Optional[Callable[[int], bool]] = None,
    include_initial_state: bool = True,
)
Source code in torchfsm/traj_recorder.py
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    control_func: Optional[Callable[[int], bool]] = None,
    include_initial_state: bool = True,
):
    control_func = default(control_func, lambda step: True)
    if include_initial_state:
        self.control_func = lambda step: True if step == 0 else control_func(step)
    else:
        self.control_func = lambda step: False if step == 0 else control_func(step)
    self.return_in_fourier = False
    self._shutdown_flag = False
set_shutdown_flag ¤
set_shutdown_flag()

Set the shutdown flag to True. This will prevent any further recording of trajectories.

Source code in torchfsm/traj_recorder.py
53
54
55
56
57
58
def set_shutdown_flag(self):
    """
    Set the shutdown flag to True.
        This will prevent any further recording of trajectories.
    """
    self._shutdown_flag = True
record ¤
record(step: int, frame: tensor)

Record the trajectory at a given step.

Parameters:

Name Type Description Default
step int

The current step.

required
frame tensor

The current frame to be recorded.

required
Source code in torchfsm/traj_recorder.py
60
61
62
63
64
65
66
67
68
69
def record(self, step: int, frame: torch.tensor):
    """
    Record the trajectory at a given step.

    Args:
        step (int): The current step.
        frame (torch.tensor): The current frame to be recorded.
    """
    if self.control_func(step):
        self._record(step, frame)

torchfsm.traj_recorder.AutoRecorder ¤

Bases: _TrajRecorder

A recorder that save the trajectory at the same devices as the simulation.

Parameters:

Name Type Description Default
control_func Optional[Callable[[int], bool]]

A function that takes a step as input and returns a boolean indicating whether to record the trajectory at that step.

None
include_initial_state bool

If True, the initial state will be included in the trajectory.

True
Source code in torchfsm/traj_recorder.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class AutoRecorder(_TrajRecorder):
    """
    A recorder that save the trajectory at the same devices as the simulation.

    Args:
        control_func (Optional[Callable[[int],bool]]): A function that takes a step as input and returns a boolean indicating whether to record the trajectory at that step.
        include_initial_state (bool): If True, the initial state will be included in the trajectory.
    """

    def __init__(
        self,
        control_func: Optional[Callable[[int], bool]] = None,
        include_initial_state: bool = True,
    ):
        super().__init__(control_func, include_initial_state)
        self._trajectory = []

    def _record(self, step: int, frame: torch.tensor):
        if not isinstance(self._trajectory, torch.Tensor):
            self._trajectory.append(frame.clone())
        else:
            raise RuntimeError("The trajectory has been finalized.")

    @property
    def trajectory(self):
        if len(self._trajectory) == 0:
            return None
        if self.return_in_fourier:
            return torch.stack(self._trajectory, dim=1)
        else:
            return self._traj_ifft(torch.stack(self._trajectory, dim=1)).real
trajectory property ¤
trajectory
control_func instance-attribute ¤
control_func = lambda step: (
    True if step == 0 else control_func(step)
)
return_in_fourier instance-attribute ¤
return_in_fourier = False
__init__ ¤
__init__(
    control_func: Optional[Callable[[int], bool]] = None,
    include_initial_state: bool = True,
)
Source code in torchfsm/traj_recorder.py
132
133
134
135
136
137
138
def __init__(
    self,
    control_func: Optional[Callable[[int], bool]] = None,
    include_initial_state: bool = True,
):
    super().__init__(control_func, include_initial_state)
    self._trajectory = []
set_shutdown_flag ¤
set_shutdown_flag()

Set the shutdown flag to True. This will prevent any further recording of trajectories.

Source code in torchfsm/traj_recorder.py
53
54
55
56
57
58
def set_shutdown_flag(self):
    """
    Set the shutdown flag to True.
        This will prevent any further recording of trajectories.
    """
    self._shutdown_flag = True
record ¤
record(step: int, frame: tensor)

Record the trajectory at a given step.

Parameters:

Name Type Description Default
step int

The current step.

required
frame tensor

The current frame to be recorded.

required
Source code in torchfsm/traj_recorder.py
60
61
62
63
64
65
66
67
68
69
def record(self, step: int, frame: torch.tensor):
    """
    Record the trajectory at a given step.

    Args:
        step (int): The current step.
        frame (torch.tensor): The current frame to be recorded.
    """
    if self.control_func(step):
        self._record(step, frame)

torchfsm.traj_recorder.CPURecorder ¤

Bases: AutoRecorder

A recorder that saves the trajectory on the CPU memory. This is useful for large trajectories that may not fit in GPU memory during simulation.

Parameters:

Name Type Description Default
control_func Optional[Callable[[int], bool]]

A function that takes a step as input and returns a boolean indicating whether to record the trajectory at that step.

None
include_initial_state bool

If True, the initial state will be included in the trajectory.

True
real_time_ifft bool

If True, the trajectory will be transformed to real space in real time (if return_in_fourier is True). Default is True.

True
Source code in torchfsm/traj_recorder.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class CPURecorder(AutoRecorder):
    """
    A recorder that saves the trajectory on the CPU memory.
        This is useful for large trajectories that may not fit in GPU memory during simulation.

    Args:
        control_func (Optional[Callable[[int],bool]]): A function that takes a step as input and returns a boolean indicating whether to record the trajectory at that step.
        include_initial_state (bool): If True, the initial state will be included in the trajectory.
        real_time_ifft (bool): If True, the trajectory will be transformed to real space in real time (if `return_in_fourier` is True). Default is True.
    """

    def __init__(
        self,
        control_func: Optional[Callable[[int], bool]] = None,
        include_initial_state: bool = True,
        real_time_ifft: bool = True,
    ):
        super().__init__(control_func, include_initial_state)
        self.real_time_ifft = real_time_ifft

    def _record(self, step: int, frame: torch.tensor):
        if frame.is_cpu:
            if self.real_time_ifft and not self.return_in_fourier:
                self._trajectory.append(self._field_ifft(frame.clone()).real)
            else:
                self._trajectory.append(frame.clone())
        else:
            if self.real_time_ifft and not self.return_in_fourier:
                self._trajectory.append(self._field_ifft(frame.cpu()).real)
            else:
                self._trajectory.append(frame.cpu())

    @property
    def trajectory(self):
        if len(self._trajectory) == 0:
            return None
        if self.return_in_fourier or self.real_time_ifft:
            return torch.stack(self._trajectory, dim=1)
        else:
            return self._traj_ifft(torch.stack(self._trajectory, dim=1)).real
real_time_ifft instance-attribute ¤
real_time_ifft = real_time_ifft
trajectory property ¤
trajectory
control_func instance-attribute ¤
control_func = lambda step: (
    True if step == 0 else control_func(step)
)
return_in_fourier instance-attribute ¤
return_in_fourier = False
__init__ ¤
__init__(
    control_func: Optional[Callable[[int], bool]] = None,
    include_initial_state: bool = True,
    real_time_ifft: bool = True,
)
Source code in torchfsm/traj_recorder.py
167
168
169
170
171
172
173
174
def __init__(
    self,
    control_func: Optional[Callable[[int], bool]] = None,
    include_initial_state: bool = True,
    real_time_ifft: bool = True,
):
    super().__init__(control_func, include_initial_state)
    self.real_time_ifft = real_time_ifft
set_shutdown_flag ¤
set_shutdown_flag()

Set the shutdown flag to True. This will prevent any further recording of trajectories.

Source code in torchfsm/traj_recorder.py
53
54
55
56
57
58
def set_shutdown_flag(self):
    """
    Set the shutdown flag to True.
        This will prevent any further recording of trajectories.
    """
    self._shutdown_flag = True
record ¤
record(step: int, frame: tensor)

Record the trajectory at a given step.

Parameters:

Name Type Description Default
step int

The current step.

required
frame tensor

The current frame to be recorded.

required
Source code in torchfsm/traj_recorder.py
60
61
62
63
64
65
66
67
68
69
def record(self, step: int, frame: torch.tensor):
    """
    Record the trajectory at a given step.

    Args:
        step (int): The current step.
        frame (torch.tensor): The current frame to be recorded.
    """
    if self.control_func(step):
        self._record(step, frame)

torchfsm.traj_recorder.DiskRecorder ¤

Bases: _TrajRecorder

A recorder that saves the trajectory on the disk. This is useful for large trajectories that may not fit in GPU memory during simulation. The trajectory is saved in a temporary cache and then written to disk at specified intervals.

Parameters:

Name Type Description Default
control_func Optional[Callable[[int], bool]]

A function that takes a step as input and returns a boolean indicating whether to record the trajectory at that step.

None
include_initial_state bool

If True, the initial state will be included in the trajectory.

True
cache_dir Optional[str]

The directory where the trajectory will be saved. Default is "./saved_traj/".

None
cache_freq int

The frequency at which to save the trajectory to disk. Default is 1.

1
temp_cache_loc Literal['auto', 'cpu']

The location of the temporary cache. Default is "cpu".

'cpu'
save_format Literal['numpy', 'torch']

The format in which to save the trajectory. Default is "torch".

'torch'
Source code in torchfsm/traj_recorder.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
class DiskRecorder(_TrajRecorder):
    """
    A recorder that saves the trajectory on the disk.
        This is useful for large trajectories that may not fit in GPU memory during simulation.
        The trajectory is saved in a temporary cache and then written to disk at specified intervals.

    Args:
        control_func (Optional[Callable[[int],bool]]): A function that takes a step as input and returns a boolean indicating whether to record the trajectory at that step.
        include_initial_state (bool): If True, the initial state will be included in the trajectory.
        cache_dir (Optional[str]): The directory where the trajectory will be saved. Default is "./saved_traj/".
        cache_freq (int): The frequency at which to save the trajectory to disk. Default is 1.
        temp_cache_loc (Literal["auto","cpu"]): The location of the temporary cache. Default is "cpu".
        save_format (Literal["numpy","torch"]): The format in which to save the trajectory. Default is "torch".
    """

    def __init__(
        self,
        control_func: Optional[Callable[[int], bool]] = None,
        include_initial_state: bool = True,
        cache_dir: Optional[str] = None,
        cache_freq: int = 1,
        temp_cache_loc: Literal["auto", "cpu"] = "cpu",
        save_format: Literal["numpy", "torch"] = "torch",
    ):
        super().__init__(control_func, include_initial_state)
        self.cache_dir = default(cache_dir, "./saved_traj/")
        self.cache_freq = cache_freq
        self._trajectory = []
        self.temp_cache_loc = temp_cache_loc
        if self.temp_cache_loc not in ["auto", "cpu"]:
            raise ValueError("temp_cache_loc must be either 'auto' or 'cpu'.")
        self.save_format = save_format
        if self.save_format not in ["numpy", "torch"]:
            raise ValueError("save_format must be either 'numpy' or 'torch'.")

    def _record(self, step: int, frame: torch.tensor):
        if len(self.traj < self.cache_freq):
            if self.temp_cache_loc == "cpu" and not frame.is_cpu:
                self._trajectory.append(frame.cpu())
            else:
                self._trajectory.append(frame.clone())
        else:
            temp_cache = torch.stack(self._trajectory, dim=1)
            temp_cache = temp_cache.to("cpu") if not temp_cache.is_cpu else temp_cache
            if not self.return_in_fourier:
                temp_cache = self._traj_ifft(temp_cache).real
            if self.save_format == "numpy":
                np.save(self.cache_dir + f"temp_cache_{step}", temp_cache.numpy())
            else:
                torch.save(temp_cache, self.cache_dir + f"temp_cache_{step}")
            self._trajectory = []

    @property
    def trajectory(self):
        return None
cache_dir instance-attribute ¤
cache_dir = default(cache_dir, './saved_traj/')
cache_freq instance-attribute ¤
cache_freq = cache_freq
temp_cache_loc instance-attribute ¤
temp_cache_loc = temp_cache_loc
save_format instance-attribute ¤
save_format = save_format
trajectory property ¤
trajectory
control_func instance-attribute ¤
control_func = lambda step: (
    True if step == 0 else control_func(step)
)
return_in_fourier instance-attribute ¤
return_in_fourier = False
__init__ ¤
__init__(
    control_func: Optional[Callable[[int], bool]] = None,
    include_initial_state: bool = True,
    cache_dir: Optional[str] = None,
    cache_freq: int = 1,
    temp_cache_loc: Literal["auto", "cpu"] = "cpu",
    save_format: Literal["numpy", "torch"] = "torch",
)
Source code in torchfsm/traj_recorder.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def __init__(
    self,
    control_func: Optional[Callable[[int], bool]] = None,
    include_initial_state: bool = True,
    cache_dir: Optional[str] = None,
    cache_freq: int = 1,
    temp_cache_loc: Literal["auto", "cpu"] = "cpu",
    save_format: Literal["numpy", "torch"] = "torch",
):
    super().__init__(control_func, include_initial_state)
    self.cache_dir = default(cache_dir, "./saved_traj/")
    self.cache_freq = cache_freq
    self._trajectory = []
    self.temp_cache_loc = temp_cache_loc
    if self.temp_cache_loc not in ["auto", "cpu"]:
        raise ValueError("temp_cache_loc must be either 'auto' or 'cpu'.")
    self.save_format = save_format
    if self.save_format not in ["numpy", "torch"]:
        raise ValueError("save_format must be either 'numpy' or 'torch'.")
set_shutdown_flag ¤
set_shutdown_flag()

Set the shutdown flag to True. This will prevent any further recording of trajectories.

Source code in torchfsm/traj_recorder.py
53
54
55
56
57
58
def set_shutdown_flag(self):
    """
    Set the shutdown flag to True.
        This will prevent any further recording of trajectories.
    """
    self._shutdown_flag = True
record ¤
record(step: int, frame: tensor)

Record the trajectory at a given step.

Parameters:

Name Type Description Default
step int

The current step.

required
frame tensor

The current frame to be recorded.

required
Source code in torchfsm/traj_recorder.py
60
61
62
63
64
65
66
67
68
69
def record(self, step: int, frame: torch.tensor):
    """
    Record the trajectory at a given step.

    Args:
        step (int): The current step.
        frame (torch.tensor): The current frame to be recorded.
    """
    if self.control_func(step):
        self._record(step, frame)