分治策略(2)—算法导论(4)
2015年09月13日 13:36

1. 写在最前面

这一篇博文主要介绍基于分治策略的矩阵乘法的Strassen算法。

2. 普通矩阵乘法算法

在介绍Strassen算法之前,我们先看看普通矩阵乘法算法。

我们都知道,矩阵乘法的基本计算规则是:

若 A = (aij)和 B = b(ij)是 n × n 的方阵(i,j = 1,2,3...),则 C = A · B中的元素 CijC_{ij}为:Cij=k=1naikbkjC_{ij}=\sum_{k=1}^n \cdot a_{ik} \cdot b_{kj}

下面给出 Java 实现代码:

public static void main(String[] args) {
    int[][] a = new int[][] { //
            { 1, 0, 1, 2 }, //
            { 1, 2, 0, 2 }, //
            { 0, 2, 1, 0 }, //
            { 0, 0, 1, 2 },//
    };
    int[][] b = new int[][] { //
            { 1, 0, 1, 2 }, //
            { 1, 2, 0, 2 }, //
            { 0, 2, 1, 0 }, //
            { 0, 0, 1, 2 },//
    };
    printMatrix(squareMatrixMutiply(a, b));
}


/**
 * 基本矩阵乘法(假定矩阵a和矩阵b都是n×n的矩阵,且n为2的幂)
 * @param a 矩阵a
 * @param b 矩阵b
 * @return
 */
private static int[][] squareMatrixMutiply(int[][] a, int[][] b) {
    int[][] c = new int[a.length][a.length];
    for (int i = 0; i < c.length; i++) {
        for (int j = 0; j < c.length; j++) {
            c[i][j] = 0;
            for (int k = 0; k < c.length; k++) {
                c[i][j] += a[i][k] * b[k][j];
            }
        }
    }
    return c;
}

/**
 * 打印矩阵
 *
 * @param matrix
 */
private static void printMatrix(int[][] matrix) {
    for (int[] is : matrix) {
        for (int i : is) {
            System.out.print(i + "\t");
        }
        System.out.println();
    }
}

结果:

3. 一个简单的分治算法

我们再换一种思想,采用分治法来解决这个问题。

为简单起见,当使用 分治法(Divide and Conquer) 计算矩阵C = A * B时,假定三个矩阵都是n×n的矩阵,并且n为2的幂。分治法 还是上一篇提到的三个步骤,算法的核心就是这个公式:

(C11C12C21C22)=(A11A12A21A22)(B11B12B21B22)\Biggl(\begin{matrix}C_{11} & C_{12}\cr C_{21} & C_{22}\cr\end{matrix}\Biggl) = \Biggl(\begin{matrix}A_{11} & A_{12}\cr A_{21} & A_{22}\cr\end{matrix}\Biggl) \cdot \Biggl(\begin{matrix}B_{11} & B_{12}\cr B_{21} & B_{22}\cr\end{matrix}\Biggl)

其中,Aij,Bij,Cij分别是A,B,C矩阵的n / 2 * n / 2的子矩阵,即:

A=(A11A12A21A22),B=(B11B12B21B22),C=(C11C12C21C22)A = \Biggl(\begin{matrix}A_{11} & A_{12}\cr A_{21} & A_{22}\cr\end{matrix}\Biggl) , B = \Biggl(\begin{matrix}B_{11} & B_{12}\cr B_{21} & B_{22}\cr\end{matrix}\Biggl), C = \Biggl(\begin{matrix}C_{11} & C_{12}\cr C_{21} & C_{22}\cr\end{matrix}\Biggl)

值得说明的是,我们不必创建子数组,那将浪费θ(n²)的时间来复制数组元素;明智的做法是直接根据下标运算。

下图是伪代码(其中所说的“(4.9)”即为上图所给的三个等式):

下面给出Java实现代码:

public static void main(String[] args) {
    int[][] a = new int[][] { //
            { 1, 0, 1, 2 }, //
            { 1, 2, 0, 2 }, //
            { 0, 2, 1, 0 }, //
            { 0, 0, 1, 2 },//
    };
    int[][] b = new int[][] { //
            { 1, 0, 1, 2 }, //
            { 1, 2, 0, 2 }, //
            { 0, 2, 1, 0 }, //
            { 0, 0, 1, 2 },//
    };
    printMatrix(squareMatrixMutiplyByRecursive(new ChildMatrix(a, 0, 0, a.length), new ChildMatrix(b, 0, 0, b.length), 0, 0, 0, 0));
}

/**
 * 打印矩阵
 */
private static void printMatrix(int[][] matrix) {
    for (int[] is : matrix) {
        for (int i : is) {
            System.out.print(i + "\t");
        }
        System.out.println();
    }
}

/**
 * 基于分治法的矩阵乘法
 */
private static int[][] squareMatrixMutiplyByRecursive(ChildMatrix matrixA, ChildMatrix matrixB, int lastStartRowA, int lastStartColumnA, int lastStartRowB,
        int lastStartColumnB) {
    int[][] c = new int[matrixA.length][matrixA.length];
    if (matrixA.length == 1) {
        c[0][0] = matrixA.getFromParentMatrix(matrixA.startRow, matrixA.startColumn) * //
                matrixB.getFromParentMatrix(matrixB.startRow, matrixB.startColumn);
        return c;
    }
    int childLength = matrixA.length / 2;
    // 第一步:分解
    ChildMatrix childMatrixA11 = new ChildMatrix(matrixA.parentMatrix, lastStartRowA, lastStartColumnA, childLength);
    ChildMatrix childMatrixA12 = new ChildMatrix(matrixA.parentMatrix, lastStartRowA, lastStartColumnA + childLength, childLength);
    ChildMatrix childMatrixA21 = new ChildMatrix(matrixA.parentMatrix, lastStartRowA + childLength, lastStartColumnA, childLength);
    ChildMatrix childMatrixA22 = new ChildMatrix(matrixA.parentMatrix, lastStartRowA + childLength, lastStartColumnA + childLength, childLength);

    ChildMatrix childMatrixB11 = new ChildMatrix(matrixB.parentMatrix, lastStartRowB, lastStartColumnB, childLength);
    ChildMatrix childMatrixB12 = new ChildMatrix(matrixB.parentMatrix, lastStartRowB, lastStartColumnB + childLength, childLength);
    ChildMatrix childMatrixB21 = new ChildMatrix(matrixB.parentMatrix, lastStartRowB + childLength, lastStartColumnB, childLength);
    ChildMatrix childMatrixB22 = new ChildMatrix(matrixB.parentMatrix, lastStartRowB + childLength, lastStartColumnB + childLength, childLength);
    // 第二步:解决
    int[][] temp1 = squareMatrixMutiplyByRecursive(childMatrixA11, childMatrixB11, 0, 0, 0, 0);
    int[][] temp2 = squareMatrixMutiplyByRecursive(childMatrixA12, childMatrixB21, 0, childLength, childLength, 0);
    int[][] c11 = sumMatrix(temp1, temp2);

    int[][] temp3 = squareMatrixMutiplyByRecursive(childMatrixA11, childMatrixB12, 0, 0, 0, childLength);
    int[][] temp4 = squareMatrixMutiplyByRecursive(childMatrixA12, childMatrixB22, 0, childLength, childLength, childLength);
    int[][] c12 = sumMatrix(temp3, temp4);

    int[][] temp5 = squareMatrixMutiplyByRecursive(childMatrixA21, childMatrixB11, childLength, 0, 0, 0);
    int[][] temp6 = squareMatrixMutiplyByRecursive(childMatrixA22, childMatrixB21, childLength, childLength, childLength, 0);
    int[][] c21 = sumMatrix(temp5, temp6);

    int[][] temp7 = squareMatrixMutiplyByRecursive(childMatrixA21, childMatrixB12, childLength, 0, 0, childLength);
    int[][] temp8 = squareMatrixMutiplyByRecursive(childMatrixA22, childMatrixB22, childLength, childLength, childLength, childLength);
    int[][] c22 = sumMatrix(temp7, temp8);
    // 第三步:合并
    for (int i = 0; i < c.length; i++) {
        for (int j = 0; j < c.length; j++) {
            if (i < childLength && j < childLength) {
                c[i][j] = c11[i][j];
            } else if (i < childLength && j < c.length) {
                int[][] child = c12;
                c[i][j] = child[i][j - childLength];
            } else if (i < c.length && j < childLength) {
                int[][] child = c21;
                c[i][j] = child[i - childLength][j];
            } else {
                int[][] child = c22;
                c[i][j] = child[i - childLength][j - childLength];
            }
        }
    }
    return c;
}

private static int[][] sumMatrix(int[][] a, int[][] b) {
    int[][] c = new int[a.length][b.length];
    for (int i = 0; i < a.length; i++) {
        for (int j = 0; j < a.length; j++) {
            c[i][j] += a[i][j];
            c[i][j] += b[i][j];
        }
    }
    return c;
}

/**
 * ChildMatrix 表示某个矩阵的一个子矩阵
 */
static class ChildMatrix {
    /**
     * 父矩阵
     */
    int[][] parentMatrix;

    /**
     * 子矩阵在父矩阵中的起始行坐标
     */
    int startRow;

    /**
     * 子矩阵在父矩阵中的起始列坐标
     */
    int startColumn;

    /**
     * 子矩阵长度
     */
    int length;

    public ChildMatrix(int[][] parentMatrix, int startRow, int startColumn, int length) {
        super();
        this.parentMatrix = parentMatrix;
        this.startRow = startRow;
        this.startColumn = startColumn;
        this.length = length;
    }

    /**
     * 获取父矩阵的row行,colum列元素
     */
    public int getFromParentMatrix(int row, int colum) {
        return parentMatrix[row][colum];
    }
}

结果是:

4. Strassen算法

下面正式介绍Strassen算法。

Strassen算法正是上述分治算法的改进。它的核心思想是令递归树稍微不那么茂盛,它只进行7次递归(上述分治法递归了8次)。

Strassen算法的操作步骤如下:

  1. 分解矩阵A,B,C为:
A=(A11A12A21A22),B=(B11B12B21B22),C=(C11C12C21C22)A = \Biggl(\begin{matrix}A_{11} & A_{12}\cr A_{21} & A_{22}\cr\end{matrix}\Biggl) , B = \Biggl(\begin{matrix}B_{11} & B_{12}\cr B_{21} & B_{22}\cr\end{matrix}\Biggl), C = \Biggl(\begin{matrix}C_{11} & C_{12}\cr C_{21} & C_{22}\cr\end{matrix}\Biggl)

同样不要创建子数组而只是进行下标计算。

  1. 创建10个n/2 ×n/2的矩阵S1,S2,S3…,S10,其计算公式如下:
S3=A21+A22S4=A21+A11S5=A11+A22S6=A11+A22S7=A12+A22S8=A21+A22S9=A11+A21S10=A11+A12S_3 = A_{21} + A_{22}\\ S_4 = A_{21} + A_{11}\\ S_5 = A_{11} + A_{22}\\ S_6 = A_{11} + A_{22}\\ S_7 = A_{12} + A_{22}\\ S_8 = A_{21} + A_{22}\\ S_9 = A_{11} + A_{21}\\ S_{10} = A_{11} + A_{12}
  1. 递归地计算7个矩阵积P1, P2…P3,P7,计算公式如下:
P1=A11S1=A11B12A11B22P2=S2B22=A11B22A12B22P3=S3B11=A21B11+A22B11P4=A22S4=A22B21A22B11P5=S5S6=A11B11+A12B22+A22B11+A22B22P6=S7S8=A12B21+A12B22A22B21A22B22P7=S9S10=A11B11+A11B12A21B11A21B12P_1 = A_{11} \cdot S_1 = A_{11} \cdot B_{12} - A_{11} \cdot B_{22}\\ P_2 = S_2 \cdot B_{22} = A_{11} \cdot B_{22} - A_{12} \cdot B_{22}\\ P_3 = S_3 \cdot B_{11} = A_{21} \cdot B_{11} + A_{22} \cdot B_{11}\\ P_4 = A_{22} \cdot S_4 = A_{22} \cdot B_{21} - A_{22} \cdot B_{11}\\ P_5 = S_5 \cdot S_6 = A_{11} \cdot B_{11} + A_{12} \cdot B_{22} + A_{22} \cdot B_{11} + A_{22} \cdot B_{22}\\ P_6 = S_7 \cdot S_8 = A_{12} \cdot B_{21} + A_{12} \cdot B_{22} - A_{22} \cdot B_{21} - A_{22} \cdot B_{22}\\ P_7 = S_9 \cdot S_{10} = A_{11} \cdot B_{11} + A_{11} \cdot B_{12} - A_{21} \cdot B_{11} - A_{21} \cdot B_{12}
  1. 计算Cij,计算公式如下:
C11=P5+P4P2+P6C12=P1+P2C21=P3+P4C22=P5+P1P3P7C_{11} = P_5 + P_4 - P_2 + P_6\\ C_{12} = P_1 + P_2\\ C_{21} = P_3 + P_4\\ C_{22} = P_5 + P_1 - P_3 - P_7

实现代码就不给出了,与上面类似。

5.算法分析

  1. 普通矩阵乘法 对于普通的矩阵乘法,3次嵌套循环,每层执行n次,所需时间为θ(n³);

  2. 简单分治算法

    1. 基本情况:T(1) = θ(1);
    2. 递归情况:分解后,矩阵规模变为原来的1/2。递归八次,用时8T(n/2);4次矩阵加法,每个矩阵中的元素个数为n² / 4, 用时θ(n²);其余用时θ(1)。因此共用时8T(n/2) + θ(n²)。即:
T(n)={Θ(1)若n = 18T(n/2)+Θ(n2)若n > 1T(n) = \begin{cases} \Theta(1) & \text{若n = 1}\\ 8T(n / 2) + \Theta(n^2) & \text{若n > 1} \end{cases}

可解得,T(n) = θ(n³)。可见 分治算法 并不优于普通矩阵乘法。

  1. Strassen算法 Strassen算法分析与上面基本一致,不同的是只进行了7次递归,并且额外多了几次n / 2 × n / 2矩阵的加法,但只是常数次。因此Strassen算法用时为:
T(n)={Θ(1)若n = 17T(n/2)+Θ(n2)若n > 1T(n) = \begin{cases} \Theta(1) & \text{若n = 1}\\ 7T(n / 2) + \Theta(n^2) & \text{若n > 1} \end{cases}

可解得,T(n)=Θ(nlg7)T(n) = \Theta(n^{lg7})

End