#include #include #include #include typedef double v4d __attribute__ ((vector_size (32))); static void matmul1(double a[], v4d b[], v4d c[], size_t m, size_t n, size_t p, size_t m1, size_t n1); void matmulinner(double a[], v4d b[], v4d c[], size_t m, size_t p); /* assumptions: p divisible by 4, b and c are 32-byte-aligned */ static void matmul1(double a[], v4d b[], v4d c[], size_t m, size_t n, size_t p, size_t m1, size_t n1) { // printf("matmul1 m1=%ld n1=%ld\n",m1,n1); if (m1>=8) { size_t m2 = (m1/2)&~3; size_t m3 = m1-m2; if (n1>=8) { size_t n2 = (n1/2)&~3; size_t n3 = n1-n2; // printf("n2 = %ld n3 = %ld\n",n2,n3); matmul1(a ,b ,c ,m,n,p,m2,n2); matmul1(a+n2*m ,b ,c+n2*p,m,n,p,m2,n3); matmul1(a+n2*m+m2,b+m2*p,c+n2*p,m,n,p,m3,n3); matmul1(a +m2,b+m2*p,c ,m,n,p,m3,n2); } else { matmul1(a ,b ,c,m,n,p,m2,n1); matmul1(a+m2,b+m2*p,c,m,n,p,m3,n1); } } else { if (n1>=8) { size_t n2 = (n1/2)&~3; size_t n3 = n1-n2; matmul1(a ,b,c ,m,n,p,m1,n2); matmul1(a+n2*m,b,c+n2*p,m,n,p,m1,n3); } else { matmulinner(a,b,c,m,p); } } } void matmul(double a[], v4d b[], v4d c[], size_t m, size_t n, size_t p) { size_t i,j,k; p=p/4; memset(c,0,n*p*sizeof(double)); matmul1(a,b,c,m,n,p,m,n); }