sophia/priv/stdlib/List.aes
2020-02-07 19:51:12 +01:00

222 lines
7.1 KiB
Plaintext

include "ListInternal.aes"
namespace List =
function is_empty(l : list('a)) : bool = switch(l)
[] => true
_ => false
function first(l : list('a)) : option('a) = switch(l)
[] => None
h::_ => Some(h)
function tail(l : list('a)) : option(list('a)) = switch(l)
[] => None
_::t => Some(t)
function last(l : list('a)) : option('a) = switch(l)
[] => None
[x] => Some(x)
_::t => last(t)
function drop_last(l : list('a)) : option(list('a)) = switch(l)
[] => None
_ => Some(drop_last_unsafe(l))
function drop_last_unsafe(l : list('a)) : list('a) = switch(l)
[_] => []
h::t => h::drop_last_unsafe(t)
[] => abort("drop_last_unsafe: list empty")
function find(p : 'a => bool, l : list('a)) : option('a) = switch(l)
[] => None
h::t => if(p(h)) Some(h) else find(p, t)
function find_indices(p : 'a => bool, l : list('a)) : list(int) = find_indices_(p, l, 0)
private function find_indices_( p : 'a => bool
, l : list('a)
, n : int
) : list(int) = switch(l)
[] => []
h::t =>
let rest = find_indices_(p, t, n+1)
if(p(h)) n::rest else rest
function nth(n : int, l : list('a)) : option('a) =
switch(l)
[] => None
h::t => if(n == 0) Some(h) else nth(n-1, t)
/* Unsafe version of `nth` */
function get(n : int, l : list('a)) : 'a =
switch(l)
[] => abort(if(n < 0) "Negative index get" else "Out of index get")
h::t => if(n == 0) h else get(n-1, t)
function length(l : list('a)) : int = length_(l, 0)
private function length_(l : list('a), acc : int) : int = switch(l)
[] => acc
_::t => length_(t, acc + 1)
function from_to(a : int, b : int) : list(int) = [a..b]
function from_to_step(a : int, b : int, s : int) : list(int) =
from_to_step_(a, b - (b-a) mod s, s, [])
private function from_to_step_(a : int, b : int, s : int, acc : list(int)) : list(int) =
if(b < a) acc
else from_to_step_(a, b - s, s, b::acc)
/* Unsafe. Replaces `n`th element of `l` with `e`. Crashes on over/underflow */
function replace_at(n : int, e : 'a, l : list('a)) : list('a) =
if(n<0) abort("insert_at underflow") else replace_at_(n, e, l)
private function replace_at_(n : int, e : 'a, l : list('a)) : list('a) =
switch(l)
[] => abort("replace_at overflow")
h::t => if (n == 0) e::t
else h::replace_at_(n-1, e, t)
/* Unsafe. Adds `e` to `l` to be its `n`th element. Crashes on over/underflow */
function insert_at(n : int, e : 'a, l : list('a)) : list('a) =
if(n<0) abort("insert_at underflow") else insert_at_(n, e, l)
private function insert_at_(n : int, e : 'a, l : list('a)) : list('a) =
if (n == 0) e::l
else switch(l)
[] => abort("insert_at overflow")
h::t => h::insert_at_(n-1, e, t)
function insert_by(cmp : (('a, 'a) => bool), x : 'a, l : list('a)) : list('a) =
switch(l)
[] => [x]
h::t =>
if(cmp(x, h)) // x < h
x::l
else
h::insert_by(cmp, x, t)
function foldr(cons : ('a, 'b) => 'b, nil : 'b, l : list('a)) : 'b = switch(l)
[] => nil
h::t => cons(h, foldr(cons, nil, t))
function foldl(rcons : ('b, 'a) => 'b, acc : 'b, l : list('a)) : 'b = switch(l)
[] => acc
h::t => foldl(rcons, rcons(acc, h), t)
function foreach(l : list('a), f : 'a => unit) : unit =
switch(l)
[] => ()
e::l' =>
f(e)
foreach(l', f)
function reverse(l : list('a)) : list('a) = reverse_(l, [])
private function reverse_(l : list('a), acc : list('a)) : list('a) = switch(l)
[] => acc
h::t => reverse_(t, h::acc)
function map(f : 'a => 'b, l : list('a)) : list('b) = switch(l)
[] => []
h::t => f(h)::map(f, t)
function flat_map(f : 'a => list('b), l : list('a)) : list('b) =
ListInternal.flat_map(f, l)
function filter(p : 'a => bool, l : list('a)) : list('a) = switch(l)
[] => []
h::t =>
let rest = filter(p, t)
if(p(h)) h::rest else rest
/* Take up to `n` first elements */
function take(n : int, l : list('a)) : list('a) =
if(n < 0) abort("Take negative number of elements") else take_(n, l)
private function take_(n : int, l : list('a)) : list('a) =
if(n == 0) []
else switch(l)
[] => []
h::t => h::take_(n-1, t)
/* Drop up to `n` first elements */
function drop(n : int, l : list('a)) : list('a) =
if(n < 0) abort("Drop negative number of elements") else drop_(n, l)
private function drop_(n : int, l : list('a)) : list('a) =
if (n == 0) l
else switch(l)
[] => []
h::t => drop_(n-1, t)
/* Get the longest prefix of a list in which every element matches predicate `p` */
function take_while(p : 'a => bool, l : list('a)) : list('a) = switch(l)
[] => []
h::t => if(p(h)) h::take_while(p, t) else []
/* Drop elements from `l` until `p` holds */
function drop_while(p : 'a => bool, l : list('a)) : list('a) = switch(l)
[] => []
h::t => if(p(h)) drop_while(p, t) else l
/* Splits list into two lists of elements that respectively match and don't match predicate `p` */
function partition(p : 'a => bool, l : list('a)) : (list('a) * list('a)) = switch(l)
[] => ([], [])
h::t =>
let (l, r) = partition(p, t)
if(p(h)) (h::l, r) else (l, h::r)
function flatten(l : list(list('a))) : list('a) = switch(l)
[] => []
h::t => h ++ flatten(t)
function all(p : 'a => bool, l : list('a)) : bool = switch(l)
[] => true
h::t => if(p(h)) all(p, t) else false
function any(p : 'a => bool, l : list('a)) : bool = switch(l)
[] => false
h::t => if(p(h)) true else any(p, t)
function sum(l : list(int)) : int = foldl ((a, b) => a + b, 0, l)
function product(l : list(int)) : int = foldl((a, b) => a * b, 1, l)
/* Zips two list by applying bimapping function on respective elements. Drops longer tail. */
private function zip_with( f : ('a, 'b) => 'c
, l1 : list('a)
, l2 : list('b)
) : list('c) = switch ((l1, l2))
(h1::t1, h2::t2) => f(h1, h2)::zip_with(f, t1, t2)
_ => []
/* Zips two lists into list of pairs. Drops longer tail. */
function zip(l1 : list('a), l2 : list('b)) : list('a * 'b) = zip_with((a, b) => (a, b), l1, l2)
function unzip(l : list('a * 'b)) : (list('a) * list('b)) = switch(l)
[] => ([], [])
(h1, h2)::t =>
let (t1, t2) = unzip(t)
(h1::t1, h2::t2)
// TODO: Improve?
function sort(lesser_cmp : ('a, 'a) => bool, l : list('a)) : list('a) = switch(l)
[] => []
h::t => switch (partition((x) => lesser_cmp(x, h), t))
(lesser, bigger) => sort(lesser_cmp, lesser) ++ h::sort(lesser_cmp, bigger)
function intersperse(delim : 'a, l : list('a)) : list('a) = switch(l)
[] => []
[e] => [e]
h::t => h::delim::intersperse(delim, t)
function enumerate(l : list('a)) : list(int * 'a) = enumerate_(l, 0)
private function enumerate_(l : list('a), n : int) : list(int * 'a) = switch(l)
[] => []
h::t => (n, h)::enumerate_(t, n + 1)