391 lines
9.6 KiB
Erlang
391 lines
9.6 KiB
Erlang
% @doc
|
|
% bit matrices
|
|
-module(wfc_bm).
|
|
|
|
-export_type([
|
|
bit/0,
|
|
rc/0, shape/0,
|
|
bm/0
|
|
]).
|
|
|
|
-export([
|
|
%% accessors
|
|
shape/1, bits/1,
|
|
%% constructors
|
|
zeros/1, idn/1, diag/1,
|
|
%% addressing
|
|
rcth/2, rc_to_idx0/2,
|
|
rth/2, cth/2,
|
|
rcput/3,
|
|
%% arithmetic
|
|
add/2, mul/2, dot/2,
|
|
%% not quite arithmetic
|
|
mand/2, bsand/2, parity/1,
|
|
%% block shit
|
|
hjoin/2, hjoin/1,
|
|
vjoin/2, vjoin/1,
|
|
blocks/1,
|
|
%% fancier constructors
|
|
from_list/1, bmt/1,
|
|
%% useful operations
|
|
transpose/1, from_bitfun/2,
|
|
to_list/1
|
|
]).
|
|
|
|
-type bit() :: 0 | 1.
|
|
-type rc() :: {rc, pos_integer(), pos_integer()}.
|
|
-type shape() :: rc().
|
|
-type bm() :: {bm, shape(), bitstring()}.
|
|
|
|
|
|
|
|
-spec shape(bm()) -> shape().
|
|
shape({bm, Shape, _}) -> Shape.
|
|
|
|
|
|
-spec bits(bm()) -> bitstring().
|
|
bits({bm, _, Bits}) -> Bits.
|
|
|
|
|
|
|
|
-spec zeros(shape()) -> bm().
|
|
% @doc zero matrix of a given shape
|
|
|
|
zeros(Shape = {rc, NR, NC}) when NR > 0, NC > 0 ->
|
|
BitSize = NR*NC,
|
|
{bm, Shape, <<0:BitSize>>}.
|
|
|
|
|
|
|
|
-spec idn(pos_integer()) -> bm().
|
|
% @doc square identity matrix of a given size
|
|
|
|
idn(Size) when Size > 0 ->
|
|
diag({rc, Size, Size}).
|
|
|
|
|
|
|
|
-spec diag(shape()) -> bm().
|
|
% @doc matrix with 1s on the diagonal, elsewhere zero
|
|
|
|
% fat matrices have more columns than rows
|
|
% stop at number of rows
|
|
diag(Shape = {rc, NR, NC}) when NR =< NC ->
|
|
diagfold(1, NR, zeros(Shape));
|
|
% tall matrices have more rows than columns
|
|
% stop at number of columns
|
|
diag(Shape = {rc, NR, NC}) when NR > NC ->
|
|
diagfold(1, NC, zeros(Shape)).
|
|
|
|
|
|
diagfold(N, MaxN, Acc) when N < MaxN ->
|
|
NewAcc = rcput({rc, N, N}, 1, Acc),
|
|
NewN = N + 1,
|
|
diagfold(NewN, MaxN, NewAcc);
|
|
diagfold(MaxN, MaxN, Acc) ->
|
|
rcput({rc, MaxN, MaxN}, 1, Acc).
|
|
|
|
|
|
|
|
-spec rcput(rc(), bit(), bm()) -> bm().
|
|
% @doc put a bit at an address
|
|
|
|
rcput(RC, Val, {bm, Shape, Bits}) ->
|
|
Skip = rc_to_idx0(RC, Shape),
|
|
<<Before:Skip/bits, _:1, After/bits>> = Bits,
|
|
NewBits = <<Before/bits, Val:1, After/bits>>,
|
|
{bm, Shape, NewBits}.
|
|
|
|
|
|
-spec rcth(rc(), bm()) -> bit().
|
|
% @doc get specific element of matrix
|
|
|
|
rcth(RC = {rc, RowIdx1, ColIdx1}, {bm, Shape = {rc, NR, NC}, Bits})
|
|
when 1 =< RowIdx1, RowIdx1 =< NR,
|
|
1 =< ColIdx1, ColIdx1 =< NC ->
|
|
Skip = rc_to_idx0(RC, Shape),
|
|
<<_:Skip, Bit:1, _/bits>> = Bits,
|
|
Bit.
|
|
|
|
|
|
-spec rth(RowIdx1 :: pos_integer(), bm()) -> bm().
|
|
% @doc get a row... very optimized
|
|
|
|
rth(RowIdx1, {bm, _Shape = {rc, NR, NC}, Bits})
|
|
when 1 =< RowIdx1, RowIdx1 =< NR ->
|
|
NewShape = {rc, 1, NC},
|
|
RowIdx0 = RowIdx1 - 1,
|
|
RowLength = NC,
|
|
SkipToGetToRow = RowIdx0 * RowLength,
|
|
<<_:SkipToGetToRow, RowBits:RowLength, _/bits>> = Bits,
|
|
{bm, NewShape, RowBits}.
|
|
|
|
|
|
|
|
-spec cth(ColIdx1 :: pos_integer(), bm()) -> bm().
|
|
% @doc get a col... less optimized
|
|
|
|
cth(ColIdx1, M = {bm, _Shape = {rc, NR, NC}, _Bits})
|
|
when 1 =< ColIdx1, ColIdx1 =< NC ->
|
|
NewShape = {rc, NR, 1},
|
|
NewBits = << <<( rcth({rc, RowIdx1, ColIdx1}, M) ):1>>
|
|
|| RowIdx1 <- lists:seq(1, NR)
|
|
>>,
|
|
{bm, NewShape, NewBits}.
|
|
|
|
|
|
-spec rc_to_idx0(rc(), shape()) -> non_neg_integer().
|
|
% @doc convert row/column address of matrix to 0-index of bitstring given shape
|
|
|
|
rc_to_idx0({rc, RowIdx1, ColIdx1}, {rc, _NR, NC}) ->
|
|
% <<Row1:NC, Row2:NC, ...>>
|
|
RowIdx0 = RowIdx1 - 1,
|
|
SkipRows = RowIdx0,
|
|
RowLength = NC,
|
|
SkipToGetToRow = SkipRows * RowLength,
|
|
ColIdx0 = ColIdx1 - 1,
|
|
% same logic
|
|
SkipToGetToVal = ColIdx0,
|
|
Result = SkipToGetToRow + SkipToGetToVal,
|
|
Result.
|
|
|
|
|
|
-spec add(bm(), bm()) -> bm().
|
|
% @doc adding matrices... very optimized
|
|
%
|
|
% arguments must have same shape
|
|
|
|
add({bm, Shape, Bits1}, {bm, Shape, Bits2}) ->
|
|
% same bit size, assert for autism
|
|
BS = bit_size(Bits1),
|
|
BS = bit_size(Bits2),
|
|
% fish out integers
|
|
<<Int1:BS>> = Bits1,
|
|
<<Int2:BS>> = Bits2,
|
|
ResultInt = Int1 bxor Int2,
|
|
ResultBits = <<ResultInt:BS>>,
|
|
{bm, Shape, ResultBits}.
|
|
|
|
|
|
|
|
-spec mul(bm(), bm()) -> bm().
|
|
% @doc multiplying matrices
|
|
%
|
|
% matrices must be compatible shape
|
|
|
|
mul(M1 = {bm, {rc, NR, Same}, _}, M2 = {bm, {rc, Same, NC}, _}) ->
|
|
NewShape = {rc, NR, NC},
|
|
mulfold({rc, 1, 1}, NewShape, <<>>, M1, M2).
|
|
|
|
% walk down the matrix
|
|
% terminal case, end of the line
|
|
mulfold({rc, R, C}, Shape = {rc, R, C}, AccBits, M1, M2) ->
|
|
% result_RC = M1_R dot M2_C
|
|
R1 = rth(R, M1),
|
|
C2 = cth(C, M2),
|
|
Val = dot(R1, C2),
|
|
FinalBits = <<AccBits/bits, Val:1>>,
|
|
Result = {bm, Shape, FinalBits},
|
|
Result;
|
|
% end of the row, go to next row
|
|
mulfold({rc, R, NC}, Shape = {rc, _, NC}, AccBits, M1, M2) ->
|
|
R1 = rth(R, M1),
|
|
C2 = cth(NC, M2),
|
|
Val = dot(R1, C2),
|
|
NewAccBits = <<AccBits/bits, Val:1>>,
|
|
NewRC = {rc, R+1, 1},
|
|
mulfold(NewRC, Shape, NewAccBits, M1, M2);
|
|
% general case, go to next column
|
|
mulfold({rc, R, C}, Shape, AccBits, M1, M2) ->
|
|
R1 = rth(R, M1),
|
|
C2 = cth(C, M2),
|
|
Val = dot(R1, C2),
|
|
NewAccBits = <<AccBits/bits, Val:1>>,
|
|
NewRC = {rc, R, C+1},
|
|
mulfold(NewRC, Shape, NewAccBits, M1, M2).
|
|
|
|
|
|
|
|
-spec dot(bm(), bm()) -> bit().
|
|
% @doc
|
|
% take the dot product of a row matrix with a column matrix
|
|
|
|
dot({bm, {rc, 1, Same}, Bits1}, {bm, {rc, Same, 1}, Bits2}) ->
|
|
<<Int1:Same>> = Bits1,
|
|
<<Int2:Same>> = Bits2,
|
|
SummandBits = Int1 band Int2,
|
|
parity(SummandBits).
|
|
|
|
|
|
-spec mand(bitstring(), bitstring()) -> bitstring().
|
|
% @doc bitwise-and two matrices of the same shape
|
|
|
|
mand({bm, Shape, Bits1}, {bm, Shape, Bits2}) ->
|
|
{bm, Shape, bsand(Bits1, Bits2)}.
|
|
|
|
|
|
|
|
-spec bsand(bitstring(), bitstring()) -> bitstring().
|
|
% @doc bitwise AND of two bitstrings
|
|
|
|
bsand(Bits1, Bits2) when bit_size(Bits1) =:= bit_size(Bits2) ->
|
|
% same bit size, assert for autism
|
|
BS = bit_size(Bits1),
|
|
% fish out integers
|
|
<<Int1:BS>> = Bits1,
|
|
<<Int2:BS>> = Bits2,
|
|
ResultInt = Int1 band Int2,
|
|
ResultBits = <<ResultInt:BS>>,
|
|
ResultBits.
|
|
|
|
|
|
-spec parity(bitstring()) -> bit().
|
|
% @doc return 0 if even number of 1s, 1 if odd
|
|
|
|
parity(Bits) -> parity(Bits, 0).
|
|
|
|
parity(<<Bit:1, Rest/bits>>, Acc) -> parity(Rest, Bit bxor Acc);
|
|
parity(<<>>, Result) -> Result.
|
|
|
|
|
|
-spec hjoin(bm(), bm()) -> bm().
|
|
% @doc
|
|
% Take two matrices with the same number of rows and glue them together
|
|
%
|
|
% [R1 [R1' [R1 R1'
|
|
% R2 R2' R2 R2'
|
|
% R3], R3'] -> R3 R3']
|
|
|
|
hjoin({bm, {rc, Same, NC1}, Bits1}, {bm, {rc, Same, NC2}, Bits2}) ->
|
|
ResultBits = hjoin2(NC1, Bits1, NC2, Bits2, <<>>),
|
|
{bm, {rc, Same, NC1+NC2}, ResultBits}.
|
|
|
|
hjoin2(_NC1, <<>>, _NC2, <<>>, Result) ->
|
|
Result;
|
|
hjoin2(NC1, Bits1, NC2, Bits2, Acc) ->
|
|
<<Row1:NC1/bits, Rest1/bits>> = Bits1,
|
|
<<Row2:NC2/bits, Rest2/bits>> = Bits2,
|
|
NewAcc = <<Acc/bits, Row1/bits, Row2/bits>>,
|
|
hjoin2(NC1, Rest1, NC2, Rest2, NewAcc).
|
|
|
|
|
|
|
|
-spec hjoin([bm()]) -> bm().
|
|
% @doc horizontally join a NONEMPTY list of matrices, which all have the
|
|
% same number of rows
|
|
|
|
hjoin([M1, M2 | Rest]) -> hjoin([hjoin(M1, M2) | Rest]);
|
|
hjoin([M]) -> M.
|
|
|
|
|
|
|
|
-spec vjoin(bm(), bm()) -> bm().
|
|
% @doc vertically join two matrices with the same number of columns
|
|
|
|
vjoin({bm, {rc, NR1, Same}, Bits1}, {bm, {rc, NR2, Same}, Bits2}) ->
|
|
{bm, {rc, NR1+NR2, Same}, <<Bits1/bits, Bits2/bits>>}.
|
|
|
|
|
|
|
|
-spec vjoin([bm()]) -> bm().
|
|
% @doc vertically join a NONEMPTY list of matrices which all have the same
|
|
% number of columns.
|
|
|
|
vjoin([M1, M2 | Rest]) -> vjoin([vjoin(M1, M2) | Rest]);
|
|
vjoin([M]) -> M.
|
|
|
|
|
|
-spec blocks([[bm()]]) -> bm().
|
|
% @doc it's 3:53 AM, it's block join shit
|
|
%
|
|
% blocks([[A, B], [A B
|
|
% [C, D]]) -> C D]
|
|
|
|
blocks(VHStack) ->
|
|
% go through each row and horizontally join them
|
|
VStack = lists:map(fun hjoin/1, VHStack),
|
|
vjoin(VStack).
|
|
|
|
|
|
-spec from_list([[bit()]]) -> bm().
|
|
|
|
from_list(List = [Row | _]) ->
|
|
NR = length(List),
|
|
NC = length(Row),
|
|
Bits = bitsy(lists:flatten(List), <<>>),
|
|
% sanity check
|
|
DontCare = bit_size(Bits),
|
|
DontCare = NR*NC,
|
|
{bm, {rc, NR, NC}, Bits}.
|
|
|
|
bitsy([Bit | Rest], Acc) -> bitsy(Rest, <<Acc/bits, Bit:1>>);
|
|
bitsy([], Result) -> Result.
|
|
|
|
|
|
-spec bmt(Arity :: pos_integer()) -> bm().
|
|
% @doc Boole-Mobius transform
|
|
%
|
|
% resulting matrix will be 2^n * 2^n in size, so be careful
|
|
|
|
bmt(N) when N > 0 ->
|
|
%% this might be backwards...
|
|
blocks([[bmt(N-1), bmt(N-1)],
|
|
[ztn(N-1), bmt(N-1)]]);
|
|
bmt(0) ->
|
|
from_list([[1]]).
|
|
|
|
|
|
% 2^n x 2^n zero matrix
|
|
ztn(N) when N >= 0 ->
|
|
zeros({rc, two_to_the(N), two_to_the(N)}).
|
|
|
|
two_to_the(N) when N >= 0 ->
|
|
1 bsl N.
|
|
|
|
|
|
-spec transpose(bm()) -> bm().
|
|
|
|
transpose(M = {bm, {rc, NR, NC}, _}) ->
|
|
NewShape = {rc, NC, NR},
|
|
BitFun =
|
|
fun({rc, R, C}) ->
|
|
rcth({rc, C, R}, M)
|
|
end,
|
|
from_bitfun(NewShape, BitFun).
|
|
|
|
|
|
-spec from_bitfun(shape(), fun( (rc()) -> bit() )) -> bm().
|
|
|
|
from_bitfun(Shape, BitFun) ->
|
|
bffold({rc, 1, 1}, Shape, BitFun, <<>>).
|
|
|
|
|
|
% walk down the matrix
|
|
% terminal case, end of the line
|
|
bffold(RC = {rc, R, C}, Shape = {rc, R, C}, BitFun, AccBits) ->
|
|
Val = BitFun(RC),
|
|
FinalBits = <<AccBits/bits, Val:1>>,
|
|
Result = {bm, Shape, FinalBits},
|
|
Result;
|
|
% end of the row, go to next row
|
|
bffold(RC = {rc, R, NC}, Shape = {rc, _, NC}, BitFun, AccBits) ->
|
|
Val = BitFun(RC),
|
|
NewAccBits = <<AccBits/bits, Val:1>>,
|
|
NewRC = {rc, R+1, 1},
|
|
bffold(NewRC, Shape, BitFun, NewAccBits);
|
|
% general case, go to next column
|
|
bffold(RC = {rc, R, C}, Shape, BitFun, AccBits) ->
|
|
Val = BitFun(RC),
|
|
NewAccBits = <<AccBits/bits, Val:1>>,
|
|
NewRC = {rc, R, C+1},
|
|
bffold(NewRC, Shape, BitFun, NewAccBits).
|
|
|
|
|
|
|
|
-spec to_list(bm()) -> [[bit()]].
|
|
|
|
to_list(M = {bm, {rc, NR, NC}, _}) ->
|
|
[ [rcth({rc, R, C}, M)
|
|
|| C <- lists:seq(1, NC)]
|
|
|| R <- lists:seq(1, NR)].
|