private: Yes


Set Implicit Arguments.

Require Import Arith.
Require Import Lia.
Require Import Recdef.

(* The factorial function. *)
Fixpoint fact n :=
  match n with
  | 0 => 1
  | S n => (S n) * (fact n)
  end.

(* Helper for the inverse factorial. Since it is counting up, we have
   to manually prove that it terminates. For this, we prove that the
   distance between n and a (i.e., n - a) decreases with each
   recursive call. If a was zero it would loop forever, so we reject
   it. *)
Function invfact_i n i a {measure (Nat.sub n) a} :=
  if a =? 0 then None else
    match n ?= a with
    | Lt => None
    | Eq => Some (1 + i)
    | Gt => invfact_i n (1 + i) (a * S (S i))
    end.
Proof.
  intros.
  apply nat_compare_Gt_gt in teq0.
  apply beq_nat_false in teq.
  lia.
Defined.

(* The actual definition of the inverse factorial. *)
Definition invfact n := invfact_i n 0 1.

(* Convenience lemma that reverses a step of computation. *)
Lemma invfact_back: forall n i a,
    a > 0 -> n > a -> invfact_i n (S i) (a * S (S i)) = invfact_i n i a.
Proof.
  intros. rewrite (invfact_i_equation _ i).
  replace (a =? 0) with false by (symmetry; rewrite Nat.eqb_neq; lia).
  replace (n ?= a) with Gt by (symmetry; now rewrite <- nat_compare_gt).
  reflexivity.
Qed.

(* Lemma: if the inverse of a is n, then the inverse of (n+1)*a must
   be n+1. *)
Lemma invfact_step: forall n a,
    n > 0 -> invfact a = Some n -> invfact (S n * a) = Some (S n).
Proof.
  unfold invfact. intros n a H.
  apply invfact_i_ind; try easy; intros.
  - apply nat_compare_eq in e0. inversion H0. subst.
    rewrite invfact_i_equation.
    rewrite e. apply beq_nat_false in e.
    replace (S (S i) * a0 ?= a0) with Gt by (symmetry; apply nat_compare_gt; lia).
    rewrite invfact_i_equation.
    replace (a0 * S (S i) =? 0) with false by (symmetry; rewrite Nat.eqb_neq; lia).
    now rewrite Nat.mul_comm, Nat.compare_refl. 
  - specialize (H0 H1). apply nat_compare_gt in e0. apply beq_nat_false in e.
    rewrite invfact_back in H0; auto; lia.
Qed.

(* Theorem: for all n > 0, the inverse of n! is n. *)
Theorem invfact_correct: forall n, n > 0 -> invfact (fact n) = Some n.
Proof.
  intros.
  induction n; [easy|]; destruct n; [easy|].
  assert (S n > 0) by lia. specialize (IHn H0).
  replace (fact (S (S n))) with (S (S n) * (fact (S n))) by easy.
  now apply invfact_step.
Qed.

(* Theorem: our invfact will never produce a zero. *)
Lemma invfact_i_zero: forall n i a, invfact_i n i a <> Some 0.
Proof. intros. apply invfact_i_ind; easy. Qed.

Theorem invfact_zero: forall n, invfact n <> Some 0.
Proof. intros. apply invfact_i_zero. Qed.