private: Yes


Set Implicit Arguments.

Require Import Arith.
Require Import Lia.
Require Import Recdef.

(* The factorial function. *)
Fixpoint fact n :=
  match n with
  | 0 => 1
  | S n => (S n) * (fact n)
  end.

(* Helper for the inverse factorial. Since it is counting up, we have
   to manually prove that it terminates. For this, we prove that the
   distance between n and a (i.e., n - a) decreases with each
   recursive call. If a was zero it would loop forever, so we reject
   it. *)
Function invfact_i n i a {measure (Nat.sub n) a} :=
  if a =? 0 then None else
    match n ?= a with
    | Lt => None
    | Eq => Some (1 + i)
    | Gt => invfact_i n (1 + i) (a * S (S i))
    end.
Proof.
  intros.
  apply nat_compare_Gt_gt in teq0.
  apply beq_nat_false in teq.
  lia.
Defined.

(* The actual definition of the inverse factorial. *)
Definition invfact n := invfact_i n 0 1.

(* Convinience lemma that reverses a step of computation. *)
Lemma invfact_back: forall n i a,
    a > 0 -> n > a -> invfact_i n (S i) (a * S (S i)) = invfact_i n i a.
Proof.
  intros. rewrite (invfact_i_equation _ i).
  replace (a =? 0) with false by (symmetry; rewrite Nat.eqb_neq; lia).
  replace (n ?= a) with Gt by (symmetry; now rewrite <- nat_compare_gt).
  reflexivity.
Qed.

(* Lemma: if the inverse of a is n, then the inverse of (n+1)*a must
   be n+1. *)
Lemma invfact_step: forall n a,
    n > 0 -> invfact a = Some n -> invfact (S n * a) = Some (S n).
Proof.
  unfold invfact. intros n a H.
  apply invfact_i_ind; try easy; intros.
  - apply nat_compare_eq in e0. inversion H0. subst.
    rewrite invfact_i_equation.
    rewrite e. apply beq_nat_false in e.
    replace (S (S i) * a0 ?= a0) with Gt by (symmetry; apply nat_compare_gt; lia).
    rewrite invfact_i_equation.
    replace (a0 * S (S i) =? 0) with false by (symmetry; rewrite Nat.eqb_neq; lia).
    now rewrite Nat.mul_comm, Nat.compare_refl. 
  - specialize (H0 H1). apply nat_compare_gt in e0. apply beq_nat_false in e.
    rewrite invfact_back in H0; auto; lia.
Qed.

(* Theorem: for all n > 0, the inverse of n! is n. *)
Theorem invfact_correct: forall n, n > 0 -> invfact (fact n) = Some n.
Proof.
  intros.
  induction n; [easy|]; destruct n; [easy|].
  assert (S n > 0) by lia. specialize (IHn H0).
  replace (fact (S (S n))) with (S (S n) * (fact (S n))) by easy.
  now apply invfact_step.
Qed.

(* Theorem: our invfact will never produce a zero. *)
Lemma invfact_i_zero: forall n i a, invfact_i n i a <> Some 0.
Proof. intros. apply invfact_i_ind; easy. Qed.

Theorem invfact_zero: forall n, invfact n <> Some 0.
Proof. intros. apply invfact_i_zero. Qed.

Lemma invfact_i_lt: forall n i a x, invfact_i n i a = Some x -> i < x.
Proof.
  intros n i a. apply invfact_i_ind; try easy; intros.
  - inversion H. subst. lia.
  - specialize (H _ H0). lia.
Qed.

(* Theorem: If n is the inverse of both a and b, then a = b. *)
Theorem invfact_only: forall n a b, invfact a = Some n -> invfact b = Some n -> a = b.
Proof.
  unfold invfact. intros n a. apply invfact_i_ind; intros; try easy.
  - apply nat_compare_eq in e0. inversion H. subst. clear H.
    rewrite invfact_i_equation in H0. rewrite e in H0.
    destruct (b ?= a0) eqn:E.
    + now apply nat_compare_eq in E.
    + easy.
    + pose proof (invfact_i_lt _ _ _ H0). lia.
  - rewrite invfact_i_equation in H1. rewrite e in H1.
    destruct (b ?= a0) eqn:E.
    + inversion H1. subst. apply invfact_i_lt in H0. lia.
    + easy.
    + auto.
Qed.

(* From the previous theorem and invfact_correct, it easily follows
   that if the inverse of m is n, then m = n! *)
Corollary invfact_fact: forall m n, invfact m = Some n -> m = fact n.
Proof.
  intros. destruct n.
  - now apply invfact_zero in H.
  - assert (S n > 0) by lia. pose proof (invfact_correct H0).
    now apply invfact_only with (n := S n).
Qed.