C++实现惰性求值

C++实现惰性求值

By Z.H. Fu

切问录 www.fuzihao.org

惰性求值在解释性语言中有着广泛的应用,通过惰性求值,延迟了临时变量的计算,避免了无谓的变量产生,是实现大规模计算的重要的优化手段。然而,很多惰性求值的应用仅限于一些脚本语言,本文介绍了如何用C++实现惰性求值。

什么是惰性求值?

我们先来看下面一段Python程序,在python中输入

1
2
3
4
5
a = [1,2,3]
b = [4,5,6]
c = a + b
for i in c:
print c

这个程序很快执行,但是试想一下,如果a和b都是很大的List,那么生成中间变量,再去遍历他将是一个极为浪费时间的做法,我们希望c只在求值(即print)的时候再来计算,这样能节省出产生临时变量的时间。

C++实现

我们以矩阵类来展示惰性求值的实现方法。首先我们有一个共同的基类:

1
2
3
4
5
6
7
8
9
10
11
template <typename _Scalar>
class MatrixBase{
public:
virtual size64_t getRowNum()const = 0;
virtual size64_t getColNum()const = 0;
virtual ~MatrixBase(){};
virtual _Scalar operator()(size64_t rowid, size64_t colid)const = 0;
inline operator Matrix<_Scalar>(){
return ...
}
};

这是一个虚基类,我们注意到,除了定义了一些常用的接口外,我们还定义了一个转换到其子类的类型转换函数Matrix<_Scalar>,该子类定义为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
template <typename _Scalar>
class Matrix : public MatrixBase<_Scalar>{
public:
Matrix(){}
Matrix(size64_t row, size64_t col){init(row, col);}
inline _Scalar operator()(size64_t rowid, size64_t colid)const{
return mPData[rowid * mColMem + colid];
}
inline size64_t getRowNum()const{return mRow;}
inline size64_t getColNum()const{return mCol;}
//override cast in a fast way
inline operator Matrix<_Scalar>(){return *this;}

private:
size64_t mCol;
size64_t mRow;
_Scalar* mPData;
};

这个类继承了基类,同时重写了类型转换方法,因为把自己转换成自己等价于直接返回自己。下面就是这个问题的关键所在,我们定义一个Holder来装下一个操作和其对应的操作对象,在没有明确要求下不对其求值。每种运算对应一个Holder,我们以点乘运算为例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
template <typename _Scalar>
class DotHolder : public MatrixBase<_Scalar>{
public:
DotHolder(const MatrixBase<_Scalar>& lhs, const MatrixBase<_Scalar>& rhs) : mLhs(&lhs), mRhs(&rhs){
//init code here
}
~DotHolder(){}
inline size64_t getRowNum()const{return mRow;}
inline size64_t getColNum()const{return mCol;}
//注意这个重载的operator(),只有再调用这个时,才会计算,否则只是保存起来计算参数,并返回这个Holder
inline _Scalar operator()(size64_t i, size64_t j)const{
_Scalar s = (*mLhs)(i, 0) * (*mRhs)(0, j);
for(size64_t k = 1; k < mK; ++k){
s += (*mLhs)(i, k) * (*mRhs)(k, j);
}
return s;
}

private:
size64_t mCol;
size64_t mRow;
size64_t mK;
const MatrixBase<_Scalar> * mLhs;
const MatrixBase<_Scalar> * mRhs;
};

并在虚基类中添加:

1
2
3
inline DotHolder<_Scalar> dot(const MatrixBase<_Scalar>& rhs){
return DotHolder<_Scalar>(*this, rhs);
}

我们看到,我们如果进行矩阵的点乘运算,返回的是一个DotHolder类型,他做的事情仅仅是将计算参数保存起来,并不进行计算,直到有运算显示调用了括号操作来取值。这样,这个变量能很快返回结果,只在最后求值得时候来计算,省略了临时变量,及大地提高了程序效率。