(* $Id: smatrices.ml 56 2006-03-02 14:43:44Z guesdon $ *)

module Simple = struct
  type ('a, 'b, 'c) matrix =
      { mat_i : 'a array ;
        mat_j : 'b array ;
        mat_mat : 'c array array ;
      }

  let latex oc ?(cols=10)
      string_of_i string_of_j string_of_v m =
    let p = Printf.fprintf in
    let i_elements = m.mat_i in
    let j_elements = m.mat_j in
    let ilen = Array.length i_elements in
    let jlen = Array.length j_elements in
    let f_begin i =
      p oc "\\begin{tabular}{|l|";
      for k = i to min (i + cols - 1) (ilen - 1) do p oc "r|" done;
      p oc "}\n\\hline\n";
      for k = i to min (i + cols - 1) (ilen - 1) do
        p oc "& %s" (string_of_i i_elements.(k))
      done;
      p oc "\\\\\n\\hline\n"
    in
    let f_end () =
      p oc "\\end{tabular}\n\n"
    in
    let rec iter i =
      if i < ilen then
        (
         f_begin i;
         for j = 0 to jlen - 1 do
           p oc "%s " (string_of_j j_elements.(j)) ;
           for k = i to min (i + cols - 1) (ilen - 1) do
             p oc "& %s" (string_of_v m.mat_mat.(k).(j))
           done;
           p oc "\\\\\n\\hline\n";
         done;
         f_end ();
         iter (i + cols)
        )
      else
        ()
    in
    iter 0

  let store_matrix f dm =
    let oc = open_out_bin f in
    output_value oc dm;
    close_out oc

  let load_matrix f =
    let ic = open_in_bin f in
    let v = input_value ic in
    close_in ic;
    v

  exception Uncompatible_matrices
  let concat_matrices_i ?check mat1 mat2 =
    if Array.length mat1.mat_j <> Array.length mat2.mat_j then
      raise Uncompatible_matrices;
    (
     match check with
       None -> ()
     | Some f -> Array.iteri (fun i v -> f v mat2.mat_j.(i)) mat1.mat_j
    );
    let len_i1 = Array.length mat1.mat_i in
    let len_i2 = Array.length mat2.mat_i in
    let len_i = len_i1 + len_i2 in
    let len_j = Array.length mat1.mat_j in
    let init_v =
      if len_j < 1 then
        None
      else
        if len_i1 < 1 then
          if len_i2 < 1 then
            None
          else
            Some mat2.mat_mat.(0).(0)
        else
          Some mat1.mat_mat.(0).(0)
    in
    let mat =
      match init_v with
        None -> [| |]
      | Some v ->
          let mat = Array.make_matrix len_i len_j v in
          for i = 0 to len_i1 - 1 do
            for j = 0 to len_j - 1 do
              mat.(i).(j) <- mat1.mat_mat.(i).(j)
            done
          done;
          for i = 0 to len_i2 - 1 do
            for j = 0 to len_j - 1 do
              mat.(len_i1 + i).(j) <- mat2.mat_mat.(i).(j)
            done
          done;
          mat
    in
    {
      mat_i = Array.append mat1.mat_i mat2.mat_i ;
      mat_j = mat1.mat_j ;
      mat_mat = mat;
    }

  let concat_matrices_j ?check mat1 mat2 =
    if Array.length mat1.mat_i <> Array.length mat2.mat_i then
      raise Uncompatible_matrices;
    (
     match check with
       None -> ()
     | Some f -> Array.iteri (fun i v -> f v mat2.mat_i.(i)) mat1.mat_i
    );
    let len_j1 = Array.length mat1.mat_j in
    let len_j2 = Array.length mat2.mat_j in
    let len_j = len_j1 + len_j2 in
    let len_i = Array.length mat1.mat_i in
    let init_v =
      if len_i < 1 then
        None
      else
        if len_j1 < 1 then
          if len_j2 < 1 then
            None
          else
            Some mat2.mat_mat.(0).(0)
        else
          Some mat1.mat_mat.(0).(0)
    in
    let mat =
      match init_v with
        None -> [| |]
      | Some v ->
          let mat = Array.make_matrix len_i len_j v in
          for j = 0 to len_j1 - 1 do
            for i = 0 to len_i - 1 do
              mat.(i).(j) <- mat1.mat_mat.(i).(j)
            done
          done;
          for j = 0 to len_j2 - 1 do
            for i = 0 to len_i - 1 do
              mat.(i).(len_j1 + j) <- mat2.mat_mat.(i).(j)
            done
          done;
          mat
    in
    {
      mat_i = mat1.mat_i ;
      mat_j = Array.append mat1.mat_j mat2.mat_j;
      mat_mat = mat;
    }
end

module Distances = struct
  type ('a, 'b) dist_matrix =
      {
        dist_elements : 'a array ;
        dist_matrix : 'b array array ;
        dist_samples : int array ;
      }

  let make_matrix ?(commutative=true) f_dist init_value elements =
    let len = Array.length elements in
    let mat = Array.make_matrix len len init_value in
    for i = 0 to len - 1 do
      for j = i + 1 to len - 1 do
        let d = f_dist elements.(i) elements.(j) in
        mat.(i).(j) <- d;
        if commutative then
          mat.(j).(i) <- d
        else
          mat.(j).(i) <- f_dist elements.(j) elements.(i)
      done
    done;
    {
      dist_elements = elements ;
      dist_matrix = mat ;
      dist_samples = Array.make len 1 ;
    }

  let group_distances ?(commutative=true) f_key
      ~intra ~inter dm =
    let t = Hashtbl.create 13 in
    let add element ind =
      let key = f_key element in
      try
        let l = Hashtbl.find t key in
        Hashtbl.replace t key (ind :: l)
      with
        Not_found -> Hashtbl.add t key [ind]
    in
    Array.iteri
      (fun i element -> add element i)
      dm.dist_elements;

    let key_indices = Hashtbl.fold
        (fun key l acc -> Array.append [| key, l |] acc)
        t
        [| |]
    in
    let new_len = Array.length key_indices in
    let mat = Array.make_matrix
        new_len new_len dm.dist_matrix.(0).(0)
    in
    let new_elements = Array.map fst key_indices in
    let old_indices i = snd key_indices.(i) in
    for i = 0 to new_len - 1 do
      for j = i to new_len - 1 do
        let old_i_s = old_indices i in
        let old_j_s = old_indices j in
        let g_group = if i = j then intra else inter in
        let v =
          let ll_dists =
            List.map
              (fun oi ->
                List.map
                  (fun oj ->
                    dm.dist_matrix.(oi).(oj)
                  )
                  old_j_s
              )
              old_i_s
          in
          g_group ll_dists
        in
        mat.(i).(j) <- v;
        if i <> j then
          if commutative then
            mat.(j).(i) <- v
          else
            let v =
              let ll_dists =
                List.map
                  (fun oj ->
                    List.map
                      (fun oi ->
                        dm.dist_matrix.(oj).(oi)
                      )
                      old_i_s
                  )
                  old_j_s
              in
              inter ll_dists
            in
            mat.(j).(i) <- v
      done
    done;
    { dist_elements = new_elements ;
      dist_matrix = mat ;
      dist_samples =
      Array.map
        (fun (_,l) ->
          List.fold_left (fun acc i -> acc + dm.dist_samples.(i)) 0
            l
        )
        key_indices ;
    }

  let latex oc
      ?(samples_in_rows=false)
      ?(cols=10)
      ?(fill=false)
      ?(rotate_titles=true)
      string_of_key string_of_dist dm =
    let p = Printf.fprintf in
    let len = Array.length dm.dist_elements in
    let f_begin i =
      p oc "\\begin{tabular}{|l|";
      for k = i to min (i + cols - 1) (len - 1) do p oc "r|" done;
      p oc "}\n\\hline\n";
      for k = i to min (i + cols - 1) (len - 1) do
        if rotate_titles then
          p oc "& \\rotatebox{90}{%s}" (string_of_key dm.dist_elements.(k))
        else
          p oc "& %s" (string_of_key dm.dist_elements.(k))
      done;
      if samples_in_rows then
        (
         p oc "\\\\\n\\hline\nsamples";
         for k = i to min (i + cols - 1) (len - 1) do
           p oc "& (%d)" dm.dist_samples.(k)
         done;
        );
      p oc "\\\\\n\\hline\n"
    in
    let f_end () =
      p oc "\\end{tabular}\n\n"
    in
    let rec iter i =
      if i < len then
        (
         f_begin i;
         for j = 0 to len - 1 do
           p oc "%s " (string_of_key dm.dist_elements.(j)) ;
           for k = i to min (i + cols - 1) (len - 1) do
             if fill or j <= k then
               p oc "& %s" (string_of_dist dm.dist_matrix.(k).(j))
             else
               p oc "&      ";
           done;
           p oc "\\\\\n\\hline\n";
         done;
         f_end ();
         iter (i + cols)
        )
      else
        ()
    in
    iter 0

  let latex_sorted_dists oc
      string_of_key string_of_dist comp dm =
    let len = Array.length dm.dist_elements in
    let subarray i t =
      Array.append
        (Array.sub t 0 i)
        (if i < len - 1 then
          (Array.sub t (i+1) (len - i - 1))
        else
          [| |]
        )
    in
    let f i _ =
      let teles = subarray i dm.dist_elements in
      let tdists = subarray i dm.dist_matrix.(i) in
      let t =
        Array.mapi
          (fun i ele -> (ele, tdists.(i)))
          teles
      in
      Array.sort (fun (_,d1) (_,d2) -> comp d1 d2) t;
      t
    in
    let t_couples = Array.mapi f dm.dist_elements in
    let p = Printf.fprintf in
    let print i ele =
      p oc "\n{\\bf %s}: " (string_of_key ele);
      let s = String.concat " $<$ "
          (Array.to_list
             (Array.map
                (fun (key,dist) ->
                  Printf.sprintf "%s(%s)"
                    (string_of_key key)
                    (string_of_dist dist)
                )
                t_couples.(i)
             )
          )
      in
      p oc "%s\n" s
    in
    Array.iteri print dm.dist_elements


  let store_dist_matrix f dm =
    let oc = open_out_bin f in
    output_value oc dm;
    close_out oc

  let load_dist_matrix f =
    let ic = open_in_bin f in
    let v = input_value ic in
    close_in ic;
    v
end