chore: remove obsolete code for torch<2

Minimum torch version is 2.1 now.
This commit is contained in:
Enno Hermann 2024-05-08 15:50:48 +02:00
parent 865a48156d
commit 6d563af623
2 changed files with 1 additions and 9 deletions

View File

@ -1,5 +1,4 @@
import torch import torch
from packaging.version import Version
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@ -90,10 +89,7 @@ class InvConvNear(nn.Module):
self.no_jacobian = no_jacobian self.no_jacobian = no_jacobian
self.weight_inv = None self.weight_inv = None
if Version(torch.__version__) < Version("1.9"): w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
else:
w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]
if torch.det(w_init) < 0: if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0] w_init[:, 0] = -1 * w_init[:, 0]

View File

@ -7,7 +7,6 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
from packaging import version
from torch import einsum, nn from torch import einsum, nn
@ -44,9 +43,6 @@ class Attend(nn.Module):
self.register_buffer("mask", None, persistent=False) self.register_buffer("mask", None, persistent=False)
self.use_flash = use_flash self.use_flash = use_flash
assert not (
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
), "in order to use flash attention, you must be using pytorch 2.0 or above"
# determine efficient attention configs for cuda and cpu # determine efficient attention configs for cuda and cpu
self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])