Skip to content

Commit 26e8c47

Browse files
authored
fix: TorchFix Errors (#1262)
* fix torchfix errors * fix torchfix errors x2 * fix torchfix errors x3
1 parent e848409 commit 26e8c47

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

segmentation_models_pytorch/base/heads.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch.nn as nn
2+
23
from .modules import Activation
34

45

@@ -10,7 +11,7 @@ def __init__(
1011
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
1112
)
1213
upsampling = (
13-
nn.UpsamplingBilinear2d(scale_factor=upsampling)
14+
nn.Upsample(mode="bilinear", scale_factor=upsampling, align_corners=True)
1415
if upsampling > 1
1516
else nn.Identity()
1617
)

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"""
3232

3333
from collections.abc import Iterable, Sequence
34-
from typing import Literal, List
34+
from typing import List, Literal
3535

3636
import torch
3737
from torch import nn
@@ -105,7 +105,9 @@ def __init__(
105105
)
106106

107107
scale_factor = 4 if output_stride == 16 and encoder_depth > 3 else 2
108-
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
108+
self.up = nn.Upsample(
109+
mode="bilinear", scale_factor=scale_factor, align_corners=True
110+
)
109111

110112
highres_in_channels = encoder_channels[2]
111113
highres_out_channels = 48 # proposed by authors of paper

segmentation_models_pytorch/losses/_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def wing_loss(
226226
idx_smaller = diff_abs < width
227227
idx_bigger = diff_abs >= width
228228

229-
loss[idx_smaller] = width * torch.log(1 + diff_abs[idx_smaller] / curvature)
229+
loss[idx_smaller] = width * torch.log1p(diff_abs[idx_smaller] / curvature)
230230

231231
C = width - width * math.log(1 + width / curvature)
232232
loss[idx_bigger] = loss[idx_bigger] - C

0 commit comments

Comments
 (0)