Swing Transformer
Assume Input Image as size x $\in R^{3224224}$ .
Follow the way as VIT(Vision Transformer), utilizing a conv kernel to extract patch and flatterned into sequence
Dim Change:
Projection code:
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
Flatten:
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
Notice: Each SWTB is cosisted by multi-Encoders(according to Arch, num of Encoder is 2, 2, 6, 2)
shift_size=0 if (i % 2 == 0) else window_size // 2
# First Encoder perform W-MSA , Second perform SW-MSA, iteration
Input Sequence: [8, 56*56, 96] is unflattened back to [8, 56, 56, 96]
default window_size = 7
Window Attention
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
# Size:[B, num_win_row, win_size, num_win_col, win_size, C]
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
# x.permute: [B, num_win_row, num_win_col, win_size, win_size, C]
# output: [B*num_win, win_size, win_size, C]
x is transformed into smal Windows and MultiHead-SelfAttn is performed during the window(Note: x is flattened into a seq before attn)
RPE is combined with WindowAttention
RPE in WindowAttention
As formulated: $Attention(Q,K,V) = SoftMax(QK^T/\sqrt{d}+B)V$,
B represents the RP Code which is inserted into the calculation of Attention. Values in B $\in{R^{n*n}}$ is taken from $\hat{B}\in{R^{2h-1, 2w-1}}$, where$\hat{B}$ is a Relative Bias Table.
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
Above Code Generates a 2-dim tensor ,c_f[0] represents all x-dim and c_f[1] represents y-dim respectively.
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
These codes gets the final RP index in RP bias $\hat{B}$.
Shifted Window
Since Window Attention merely consider inner window elements, Shifted Window Mechanism is proposed to extract inter-window features.
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
Slices is created to slice the input Patch mask:[1, H, W, 1] into different parts as below:
Numbers inside the window stand for unique mask value.
Since initial Patch is shifted toworad upper left for shifted_size(default window_size//2, i.e., 7//2=3),an attention mask is proved to idntify elements in different initial windows(e.g., 4,5,7,8 are from four different windows).
Finaly, the attn mask is calculated as follow:
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
# img_mask is the matrix as mentioned above
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask is a matrix $\in{R^{win\_size^2 * win\_size^2}}$ which represents the connections between two window elements.
By applying W-MSA and SW-MSA multi times, the SwinT Block returns a embedding with size $\in{R^{B*(H*W)*C}}$ and the embedding is then sent into the Patch Merging Block.
The input embedding is permuted and merged as follows:
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1)
x = x.view(B, -1, 4 * C)
# permuted into [B,H/2*W/2,4*C]
x = self.norm(x)
x = self.reduction(x)
# redunction Reduce the channels into 2*C
The overall SwinT model simpliy stacking the modules as described in global arch figure, and return a Embedding as [B, H/32W/32, 8C], the following Pooling method is:
x = self.avgpool(x.transpose(1, 2)) # B 8*C 1
x = torch.flatten(x, 1)
# return [B,8*C]
x = self.head(x) # Linear(8*C, num_classes)
# return the classification vector
The architecture of the Video Swin Transformer (VST) closely resembles that of the Swin Transformer (SwinT). A notable distinction is the inclusion of an additional temporal dimension, denoted as T, in the input features, which corresponds to a fixed number of frames, typically set to 32, sampled from the input video.The overall architecture of VST follows the structure of SwinT, as shown in the image below:
Global Arch of VST
Conver CNN2D to CNN3D, Patch_size default=(4,4,4), means the output size is [T/4. H/4, W/4]:
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
D, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
After flatten, the embedding X $\in{R^{BC\frac{T}{4}\frac{H}{4}\frac{W}{4}}}$, Both the SwinTBlock and PatchMerging Block adhere to the methodology of the SwinT while extending the dimensionality from 2D to 3D.