SE、ECA、CA、SA、CBAM、ShuffleAttention、SimAM、CrissCrossAttention、SK、NAM、GAM、SOCA注意力模块、程序 做个小笔记,,,

SE、ECA、CA、SA、CBAM、ShuffleAttention、SimAM、CrissCrossAttention、SKAttention、NAMAttention、GAMAttention、SOCA注意力模块、程序



1、SE 通道注意力


SENet:
1、对输入进来的特征层进行全局平均池化。
2、然后进行两次全连接。
3、取Sigmoid将值固定到0-1之间。
4、将这个权值乘上原输入特征层。

import torch
import torch.nn as nn
import math
class se_block(nn.Module):
 def __init__(self, channel, ratio=16):
 super(se_block, self).__init__()
 self.avg_pool = nn.AdaptiveAvgPool2d(1) # 平均池化
 self.fc = nn.Sequential(
 nn.Linear(channel, channel // ratio, bias=False), # 全连接
 nn.ReLU(inplace=True),
 nn.Linear(channel // ratio, channel, bias=False), # 全连接
 nn.Sigmoid()
 )
 def forward(self, x):
 b, c, _, _ = x.size()
 y = self.avg_pool(x).view(b, c)
 y = self.fc(y).view(b, c, 1, 1)
 return x * y

2、ECA 通道注意力

ECANet也是通道注意力机制的一种。ECANet可以看作SENet的改进版。
卷积具有良好的跨通道信息获取能力。ECA把EA的全连接层换成了卷积。

class eca_block(nn.Module):
 def __init__(self, channel, b=1, gamma=2):
 super(eca_block, self).__init__()
 kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
 kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
 
 self.avg_pool = nn.AdaptiveAvgPool2d(1) # # 平均池化
 self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) 
 self.sigmoid = nn.Sigmoid()
 def forward(self, x):
 y = self.avg_pool(x) 
 y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) # 1D卷积
 y = self.sigmoid(y)
 return x * y.expand_as(x) # expand_as 扩展维度跟x一样

3、 CA 通道注意力

class ChannelAttention(nn.Module):
 def __init__(self, in_planes, ratio=16):
 super(ChannelAttention, self).__init__()
 self.avg_pool = nn.AdaptiveAvgPool2d(1) # 平均池化
 self.max_pool = nn.AdaptiveMaxPool2d(1) # 最大池化
 
 self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
 self.relu1 = nn.ReLU()
 self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
 
 self.sigmoid = nn.Sigmoid()
 
 def forward(self, x):
 avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
 max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
 out = avg_out + max_out
 return self.sigmoid(out)

这里只是列出了注意力模块 还需要把最后输出的权值乘上原输入特征层。
例如: x = x * self.ChannelAttention(x),可以参考CBAM那个程序。

4、SA 空间注意力

class SpatialAttention(nn.Module):
 def __init__(self, kernel_size=7):
 super(SpatialAttention, self).__init__()
 
 assert kernel_size in (3,7), 'kernel size must be 3 or 7'
 padding = 3 if kernel_size == 7 else 1
 
 self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
 self.sigmoid = nn.Sigmoid()
 def forward(self, x):
 avg_out = torch.mean(x, dim=1, keepdim=True) # 平均池化
 max_out,_ = torch.max(x, dim=1, keepdim=True) # 最大池化
 x = torch.cat([avg_out, max_out], dim=1)
 x = self.conv1(x)
 return self.sigmoid(x)

这里只是列出了注意力模块 还需要把最后输出的权值乘上原输入特征层。
例如: x = x * self.SpatialAttention(x),可以参考CBAM那个程序。

5、CBAM(通道注意力和空间注意力)

CBAM是通道注意力机制和空间注意力机制的混合。
通道注意力机制:通道注意力机制可以分为两部分,首先对输入进来的单个特征层,分别进行全局平均池化和全局最大池化。之后对平均池化和最大池化的结果,利用共享的全连接层进行处理,对处理后的两个结果进行相加,取Sigmoid将值固定到0-1之间。获得这个权值,将这个权值乘上原输入特征层。
空间注意力机制:对输入进来的特征层,在每一个特征点的通道上取最大值和平均值。之后将这两个结果进行一个堆叠,利用一次通道数为1的卷积调整通道数,取Sigmoid将值固定到0-1之间。获得这个权值,将这个权值乘上原输入特征层。

# 通道注意力机制
class ChannelAttention(nn.Module): 
 def __init__(self, in_planes, ratio=8):
 super(ChannelAttention, self).__init__()
 self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化
 self.max_pool = nn.AdaptiveMaxPool2d(1) # 全局最大池化
 # 利用1x1卷积代替全连接
 self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
 self.relu1 = nn.ReLU()
 self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
 self.sigmoid = nn.Sigmoid()
 def forward(self, x):
 avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
 max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
 out = avg_out + max_out
 return self.sigmoid(out)
 
# 空间注意力机制
class SpatialAttention(nn.Module):
 def __init__(self, kernel_size=7):
 super(SpatialAttention, self).__init__()
 assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
 padding = 3 if kernel_size == 7 else 1
 self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
 self.sigmoid = nn.Sigmoid()
 def forward(self, x):
 avg_out = torch.mean(x, dim=1, keepdim=True) # 每个通道上取平均值
 max_out, _ = torch.max(x, dim=1, keepdim=True) # 每个通道上取最大值
 x = torch.cat([avg_out, max_out], dim=1)
 x = self.conv1(x)
 return self.sigmoid(x)
class cbam_block(nn.Module):
 def __init__(self, channel, ratio=8, kernel_size=7):
 super(cbam_block, self).__init__()
 self.channelattention = ChannelAttention(channel, ratio=ratio) # 通道注意力机制
 self.spatialattention = SpatialAttention(kernel_size=kernel_size) # 空间注意力机制
 def forward(self, x):
 x = x * self.channelattention(x)
 x = x * self.spatialattention(x)
 return x

不足:
1、SE注意力中只关注构建通道之间的相互依赖关系,忽略了空间特征。
2、CBAM中引入了大尺度的卷积核提取空间特征,但忽略了长程依赖问题。

6、ShuffleAttention注意力

函数

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter
# https://arxiv.org/pdf/2102.00240.pdf
class ShuffleAttention(nn.Module):
 def __init__(self, channel=512,reduction=16,G=8):
 super().__init__()
 self.G=G
 self.channel=channel
 self.avg_pool = nn.AdaptiveAvgPool2d(1)
 self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
 self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
 self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
 self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
 self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
 self.sigmoid=nn.Sigmoid()
 
 def init_weights(self):
 for m in self.modules():
 if isinstance(m, nn.Conv2d):
 init.kaiming_normal_(m.weight, mode='fan_out')
 if m.bias is not None:
 init.constant_(m.bias, 0)
 elif isinstance(m, nn.BatchNorm2d):
 init.constant_(m.weight, 1)
 init.constant_(m.bias, 0)
 elif isinstance(m, nn.Linear):
 init.normal_(m.weight, std=0.001)
 if m.bias is not None:
 init.constant_(m.bias, 0)
 @staticmethod
 def channel_shuffle(x, groups):
 b, c, h, w = x.shape
 x = x.reshape(b, groups, -1, h, w)
 x = x.permute(0, 2, 1, 3, 4)
 # flatten
 x = x.reshape(b, -1, h, w)
 return x
 def forward(self, x):
 b, c, h, w = x.size()
 #group into subfeatures
 x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w
 
 #channel_split
 x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w
 #channel attention
 x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1
 x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1
 x_channel=x_0*self.sigmoid(x_channel)
 #spatial attention
 x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w
 x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w
 x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w
 # concatenate along channel axis
 out=torch.cat([x_channel,x_spatial],dim=1) #bs*G,c//G,h,w
 out=out.contiguous().view(b,-1,h,w)
 # channel shuffle
 out = self.channel_shuffle(out, 2)
 return out


调用

self.ShuffleAttention = ShuffleAttention(channel=512,reduction=16,G=8)

这里batch_size=8,模型设置batch_size不一样通道数会报错。
可以根据图片改一下这里的通道数也就是(channel // (2 * G),让他等于函数中forward里的x_channel就可以了。

self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))

7、SimAM注意力

函数

class SimAM(torch.nn.Module):
 def __init__(self, channels = None,out_channels = None, e_lambda = 1e-4):
 super(SimAM, self).__init__()
 
 self.activaton = nn.Sigmoid()
 self.e_lambda = e_lambda
 
 def forward(self, x):
 
 b, c, h, w = x.size()
 
 n = w * h - 1
 
 x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)
 y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5
 
 return x * self.activaton(y)

调用

self.SimAM = SimAM(512,512)

8、CrissCrossAttention注意力

函数

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax
def INF(B,H,W):
 return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
class CrissCrossAttention(nn.Module):
 """ Criss-Cross Attention Module"""
 def __init__(self, in_dim):
 super(CrissCrossAttention,self).__init__()
 self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
 self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
 self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
 self.softmax = Softmax(dim=3)
 self.INF = INF
 self.gamma = nn.Parameter(torch.zeros(1))
 def forward(self, x):
 m_batchsize, _, height, width = x.size()
 proj_query = self.query_conv(x)
 proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
 proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
 proj_key = self.key_conv(x)
 proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
 proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
 proj_value = self.value_conv(x)
 proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
 proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
 energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
 energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
 concate = self.softmax(torch.cat([energy_H, energy_W], 3))
 att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
 #print(concate)
 #print(att_H) 
 att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
 out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
 out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
 #print(out_H.size(),out_W.size())
 return self.gamma*(out_H + out_W) + x

调用

self.CrissCrossAttention = CrissCrossAttention(1024)

这里需要注意def init(self, in_dim):函数里 定义的通道数,跟batch_size的大小有关,不一样需要改一下,就是下面这几行。

self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
 self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
 self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)

9、SKAttention注意力

函数

from collections import OrderedDict
class SKAttention(nn.Module):
 def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
 super().__init__()
 self.d = max(L, channel // reduction)
 self.convs = nn.ModuleList([])
 for k in kernels:
 self.convs.append(
 nn.Sequential(OrderedDict([
 ('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)),
 ('bn', nn.BatchNorm2d(channel)),
 ('relu', nn.ReLU())
 ]))
 )
 self.fc = nn.Linear(channel, self.d)
 self.fcs = nn.ModuleList([])
 for i in range(len(kernels)):
 self.fcs.append(nn.Linear(self.d, channel))
 self.softmax = nn.Softmax(dim=0)
 def forward(self, x):
 bs, c, _, _ = x.size()
 conv_outs = []
 ### split
 for conv in self.convs:
 conv_outs.append(conv(x))
 feats = torch.stack(conv_outs, 0) # k,bs,channel,h,w
 ### fuse
 U = sum(conv_outs) # bs,c,h,w
 ### reduction channel
 S = U.mean(-1).mean(-1) # bs,c
 Z = self.fc(S) # bs,d
 ### calculate attention weight
 weights = []
 for fc in self.fcs:
 weight = fc(Z)
 weights.append(weight.view(bs, c, 1, 1)) # bs,channel
 attention_weughts = torch.stack(weights, 0) # k,bs,channel,1,1
 attention_weughts = self.softmax(attention_weughts) # k,bs,channel,1,1
 ### fuse
 V = (attention_weughts * feats).sum(0)
 return V

调用

self.SKAttention = SKAttention(1024)

10、S2-MLPv2注意力

函数

def spatial_shift1(x):
 b,w,h,c = x.size()
 x[:,1:,:,:c//4] = x[:,:w-1,:,:c//4]
 x[:,:w-1,:,c//4:c//2] = x[:,1:,:,c//4:c//2]
 x[:,:,1:,c//2:c*3//4] = x[:,:,:h-1,c//2:c*3//4]
 x[:,:,:h-1,3*c//4:] = x[:,:,1:,3*c//4:]
 return x
def spatial_shift2(x):
 b,w,h,c = x.size()
 x[:,:,1:,:c//4] = x[:,:,:h-1,:c//4]
 x[:,:,:h-1,c//4:c//2] = x[:,:,1:,c//4:c//2]
 x[:,1:,:,c//2:c*3//4] = x[:,:w-1,:,c//2:c*3//4]
 x[:,:w-1,:,3*c//4:] = x[:,1:,:,3*c//4:]
 return x
class SplitAttention(nn.Module):
 def __init__(self,channel=512,k=3):
 super().__init__()
 self.channel=channel
 self.k=k
 self.mlp1=nn.Linear(channel,channel,bias=False)
 self.gelu=nn.GELU()
 self.mlp2=nn.Linear(channel,channel*k,bias=False)
 self.softmax=nn.Softmax(1)
 
 def forward(self,x_all):
 b,k,h,w,c=x_all.shape
 x_all=x_all.reshape(b,k,-1,c) 
 a=torch.sum(torch.sum(x_all,1),1) 
 hat_a=self.mlp2(self.gelu(self.mlp1(a))) 
 hat_a=hat_a.reshape(b,self.k,c) 
 bar_a=self.softmax(hat_a) 
 attention=bar_a.unsqueeze(-2) 
 out=attention*x_all 
 out=torch.sum(out,1).reshape(b,h,w,c)
 return out
class S2Attention(nn.Module):
 def __init__(self, channels=512 ):
 super().__init__()
 self.mlp1 = nn.Linear(channels,channels*3)
 self.mlp2 = nn.Linear(channels,channels)
 self.split_attention = SplitAttention()
 def forward(self, x):
 b,c,w,h = x.size()
 x=x.permute(0,2,3,1)
 x = self.mlp1(x)
 x1 = spatial_shift1(x[:,:,:,:c])
 x2 = spatial_shift2(x[:,:,:,c:c*2])
 x3 = x[:,:,:,c*2:]
 x_all=torch.stack([x1,x2,x3],1)
 a = self.split_attention(x_all)
 x = self.mlp2(a)
 x=x.permute(0,3,1,2)
 return x

调用

self.S2Attention = S2Attention(512)

11、NAMAttention注意力

函数

class Channel_Att(nn.Module):
 def __init__(self, channels, t=16):
 super(Channel_Att, self).__init__()
 self.channels = channels
 self.bn2 = nn.BatchNorm2d(self.channels, affine=True)
 def forward(self, x):
 residual = x
 x = self.bn2(x)
 weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())
 x = x.permute(0, 2, 3, 1).contiguous()
 x = torch.mul(weight_bn, x)
 x = x.permute(0, 3, 1, 2).contiguous()
 x = torch.sigmoid(x) * residual #
 return x
class NAMAttention(nn.Module):
 def __init__(self, channels, out_channels=None, no_spatial=True):
 super(NAMAttention, self).__init__()
 self.Channel_Att = Channel_Att(channels)
 def forward(self, x):
 x_out1 = self.Channel_Att(x)
 return x_out1

调用

self.NAMAttention = NAMAttention(512)

12、SOCA注意力


函数

from torch.autograd import Function
class Covpool(Function):
 @staticmethod
 def forward(ctx, input):
 x = input
 batchSize = x.data.shape[0]
 dim = x.data.shape[1]
 h = x.data.shape[2]
 w = x.data.shape[3]
 M = h*w
 x = x.reshape(batchSize,dim,M)
 I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device)
 I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype)
 y = x.bmm(I_hat).bmm(x.transpose(1,2))
 ctx.save_for_backward(input,I_hat)
 return y
 @staticmethod
 def backward(ctx, grad_output):
 input,I_hat = ctx.saved_tensors
 x = input
 batchSize = x.data.shape[0]
 dim = x.data.shape[1]
 h = x.data.shape[2]
 w = x.data.shape[3]
 M = h*w
 x = x.reshape(batchSize,dim,M)
 grad_input = grad_output + grad_output.transpose(1,2)
 grad_input = grad_input.bmm(x).bmm(I_hat)
 grad_input = grad_input.reshape(batchSize,dim,h,w)
 return grad_input
class Sqrtm(Function):
 @staticmethod
 def forward(ctx, input, iterN):
 x = input
 batchSize = x.data.shape[0]
 dim = x.data.shape[1]
 dtype = x.dtype
 I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
 normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1)
 A = x.div(normA.view(batchSize,1,1).expand_as(x))
 Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device)
 Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1)
 if iterN < 2:
 ZY = 0.5*(I3 - A)
 Y[:,0,:,:] = A.bmm(ZY)
 else:
 ZY = 0.5*(I3 - A)
 Y[:,0,:,:] = A.bmm(ZY)
 Z[:,0,:,:] = ZY
 for i in range(1, iterN-1):
 ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:]))
 Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY)
 Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:])
 ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]))
 y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
 ctx.save_for_backward(input, A, ZY, normA, Y, Z)
 ctx.iterN = iterN
 return y
 @staticmethod
 def backward(ctx, grad_output):
 input, A, ZY, normA, Y, Z = ctx.saved_tensors
 iterN = ctx.iterN
 x = input
 batchSize = x.data.shape[0]
 dim = x.data.shape[1]
 dtype = x.dtype
 der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
 der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA))
 I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
 if iterN < 2:
 der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace))
 else:
 dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) -
 Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom))
 dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:])
 for i in range(iterN-3, -1, -1):
 YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:])
 ZY = Z[:,i,:,:].bmm(Y[:,i,:,:])
 dldY_ = 0.5*(dldY.bmm(YZ) - 
 Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - 
 ZY.bmm(dldY))
 dldZ_ = 0.5*(YZ.bmm(dldZ) - 
 Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) -
 dldZ.bmm(ZY))
 dldY = dldY_
 dldZ = dldZ_
 der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY))
 grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x))
 grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1)
 for i in range(batchSize):
 grad_input[i,:,:] += (der_postComAux[i] \
 - grad_aux[i] / (normA[i] * normA[i])) \
 *torch.ones(dim,device = x.device).diag()
 return grad_input, None
def CovpoolLayer(var):
 return Covpool.apply(var)
def SqrtmLayer(var, iterN):
 return Sqrtm.apply(var, iterN)
class SOCA(nn.Module):
 # second-order Channel attention
 def __init__(self, channel, reduction=8):
 super(SOCA, self).__init__()
 self.max_pool = nn.MaxPool2d(kernel_size=2)
 self.conv_du = nn.Sequential(
 nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
 nn.ReLU(inplace=True),
 nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
 nn.Sigmoid()
 )
 def forward(self, x):
 batch_size, C, h, w = x.shape # x: NxCxHxW
 N = int(h * w)
 min_h = min(h, w)
 h1 = 1000
 w1 = 1000
 if h < h1 and w < w1:
 x_sub = x
 elif h < h1 and w > w1:
 W = (w - w1) // 2
 x_sub = x[:, :, :, W:(W + w1)]
 elif w < w1 and h > h1:
 H = (h - h1) // 2
 x_sub = x[:, :, H:H + h1, :]
 else:
 H = (h - h1) // 2
 W = (w - w1) // 2
 x_sub = x[:, :, H:(H + h1), W:(W + w1)]
 cov_mat = CovpoolLayer(x_sub) # Global Covariance pooling layer
 cov_mat_sqrt = SqrtmLayer(cov_mat,5) # Matrix square root layer( including pre-norm,Newton-Schulz iter. and post-com. with 5 iteration)
 cov_mat_sum = torch.mean(cov_mat_sqrt,1)
 cov_mat_sum = cov_mat_sum.view(batch_size,C,1,1)
 y_cov = self.conv_du(cov_mat_sum)
 return y_cov*x

调用

self.SOCA = SOCA(512)

13、GAMAttention注意力

函数

class GAMAttention(nn.Module):
 def __init__(self, c1, c2, group=True, rate=4):
 super(GAMAttention, self).__init__()
 self.channel_attention = nn.Sequential(
 nn.Linear(c1, int(c1 / rate)),
 nn.ReLU(inplace=True),
 nn.Linear(int(c1 / rate), c1)
 )
 self.spatial_attention = nn.Sequential(
 nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),
 kernel_size=7,
 padding=3),
 nn.BatchNorm2d(int(c1 / rate)),
 nn.ReLU(inplace=True),
 nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,
 kernel_size=7,
 padding=3),
 nn.BatchNorm2d(c2)
 )
 def forward(self, x):
 b, c, h, w = x.shape
 x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
 x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
 x_channel_att = x_att_permute.permute(0, 3, 1, 2)
 x = x * x_channel_att
 x_spatial_att = self.spatial_attention(x).sigmoid()
 x_spatial_att = channel_shuffle(x_spatial_att, 4) # last shuffle
 out = x * x_spatial_att
 return out
def channel_shuffle(x, groups=2):
 B, C, H, W = x.size()
 out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
 out = out.view(B, C, H, W)
 return out

调用

self.GAMAttention = GAMAttention(512,512)

总结

做个小笔记,,,

作者:Chaoy6565原文地址:https://blog.csdn.net/weixin_45464524/article/details/129641355

%s 个评论

要回复文章请先登录注册