压缩感知正交匹跟踪算法(OMP)代码实现之图像上压缩与恢复 - Go语言中文社区

压缩感知正交匹跟踪算法(OMP)代码实现之图像上压缩与恢复


算法遇到了问题,被卡了挺久。后来发现是个小问题,就是忘记将最终数据转换为uint8,导致图像老是无法正常显示,但是检查代码的整个逻辑过程又没有错误。这是因为在图像中,数据是以两个字节保存的。加上今天的去兼职,时间不太够,代码也是在地铁上给调试成功的。
原理我就不多说了。就是我发现压缩再复原,图像会出现很多噪点,而且压缩率越大,恢复出来的图像噪点越多。

具体原因还不是很清楚,我猜大致以下原因:

1.不可避免会产生噪点,也就是无法做到百分百完美复原。我看参考论文好像也没办法完美复原。

2.恢复矩阵或者压缩矩阵选择得不好。这个有待商榷,不过看了几篇论文,选择的矩阵都不同,不知这个会不会有一定影响。

下图,为采样率为0.75时,恢复图(左)和原图(右)的对比。

在这里插入图片描述

代码参考了别人写的代码,原文链接实在找不到了。代码如下:

'''
基于图像的压缩感知代码的实现(正交匹配跟踪算法)
参考文献: 任晓馨. 压缩感知贪婪匹配追踪类重建算法研究[D].
作者: kyda
时间: 2019.7.22
'''
import cv2
import numpy as np
import numpy.linalg as lg
import math

def get_matri_new_plus(A):
    '''
    求A矩阵的伪逆
    :param A: 输入矩阵
    :return A_new: 求解的伪逆矩阵
    '''
    A_new = A.T.dot(A)
    if A_new.ndim < 2:
        A_new = 1 / A_new
        A_new = A_new * A.T
    else:
        A_new = lg.inv(A_new)
        A_new = A_new.dot(A.T)
    return A_new

def get_A_ba(A):
    '''
    求解A_ba,也就是矩阵列向量单位化
    :param A: 输入矩阵
    :return A_new: 返回列向量单位后的矩阵
    '''
    A_new = A.T
    for row in range(A_new.shape[0]):
        norm2 = lg.norm(A_new[row], ord=2)           # 求向量的2范数
        A_new[row] /= norm2                          # 单位化
    return A_new


def find_pos_VecInMatrix(A, v):
    '''
    求解某个向量在矩阵中的相应行位置
    :param A: 输入矩阵
    :param v: 输入向量
    :return pos: 位置
    '''
    A_rownum = A.shape[0]                           # 矩阵行的数量
    for pos in range(A_rownum):
        if (v == A[pos]).all():                       # 判断向量v和矩阵中的某一行是否相等
            return pos


N = 256                                             # 向量长度
sampleRate = 0.75                                   # 采样率,也就是256长度的向量只采样64个
im = cv2.imread('lena.png')                        # 读取图像
im = cv2.resize(im, (N, N))                         # 将图像修改成256X256的大小,这么做主要是方便计算

Phi=np.random.randn(int(sampleRate*N), N)                       # 生成高斯随机采样矩阵
s = Phi.dot(im)                                     # 获得线性测量,s其实就是我们要采样的数据

#生成稀疏基DCT矩阵
mat_dct_1d = np.zeros((256,256))
v=range(256)
for k in range(0,256):
    dct_1d = np.cos(np.dot(v, k*math.pi/256))
    if k > 0:
        dct_1d = dct_1d-np.mean(dct_1d)
    mat_dct_1d[:, k] = dct_1d/np.linalg.norm(dct_1d)

P = Phi.dot(mat_dct_1d.T)                           # 恢复矩阵(测量矩阵*正交反变换矩阵)


def cs_omp(s, mat):
    '''
    压缩感知主体算法
    :param s: 待重构向量
    :param mat: 重构矩阵
    :param N: 重构向量的长度
    :return: 重构完成的向量
    '''
    global N, K
    m = math.floor(3*(s.shape[0])/4)                # 算法迭代次数(m>=K)
    hat_y = np.zeros(N)                             # 待重构的谱域(变换域)向量
    Aug_t = None                                    # 增量矩阵(初始值为空矩阵)
    r_n = s                                         # 残差值

    T_tran = mat.T
    A_ba = T_tran.copy()                            # 初始化A矩阵
    A_ba1 = A_ba                                    # 复制
    pos_selected_v = []
    for times in range(m):
        w = A_ba.dot(r_n)                           # 求各列向量和残差的投影系数
        pos = np.argmax(abs(w))                     # 最大投影系数所对应的位置
        b_k = A_ba[pos]
        A_ba = np.delete(A_ba, pos, axis=0)         # 选中的列删除
        pos1 = find_pos_VecInMatrix(A_ba1, b_k)
        pos_selected_v.append(pos1)                 # 纪录最大投影系数的位置
        if Aug_t is None:                          # 将最大贡献向量(未单位化的)加入Aug_t矩阵
            Aug_t = T_tran[pos1]
        else:
            Aug_t = np.c_[Aug_t, T_tran[pos1]]
        A_new_plus = get_matri_new_plus(Aug_t)
        aug_y = A_new_plus.dot(s)                   # 解最小二乘解
        r_n = s - Aug_t.dot(aug_y)                  # 求残差

    for item in range(m):                          # 重构的谱域向量
        pos = pos_selected_v[item]
        x_r = aug_y[item]
        try:
            hat_y[pos] = x_r
        except np.ComplexWarning:
            pass

    hat_x = mat_dct_1d.T.dot(hat_y)                 # 做逆傅里叶变换重构得到时域信号
    return hat_x

restruct_im = np.zeros([256, 256, 3])               # 重构的图像
for x in range(256):
    for y in range(3):
        restruct_s = cs_omp(s[:, x].T[y], P)        # 重构向量
        restruct_im[x].T[y] = restruct_s            # 将重构的向量写入图像

restruct_im = np.array(restruct_im, dtype=np.uint8) # 这里一定要将矩阵的数据形式转换为uint8,不然图片无法正常显示

cv2.imshow('im_re', restruct_im)
cv2.imshow('im_ss', im)
cv2.waitKey(0)

欢迎关注我的公众号,可以交流问题:
在这里插入图片描述

版权声明:本文来源CSDN,感谢博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/weixin_37720172/article/details/97612621
站方申明:本站部分内容来自社区用户分享,若涉及侵权,请联系站方删除。
  • 发表于 2021-06-13 20:07:40
  • 阅读 ( 1128 )
  • 分类:算法

0 条评论

请先 登录 后评论

官方社群

GO教程

推荐文章

猜你喜欢