forked from thanhnguyen-aws/plausible
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathMExp.lean
More file actions
444 lines (391 loc) · 19.4 KB
/
MExp.lean
File metadata and controls
444 lines (391 loc) · 19.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
import Plausible.Arbitrary
import Plausible.Chamelean.ArbitrarySizedSuchThat
import Plausible.Chamelean.Enumerators
import Plausible.Chamelean.DecOpt
import Plausible.Chamelean.TSyntaxCombinators
import Plausible.Chamelean.Schedules
import Plausible.Chamelean.UnificationMonad
import Plausible.Chamelean.Idents
open Plausible
open Idents
open Lean Parser Elab Term Command
-- Adapted from QuickChick source code
-- https://github.com/QuickChick/QuickChick/blob/internal-rewrite/plugin/newGenericLib.ml
/-- The sort of monad we are compiling to, i.e. one of the following:
- An unconstrained / constrained generator (`Gen` / `OptionT Gen`)
- An unconstrained / constrained enumerator (`Enumerator` / `OptionT Enumerator`)
- A Checker (`Option Bool` monad) -/
inductive MonadSort
| Gen
| OptionTGen
| Enumerator
| OptionTEnumerator
| Checker
deriving Repr, BEq
/-- Determines whether a `MonadSort` corresponds to a monad
for an enumerator (i.e. `Enumerator` or `OptionT Enumerator`) -/
def MonadSort.isEnumerator : MonadSort → Bool
| .Enumerator | .OptionTEnumerator => true
| _ => false
/-- An intermediate representation of monadic expressions that are
used in generators/enumerators/checkers.
- Schedules are compiled to `MExp`s, which are then compiled to Lean code
- Note: `MExp`s make it easy to optimize generator code down the line
(e.g. combine pattern-matches when we have disjoint patterns
- The cool thing about `MExp` is that we can interpret it differently
based on the `MonadSort` -/
inductive MExp : Type where
/-- `MRet e` represents `return e` in some monad -/
| MRet (e : MExp)
/-- `MBind monadSort m1 vars m2` represents `m1 >>= fun vars => m2` in a particular monad,
as determined by `monadSort` -/
| MBind (monadSort : MonadSort) (m1 : MExp) (vars : List Unknown) (m2 : MExp)
/-- N-ary function application -/
| MApp (f : MExp) (args : List MExp)
/-- N-ary constructor application -/
| MCtr (c : Name) (args : List MExp)
/-- Some constant name (e.g. refers to functions) -/
| MConst (name : Name)
/-- `MMatch scrutinee [(p1, e1), …, (pn, en)]` represents
```lean
match scrutinee with
| p1 => e1
...
| pn => en
```
-/
| MMatch (scrutinee : MExp) (cases : List (Pattern × MExp))
/-- Refers to a variable identifier -/
| MId (name : Name)
/-- A function abstraction, where `args` is a list of variable names,
and `body` is an `MExp` representing the function body -/
| MFun (args : List Name) (body : MExp)
/-- Signifies failure (corresponds to the term `OptionT.fail`) -/
| MFail
/-- Signifies running out of fuel -/
| MOutOfFuel
deriving Repr, Inhabited, BEq
/-- Converts a `ProducerSort` to a `MonadSort`
representing an unconstrained producer (i.e. `Gen` or `Enumerator`) -/
def prodSortToMonadSort (prodSort : ProducerSort) : MonadSort :=
match prodSort with
| .Enumerator => MonadSort.Enumerator
| .Generator => MonadSort.Gen
/-- Converts a `ProducerSort` to a `MonadSort`
representing a *constrained* producer
(i.e. `OptionT Gen` or `OptionT Enumerator`) -/
def prodSortToOptionTMonadSort (prodSort : ProducerSort) : MonadSort :=
match prodSort with
| .Enumerator => MonadSort.OptionTEnumerator
| .Generator => MonadSort.OptionTGen
/-- `MExp` representation of `EnumSizedSuchThat.enumSizedST`,
where `prop` is the `Prop` constraining the value being enumerated
and `fuel` is an `MExp` representing the fuel argument to the enumerator -/
def enumSizedST (prop : MExp) (fuel : MExp) : MExp :=
.MApp (.MConst ``EnumSizedSuchThat.enumSizedST) [prop, fuel]
/-- `MExp` representation of `ArbitrarySizedSuchThat.arbitrarySizedST`,
where `prop` is the `Prop` constraining the value being generated
and `fuel` is an `MExp` representing the fuel argument to the generator -/
def arbitrarySizedST (prop : MExp) (fuel : MExp) : MExp :=
.MApp (.MConst ``ArbitrarySizedSuchThat.arbitrarySizedST) [prop, fuel]
/-- `mexpSome x` is an `MExp` representing `Option.some x`.
We call this function `mexpSome` to avoid name clashes with the existing `some` constructor
for `Option` types. -/
def mexpSome (x : MExp) : MExp :=
.MApp (.MConst ``Option.some) [x]
/-- `someTrue` is an `MExp` representing `some true`
- This expression is often used when deriving checkers, so we define it here as an abbreviation. -/
def someTrue : MExp :=
mexpSome (.MConst ``true)
/-- `someFalse` is an `MExp` representing `some false`
- This expression is often used when deriving checkers, so we define it here as an abbreviation. -/
def someFalse : MExp :=
mexpSome (.MConst ``false)
/-- Converts a `List α` to a "tuple", where the function `pair`
is used to create tuples. The `default` element is used when
the input list `l` is empty, although for most use-cases,
this function will be called with non-empty lists `l`, so `default`
will be `none`. -/
def tupleOfList [Inhabited α] (pair : α → α → α) (l : List α) (default : Option α) : α :=
match l with
| [] => default.get!
| [x] => x
| x :: xs => List.foldl pair x xs
/-- Converts a list of `Pattern`s to a one single `Pattern` expressed
as a tuple -/
def patternTupleOfList (xs : List Pattern) : Pattern :=
tupleOfList (fun x y => Pattern.CtorPattern ``Prod.mk [x, y]) xs none
/-- Compiles a `Pattern` to a `TSyntax term` -/
partial def compilePattern (p : Pattern) : MetaM (TSyntax `term) :=
match p with
| .UnknownPattern u => `($(mkIdent u):ident)
| .CtorPattern ctorName args => do
let compiledArgs ← args.toArray.mapM compilePattern
`($(mkIdent ctorName):ident $compiledArgs*)
/-- `MExp` representation of a `DecOpt` instance (a checker).
Specifically, `decOptChecker prop fuel` represents the term
`DecOpt.decOpt $prop $fuel`. -/
def decOptChecker (prop : MExp) (fuel : MExp) : MExp :=
.MApp (.MConst ``DecOpt.decOpt) [prop, fuel]
/-- Converts a `ConstructorExpr` to an `MExp` -/
partial def constructorExprToMExp (expr : ConstructorExpr) : MExp :=
match expr with
| .Unknown u => .MId u
| .Ctor c args => .MCtr c (constructorExprToMExp <$> args)
/-- `MExp` representation of a recursive function call,
where `f` is the function name and `args` are the arguments
(each represented as a `ConstructorExpr`) -/
def recCall (f : Name) (args : List ConstructorExpr) : MExp :=
.MApp (.MId f) $
[.MId `initSize, .MId `size'] ++ (constructorExprToMExp <$> args)
/-- Converts a `HypothesisExpr` to an `MExp` -/
def hypothesisExprToMExp (hypExpr : HypothesisExpr) : MExp :=
let (ctorName, ctorArgs) := hypExpr
.MCtr ctorName (constructorExprToMExp <$> ctorArgs)
/-- `Pattern` that represents a wildcard (i.e. `_` on the LHS of a pattern-match) -/
def wildCardPattern : Pattern :=
.UnknownPattern `_
/-- `MExp` representing a pattern-match on a `scrutinee` of type `Option Bool`.
Specifically, `matchOptionBool scrutinee trueBranch falseBranch` represents
```lean
match scrutinee with
| some true => $trueBranch
| some false => $falseBranch
| none => $MExp.MOutOfFuel
```
-/
def matchOptionBool (scrutinee : MExp) (trueBranch : MExp) (falseBranch : MExp) : MExp :=
.MMatch scrutinee
[
(.CtorPattern ``some [.UnknownPattern ``true], trueBranch),
(.CtorPattern ``some [.UnknownPattern ``false], falseBranch),
(.UnknownPattern ``none, .MOutOfFuel)
]
/-- `CompileScheduleM` is a monad for compiling `Schedule`s to `TSyntax term`s.
Under the hood, this is just a `State` monad stacked on top of `TermElabM`,
where the state is an `Array` of `TSyntax term`s, representing any auxiliary typeclass
instances that need to derived beforehand. -/
abbrev CompileScheduleM (α : Type) := StateT (TSyntaxArray `term) TermElabM α
/-- `MExp` representation of an unconstrained producer,
parameterized by a `producerSort` and the type `ty` (represented as a `TSyntax term`)
of the value being generated -/
def unconstrainedProducer (prodSort : ProducerSort) (ty : TSyntax `term) : CompileScheduleM MExp := do
let typeClassName :=
match prodSort with
| .Enumerator => ``Enum
| .Generator => ``Arbitrary
let typeClassInstance ← `( $(Lean.mkIdent typeClassName) $ty:term )
-- Add the `typeClassInstance` for the unconstrained producer to the state,
-- then obtain the `MExp` representing the unconstrained producer
StateT.modifyGet $ fun instances =>
let producerMExp :=
match prodSort with
| .Enumerator => .MConst ``Enum.enum
| .Generator => .MConst ``Arbitrary.arbitrary
(producerMExp, instances.push typeClassInstance)
mutual
/-- Compiles a `MExp` to a Lean `doElem`, according to the `DeriveSort` provided -/
partial def mexpToTSyntax (mexp : MExp) (deriveSort : DeriveSort) : CompileScheduleM (TSyntax `term) :=
match mexp with
| .MId v | .MConst v => `($(mkIdent v))
| .MApp func args => do
let f ← mexpToTSyntax func deriveSort
let compiledArgs ← args.toArray.mapM (fun e => mexpToTSyntax e deriveSort)
`($f $compiledArgs*)
| .MCtr ctorName args => do
let compiledArgs ← args.toArray.mapM (fun e => mexpToTSyntax e deriveSort)
`($(mkIdent ctorName) $compiledArgs*)
| .MFun vars body => do
let compiledBody ← mexpToTSyntax body deriveSort
match vars with
| [] => throwError "empty list of function arguments supplied to MFun"
| [x] => `((fun $(mkIdent x) => $compiledBody))
| _ =>
-- When we have multiple args, create a tuple containing all of them
-- in the argument of the lambda
let args ← mkTuple vars
`((fun $args:term => $compiledBody))
| .MFail | .MOutOfFuel =>
-- Note: right now we compile `MFail` and `MOutOfFuel` to the same Lean terms
-- for simplicity, but in the future we may want to distinguish them
match deriveSort with
| .Generator | .Enumerator => `($failFn)
| .Checker => `($(mkIdent ``some) $(mkIdent ``false))
| .Theorem => throwError "compiling MExps for Theorem DeriveSorts not implemented"
| .MRet e => do
let e' ← mexpToTSyntax e deriveSort
`(return $e')
| .MBind monadSort m vars k => do
-- Compile the monadic expression `m` and the continuation `k` to `TSyntax term`s
let m1 ← mexpToTSyntax m deriveSort
let k1 ← mexpToTSyntax k deriveSort
match deriveSort, monadSort with
| .Generator, .Gen
| .Generator, .OptionTGen
| .Enumerator, .Enumerator
| .Enumerator, .OptionTEnumerator =>
-- If there are multiple variables that are bound to the result
-- of the monadic expression `m`, convert them to a tuple
let compiledArgs ←
if vars.isEmpty then
throwError m!"empty list of vars supplied to MBind, deriveSort = {repr deriveSort}, monadSort = {repr monadSort}, m1 = {m1}, k1 = {k1}"
else
mkTuple vars
-- If we have a producer, we can just produce a monadic bind
`(do let $compiledArgs:term ← $m1:term ; $k1:term)
| .Generator, .Checker
| .Enumerator, .Checker => do
-- If a producer invokes a checker, we have to invoke the checker
-- provided by the `DecOpt` instance for the proposition, then pattern
-- match on its result
let trueCase ← `(Term.matchAltExpr| | $(mkIdent ``some) $(mkIdent ``true) => $k1)
let wildCardCase ← `(Term.matchAltExpr| | _ => $failFn)
let cases := #[trueCase, wildCardCase]
`(match $m1:term with $cases:matchAlt*)
| .Checker, .Checker =>
-- If the continuation of the bind is just returning `some True`,
-- we can just inline the checker call `m1` to avoid the extra indirection
-- of calling checker combinator functions
if k == someTrue then
`($m1:term)
else
-- For checkers, we can just invoke `DecOpt.andOptList`
`($andOptListFn [$m1:term, $k1:term])
| .Checker, .Enumerator
| .Checker, .OptionTEnumerator => do
-- If there are multiple variables that are bound to the result
-- of the enumerator `m`, convert them to a tuple
let args ←
if vars.isEmpty then
throwError m!"empty list of vars supplied to MBind, deriveSort = {repr deriveSort}, monadSort = {repr monadSort}, m1 = {m1}, k1 = {k1}"
else
mkTuple vars
-- We pass in `(min 2 initSize)` as the amount of fuel for the enumerator to avoid stack-overflow
-- See https://github.com/ngernest/chamelean/issues/40 for details
let fuelForEnumerator ← `($(mkIdent `min) 2 $initSizeIdent)
match monadSort with
| .Enumerator =>
-- If a checker invokes an unconstrained enumerator,
-- we call `EnumeratorCombinators.enumerating` a la QuickChick
`($enumeratingFn $m1:term (fun $args:term => $k1:term) $fuelForEnumerator:term)
| .OptionTEnumerator =>
-- If a checker invokes a contrained enumerator,
-- we call `EnumeratorCombinators.enumeratingOpt` a la QuickChick
`($enumeratingOptFn $m1:term (fun $args:term => $k1:term) $fuelForEnumerator:term)
| .(_) => throwError "Unreachable pattern match: Checkers can only invoke enumerators in this branch"
| .Theorem, _ => throwError "Theorem DeriveSort not implemented yet"
| _, _ => throwError m!"Invalid monadic bind for deriveSort {repr deriveSort}"
| .MMatch scrutinee cases => do
-- Compile the scrutinee, the LHS & RHS of each case separately
let compiledScrutinee ← mexpToTSyntax scrutinee deriveSort
let compiledCases ← cases.toArray.mapM (fun (pattern, rhs) => do
let lhs ← compilePattern pattern
let compiledRHS ← mexpToTSyntax rhs deriveSort
`(Term.matchAltExpr| | $lhs:term => $compiledRHS))
`(match $compiledScrutinee:term with $compiledCases:matchAlt*)
/-- `MExp` representation of a constrained producer,
parameterized by a `producerSort`, a list of variable names & their types `varsTys`,
and a `Prop` (`prop`) constraining the values being produced
- Note: this function corresponds to `such_that_producer`
in the QuickChick code -/
partial def constrainedProducer (prodSort : ProducerSort) (varsTys : List (Name × ConstructorExpr)) (prop : MExp) (fuel : MExp) : CompileScheduleM MExp :=
if varsTys.isEmpty then
panic! "Received empty list of variables for constrainedProducer"
else do
-- Determine whether the typeclass instance for the constrained generator already exists
-- i.e. check if an instance for `ArbitrarySizedSuchThat` / `EnumSizedSuchThat` with the
-- specified `argTys` and `prop` already exists
let (args, argTys) := List.unzip varsTys
let argsTuple ← mkTuple args
let argTyTerms ← monadLift $ argTys.toArray.mapM constructorExprToTSyntaxTerm
let propBody ← mexpToTSyntax prop .Generator
let typeClassName :=
match prodSort with
| .Enumerator => ``EnumSizedSuchThat
| .Generator => ``ArbitrarySizedSuchThat
let typeClassInstance ← `($(mkIdent typeClassName) $argTyTerms* (fun $argsTuple:term => $propBody))
-- Add the `typeClassInstance` for the constrained producer to the state,
-- then obtain the `MExp` representing the constrained producer
StateT.modifyGet $ fun instances =>
let producerWithArgs := MExp.MFun args prop
let producerMExp :=
match prodSort with
| .Enumerator => enumSizedST producerWithArgs fuel
| .Generator => arbitrarySizedST producerWithArgs fuel
(producerMExp, instances.push typeClassInstance)
end
/-- Compiles a `ScheduleStep` to an `MExp`.
Note that `MExp` that is returned by this function is represented
as a function `MExp → MExp`, akin to difference lists in Haskell
(see https://www.seas.upenn.edu/~cis5520/22fa/lectures/stub/03-trees/DList.html)
The arguments to this function are:
- The current step of the schedule (`step`)
- The function parameter `k` represents the remainder of the `mexp`
(the rest of the monadic `do`-block)
- `mfuel` and `defFuel` are `MExp`s representing the current size and the initial size
supplied to the generator/enumerator/checker we're deriving
-/
def scheduleStepToMExp (step : ScheduleStep) (defFuel : MExp) (k : MExp) : CompileScheduleM MExp :=
match step with
| .Unconstrained v src prodSort => do
let producer ←
match src with
| Source.NonRec hyp =>
let ty ← hypothesisExprToTSyntaxTerm hyp
unconstrainedProducer prodSort ty
| Source.Rec f args => pure $ recCall f args
pure $ .MBind (prodSortToMonadSort prodSort) producer [v] k
| .SuchThat varsTys prod ps => do
let producer ←
match prod with
| Source.NonRec hypExpr =>
constrainedProducer ps varsTys (hypothesisExprToMExp hypExpr) defFuel
| Source.Rec f args => pure (recCall f args)
let vars := Prod.fst <$> varsTys
pure $ .MBind (prodSortToOptionTMonadSort ps) producer vars k
| .Check src _ =>
-- TODO: double check if we need to pattern-match on `scheduleSort` here
let checker :=
match src with
| Source.NonRec hypExpr =>
decOptChecker (hypothesisExprToMExp hypExpr) defFuel
| Source.Rec f args =>
recCall f args
-- TODO: handle checking hypotheses w/ negative polarity (currently not handled)
-- TODO: double check if this is right
pure $ .MBind .Checker checker [] k
| .Match scrutinee pattern =>
pure $ .MMatch (.MId scrutinee) [(pattern, k), (wildCardPattern, .MFail)]
/-- Converts a `Schedule` (a list of `ScheduleStep`s along with a `ScheduleSort`,
which acts as the conclusion of the schedule) to an `MExp`.
- `mfuel` and `defFuel` are auxiliary `MExp`s representing the fuel
for the function we are deriving (these correspond to `size` and `initSize`
in the QuickChick code for the derived functions) -/
def scheduleToMExp (schedule : Schedule) (mfuel : MExp) (defFuel : MExp) : CompileScheduleM MExp :=
let (scheduleSteps, scheduleSort) := schedule
-- Determine the *epilogue* of the schedule (i.e. what happens after we
-- have finished executing all the `scheduleStep`s)
let epilogue :=
match scheduleSort with
| .ProducerSchedule _ conclusionOutputs =>
-- Convert all the outputs in the conclusion to `mexp`s
let conclusionMExps := constructorExprToMExp <$> conclusionOutputs
-- If there are multiple outputs, wrap them in a tuple
match conclusionMExps with
| [] => panic! "No outputs being returned in producer schedule"
| [output] => MExp.MRet output
| outputs => tupleOfList (fun e1 e2 => .MApp (.MConst ``Prod.mk) [e1, e2]) outputs outputs[0]?
| .CheckerSchedule => someTrue
| .TheoremSchedule conclusion typeClassUsed =>
-- Create a pattern-match on the result of hte checker
-- on the conclusion, returning `some true` or `some false` accordingly
let conclusionMExp := hypothesisExprToMExp conclusion
let scrutinee :=
if typeClassUsed then decOptChecker conclusionMExp mfuel
else conclusionMExp
matchOptionBool scrutinee someTrue someFalse
-- Fold over the `scheduleSteps` and convert each of them to a functional `MExp`
-- Note that the fold composes the `MExp`, and we use `foldr` since
-- we want the `epilogue` to be the base-case of the fold
List.foldrM (fun step acc => scheduleStepToMExp step defFuel acc)
epilogue scheduleSteps