Files
linguist/samples/Terra/sgemm-old.t
2016-01-28 11:22:27 +07:00

205 lines
4.0 KiB
Perl

local IO = terralib.includec("stdio.h")
local stdlib = terralib.includec("stdlib.h")
local NB = 64
local V = 8
terra vecload(data : &float, idx : int)
var addr = &data[idx]
return @[&vector(float,V)](addr)
end
haddavx = terralib.intrinsic("llvm.x86.avx.hadd.ps.256", { vector(float,8), vector(float,8) } -> vector(float,8))
terra hadd(v : vector(float,8))
var v1 = haddavx(v,v)
var v2 = haddavx(v1,v1)
return v2[0] + v2[4]
end
local AR = 2
local BR = 2
local KR = 2
local NK = 64
local BLOCKB = true
local BLOCKA = true
local function isinteger(x) return math.floor(x) == x end
assert(isinteger(NK / (KR * V)))
assert(isinteger(NB / AR ))
assert(isinteger(NB / BR ))
blockregisters = macro(function(C,A,B,K,lda,ldc,mm,nn,kk,m,n)
local function mkmatrix(nm,I,J)
local r = {}
for i = 0,I-1 do
r[i] = {}
for j = 0,J-1 do
r[i][j] = symbol(nm..tostring(i)..tostring(j))
end
end
return r
end
local as,bs,cs = mkmatrix("a",AR,KR),mkmatrix("b",BR,KR),mkmatrix("c",AR,BR)
local stmts = terralib.newlist()
for i = 0, AR-1 do
for j = 0, BR-1 do
stmts:insert(quote var [cs[i][j]] : vector(float,V) = 0.f end)
end
end
local k = symbol("k")
local kloopbody = terralib.newlist()
local alreadyloaded = {}
local function get(vs,i,j,loadfn)
if not alreadyloaded[vs[i][j]] then
alreadyloaded[vs[i][j]] = true
kloopbody:insert(loadfn(vs[i][j]))
end
return vs[i][j]
end
local function getA(i,j)
return get(as,i,j,function(sym)
local result
if BLOCKA then
result = quote
var [sym] = vecload(A, mm * K + NB*(kk + m - mm + i) + k - kk + j *V)
end
else
result = quote
var [sym] = vecload(A, (m + i) * lda + k + j * V)
end
end
return result
end)
end
local function getB(i,j)
return get(bs,i,j,function(sym)
local result
if BLOCKB then
result = quote
var [sym] = vecload(B, nn * K + NB*(kk + n - nn + i) + k - kk + j * V)
end
else
result = quote
var [sym] = vecload(B, (n + i) * K + k + j * V)
end
end
return result
end)
end
for l = 0, KR-1 do
for i = 0, AR-1 do
for j = 0, BR-1 do
local aa = getA(i,l)
local bb = getB(j,l)
kloopbody:insert(quote
[cs[i][j]] = [cs[i][j]] + aa * bb
end)
end
end
end
stmts:insert(quote
for [k] = kk, kk + NK, V*KR do
kloopbody
end
end)
for i = 0, AR-1 do
for j = 0, BR-1 do
local function getsum(b,e)
if b + 1 == e then
return `[cs[i][j]][b]
else
local mid = (e + b)/2
assert(math.floor(mid) == mid)
local lhs = getsum(b,mid)
local rhs = getsum(mid,e)
return `lhs + rhs
end
end
local sum
if V == 8 and false then
sum = `hadd([cs[i][j]])
else
sum = getsum(0,V)
end
stmts:insert(quote
var r = sum
if kk == 0 then
C[(m + i)*ldc + (n + j)] = r
else
C[(m + i)*ldc + (n + j)] = C[(m + i)*ldc + (n + j)] + r
end
end)
end
end
return stmts
end)
terra my_sgemm(gettime : {} -> double, M : int, N : int, K : int, alpha : float, A : &float, lda : int, B : &float, ldb : int,
beta : float, C : &float, ldc : int)
var TB = [&float](stdlib.malloc(K * N * sizeof(float)))
var TA = A
for kk = 0,K,NK do
for nn = 0,N,NB do
for k = kk,kk+NK do
for n = nn,nn+NB do
if BLOCKB then
TB[nn * K + NB*(kk + n - nn) + k - kk] = B[k*ldb + n]
else
TB[n*K + k] = B[k*ldb + n]
end
end
end
end
end
if BLOCKA then
TA = [&float](stdlib.malloc(M * K * sizeof(float)))
for mm = 0,M,NB do
for kk = 0,K,NK do
for m = mm,mm+NB do
for k = kk,kk+NK do
TA[mm * K + NB*(kk + m - mm) + k - kk] = A[m*lda + k]
end
end
end
end
end
for mm = 0,M,NB do
for nn = 0, N,NB do
for kk = 0, K, NK do
for m = mm,mm+NB,AR do
for n = nn,nn+NB,BR do
blockregisters(C,TA,TB,K,lda,ldc,mm,nn,kk,m,n)
end
end
end
end
end
stdlib.free(TB)
if BLOCKA then
stdlib.free(TA)
end
end
my_sgemm:compile()
my_sgemm:printpretty()
terralib.saveobj("my_sgemm.o", {my_sgemm = my_sgemm})