chore: remove obsolete code for torch<2

Minimum torch version is 2.1 now.
This commit is contained in:
Enno Hermann 2024-05-08 15:50:48 +02:00
parent 865a48156d
commit 6d563af623
2 changed files with 1 additions and 9 deletions

View File

@ -1,5 +1,4 @@
import torch
from packaging.version import Version
from torch import nn
from torch.nn import functional as F
@ -90,10 +89,7 @@ class InvConvNear(nn.Module):
self.no_jacobian = no_jacobian
self.weight_inv = None
if Version(torch.__version__) < Version("1.9"):
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
else:
w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]
w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]
if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0]

View File

@ -7,7 +7,6 @@ import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from packaging import version
from torch import einsum, nn
@ -44,9 +43,6 @@ class Attend(nn.Module):
self.register_buffer("mask", None, persistent=False)
self.use_flash = use_flash
assert not (
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
), "in order to use flash attention, you must be using pytorch 2.0 or above"
# determine efficient attention configs for cuda and cpu
self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])