(* $Id: sequation.ml 56 2006-03-02 14:43:44Z guesdon $ *)

type solution = [`Real of float | `Complex of Complex.t]
let pi = 3.14159265

let print_check = ref (fun _ -> ())
(*let print_check = ref prerr_endline*)

let ( *? ) r c =
  Complex.polar (r *. (Complex.norm c)) (Complex.arg c)

let ( **? ) c r =
  let n = Complex.norm c
  and a = Complex.arg c in
  Complex.polar (n ** r) (a *. r)

(*
let complex_cubic_root c =
  let n = Complex.norm c
  and a = Complex.arg c in
  let n' =
    if n < 0.0 then
      -. ((-. n) ** (1. /. 3.))
    else
      n ** (1. /. 3.)
  in
  Complex.polar n' (a /. 3.)

let complex_square_root c =
  let n = Complex.norm c
  and a = Complex.arg c in
  let n' =
    if n < 0.0 then
      -. ((-. n) ** (1. /. 2.))
    else
      n ** (1. /. 2.)
  in
  Complex.polar n' (a /. 2.)
*)


let string_of_complex c =
  if c.Complex.im < 0.0 then
    Printf.sprintf "%f-%fi" c.Complex.re (-. c.Complex.im)
  else
    Printf.sprintf "%f+%fi" c.Complex.re c.Complex.im

let solve_deg_2 (a,b,c) =
  let delta = ((b ** 2.0) -. 4.0 *. a *. c) in
  !print_check (Printf.sprintf "solve_deg_2\na=%f" a);
  !print_check (Printf.sprintf "b=%f" b);
  !print_check (Printf.sprintf "c=%f" c);
  !print_check (Printf.sprintf "delta=%f" delta);
  let x1 = ((-. b) -. (sqrt delta)) /. (2.0 *. a) in
  let x2 = ((-. b) +. (sqrt delta)) /. (2.0 *. a) in
  let check x =
    !print_check (Printf.sprintf "a %f^2 + b %f + = %f"
                     x x
                     (a*. (x**2.0) +. b *. x +. c))
  in
  check x1;
  check x2;
  (`Real x1, `Real x2)

exception Not_degre_3

let solve_deg_3 (t,a2,a1,a0) =
  if t = 0.0 then raise Not_degre_3;
  let a2 = a2 /. t
  and a1 = a1 /. t
  and a0 = a0 /. t in
  let p = (-. 1. /. 3.) *. (a2 ** 2.0) +. a1
  and q = (2. /. 27.) *. (a2 ** 3.0) -. (1. /. 3.) *. a1 *. a2 +. a0 in
  let delta = 4. *. (p ** 3.0) +. 27. *. (q ** 2.0) in
  let j = Complex.polar 1.0 (2. *. pi /. 3.) in
  let j' = Complex.conj j in
  let check = function
      `Real x ->
        !print_check (Printf.sprintf " %f^3 + b/a %f^2 + c/a %f + d/a = %f"
                         x x x
                         ((x**3.0) +. a2 *. (x**2.0) +. a1 *. x +. a0))
    | `Complex c ->
        let x = string_of_complex c in
        let sum =
          List.fold_left
            Complex.add
            {Complex.re = a0; Complex.im = 0.0}
            [ (c **? 3.0) ;
              a2 *? (c **? 2.0);
              a1 *? c ;
            ]
        in
        !print_check (Printf.sprintf " %s^3 + b/a %s^2 + c/a %s + d/a = %s"
                         x x x (string_of_complex sum))
  in
  if delta >= 0.0 then
    (
     let u = ((-. 27. *. q +. 3. *. (sqrt 3.) *. (sqrt delta)) /. 2.0) ** (1. /. 3.) in
     let v0 = (-. 27. *. q -. 3. *. (sqrt 3.) *. (sqrt delta)) /. 2.0 in
     let v = if v0 >= 0.0 then v0 ** (1. /. 3.) else -. ((-. v0) ** (1. /. 3.)) in
     !print_check (Printf.sprintf "u=%f, v=%f" u v);
     let x1 = (1. /. 3.) *. (u +. v) in
     let x2 = (1. /. 3.) *? (Complex.add (u *? j) (v *? j'))
     and x3 = (1. /. 3.) *? (Complex.add (u *? j') (v *? j)) in
     let fc x = `Complex { x with Complex.re = x.Complex.re -. a2 /. 3. } in
     let x1 = `Real (x1 -. a2 /. 3.) in
     check x1;
     check (fc x2);
     check (fc x3);
     (x1, fc x2, fc x3)
    )
  else
    (
     let u0 =
       { Complex.re = (-. 27. *. q /. 2.0) ;
         Complex.im = (3. *. (sqrt 3.) *. (sqrt (-. delta)) /. 2.) ;
       }
     in
     let u = u0 **? (1. /. 3.) in
     !print_check (Printf.sprintf "u0=%f+%fi, u=%f+%fi"
                      u0.Complex.re u0.Complex.im
                      u.Complex.re u.Complex.im);
     let u' = Complex.conj u in
     let x1 = (Complex.add u u').Complex.re /. 3.
     and x2 = (Complex.add (Complex.mul j u) (Complex.mul j' u')).Complex.re /. 3.0
     and x3 =
       (Complex.add
          (Complex.mul (Complex.mul j j) u)
          (Complex.mul (Complex.mul j' j') u')).Complex.re /. 3.0
     in
     let f x = `Real (x -. a2 /. 3.) in
     check (f x1);
     check (f x2);
     check (f x3);
     (f x1, f x2, f x3)
    )

let solve_deg_4 ?(imeps=0.000001) ?(compeps=0.00000000001) (t,a1,b1,c1,d1) =
  !print_check
    (Printf.sprintf "solving %fX^4 + %fX^3 + %fX^2 + %fX + %f = 0"
                       t a1 b1 c1 d1);
  let a1 = a1 /. t
  and b1 = b1 /. t
  and c1 = c1 /. t
  and d1 = d1 /. t in
  let b = (-. 3. /. 8.) *. (a1 ** 2.0) +. b1
  and c = ((a1 ** 3.0) /. 8.) -. (a1 *. b1 /. 2.) +. c1
  and d = (-. 3. /. 256.) *. (a1 ** 4.0) +.
      (b1 *. (a1 ** 2.0) /. 16.) -. (a1 *. c1 /. 4.) +. d1 in
  let real_or_complex c =
    if c.Complex.im >= -. imeps && c.Complex.im <= imeps then
      `Real c.Complex.re
    else
      `Complex c
  in
  let to_sol = function
      `Real z ->
        let x = z -. a1 /. 4. in
        !print_check
          (Printf.sprintf "%f^4 + a1 %f^3 + b1 %f^2 + c1 %f + d1 = %f"
             x x x x
             ((x**4.0) +. a1 *. (x**3.0) +. b1 *. (x**2.0) +. c1 *. x +. d1)
          );
        `Real x
    | `Complex z ->
        let c = { z with Complex.re = z.Complex.re -. a1 /. 4. } in
        let s = string_of_complex c in
        let sum =
          List.fold_left
            Complex.add
            {Complex.re = d1; Complex.im = 0.0}
            [ (c **? 4.0) ;
              a1 *? (c **? 3.0);
              b1 *? (c **? 2.0);
              c1 *? c ;
            ]
        in
        !print_check (Printf.sprintf " %s^4 + b/a %s^3 + c/a %s^2 + d/a %s + e/a = %s"
                         s s s s (string_of_complex sum));

        `Complex c
  in
  let to_sol = function
      `Real x -> to_sol (`Real x)
    | `Complex c -> to_sol (real_or_complex c)
  in
  if d >= -. compeps && d <= compeps then
    let (z1,z2,z3) = solve_deg_3 (1.0, 0.0, b, c) in
    (to_sol (`Real 0.0), to_sol z1, to_sol z2, to_sol z3)
  else
    if c >= -. compeps && c <= compeps then
      let (z1,z2) = solve_deg_2 (1.0, b, d) in
      let to_sol2 = function
          `Real x -> to_sol (`Real (sqrt x))
        | `Complex _ -> `Real nan (* A VOIR *)
      in
      let to_sol2' = function
          `Real x -> to_sol (`Real (-. (sqrt x)))
        | `Complex _ -> `Real nan (* A VOIR *)
      in
      (to_sol2 z1, to_sol2' z1, to_sol2 z2, to_sol2' z2)
    else
      begin
        let (g1,g2,g3) =
          match solve_deg_3 (1.0, 8. *. b, 16. *.((b**2.0) -. 4. *. d), -. (64.0 *. (c**2.0))) with
            (`Real g1, `Real g2, `Real g3) ->
              ({ Complex.re = g1 ; im = 0.0 },
               { Complex.re = g2 ; im = 0.0 },
               { Complex.re = g3 ; im = 0.0 })
          | (`Real g1, `Complex g2, `Complex g3) ->
              ({ Complex.re = g1 ; im = 0.0 }, g2, g3)
          | _ -> assert false
        in
        let ro1 = Complex.sqrt g1 in
        let ro2 = Complex.sqrt g2 in
        let ro3 =
          let r = { Complex.re = (-. 8.) *. c; im = 0.0 } in
          prerr_endline (Printf.sprintf "-8c = %s" (string_of_complex r));
          Complex.div r (Complex.mul ro1 ro2)
        in
        let g3' = ro3 **? 2.0 in
        !print_check (Printf.sprintf "g3'=%s =? %s=g3"
                        (string_of_complex g3')
                      (string_of_complex g3)
                     );
(*
   let ro3 =
   if c >= 0.0 then
   Complex.neg (g3 **? (1. /. 2.))
   else
   g3 **? (1. /. 2.)
   in
*)

        let foo = Complex.mul ro1 (Complex.mul ro2 ro3) in
        !print_check
          (Printf.sprintf "ro1.ro2.ro3 = %s\n-8c = %f"
           (string_of_complex foo)
             (-. 8. *. c)
          );
        let (+?) = Complex.add and (-?) = Complex.sub in
        let q = 1. /. 4. in
        let x1 = to_sol (`Complex (q *? (ro1 +? ro2 +? ro3)))
        and x2 = to_sol (`Complex (q *? ((ro1 -? ro2) -? ro3)))
        and x3 = to_sol (`Complex (q *? ((ro2 -? ro3) -? ro1)))
        and x4 = to_sol (`Complex (q *? ((ro3 -? ro1) -? ro2))) in
        (x1, x2, x3, x4)
      end