[docs]defgamma(x:torch.Tensor)->torch.Tensor:""" (Complete) Gamma function. pytorch has not implemented the commonly used (complete) Gamma function. We define it in a custom way to make autograd work as in https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122 """returntorch.exp(gammaln(x))
class_CustomExp1(torch.autograd.Function):@staticmethoddefforward(ctx,x):# this implementation is inspired by the one in scipy:# https://github.com/scipy/scipy/blob/34d91ce06d4d05e564b79bf65288284247b1f3e3/scipy/special/xsf/expint.h#L22ctx.save_for_backward(x)# ConstantsSCIPY_EULER=(0.577215664901532860606512090082402431# Euler-Mascheroni constant)inf=torch.inf# Handle case when x == 0result=torch.full_like(x,inf)mask=x>0# Compute for x <= 1x_small=x[mask&(x<=1)]ifx_small.numel()>0:e1=torch.ones_like(x_small)r=torch.ones_like(x_small)forkinrange(1,26):r=-r*k*x_small/(k+1.0)**2e1+=riftorch.all(torch.abs(r)<=torch.abs(e1)*1e-15):breakresult[mask&(x<=1)]=-SCIPY_EULER-torch.log(x_small)+x_small*e1# Compute for x > 1x_large=x[mask&(x>1)]ifx_large.numel()>0:m=20+(80.0/x_large).to(torch.int32)t0=torch.zeros_like(x_large)forkinrange(m.max(),0,-1):t0=k/(1.0+k/(x_large+t0))t=1.0/(x_large+t0)result[mask&(x>1)]=torch.exp(-x_large)*treturnresult@staticmethoddefbackward(ctx,grad_output):(x,)=ctx.saved_tensorsreturn-grad_output*torch.exp(-x)/x
[docs]defexp1(x):r""" Exponential integral E1. For a real number :math:`x > 0` the exponential integral can be defined as .. math:: E1(x) = \int_{x}^{\infty} \frac{e^{-t}}{t} dt :param x: Input tensor (x > 0) :return: Exponential integral E1(x) """return_CustomExp1.apply(x)
[docs]defgammaincc_over_powerlaw(exponent:torch.Tensor,z:torch.Tensor)->torch.Tensor:""" Compute the regularized incomplete gamma function complement for integer exponents. :param exponent: Exponent of the power law :param z: Value at which to evaluate the function :return: Regularized incomplete gamma function complement """ifexponent==1:returntorch.exp(-z)/zifexponent==2:returntorch.sqrt(torch.pi/z)*torch.erfc(torch.sqrt(z))ifexponent==3:returnexp1(z)ifexponent==4:return2*(torch.exp(-z)-torch.sqrt(torch.pi*z)*torch.erfc(torch.sqrt(z)))ifexponent==5:returntorch.exp(-z)-z*exp1(z)ifexponent==6:return((2-4*z)*torch.exp(-z)+4*torch.sqrt(torch.pi*z**3)*torch.erfc(torch.sqrt(z)))/3raiseValueError(f"Unsupported exponent: {exponent}")