
CSR 稀疏矩阵压缩 c++实现
使用c++实现了稀疏矩阵的CSR压缩
·
先实现一个稀疏矩阵类,这里选择通过实现随机值矩阵,但是将其中大部分值设置为0,在将原先矩阵输入之后转为稀疏矩阵保存。
使用CSR格式的稀疏矩阵。
CSR将非零值保存在一个一维数组中,value或者data, 然后将非零值对应的列索引保存在另一个一维数组中colIdx。
另外一个行指针 row_ptr, 其中row_ptr[i]表示第i行的非零元素在data中的起始位置, 当然如果第i行全为0,则row_ptr[i] 等于 row_ptr[i+1] 。
为了方便,使用vector先保存非零元素data 以及 对应索引。因为事先是不知道矩阵中非零元素个数的。
CSR稀疏矩阵类实现:
class SparseMatrix
{
public:
SparseMatrix(){};
SparseMatrix(
u_int num_rows,
u_int num_cols,
float *data,
u_int *col_index,
u_int *row_ptr):num_rows(num_rows),num_cols(num_cols),data(data),col_index(col_index),row_ptr(row_ptr){}
SparseMatrix(
u_int num_rows,
float *data,
u_int *col_index,
u_int *row_ptr):num_rows(num_rows),data(data),col_index(col_index),row_ptr(row_ptr){}
SparseMatrix(float* arr, u_int rows, u_int cols){
/*
通过原始矩阵的一维数组,转化为CSR格式的稀疏矩阵
params:
arr 指向原始矩阵的一维数组
rows 原始矩阵的行维度
cols 原始矩阵的列维度
*/
num_rows = rows;
num_cols = cols;
row_ptr = new u_int[num_rows + 1];
std::vector<u_int> vec_col_index;
std::vector<u_int> vec_data;
vec_data.reserve(num_cols);
vec_col_index.reserve(num_cols);
row_ptr[0] = 0;
for(u_int i = 0; i<num_rows; i++){
row_ptr[i+1] = row_ptr[i];
for(u_int j = 0; j< num_cols; j++){
u_int index = i*cols + j;
if(abs(arr[index]) > EPSILON){
vec_data.emplace_back(arr[index]);
vec_col_index.emplace_back(j);
row_ptr[i+1]++;
}
}
}
this->data_length= vec_data.size();
data = new float[data_length];
col_index = new u_int[data_length];
for(int i=0;i<data_length;i++){
data[i] = vec_data[i];
col_index[i] = vec_col_index[i];
}
}
~SparseMatrix(){
if(data != nullptr){
delete[] data;
}
if(col_index != nullptr){
delete[] col_index;
}
if(row_ptr != nullptr){
delete[] row_ptr;
}
}
void printSparseMatrix(){
/*
通过CSR打印出原始矩阵样貌
*/
for(u_int row=0; row<num_rows; row++){
u_int row_start = row_ptr[row];
u_int row_end = row_ptr[row+1];
u_int p_row = row_start;
for(int i=0;i<num_cols; i++){
if(i==col_index[p_row]&&(p_row<row_end)){
printf("%2.2f ", data[p_row]);
p_row++;
}else{
printf("%2.2f ", 0.0);
}
}
printf("\n");
}
}
u_int num_rows;
u_int num_cols;
float *data;
u_int *col_index;
u_int *row_ptr;
u_int data_length;
};
int main(){
float arr[] = {
1,7,0,0,5,0,3,9,0,2,8,0,0,0,0,6
};
SparseMatrix sm = SparseMatrix(arr, 4,4);
sm.printSparseMatrix();
return 0;
}
最后结果:
OK!!!
点击阅读全文
更多推荐
所有评论(0)