Source code for torchgpe.utils.callbacks
import torch
from abc import ABCMeta
[docs]
class Callback(metaclass=ABCMeta):
"""Base class for callbacks.
Before a simulation starts, it is provided with the instance of the :class:`gpe.bec2D.gas.Gas` (stored in the :py:attr:`gpe.utils.callbacks.Callback.gas` variable) and with a dictionary of parameters for the simulation (stored in :py:attr:`gpe.utils.callbacks.Callback.propagation_params`)
"""
def __init__(self) -> None:
#: gpe.bec2D.gas.Gas: The instance of the :class:`gpe.bec2D.gas.Gas` class. Populated when the simulation starts.
self.gas = None
#: dict: A dictionary of parameters for the simulation. Populated when the simulation starts.
self.propagation_params = None
def set_gas(self, gas):
self.gas = gas
def set_propagation_params(self, propagation_params):
self.propagation_params = propagation_params
[docs]
def on_propagation_begin(self):
"""Function called by the :class:`gpe.bec2D.gas.Gas` class before the simulation begins
"""
pass
[docs]
def on_propagation_end(self):
"""Function called by the :class:`gpe.bec2D.gas.Gas` class after the simulation ends
"""
pass
[docs]
def on_epoch_begin(self, epoch: int):
"""Function called by the :class:`gpe.bec2D.gas.Gas` at the beginning of each epoch
Args:
epoch (int): The epoch number
"""
pass
[docs]
def on_epoch_end(self, epoch: int):
"""Function called by the :class:`gpe.bec2D.gas.Gas` at the end of each epoch
Args:
epoch (int): The epoch number
"""
pass
[docs]
class LInfNorm(Callback):
"""Callback computing the :math:`L_\infty` norm of the wavefunction
The :math:`L_\infty` norm is defined as:
.. math::
L_\infty = \\text{max}_{(x,y)}|\Psi_t - \Psi_{t+\\Delta t}|
Args:
compute_every (int): Optional. The number of epochs after which the norm is computed. Defaults to 1.
print_every (int): Optional. The number of epochs after which, if computed, the norm is also printed. Defaults to 1.
"""
def __init__(self, compute_every=1, print_every=1) -> None:
super().__init__()
#: list: A list of the computed norms
self.norms = []
self.compute_every = compute_every
self.print_every = print_every
def on_epoch_begin(self, epoch: int):
"""At the beginning of an epoch, if its number is a multiple of ``compute_every`` stores the wave function of the gas
Args:
epoch (int): The epoch number
"""
if epoch % self.compute_every != 0:
return
self.psi = self.gas.psi
def on_epoch_end(self, epoch: int):
"""At the end of an epoch, if its number is a multiple of ``compute_every`` uses the stored wave function of the gas to compute
the :math:`L_\infty` norm. If the epoch number is a multiple of ``print_every`` as well, the value of the norm is printed on screen.
Args:
epoch (int): The epoch number
"""
if epoch % self.compute_every != 0:
return
psi = self.gas.psi
self.norms.append(torch.max(torch.abs(psi-self.psi)).cpu())
del self.psi
if epoch % self.print_every == 0:
print(self.norms[-1])
[docs]
class L1Norm(Callback):
"""Callback computing the :math:`L_1` norm of the wavefunction
The :math:`L_1` norm is defined as:
.. math::
L_1 = \sum_{(x,y)}|\Psi_t - \Psi_{t+\\Delta t}| \, dx \, dy
Args:
compute_every (int): Optional. The number of epochs after which the norm is computed. Defaults to 1.
print_every (int): Optional. The number of epochs after which, if computed, the norm is also printed. Defaults to 1.
"""
def __init__(self, compute_every=1, print_every=1) -> None:
super().__init__()
#: list: A list of the computed norms
self.norms = []
self.compute_every = compute_every
self.print_every = print_every
def on_epoch_begin(self, epoch):
"""At the beginning of an epoch, if its number is a multiple of ``compute_every`` stores the wave function of the gas
Args:
epoch (int): The epoch number
"""
if epoch % self.compute_every != 0:
return
self.psi = self.gas.psi
def on_epoch_end(self, epoch):
"""At the end of an epoch, if its number is a multiple of ``compute_every`` uses the stored wave function of the gas to compute
the :math:`L_1` norm. If the epoch number is a multiple of ``print_every`` as well, the value of the norm is printed on screen.
Args:
epoch (int): The epoch number
"""
if epoch % self.compute_every != 0:
return
psi = self.gas.psi
self.norms.append((torch.sum(torch.abs(psi-self.psi))
* self.gas.dx*self.gas.dy).cpu())
del self.psi
if epoch % self.print_every == 0:
print(self.norms[-1])
[docs]
class L2Norm(Callback):
"""Callback computing the :math:`L_2` norm of the wavefunction
The :math:`L_2` norm is defined as:
.. math::
L_2 = \sqrt{\sum_{(x,y)}|\Psi_t - \Psi_{t+\\Delta t}|^2 \, dx \, dy}
Args:
compute_every (int): Optional. The number of epochs after which the norm is computed. Defaults to 1.
print_every (int): Optional. The number of epochs after which, if computed, the norm is also printed. Defaults to 1.
"""
def __init__(self, compute_every=1, print_every=1) -> None:
super().__init__()
#: list: A list of the computed norms
self.norms = []
self.compute_every = compute_every
self.print_every = print_every
def on_epoch_begin(self, epoch):
"""At the beginning of an epoch, if its number is a multiple of ``compute_every`` stores the wave function of the gas
Args:
epoch (int): The epoch number
"""
if epoch % self.compute_every != 0:
return
self.psi = self.gas.psi
def on_epoch_end(self, epoch):
"""At the end of an epoch, if its number is a multiple of ``compute_every`` uses the stored wave function of the gas to compute
the :math:`L_2` norm. If the epoch number is a multiple of ``print_every`` as well, the value of the norm is printed on screen.
Args:
epoch (int): The epoch number
"""
if epoch % self.compute_every != 0:
return
psi = self.gas.psi
self.norms.append(torch.sqrt(
torch.sum(torch.abs(psi-self.psi)**2)*self.gas.dx*self.gas.dy).cpu())
del self.psi
if epoch % self.print_every == 0:
print(self.norms[-1])