diff --git a/settings.gradle.kts b/settings.gradle.kts index 75ba8e1ab5..04e91e25df 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -4,3 +4,4 @@ include("usvm-core") include("usvm-jvm") include("usvm-util") include("usvm-sample-language") +include("usvm-ml-path-selection") diff --git a/usvm-ml-path-selection/README.md b/usvm-ml-path-selection/README.md new file mode 100644 index 0000000000..8bc742c259 --- /dev/null +++ b/usvm-ml-path-selection/README.md @@ -0,0 +1,44 @@ +## Machine Learning Path Selector + +### Entry point + +To run tests with this path selector use `jarRunner.kt`. You can pass a path to a configuration json as the first argument. Gathered statistics will be put in a folder according to your configuration. + +### Config + +A config object is declared inside `MLConfig.kt`. A detailed description of all the options is listed below: + +- `gameEnvPath` - a path to a folder that contains trained models (`rnn_cell.onnx`, `gnn_model.onnx`, `actor_model.onnx`) and a blacklist of tests to be skipped (`blacklist.txt`), also some logs are saved to this folder +- `dataPath` - a path to a folder to save all statistics into +- `defaultAlgorithm` - an algorithm to use if a trained model is not found, must be one of: `BFS`, `ForkDepthRandom` +- `postprocessing` - how actor model's outputs should be processed, must be one of: `Argmax` (choose an id of the maximum value), `Softmax` (sample from a distribution derived from the outputs via the softmax), `None` (sample from the outputs — only when they form a distribution) +- `mode` - a mode for `jarRunner.kt`, must be one of: `Calculation` (to calculate statistics used to train models), `Aggregation` (to aggregate statistics for different tests into one file), `Both` (to both calculate statistics and aggregate them), `Test` (to test this path selector with different time limits and compare it to other path selectors) +- `logFeatures` - whether to save statistics used to train models +- `shuffleTests` - whether to shuffle tests before running (affects the tests being run if the `dataConsumption` option is less than 100) +- `discounts` - time discounts used when testing path selectors +- `inputShape` - an input shape of an actor model +- `maxAttentionLength` - a maximum attention length of a PPO actor model +- `useGnn` - whether to use a GNN model +- `dataConsumption` - a percentage of tests to run +- `hardTimeLimit` - a time limit for one test +- `solverTimeLimit` - a time limit for one solver call +- `maxConcurrency` - a maximum number of threads running different tests concurrently +- `graphUpdate` - when to update block graph data, must be one of: `Once` (at the beginning of a test), `TestGeneration` (every time a new test is generated) +- `logGraphFeatuers` - whether to save graph statistics used to train a GNN model to a dataset file +- `gnnFeaturesCount` - a number of features that a GNN model returns +- `useRnn` - whether to use an RNN model +- `rnnStateShape` - a shape of an RNN state +- `rnnFeaturesCount` - a number of features that an RNN model returns +- `inputJars` - jars and their packages to run tests on + +### How to modify the metric + +To modify the metric you may change values of the `reward` property of the `ActionData` objects. They are written inside the property `path` of the `FeaturesLoggingPathSelector`. Currently, the metric is calculated in the `remove` method of the `FeaturesLoggingPathSelector`. + +### Training environment + +The training environment and its description are inside `environment.zip`. + +### "Modified" files + +Source files which names start with "Modified" are modified copies of files from other modules. They were modified to support this path selector. diff --git a/usvm-ml-path-selection/build.gradle.kts b/usvm-ml-path-selection/build.gradle.kts new file mode 100644 index 0000000000..42636b2eea --- /dev/null +++ b/usvm-ml-path-selection/build.gradle.kts @@ -0,0 +1,22 @@ +object MLVersions { + const val serialization = "1.5.1" + const val onnxruntime = "1.15.1" + const val dotlin = "1.0.2" +} + +plugins { + id("usvm.kotlin-conventions") + kotlin("plugin.serialization") version "1.8.21" +} + +dependencies { + implementation(project(":usvm-jvm")) + implementation(project(":usvm-core")) + + implementation("org.jacodb:jacodb-analysis:${Versions.jcdb}") + implementation("ch.qos.logback:logback-classic:${Versions.logback}") + + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:${MLVersions.serialization}") + implementation("io.github.rchowell:dotlin:${MLVersions.dotlin}") + implementation("com.microsoft.onnxruntime:onnxruntime:${MLVersions.onnxruntime}") +} diff --git a/usvm-ml-path-selection/environment.zip b/usvm-ml-path-selection/environment.zip new file mode 100644 index 0000000000..5b05175cdc Binary files /dev/null and b/usvm-ml-path-selection/environment.zip differ diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/CoverageCounter.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/CoverageCounter.kt new file mode 100644 index 0000000000..b26d71545e --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/CoverageCounter.kt @@ -0,0 +1,80 @@ +package org.usvm + +import kotlinx.serialization.Serializable +import java.util.concurrent.ConcurrentHashMap + +class CoverageCounter( + private val mlConfig: MLConfig +) { + private val testCoverages = ConcurrentHashMap>() + private val testStatementsCounts = ConcurrentHashMap() + private val testDiscounts = ConcurrentHashMap>() + private val testFinished = ConcurrentHashMap() + + fun addTest(testName: String, statementsCount: Float) { + testCoverages[testName] = List(mlConfig.discounts.size) { 0.0f } + testStatementsCounts[testName] = statementsCount + testDiscounts[testName] = List(mlConfig.discounts.size) { 1.0f } + testFinished[testName] = false + } + + fun updateDiscounts(testName: String) { + testDiscounts[testName] = testDiscounts.getValue(testName) + .mapIndexed { id, currentDiscount -> mlConfig.discounts[id] * currentDiscount } + } + + fun updateResults(testName: String, newCoverage: Float) { + val currentDiscounts = testDiscounts.getValue(testName) + testCoverages[testName] = testCoverages.getValue(testName) + .mapIndexed { id, currentCoverage -> currentCoverage + currentDiscounts[id] * newCoverage } + } + + fun finishTest(testName: String) { + testFinished[testName] = true + } + + fun reset() { + testCoverages.clear() + testStatementsCounts.clear() + testDiscounts.clear() + testFinished.clear() + } + + private fun getTotalCoverages(): List { + return testCoverages.values.reduce { acc, floats -> + acc.zip(floats).map { (total, value) -> total + value } + } + } + + @Serializable + data class TestStatistics( + private val discounts: Map, + private val statementsCount: Float, + private val finished: Boolean, + ) + + @Serializable + data class Statistics( + private val tests: Map, + private val totalDiscounts: Map, + private val totalStatementsCount: Float, + private val finishedTestsCount: Float, + ) + + fun getStatistics(): Statistics { + val discountStrings = mlConfig.discounts.map { it.toString() } + val testStatistics = testCoverages.mapValues { (test, coverages) -> + TestStatistics( + discountStrings.zip(coverages).toMap(), + testStatementsCounts.getValue(test), + testFinished.getValue(test), + ) + } + return Statistics( + testStatistics, + discountStrings.zip(getTotalCoverages()).toMap(), + testStatementsCounts.values.sum(), + testFinished.values.sumOf { if (it) 1.0 else 0.0 }.toFloat(), + ) + } +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/MLConfig.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/MLConfig.kt new file mode 100644 index 0000000000..fc0a8b67ee --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/MLConfig.kt @@ -0,0 +1,51 @@ +package org.usvm + +enum class Postprocessing { + Argmax, + Softmax, + None, +} + +enum class Mode { + Calculation, + Aggregation, + Both, + Test, +} + +enum class Algorithm { + BFS, + ForkDepthRandom, +} + +enum class GraphUpdate { + Once, + TestGeneration, +} + +data class MLConfig ( + val gameEnvPath: String = "../Game_env", + val dataPath: String = "../Data", + val defaultAlgorithm: Algorithm = Algorithm.BFS, + val postprocessing: Postprocessing = Postprocessing.Argmax, + val mode: Mode = Mode.Both, + val logFeatures: Boolean = true, + val shuffleTests: Boolean = true, + val discounts: List = listOf(1.0f, 0.998f, 0.99f), + val inputShape: List = listOf(1, -1, 77), + val maxAttentionLength: Int = -1, + val useGnn: Boolean = true, + val dataConsumption: Float = 100.0f, + val hardTimeLimit: Int = 30000, // in ms + val solverTimeLimit: Int = 10000, // in ms + val maxConcurrency: Int = 64, + val graphUpdate: GraphUpdate = GraphUpdate.Once, + val logGraphFeatures: Boolean = false, + val gnnFeaturesCount: Int = 8, + val useRnn: Boolean = true, + val rnnStateShape: List = listOf(4, 1, 512), + val rnnFeaturesCount: Int = 33, + val inputJars: Map> = mapOf( + Pair("../Game_env/jars/usvm-jvm-new.jar", listOf("org.usvm.samples", "com.thealgorithms")) + ) // path to jar file -> list of package names +) diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/ModifiedUMachineOptions.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ModifiedUMachineOptions.kt new file mode 100644 index 0000000000..66e23e39f8 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ModifiedUMachineOptions.kt @@ -0,0 +1,19 @@ +package org.usvm + +enum class ModifiedPathSelectionStrategy { + /** + * Collects features according to states selected by any other path selector. + */ + FEATURES_LOGGING, + /** + * Collects features and feeds them to the ML model to select states. + * Extends FEATURE_LOGGING path selector. + */ + MACHINE_LEARNING, +} + +data class ModifiedUMachineOptions( + val basicOptions: UMachineOptions = UMachineOptions(), + val pathSelectionStrategies: List = + listOf(ModifiedPathSelectionStrategy.MACHINE_LEARNING) +) diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/machine/ModifiedJcMachine.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/machine/ModifiedJcMachine.kt new file mode 100644 index 0000000000..74f0669374 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/machine/ModifiedJcMachine.kt @@ -0,0 +1,193 @@ +package org.usvm.machine + +import mu.KLogging +import org.jacodb.api.JcClasspath +import org.jacodb.api.JcMethod +import org.jacodb.api.cfg.JcInst +import org.jacodb.api.ext.methods +import org.usvm.* +import org.usvm.api.targets.JcTarget +import org.usvm.forkblacklists.TargetsReachableForkBlackList +import org.usvm.forkblacklists.UForkBlackList +import org.usvm.machine.interpreter.JcInterpreter +import org.usvm.machine.state.JcMethodResult +import org.usvm.machine.state.JcState +import org.usvm.machine.state.lastStmt +import org.usvm.ps.FeaturesLoggingPathSelector +import org.usvm.ps.modifiedCreatePathSelector +import org.usvm.statistics.* +import org.usvm.statistics.collectors.CoveredNewStatesCollector +import org.usvm.statistics.collectors.TargetsReachedStatesCollector +import org.usvm.statistics.distances.CfgStatistics +import org.usvm.statistics.distances.CfgStatisticsImpl +import org.usvm.statistics.distances.InterprocDistance +import org.usvm.statistics.distances.InterprocDistanceCalculator +import org.usvm.statistics.distances.MultiTargetDistanceCalculator +import org.usvm.statistics.distances.PlainCallGraphStatistics +import org.usvm.stopstrategies.createStopStrategy +import org.usvm.util.getMethodFullName +import org.usvm.util.originalInst + +val logger = object : KLogging() {}.logger + +class ModifiedJcMachine( + cp: JcClasspath, + private val options: ModifiedUMachineOptions, + private val interpreterObserver: JcInterpreterObserver? = null +) : UMachine() { + private val applicationGraph = JcApplicationGraph(cp) + + private val typeSystem = JcTypeSystem(cp) + private val components = JcComponents( + typeSystem, options.basicOptions.solverType, + options.basicOptions.useSolverForForks + ) + private val ctx = JcContext(cp, components) + + private val interpreter = JcInterpreter(ctx, applicationGraph, interpreterObserver) + + private val cfgStatistics = CfgStatisticsImpl(applicationGraph) + + fun analyze( + method: JcMethod, + targets: List = emptyList(), + coverageCounter: CoverageCounter? = null, + mlConfig: MLConfig? = null + ): List { + logger.debug("{}.analyze({}, {})", this, method, targets) + val initialState = interpreter.getInitialState(method, targets) + + val methodsToTrackCoverage = + when (options.basicOptions.coverageZone) { + CoverageZone.METHOD -> setOf(method) + CoverageZone.TRANSITIVE -> setOf(method) + // TODO: more adequate method filtering. !it.isConstructor is used to exclude default constructor which is often not covered + CoverageZone.CLASS -> method.enclosingClass.methods.filter { + it.enclosingClass == method.enclosingClass && !it.isConstructor + }.toSet() + } + + val coverageStatistics: CoverageStatistics = CoverageStatistics( + methodsToTrackCoverage, + applicationGraph + ) + + val callGraphStatistics = + when (options.basicOptions.targetSearchDepth) { + 0u -> PlainCallGraphStatistics() + else -> JcCallGraphStatistics( + options.basicOptions.targetSearchDepth, + applicationGraph, + typeSystem.topTypeStream(), + subclassesToTake = 10 + ) + } + + val transparentCfgStatistics = transparentCfgStatistics() + + val pathSelector = modifiedCreatePathSelector( + initialState, + options, + applicationGraph, + { coverageStatistics }, + { transparentCfgStatistics }, + { callGraphStatistics }, + { mlConfig }, + ) + + val statesCollector = + when (options.basicOptions.stateCollectionStrategy) { + StateCollectionStrategy.COVERED_NEW -> CoveredNewStatesCollector(coverageStatistics) { + it.methodResult is JcMethodResult.JcException + } + + StateCollectionStrategy.REACHED_TARGET -> TargetsReachedStatesCollector() + } + + val stopStrategy = createStopStrategy( + options.basicOptions, + targets, + coverageStatistics = { coverageStatistics }, + getCollectedStatesCount = { statesCollector.collectedStates.size } + ) + + val observers = mutableListOf>(coverageStatistics) + observers.add(TerminatedStateRemover()) + + if (interpreterObserver is UMachineObserver<*>) { + @Suppress("UNCHECKED_CAST") + observers.add(interpreterObserver as UMachineObserver) + } + + if (options.basicOptions.coverageZone == CoverageZone.TRANSITIVE) { + observers.add( + TransitiveCoverageZoneObserver( + initialMethod = method, + methodExtractor = { state -> state.lastStmt.location.method }, + addCoverageZone = { coverageStatistics.addCoverageZone(it) }, + ignoreMethod = { false } // TODO replace with a configurable setting + ) + ) + } + observers.add(statesCollector) + // TODO: use the same calculator which is used for path selector + if (targets.isNotEmpty()) { + val distanceCalculator = MultiTargetDistanceCalculator { stmt -> + InterprocDistanceCalculator( + targetLocation = stmt, + applicationGraph = applicationGraph, + cfgStatistics = cfgStatistics, + callGraphStatistics = callGraphStatistics + ) + } + interpreter.forkBlackList = TargetsReachableForkBlackList(distanceCalculator, shouldBlackList = { isInfinite }) + } else { + interpreter.forkBlackList = UForkBlackList.createDefault() + } + + val methodFullName = getMethodFullName(method) + if (coverageCounter != null) { + observers.add(CoverageCounterStatistics(coverageStatistics, coverageCounter, methodFullName)) + } + + run( + interpreter, + pathSelector, + observer = CompositeUMachineObserver(observers), + isStateTerminated = ::isStateTerminated, + stopStrategy = stopStrategy, + ) + + coverageCounter?.finishTest(methodFullName) + if (pathSelector is FeaturesLoggingPathSelector<*, *, *> && mlConfig != null) { + if (mlConfig.logFeatures && mlConfig.mode != Mode.Test) { + pathSelector.savePath() + } + } + + return statesCollector.collectedStates + } + + /** + * Returns a wrapper for the [cfgStatistics] that ignores [JcTransparentInstruction]s. + * Instead of calculating statistics for them, it just takes the statistics for + * their original instructions. + */ + private fun transparentCfgStatistics() = object : CfgStatistics { + override fun getShortestDistance(method: JcMethod, stmtFrom: JcInst, stmtTo: JcInst): UInt { + return cfgStatistics.getShortestDistance(method, stmtFrom.originalInst(), stmtTo.originalInst()) + } + + override fun getShortestDistanceToExit(method: JcMethod, stmtFrom: JcInst): UInt { + return cfgStatistics.getShortestDistanceToExit(method, stmtFrom.originalInst()) + } + } + + private fun isStateTerminated(state: JcState): Boolean { + return state.callStack.isEmpty() + } + + override fun close() { + components.close() + } +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/BlockGraph.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/BlockGraph.kt new file mode 100644 index 0000000000..5df7434bd0 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/BlockGraph.kt @@ -0,0 +1,196 @@ +package org.usvm.ps + +import io.github.rchowell.dotlin.digraph +import kotlinx.serialization.Serializable +import org.usvm.statistics.ApplicationGraph +import org.usvm.statistics.CoverageStatistics +import org.usvm.util.escape +import org.usvm.util.log +import java.nio.file.Path +import kotlin.io.path.writeText + +class BlockGraph( + private val applicationGraph: ApplicationGraph, + private val coverageStatistics: CoverageStatistics, + initialStatement: Statement, + private val forkCountsToExit: Map, + private val minForkCountsToExit: Map +) { + private val root: Block + private val successorsMap = mutableMapOf, List>().withDefault { listOf() } + private val predecessorsMap = mutableMapOf, List>().withDefault { listOf() } + private val coveredStatements = mutableMapOf>() + var currentBlockId = 0 + internal val blockList = mutableListOf>() + + init { + root = buildBlocks(initialStatement) + } + + private fun chooseNextStatement(statementQueue: ArrayDeque): Statement? { + var currentStatement = statementQueue.removeFirstOrNull() + while (currentStatement != null && coveredStatements.contains(currentStatement)) { + currentStatement = statementQueue.removeFirstOrNull() + } + return currentStatement + } + + private fun addSuccessor(block: Block, statement: Statement) { + successorsMap[block] = successorsMap.getValue(block) + statement + } + + private fun getPredecessors() { + val tmpPredecessorsMap = mutableMapOf, MutableList>() + .withDefault { mutableListOf() } + blockList.forEach { previousBlock -> + val lastStatement = previousBlock.path.last() + successors(previousBlock).forEach { nextBlock -> + tmpPredecessorsMap.getValue(nextBlock).add(lastStatement) + } + } + tmpPredecessorsMap.forEach { (block, predecessors) -> predecessorsMap[block] = predecessors } + } + + private fun buildBlocks(statement: Statement): Block { + var currentStatement = statement + val statementQueue = ArrayDeque() + val rootBlock = Block(this) + var currentBlock = rootBlock + while (true) { + if (coveredStatements.contains(currentStatement)) { + addSuccessor(currentBlock, currentStatement) + val nextStatement = chooseNextStatement(statementQueue) ?: break + currentStatement = nextStatement + currentBlock = Block(this) + continue + } + val predecessors = applicationGraph.predecessors(currentStatement).toList() + val successors = applicationGraph.successors(currentStatement).toList() + var newBlock = false + predecessors.forEach { previousStatement -> + val previousBlock = coveredStatements[previousStatement] + if (previousBlock == currentBlock) { + return@forEach + } + newBlock = true + } + if (newBlock && currentBlock.path.isNotEmpty()) { + addSuccessor(currentBlock, currentStatement) + currentBlock = Block(this) + } + coveredStatements[currentStatement] = currentBlock + currentBlock.path.add(currentStatement) + if (successors.size == 1) { + currentStatement = successors.first() + } else { + statementQueue.addAll(successors) + successors.forEach { + addSuccessor(currentBlock, it) + } + val nextStatement = chooseNextStatement(statementQueue) ?: break + currentStatement = nextStatement + currentBlock = Block(this) + } + } + getPredecessors() + return rootBlock + } + + private fun predecessors(block: Block): List> { + return predecessorsMap.getValue(block).map { coveredStatements[it]!! } + } + + private fun successors(block: Block): List> { + return successorsMap.getValue(block).map { coveredStatements[it]!! } + } + + fun getEdges(): Pair, List> { + return blockList.flatMap { block -> + predecessors(block).map { Pair(block.id, it.id) } + }.unzip() + } + + fun getBlock(statement: Statement): Block? { + return coveredStatements[statement] + } + + private fun getBlockFeatures(block: Block): BlockFeatures { + val firstStatement = block.path.first() + + val length = block.path.size + val predecessorsCount = predecessors(block).size + val successorsCount = successors(block).size + val totalCalleesCount = block.path.sumOf { applicationGraph.callees(it).count() } + val forkCountToExit = forkCountsToExit.getValue(firstStatement) + val minForkCountToExit = minForkCountsToExit.getValue(firstStatement) + val isCovered = firstStatement !in coverageStatistics.getUncoveredStatements() + + return BlockFeatures( + length.log(), + predecessorsCount.log(), + successorsCount.log(), + totalCalleesCount.log(), + forkCountToExit.log(), + minForkCountToExit.log(), + if (isCovered) 1.0f else 0.0f, + ) + } + + fun getGraphFeatures(): List { + return blockList.map { getBlockFeatures(it) } + } + + fun saveGraph(filePath: Path) { + val nodes = mutableListOf>() + val treeQueue = ArrayDeque>() + treeQueue.add(root) + val visitedBlocks = mutableSetOf>() + visitedBlocks.add(root) + while (treeQueue.isNotEmpty()) { + val currentNode = treeQueue.removeFirst() + nodes.add(currentNode) + val successors = successors(currentNode) + treeQueue.addAll(successors.filter { !visitedBlocks.contains(it) }) + visitedBlocks.addAll(successors) + } + val graph = digraph("BlockGraph") { + nodes.forEach { node -> + val nodeName = node.toString() + +nodeName + successors(node).forEach { child -> + val childName = child.toString() + nodeName - childName + } + } + } + filePath.parent.toFile().mkdirs() + filePath.writeText(graph.dot()) + } +} + +@Serializable +data class BlockFeatures( + val logLength: Float = 0.0f, + val logPredecessorsCount: Float = 0.0f, + val logSuccessorsCount: Float = 0.0f, + val logTotalCalleesCount: Float = 0.0f, + val logForkCountToExit: Float = 0.0f, + val logMinForkCountToExit: Float = 0.0f, + val isCovered: Float = 0.0f, +) + +data class Block( + val id: Int = 0, + var path: MutableList = mutableListOf() +) { + constructor(blockGraph: BlockGraph<*, Statement>) : this( + id = blockGraph.currentBlockId + ) { + blockGraph.currentBlockId += 1 + blockGraph.blockList.add(this) + } + + override fun toString(): String { + return "\"${id}: ${path.map { it.toString().escape() }}\"" + } +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/FeaturesData.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/FeaturesData.kt new file mode 100644 index 0000000000..e4edc06cff --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/FeaturesData.kt @@ -0,0 +1,116 @@ +package org.usvm.ps + +import kotlinx.serialization.Serializable + +@Serializable +internal data class StateFeatures( + val logPredecessorsCount: Float = 0.0f, + val logSuccessorsCount: Float = 0.0f, + val logCalleesCount: Float = 0.0f, + val logLogicalConstraintsLength: Float = 0.0f, + val logStateTreeDepth: Float = 0.0f, + val logStatementRepetitionLocal: Float = 0.0f, + val logStatementRepetitionGlobal: Float = 0.0f, + val logDistanceToUncovered: Float = 0.0f, + val logLastNewDistance: Float = 0.0f, + val logPathCoverage: Float = 0.0f, + val logDistanceToBlockEnd: Float = 0.0f, + val logDistanceToExit: Float = 0.0f, + val logForkCount: Float = 0.0f, + val logStatementFinishCount: Float = 0.0f, + val logForkCountToExit: Float = 0.0f, + val logMinForkCountToExit: Float = 0.0f, + val logSubpathCount2: Float = 0.0f, + val logSubpathCount4: Float = 0.0f, + val logSubpathCount8: Float = 0.0f, + val logReward: Float = 0.0f, +) + +@Serializable +internal data class GlobalStateFeatures( + val averageLogLogicalConstraintsLength: Float = 0.0f, + val averageLogStateTreeDepth: Float = 0.0f, + val averageLogStatementRepetitionLocal: Float = 0.0f, + val averageLogStatementRepetitionGlobal: Float = 0.0f, + val averageLogDistanceToUncovered: Float = 0.0f, + val averageLogLastNewDistance: Float = 0.0f, + val averageLogPathCoverage: Float = 0.0f, + val averageLogDistanceToBlockEnd: Float = 0.0f, + val averageLogSubpathCount2: Float = 0.0f, + val averageLogSubpathCount4: Float = 0.0f, + val averageLogSubpathCount8: Float = 0.0f, + val averageLogReward: Float = 0.0f, + val logFinishedStatesCount: Float = 0.0f, + val finishedStatesFraction: Float = 0.0f, + val visitedStatesFraction: Float = 0.0f, + val totalCoverage: Float = 0.0f, +) + +@Serializable +internal data class ActionData( + val queue: List, + val globalStateFeatures: GlobalStateFeatures, + val chosenStateId: Int, + var reward: Float, + val graphId: Int = 0, + val blockIds: List, + val extraFeatures: List, +) + +internal fun stateFeaturesToFloatList(stateFeatures: StateFeatures): List { + return listOf( + stateFeatures.logPredecessorsCount, + stateFeatures.logSuccessorsCount, + stateFeatures.logCalleesCount, + stateFeatures.logLogicalConstraintsLength, + stateFeatures.logStateTreeDepth, + stateFeatures.logStatementRepetitionLocal, + stateFeatures.logStatementRepetitionGlobal, + stateFeatures.logDistanceToUncovered, + stateFeatures.logLastNewDistance, + stateFeatures.logPathCoverage, + stateFeatures.logDistanceToBlockEnd, + stateFeatures.logDistanceToExit, + stateFeatures.logForkCount, + stateFeatures.logStatementFinishCount, + stateFeatures.logForkCountToExit, + stateFeatures.logMinForkCountToExit, + stateFeatures.logSubpathCount2, + stateFeatures.logSubpathCount4, + stateFeatures.logSubpathCount8, + stateFeatures.logReward, + ) +} + +internal fun globalStateFeaturesToFloatList(globalStateFeatures: GlobalStateFeatures): List { + return listOf( + globalStateFeatures.averageLogLogicalConstraintsLength, + globalStateFeatures.averageLogStateTreeDepth, + globalStateFeatures.averageLogStatementRepetitionLocal, + globalStateFeatures.averageLogStatementRepetitionGlobal, + globalStateFeatures.averageLogDistanceToUncovered, + globalStateFeatures.averageLogLastNewDistance, + globalStateFeatures.averageLogPathCoverage, + globalStateFeatures.averageLogDistanceToBlockEnd, + globalStateFeatures.averageLogSubpathCount2, + globalStateFeatures.averageLogSubpathCount4, + globalStateFeatures.averageLogSubpathCount8, + globalStateFeatures.averageLogReward, + globalStateFeatures.logFinishedStatesCount, + globalStateFeatures.finishedStatesFraction, + globalStateFeatures.visitedStatesFraction, + globalStateFeatures.totalCoverage, + ) +} + +internal fun blockFeaturesToList(blockFeatures: BlockFeatures): List { + return listOf( + blockFeatures.logLength, + blockFeatures.logPredecessorsCount, + blockFeatures.logSuccessorsCount, + blockFeatures.logTotalCalleesCount, + blockFeatures.logForkCountToExit, + blockFeatures.logMinForkCountToExit, + blockFeatures.isCovered, + ) +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/FeaturesLogger.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/FeaturesLogger.kt new file mode 100644 index 0000000000..4ca7eb2e40 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/FeaturesLogger.kt @@ -0,0 +1,154 @@ +package org.usvm.ps + +import io.github.rchowell.dotlin.digraph +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.* +import org.usvm.MLConfig +import org.usvm.PathsTrieNode +import org.usvm.RootNode +import org.usvm.UState +import org.usvm.util.escape +import org.usvm.util.getMethodFullName +import java.io.File +import kotlin.io.path.Path +import kotlin.io.path.writeText + +internal class FeaturesLogger, Statement, Method>( + method: Method, + blockGraph: BlockGraph<*, Statement>, + private val mlConfig: MLConfig +) { + private val filepath = Path(mlConfig.dataPath, "jsons").toString() + private val filename = getMethodFullName(method) + private val graphsPath = Path(mlConfig.gameEnvPath, "graphs").toString() + private val blockGraphsPath = Path(mlConfig.gameEnvPath, "block_graphs").toString() + private val jsonFormat = Json { encodeDefaults = true } + + init { + File(filepath).mkdirs() + blockGraph.saveGraph(Path(blockGraphsPath, filename, "graph.dot")) + } + + private fun getNodeName( + node: PathsTrieNode, + id: Int, + extraNodeInfo: (PathsTrieNode) -> String + ): String { + val statement = if (node is RootNode) { + "No Statement" + } else { + node.statement.toString() + } + var name = "\"$id: ${statement.escape()}" + name += extraNodeInfo(node) + name += "\"" + return name + } + + fun saveGraph( + pathsTreeRoot: PathsTrieNode, + step: Int, + extraNodeInfo: (PathsTrieNode) -> String + ) { + val nodes = mutableListOf>() + val treeQueue = ArrayDeque>() + treeQueue.add(pathsTreeRoot) + while (treeQueue.isNotEmpty()) { + val currentNode = treeQueue.removeFirst() + nodes.add(currentNode) + treeQueue.addAll(currentNode.children.values) + } + val nodeNames = nodes.zip(nodes.indices).associate { (node, id) -> + Pair(node, getNodeName(node, id, extraNodeInfo)) + }.withDefault { "" } + val graph = digraph("step$step") { + nodes.forEach { node -> + val nodeName = nodeNames.getValue(node) + +nodeName + node.children.values.forEach { child -> + val childName = nodeNames.getValue(child) + nodeName - childName + } + } + } + val path = Path(graphsPath, filename, "${graph.name}.dot") + path.parent.toFile().mkdirs() + path.writeText(graph.dot()) + } + + fun savePath( + path: List, + blockGraph: BlockGraph<*, Statement>, + probabilities: List>, + statementsCount: Int, + graphFeaturesList: List>, + getAllFeatures: (StateFeatures, ActionData, Int) -> List + ) { + if (path.isEmpty()) { + return + } + val jsonData = buildJsonObject { + putJsonArray("path") { + path.forEach { actionData -> + addJsonArray { + addJsonArray { + actionData.queue.zip(actionData.blockIds).forEach { (stateFeatures, blockId) -> + addJsonArray { + getAllFeatures(stateFeatures, actionData, blockId).forEach { + add(it) + } + } + } + } + add(actionData.chosenStateId) + add(actionData.reward) + if (mlConfig.logGraphFeatures) { + add(actionData.graphId) + addJsonArray { + actionData.blockIds.forEach { + add(it) + } + } + } + } + } + } + put("statementsCount", statementsCount) + if (mlConfig.logGraphFeatures) { + putJsonArray("graphFeatures") { + graphFeaturesList.forEach { graphFeatures -> + addJsonArray { + graphFeatures.forEach { nodeFeatures -> + addJsonArray { + jsonFormat.encodeToJsonElement(nodeFeatures).jsonObject.forEach { _, u -> + add(u) + } + } + } + } + } + } + putJsonArray("graphEdges") { + blockGraph.getEdges().toList().forEach { nodeList -> + addJsonArray { + nodeList.forEach { + add(it) + } + } + } + } + } + putJsonArray("probabilities") { + probabilities.forEach { queueProbabilities -> + addJsonArray { + queueProbabilities.forEach { probability -> + add(probability) + } + } + } + } + } + Path(filepath, "$filename.json").toFile() + .writeText(jsonFormat.encodeToString(jsonData)) + } +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/FeaturesLoggingPathSelector.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/FeaturesLoggingPathSelector.kt new file mode 100644 index 0000000000..ab4192ef42 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/FeaturesLoggingPathSelector.kt @@ -0,0 +1,355 @@ +package org.usvm.ps + +import org.usvm.* +import org.usvm.statistics.ApplicationGraph +import org.usvm.statistics.CoverageStatistics +import org.usvm.statistics.distances.CallStackDistanceCalculator +import org.usvm.statistics.distances.CfgStatistics +import org.usvm.util.LOG_BASE +import org.usvm.util.average +import org.usvm.util.getLast +import org.usvm.util.log + +open class FeaturesLoggingPathSelector, Statement, Method>( + private val pathsTreeRoot: PathsTrieNode, + private val coverageStatistics: CoverageStatistics, + cfgStatistics: CfgStatistics, + private val applicationGraph: ApplicationGraph, + private val mlConfig: MLConfig, + private val pathSelector: UPathSelector +) : UPathSelector { + protected val lru = mutableListOf() + + private val allStatements: List + private val visitedStatements = HashSet() + private var coveredStatementsCount = 0 + + internal val path = mutableListOf() + protected val probabilities = mutableListOf>() + + private val method: Method + + private var stepCount = 0 + + private val penalty = 0.0f + private var finishedStatesCount = 0u + private var allStatesCount = 0u + private val distanceCalculator = CallStackDistanceCalculator( + targets = coverageStatistics.getUncoveredStatements(), + cfgStatistics = cfgStatistics, + applicationGraph + ) + private val stateLastNewStatement = mutableMapOf() + private val statePathCoverage = mutableMapOf().withDefault { 0u } + private val stateForkCount = mutableMapOf().withDefault { 0u } + private val statementFinishCounts = mutableMapOf().withDefault { 0u } + private val distancesToExit: Map + private val forkCountsToExit: Map + private val minForkCountsToExit: Map + private val statementRepetitions = mutableMapOf().withDefault { 0u } + private val subpathCounts = mutableMapOf, UInt>().withDefault { 0u } + + protected val blockGraph: BlockGraph + protected val graphFeaturesList = mutableListOf>() + + private val featuresLogger: FeaturesLogger + + init { + coverageStatistics.addOnCoveredObserver { _, method, statement -> + distanceCalculator.removeTarget(method, statement) + } + allStatements = coverageStatistics.getUncoveredStatements().toList() + method = applicationGraph.methodOf(allStatements.first()) + val (tmpDistancesToExit, tmpForkCountsToExit) = getDistancesToExit() + distancesToExit = tmpDistancesToExit + forkCountsToExit = tmpForkCountsToExit + minForkCountsToExit = getMinForkCountsToExit() + blockGraph = BlockGraph( + applicationGraph, coverageStatistics, + applicationGraph.entryPoints(method).first(), forkCountsToExit, minForkCountsToExit + ) + graphFeaturesList.add(blockGraph.getGraphFeatures()) + featuresLogger = FeaturesLogger(method, blockGraph, mlConfig) + } + + private fun getDistancesToExit(): Array> { + val exits = applicationGraph.exitPoints(method) + val statementsQueue = ArrayDeque() + val distancesToExit = mutableMapOf().withDefault { 0u } + val forkCountsToExit = mutableMapOf().withDefault { 0u } + statementsQueue.addAll(exits) + while (statementsQueue.isNotEmpty()) { + val currentStatement = statementsQueue.removeFirst() + val distance = distancesToExit.getValue(currentStatement) + 1u + val lastForkCount = forkCountsToExit.getValue(currentStatement) + applicationGraph.predecessors(currentStatement).forEach { statement -> + if (distancesToExit.contains(statement)) { + return@forEach + } + distancesToExit[statement] = distance + val isFork = applicationGraph.successors(statement).count() > 1 + forkCountsToExit[statement] = lastForkCount + if (isFork) 1u else 0u + statementsQueue.add(currentStatement) + } + } + return arrayOf(distancesToExit, forkCountsToExit) + } + + private fun getMinForkCountsToExit(): Map { + val exits = applicationGraph.exitPoints(method) + val statementsQueue = ArrayDeque() + val forkCountsToExit = mutableMapOf().withDefault { 0u } + statementsQueue.addAll(exits) + while (statementsQueue.isNotEmpty()) { + val currentStatement = statementsQueue.removeFirst() + val lastForkCount = forkCountsToExit.getValue(currentStatement) + applicationGraph.predecessors(currentStatement).forEach { statement -> + val isFork = applicationGraph.successors(statement).count() > 1 + val newForkCount = lastForkCount + if (isFork) 1u else 0u + if (forkCountsToExit.contains(statement) || newForkCount > forkCountsToExit.getValue(statement)) { + return@forEach + } + forkCountsToExit[statement] = newForkCount + if (isFork) { + statementsQueue.add(currentStatement) + } else { + statementsQueue.addFirst(currentStatement) + } + } + } + return forkCountsToExit + } + + // Reward feature calculation, not actual reward + private fun getReward(state: State): Float { + val statement = state.currentStatement + if (statement === null || + (applicationGraph.successors(statement).toList().size + + applicationGraph.callees(statement).toList().size != 0) || + applicationGraph.methodOf(statement) != method || + state.callStack.size != 1 + ) { + return 0.0f + } + return coverageStatistics.getUncoveredStatements().toSet() + .intersect(state.reversedPath.asSequence().toSet()).size.toFloat() + } + + private fun getStateFeatures(state: State): StateFeatures { + val currentStatement = state.currentStatement!! + val currentBlock = blockGraph.getBlock(currentStatement) + val currentPath = state.reversedPath.asSequence().toList().reversed() + + val predecessorsCount = applicationGraph.predecessors(currentStatement).count() + val successorsCount = applicationGraph.successors(currentStatement).count() + val calleesCount = applicationGraph.callees(currentStatement).count() + val logicalConstraintsLength = state.pathConstraints.logicalConstraints.size + val stateTreeDepth = state.pathLocation.depth + val statementRepetitionLocal = currentPath.filter { statement -> + statement == currentStatement + }.size + val statementRepetitionGlobal = statementRepetitions.getValue(currentStatement) + val distanceToUncovered = distanceCalculator.calculateDistance(state.currentStatement, state.callStack) + val lastNewDistance = if (stateLastNewStatement.contains(state)) { + currentPath.size - stateLastNewStatement.getValue(state) + } else { + 1 / LOG_BASE - 1 // Equal to -1 after log + } + val pathCoverage = statePathCoverage.getValue(state) + val distanceToBlockEnd = (currentBlock?.path?.size ?: 1) - 1 - + (currentBlock?.path?.indexOf(currentStatement) ?: 0) + val distanceToExit = distancesToExit.getValue(currentStatement) + val forkCount = stateForkCount.getValue(state) + val statementFinishCount = statementFinishCounts.getValue(currentStatement) + val forkCountToExit = forkCountsToExit.getValue(currentStatement) + val minForkCountToExit = minForkCountsToExit.getValue(currentStatement) + val subpathCount2 = if (currentPath.size >= 2) subpathCounts.getValue(currentPath.getLast(2)) else 0u + val subpathCount4 = if (currentPath.size >= 4) subpathCounts.getValue(currentPath.getLast(4)) else 0u + val subpathCount8 = if (currentPath.size >= 8) subpathCounts.getValue(currentPath.getLast(8)) else 0u + + val reward = getReward(state) + + return StateFeatures( + predecessorsCount.log(), + successorsCount.log(), + calleesCount.log(), + logicalConstraintsLength.log(), + stateTreeDepth.log(), + statementRepetitionLocal.log(), + statementRepetitionGlobal.log(), + distanceToUncovered.log(), + lastNewDistance.log(), + pathCoverage.log(), + distanceToBlockEnd.log(), + distanceToExit.log(), + forkCount.log(), + statementFinishCount.log(), + forkCountToExit.log(), + minForkCountToExit.log(), + subpathCount2.log(), + subpathCount4.log(), + subpathCount8.log(), + reward.log(), + ) + } + + private fun getStateFeatureQueue(): List { + return lru.map { state -> + getStateFeatures(state) + } + } + + private fun getGlobalStateFeatures(stateFeatureQueue: List): GlobalStateFeatures { + val uncoveredStatements = coverageStatistics.getUncoveredStatements().toSet() + + val logFinishedStatesCount = finishedStatesCount.log() + val finishedStatesFraction = finishedStatesCount.toFloat() / allStatesCount.toFloat() + val totalCoverage = coverageStatistics.getTotalCoverage() / 100 + val visitedStatesFraction = visitedStatements.intersect(uncoveredStatements).size.toFloat() / allStatements.size + + return GlobalStateFeatures( + stateFeatureQueue.map { it.logLogicalConstraintsLength }.average(), + stateFeatureQueue.map { it.logStateTreeDepth }.average(), + stateFeatureQueue.map { it.logStatementRepetitionLocal }.average(), + stateFeatureQueue.map { it.logStatementRepetitionGlobal }.average(), + stateFeatureQueue.map { it.logDistanceToUncovered }.average(), + stateFeatureQueue.map { it.logLastNewDistance }.average(), + stateFeatureQueue.map { it.logPathCoverage }.average(), + stateFeatureQueue.map { it.logDistanceToBlockEnd }.average(), + stateFeatureQueue.map { it.logReward }.average(), + stateFeatureQueue.map { it.logSubpathCount2 }.average(), + stateFeatureQueue.map { it.logSubpathCount4 }.average(), + stateFeatureQueue.map { it.logSubpathCount8 }.average(), + logFinishedStatesCount, + finishedStatesFraction, + visitedStatesFraction, + totalCoverage, + ) + } + + protected open fun getExtraFeatures(): List { + return listOf() + } + + private fun getActionData( + stateFeatureQueue: List, + globalStateFeatures: GlobalStateFeatures, + chosenState: State + ): ActionData { + val stateId = lru.indexOfFirst { it.id == chosenState.id } + return ActionData( + stateFeatureQueue, + globalStateFeatures, + stateId, + 0.0f, + graphFeaturesList.lastIndex, + lru.map { it.currentStatement!! }.map { blockGraph.getBlock(it)?.id ?: -1 }, + getExtraFeatures() + ) + } + + internal open fun getAllFeatures(stateFeatures: StateFeatures, actionData: ActionData, blockId: Int): List { + return stateFeaturesToFloatList(stateFeatures) + globalStateFeaturesToFloatList(actionData.globalStateFeatures) + } + + private fun updateCoverage(state: State) { + val statePath = state.reversedPath.asSequence().toList().reversed() + + arrayOf(2, 4, 8).forEach { length -> + if (statePath.size < length) { + return@forEach + } + val subpath = statePath.getLast(length) + subpathCounts[subpath] = subpathCounts.getValue(subpath) + 1u + } + + val statement = state.currentStatement!! + statementRepetitions[statement] = statementRepetitions.getValue(statement) + 1u + visitedStatements.add(statement) + + if (applicationGraph.successors(statement).count() > 1) { + stateForkCount[state] = stateForkCount.getValue(state) + 1u + } + + if (coverageStatistics.getUncoveredStatements().contains(statement)) { + stateLastNewStatement[state] = statePath.size + statePathCoverage[state] = statePathCoverage.getValue(state) + 1u + } + } + + protected open fun getExtraNodeInfo(node: PathsTrieNode) = + node.states.joinToString(separator = "") { state -> ", ${state.id}" } + + private fun saveGraph() { + featuresLogger.saveGraph(pathsTreeRoot, stepCount) { node -> + getExtraNodeInfo(node) + } + stepCount += 1 + } + + fun savePath() { + featuresLogger.savePath(path, blockGraph, probabilities, allStatements.size, graphFeaturesList) + { stateFeatures, actionData, blockId -> + getAllFeatures(stateFeatures, actionData, blockId) + } + } + + internal fun beforePeek(): Pair, GlobalStateFeatures> { + if (mlConfig.graphUpdate == GraphUpdate.TestGeneration && (path.lastOrNull()?.reward ?: 0.0f) > 0.5f) { + graphFeaturesList.add(blockGraph.getGraphFeatures()) + } + val stateFeatureQueue = getStateFeatureQueue() + return Pair(stateFeatureQueue, getGlobalStateFeatures(stateFeatureQueue)) + } + + internal fun afterPeek( + state: State, + stateFeatureQueue: List, + globalStateFeatures: GlobalStateFeatures + ) { + val actionData = getActionData(stateFeatureQueue, globalStateFeatures, state) + path.add(actionData) + updateCoverage(state) + if (stepCount < 100) { + saveGraph() + } + lru.remove(state) + lru.add(state) + } + + override fun isEmpty(): Boolean { + pathSelector.isEmpty() + return lru.isEmpty() + } + + override fun peek(): State { + val (stateFeatureQueue, globalStateFeatures) = beforePeek() + val state = pathSelector.peek() + afterPeek(state, stateFeatureQueue, globalStateFeatures) + return state + } + + override fun update(state: State) { + pathSelector.update(state) + } + + override fun add(states: Collection) { + pathSelector.add(states) + lru.addAll(states) + allStatesCount += states.size.toUInt() + } + + override fun remove(state: State) { + pathSelector.remove(state) + lru.remove(state) + finishedStatesCount += 1u + state.reversedPath.asSequence().toSet().forEach { statement -> + statementFinishCounts[statement] = statementFinishCounts.getValue(statement) + 1u + } + + // Actual reward calculation, change it in accordance to metrics + val newCoveredStatementsCount = (allStatements.size - coverageStatistics.getUncoveredStatements().size) + path.last().reward = (newCoveredStatementsCount - coveredStatementsCount).toFloat() + coveredStatementsCount = newCoveredStatementsCount + } +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/MachineLearningPathSelector.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/MachineLearningPathSelector.kt new file mode 100644 index 0000000000..6b3f6d65b5 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/MachineLearningPathSelector.kt @@ -0,0 +1,238 @@ +package org.usvm.ps + +import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.OrtEnvironment +import ai.onnxruntime.OrtSession +import org.usvm.* +import org.usvm.statistics.ApplicationGraph +import org.usvm.statistics.CoverageStatistics +import org.usvm.statistics.distances.CfgStatistics +import org.usvm.util.prod +import java.io.File +import java.nio.FloatBuffer +import java.nio.LongBuffer +import java.text.DecimalFormat +import kotlin.io.path.Path +import kotlin.math.exp +import kotlin.random.Random + +open class MachineLearningPathSelector, Statement, Method>( + pathsTreeRoot: PathsTrieNode, + coverageStatistics: CoverageStatistics, + cfgStatistics: CfgStatistics, + applicationGraph: ApplicationGraph, + private val mlConfig: MLConfig, + private val defaultPathSelector: UPathSelector +) : FeaturesLoggingPathSelector( + pathsTreeRoot, + coverageStatistics, + cfgStatistics, + applicationGraph, + mlConfig, + defaultPathSelector, +) { + private var outputValues = listOf() + private val random = Random(System.nanoTime()) + private val gnnFeaturesList = mutableListOf>>() + private var lastStateFeatures = List(mlConfig.rnnStateShape.prod().toInt()) { 0.0f } + private var rnnFeatures = if (mlConfig.useRnn) List(mlConfig.rnnFeaturesCount) { 0.0f } else emptyList() + + private val env: OrtEnvironment = OrtEnvironment.getEnvironment() + private val actorModelPath = Path(mlConfig.gameEnvPath, "actor_model.onnx").toString() + private val gnnModelPath = Path(mlConfig.gameEnvPath, "gnn_model.onnx").toString() + private val rnnModelPath = Path(mlConfig.gameEnvPath, "rnn_cell.onnx").toString() + private var actorSession: OrtSession? = if (File(actorModelPath).isFile) + env.createSession(actorModelPath) else null + private var gnnSession: OrtSession? = if (mlConfig.useGnn) + env.createSession(gnnModelPath) else null + private var rnnSession: OrtSession? = if (mlConfig.useRnn) + env.createSession(rnnModelPath) else null + + override fun getExtraNodeInfo(node: PathsTrieNode) = + node.states.joinToString(separator = "") { state -> + ", ${DecimalFormat("0.00E0").format(outputValues.getOrElse(lru.indexOf(state)) { -1.0f })}" + } + + private fun chooseRandomId(probabilities: Collection): Int { + val randomNumber = random.nextFloat() + var probability = 0.0f + probabilities.withIndex().forEach { + probability += it.value + if (randomNumber < probability) { + return it.index + } + } + return probabilities.size - 1 + } + + private fun runGnn(): List> { + if (gnnFeaturesList.size == graphFeaturesList.size) { + return gnnFeaturesList.last() + } + if (gnnSession === null) { + gnnSession = env.createSession(gnnModelPath, OrtSession.SessionOptions()) + } + val graphFeatures = graphFeaturesList.last().map { blockFeaturesToList(it) } + val graphEdges = blockGraph.getEdges().toList() + val featuresShape = listOf(graphFeatures.size, graphFeatures.first().size) + val edgesShape = listOf(2, graphEdges.first().size) + val featuresDataBuffer = FloatBuffer.allocate(featuresShape.prod()) + graphFeatures.forEach { blockFeatures -> + blockFeatures.forEach { feature -> + featuresDataBuffer.put(feature) + } + } + featuresDataBuffer.rewind() + val edgesDataBuffer = LongBuffer.allocate(edgesShape.prod()) + graphEdges.forEach { nodes -> + nodes.forEach { node -> + edgesDataBuffer.put(node.toLong()) + } + } + edgesDataBuffer.rewind() + val featuresData = OnnxTensor.createTensor( + env, featuresDataBuffer, + featuresShape.map { it.toLong() }.toLongArray() + ) + val edgesData = OnnxTensor.createTensor( + env, edgesDataBuffer, + edgesShape.map { it.toLong() }.toLongArray() + ) + val result = gnnSession!!.run(mapOf(Pair("x", featuresData), Pair("edge_index", edgesData))) + val output = (result.get("output").get().value as Array<*>).map { + (it as FloatArray).toList() + } + gnnFeaturesList.add(output) + return output + } + + private fun runRnn(): List { + if (path.size == 0) { + return listOf() + } + if (rnnSession === null) { + rnnSession = env.createSession(rnnModelPath, OrtSession.SessionOptions()) + } + val lastActionData = path.last() + val lastChosenAction = lastActionData.chosenStateId + val gnnFeaturesCount = gnnFeaturesList.firstOrNull()?.first()?.size ?: 0 + val gnnFeatures = gnnFeaturesList.getOrNull(lastActionData.graphId)?.getOrNull( + lastActionData.blockIds[lastChosenAction] + ) + ?: List(gnnFeaturesCount) { 0.0f } + val lastActionFeatures = super.getAllFeatures( + lastActionData.queue[lastChosenAction], lastActionData, + lastActionData.blockIds[lastChosenAction] + ) + gnnFeatures + val lastActionShape = listOf(1, lastActionFeatures.size.toLong()) + val lastStateShape = mlConfig.rnnStateShape + val actionFeaturesDataBuffer = FloatBuffer.allocate(lastActionShape.prod().toInt()) + val stateFeaturesDataBuffer = FloatBuffer.allocate(lastStateShape.prod().toInt()) + lastActionFeatures.forEach { + actionFeaturesDataBuffer.put(it) + } + actionFeaturesDataBuffer.rewind() + lastStateFeatures.forEach { + stateFeaturesDataBuffer.put(it) + } + stateFeaturesDataBuffer.rewind() + val actionFeaturesData = OnnxTensor.createTensor(env, actionFeaturesDataBuffer, lastActionShape.toLongArray()) + val stateFeaturesData = OnnxTensor.createTensor(env, stateFeaturesDataBuffer, lastStateShape.toLongArray()) + val result = rnnSession!!.run(mapOf(Pair("input", actionFeaturesData), Pair("state_in", stateFeaturesData))) + lastStateFeatures = (result.get("state_out").get().value as Array<*>).flatMap { + ((it as Array<*>)[0] as FloatArray).toList() + } + rnnFeatures = (result.get("rnn_features").get().value as Array<*>).flatMap { + (it as FloatArray).toList() + } + return rnnFeatures + } + + private fun runActor(allFeaturesListFull: List>): Int { + val firstIndex = if (mlConfig.maxAttentionLength == -1) 0 else + maxOf(0, lru.size - mlConfig.maxAttentionLength) + val allFeaturesList = allFeaturesListFull.subList(firstIndex, lru.size) + val totalSize = allFeaturesList.size * allFeaturesList.first().size + val totalKnownSize = mlConfig.inputShape.prod() + val shape = mlConfig.inputShape.map { if (it != -1L) it else -totalSize / totalKnownSize }.toLongArray() + val dataBuffer = FloatBuffer.allocate(totalSize) + allFeaturesList.forEach { stateFeatures -> + stateFeatures.forEach { feature -> + dataBuffer.put(feature) + } + } + dataBuffer.rewind() + val data = OnnxTensor.createTensor(env, dataBuffer, shape) + val result = actorSession!!.run(mapOf(Pair("input", data))) + val output = (result.get("output").get().value as Array<*>).flatMap { (it as FloatArray).toList() } + outputValues = List(firstIndex) { -1.0f } + output + return firstIndex + when (mlConfig.postprocessing) { + Postprocessing.Argmax -> { + output.indices.maxBy { output[it] } + } + Postprocessing.Softmax -> { + val exponents = output.map { exp(it) } + val exponentsSum = exponents.sum() + val softmaxProbabilities = exponents.map { it / exponentsSum } + probabilities.add(softmaxProbabilities) + chooseRandomId(softmaxProbabilities) + } + else -> { + probabilities.add(output) + chooseRandomId(output) + } + } + } + + private fun peekWithOnnxRuntime( + stateFeatureQueue: List?, + globalStateFeatures: GlobalStateFeatures? + ): State { + if (stateFeatureQueue == null || globalStateFeatures == null) { + throw IllegalArgumentException("No features") + } + if (lru.size == 1) { + if (mlConfig.postprocessing != Postprocessing.Argmax) { + probabilities.add(listOf(1.0f)) + } + return lru[0] + } + val graphFeatures = gnnFeaturesList.lastOrNull() ?: listOf() + val blockFeaturesCount = graphFeatures.firstOrNull()?.size ?: 0 + val allFeaturesListFull = stateFeatureQueue.zip(lru).map { (stateFeatures, state) -> + stateFeaturesToFloatList(stateFeatures) + globalStateFeaturesToFloatList(globalStateFeatures) + + (blockGraph.getBlock(state.currentStatement!!)?.id?.let { graphFeatures.getOrNull(it) } + ?: List(blockFeaturesCount) { 0.0f }) + + rnnFeatures + } + return lru[runActor(allFeaturesListFull)] + } + + override fun getExtraFeatures(): List { + return rnnFeatures + } + + override fun getAllFeatures(stateFeatures: StateFeatures, actionData: ActionData, blockId: Int): List { + val gnnFeaturesCount = gnnFeaturesList.firstOrNull()?.first()?.size ?: 0 + val gnnFeatures = gnnFeaturesList.getOrNull(actionData.graphId)?.getOrNull(blockId) + ?: List(gnnFeaturesCount) { 0.0f } + return super.getAllFeatures(stateFeatures, actionData, blockId) + gnnFeatures + actionData.extraFeatures + } + + override fun peek(): State { + val (stateFeatureQueue, globalStateFeatures) = beforePeek() + if (mlConfig.useRnn) { + runRnn() + } + if (mlConfig.useGnn) { + runGnn() + } + val state = if (actorSession !== null) { + peekWithOnnxRuntime(stateFeatureQueue, globalStateFeatures) + } else { + defaultPathSelector.peek() + } + afterPeek(state, stateFeatureQueue, globalStateFeatures) + return state + } +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/ModifiedPathSelectorFactory.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/ModifiedPathSelectorFactory.kt new file mode 100644 index 0000000000..5708eae1c7 --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/ps/ModifiedPathSelectorFactory.kt @@ -0,0 +1,119 @@ +package org.usvm.ps + +import org.usvm.* +import org.usvm.ModifiedPathSelectionStrategy +import org.usvm.ModifiedUMachineOptions +import org.usvm.algorithms.DeterministicPriorityCollection +import org.usvm.algorithms.RandomizedPriorityCollection +import org.usvm.statistics.ApplicationGraph +import org.usvm.statistics.CoverageStatistics +import org.usvm.statistics.distances.* +import org.usvm.targets.UTarget +import kotlin.math.max +import kotlin.random.Random + +fun modifiedCreatePathSelector( + initialState: State, + options: ModifiedUMachineOptions, + applicationGraph: ApplicationGraph, + coverageStatistics: () -> CoverageStatistics? = { null }, + cfgStatistics: () -> CfgStatistics? = { null }, + @Suppress("UNUSED_PARAMETER") callGraphStatistics: () -> CallGraphStatistics? = { null }, + mlConfig: () -> MLConfig? = { null } +) : UPathSelector + where Target : UTarget, + State : UState<*, Method, Statement, *, Target, State> { + val strategies = options.pathSelectionStrategies + require(strategies.isNotEmpty()) { "At least one path selector strategy should be specified" } + + val random by lazy { Random(options.basicOptions.randomSeed) } + + val selectors = strategies.map { strategy -> + when (strategy) { + ModifiedPathSelectionStrategy.FEATURES_LOGGING -> FeaturesLoggingPathSelector( + requireNotNull(initialState.pathLocation.parent) { "Paths tree root is required for Features Logging path selector" }, + requireNotNull(coverageStatistics()) { "Coverage statistics is required for Features Logging path selector" }, + requireNotNull(cfgStatistics()) { "CFG statistics is required for Features Logging path selector" }, + applicationGraph, + requireNotNull(mlConfig()) { "ML config is required for Features Logging path selector" }, + when(requireNotNull(mlConfig()).defaultAlgorithm) { + Algorithm.BFS -> BfsPathSelector() + Algorithm.ForkDepthRandom -> createForkDepthPathSelector(random) + }, + ) + + ModifiedPathSelectionStrategy.MACHINE_LEARNING -> MachineLearningPathSelector( + requireNotNull(initialState.pathLocation.parent) { "Paths tree root is required for Machine Learning path selector" }, + requireNotNull(coverageStatistics()) { "Coverage statistics is required for Machine Learning path selector" }, + requireNotNull(cfgStatistics()) { "Distance statistics is required for Machine Learning path selector" }, + applicationGraph, + requireNotNull(mlConfig()) { "ML config is required for Machine Learning path selector" }, + when(requireNotNull(mlConfig()).defaultAlgorithm) { + Algorithm.BFS -> BfsPathSelector() + Algorithm.ForkDepthRandom -> createForkDepthPathSelector(random) + }, + ) + } + } + + val propagateExceptions = options.basicOptions.exceptionsPropagation + + selectors.singleOrNull()?.let { selector -> + val resultSelector = selector.wrapIfRequired(propagateExceptions) + resultSelector.add(listOf(initialState)) + return resultSelector + } + + require(selectors.size >= 2) { "Cannot create collaborative path selector from less than 2 selectors" } + + val selector = when (options.basicOptions.pathSelectorCombinationStrategy) { + PathSelectorCombinationStrategy.INTERLEAVED -> { + // Since all selectors here work as one, we can wrap an interleaved selector only. + val interleavedPathSelector = InterleavedPathSelector(selectors).wrapIfRequired(propagateExceptions) + interleavedPathSelector.add(listOf(initialState)) + interleavedPathSelector + } + + PathSelectorCombinationStrategy.PARALLEL -> { + // Here we should wrap all selectors independently since they work in parallel. + val wrappedSelectors = selectors.map { it.wrapIfRequired(propagateExceptions) } + + wrappedSelectors.first().add(listOf(initialState)) + wrappedSelectors.drop(1).forEach { + it.add(listOf(initialState.clone())) + } + + ParallelPathSelector(wrappedSelectors) + } + } + + return selector +} + +/** + * Wraps the selector into an [ExceptionPropagationPathSelector] if [propagateExceptions] is true. + */ +private fun > UPathSelector.wrapIfRequired(propagateExceptions: Boolean) = + if (propagateExceptions && this !is ExceptionPropagationPathSelector) { + ExceptionPropagationPathSelector(this) + } else { + this + } + +private fun > compareById(): Comparator = compareBy { it.id } + +private fun > createForkDepthPathSelector( + random: Random? = null, +): UPathSelector { + if (random == null) { + return WeightedPathSelector( + priorityCollectionFactory = { DeterministicPriorityCollection(Comparator.naturalOrder()) }, + weighter = { it.pathLocation.depth } + ) + } + + return WeightedPathSelector( + priorityCollectionFactory = { RandomizedPriorityCollection(compareById()) { random.nextDouble() } }, + weighter = { 1.0 / max(it.pathLocation.depth.toDouble(), 1.0) } + ) +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/statistics/CoverageCounterStatistics.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/statistics/CoverageCounterStatistics.kt new file mode 100644 index 0000000000..6c62ce8b1b --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/statistics/CoverageCounterStatistics.kt @@ -0,0 +1,28 @@ +package org.usvm.statistics + +import org.usvm.CoverageCounter +import org.usvm.UState + +class CoverageCounterStatistics>( + private val coverageStatistics: CoverageStatistics<*, *, State>, + private val coverageCounter: CoverageCounter, + private val methodFullName: String +) : UMachineObserver { + private val totalStatementsCount = coverageStatistics.getUncoveredStatements().size + private var totalCoverage = 0 + + init { + coverageCounter.addTest(methodFullName, totalStatementsCount.toFloat()) + } + + override fun onState(parent: State, forks: Sequence) { + coverageCounter.updateDiscounts(methodFullName) + } + + override fun onStateTerminated(state: State, stateReachable: Boolean) { + if (!stateReachable) return + val newTotalCoverage = totalStatementsCount - coverageStatistics.getUncoveredStatements().size + coverageCounter.updateResults(methodFullName, (newTotalCoverage - totalCoverage).toFloat()) + totalCoverage = newTotalCoverage + } +} diff --git a/usvm-ml-path-selection/src/main/kotlin/org/usvm/util/Utils.kt b/usvm-ml-path-selection/src/main/kotlin/org/usvm/util/Utils.kt new file mode 100644 index 0000000000..5c1bb0623a --- /dev/null +++ b/usvm-ml-path-selection/src/main/kotlin/org/usvm/util/Utils.kt @@ -0,0 +1,86 @@ +package org.usvm.util + +import org.jacodb.api.* +import org.jacodb.api.cfg.JcInst +import org.jacodb.api.ext.findFieldOrNull +import org.jacodb.api.ext.toType +import org.usvm.UConcreteHeapRef +import org.usvm.UExpr +import org.usvm.USort +import org.usvm.machine.JcContext +import org.usvm.machine.JcTransparentInstruction +import org.usvm.memory.ULValue +import org.usvm.memory.UWritableMemory +import org.usvm.uctx +import kotlin.reflect.KClass + +const val LOG_BASE = 1.42 + +fun Collection.prod(): Long { + return this.reduce { acc, l -> acc * l } +} + +fun Collection.prod(): Int { + return this.reduce { acc, l -> acc * l } +} + +fun Collection.average(): Float { + return this.sumOf { it.toDouble() }.toFloat() / this.size +} + +fun Number.log(): Float { + return kotlin.math.log(this.toDouble() + 1, LOG_BASE).toFloat() +} + +fun UInt.log(): Float { + return this.toDouble().log() +} + +fun List.getLast(count: Int): List { + return this.subList(this.size - count, this.size) +} + +fun String.escape(): String { + val result = StringBuilder(this.length) + this.forEach { ch -> + result.append( + when (ch) { + '\n' -> "\\n" + '\t' -> "\\t" + '\b' -> "\\b" + '\r' -> "\\r" + '\"' -> "\\\"" + '\'' -> "\\\'" + '\\' -> "\\\\" + '$' -> "\\$" + else -> ch + } + ) + } + return result.toString() +} + +fun JcContext.extractJcType(clazz: KClass<*>): JcType = cp.findTypeOrNull(clazz.qualifiedName!!)!! + +fun JcContext.extractJcRefType(clazz: KClass<*>): JcRefType = extractJcType(clazz) as JcRefType + +val JcClassOrInterface.enumValuesField: JcTypedField + get() = toType().findFieldOrNull("\$VALUES") ?: error("No \$VALUES field found for the enum type $this") + +@Suppress("UNCHECKED_CAST") +fun UWritableMemory<*>.write(ref: ULValue<*, *>, value: UExpr<*>) { + write(ref as ULValue<*, USort>, value as UExpr, value.uctx.trueExpr) +} + +internal fun UWritableMemory.allocHeapRef(type: JcType, useStaticAddress: Boolean): UConcreteHeapRef = + if (useStaticAddress) allocStatic(type) else allocConcrete(type) + +tailrec fun JcInst.originalInst(): JcInst = if (this is JcTransparentInstruction) originalInst.originalInst() else this + +fun getMethodFullName(method: Any?): String { + return if (method is JcMethod) { + "${method.enclosingClass.name}#${method.name}(${method.parameters.joinToString { it.type.typeName }})" + } else { + method.toString() + } +} diff --git a/usvm-ml-path-selection/src/test/kotlin/org/usvm/jarRunner.kt b/usvm-ml-path-selection/src/test/kotlin/org/usvm/jarRunner.kt new file mode 100644 index 0000000000..477e2be03a --- /dev/null +++ b/usvm-ml-path-selection/src/test/kotlin/org/usvm/jarRunner.kt @@ -0,0 +1,430 @@ +package org.usvm + +import kotlinx.coroutines.* +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.json.* +import org.jacodb.api.JcClassOrInterface +import org.jacodb.api.JcMethod +import org.jacodb.api.ext.packageName +import org.usvm.machine.ModifiedJcMachine +import org.usvm.ps.GlobalStateFeatures +import org.usvm.ps.StateFeatures +import org.usvm.samples.JacoDBContainer +import org.usvm.util.getMethodFullName +import java.io.File +import kotlin.io.path.Path +import kotlin.io.path.nameWithoutExtension +import kotlin.system.measureTimeMillis + +fun jarLoad(jars: Set, classes: MutableMap>) { + jars.forEach { filePath -> + val file = Path(filePath).toFile() + val container = JacoDBContainer(key = filePath, classpath = listOf(file)) + val classNames = container.db.locations.flatMap { it.jcLocation?.classNames ?: listOf() } + classes[filePath] = mutableListOf() + classNames.forEach { className -> + container.cp.findClassOrNull(className)?.let { + classes[filePath]?.add(it) + } + } + } +} + +private class MainTestRunner( + private val config: MLConfig, + pathSelectionStrategies: List = + listOf(ModifiedPathSelectionStrategy.MACHINE_LEARNING), + timeoutMs: Long? = 20000 +) { + val coverageCounter = CoverageCounter(config) + + val options = ModifiedUMachineOptions().copy( + UMachineOptions().copy( + exceptionsPropagation = false, + timeoutMs = timeoutMs, + stepLimit = 1500u, + solverType = SolverType.Z3 + ), + pathSelectionStrategies + ) + + fun runTest(method: JcMethod, jarKey: String) { + ModifiedJcMachine(JacoDBContainer(jarKey).cp, options).use { jcMachine -> + jcMachine.analyze(method, emptyList(), coverageCounter, config) + } + } +} + +private val prettyJson = Json { prettyPrint = true } + +@OptIn(ExperimentalCoroutinesApi::class, ExperimentalSerializationApi::class) +fun calculate(config: MLConfig) { + val pathSelectorSets = if (config.mode == Mode.Test) + listOf( + listOf(ModifiedPathSelectionStrategy.MACHINE_LEARNING), + listOf(ModifiedPathSelectionStrategy.FEATURES_LOGGING), + ) + else listOf(listOf(ModifiedPathSelectionStrategy.MACHINE_LEARNING)) + val timeLimits = if (config.mode == Mode.Test) + listOf( + 1000, + 5000, + 20000, + ) + else listOf(20000) + + val jarClasses = mutableMapOf>() + jarLoad(config.inputJars.keys, jarClasses) + println("\nLOADING COMPLETE\n") + + val blacklist = Path(config.gameEnvPath, "blacklist.txt").toFile().let { + it.createNewFile() + it.readLines() + } + + val tests = mutableListOf() + var finishedTestsCount = 0 + + jarClasses.forEach { (key, classesList) -> + println("RUNNING TESTS FOR $key") + val allMethods = classesList.filter { cls -> + !cls.isAnnotation && !cls.isInterface && + config.inputJars.getValue(key).any { cls.packageName.contains(it) } && + !cls.name.contains("Test") + }.flatMap { cls -> + cls.declaredMethods.filter { method -> + method.enclosingClass == cls && getMethodFullName(method) !in blacklist && !method.isConstructor + } + }.sortedBy { getMethodFullName(it).hashCode() }.distinctBy { getMethodFullName(it) } + val orderedMethods = if (config.shuffleTests) allMethods.shuffled() else allMethods + + timeLimits.forEach { timeLimit -> + println(" RUNNING TESTS WITH ${timeLimit}ms TIME LIMIT") + pathSelectorSets.forEach { pathSelectors -> + println(" RUNNING TESTS WITH ${pathSelectors.joinToString("|")} PATH SELECTOR") + val statisticsFile = Path( + config.dataPath, + "statistics", + Path(key).nameWithoutExtension, + "${timeLimit}ms", + "${pathSelectors.joinToString(separator = "|") { it.toString() }}.json" + ).toFile() + statisticsFile.parentFile.mkdirs() + statisticsFile.createNewFile() + statisticsFile.writeText("") + + val testRunner = MainTestRunner(config, pathSelectors, timeLimit) + runBlocking(Dispatchers.IO.limitedParallelism(config.maxConcurrency)) { + orderedMethods.take((orderedMethods.size * config.dataConsumption / 100).toInt()) + .forEach { method -> + val test = launch { + try { + println(" Running test ${method.name}") + val time = measureTimeMillis { + testRunner.runTest(method, key) + } + println(" Test ${method.name} finished after ${time}ms") + finishedTestsCount += 1 + } catch (e: Exception) { + println(" $e") + } catch (e: NotImplementedError) { + println(" $e") + } + } + tests.add(test) + } + tests.joinAll() + } + + prettyJson.encodeToStream(testRunner.coverageCounter.getStatistics(), statisticsFile.outputStream()) + testRunner.coverageCounter.reset() + } + } + } + + println("\nALL $finishedTestsCount TESTS FINISHED\n") +} + +fun getJsonSchemes(config: MLConfig): Pair { + val jsonFormat = Json { + encodeDefaults = true + } + val jsonStateScheme: JsonArray = buildJsonArray { + addJsonArray { + jsonFormat.encodeToJsonElement(StateFeatures()).jsonObject.forEach { t, _ -> + add(t) + } + jsonFormat.encodeToJsonElement(GlobalStateFeatures()).jsonObject.forEach { t, _ -> + add(t) + } + if (config.useGnn) { + (0 until config.gnnFeaturesCount).forEach { + add("gnnFeature$it") + } + } + if (config.useRnn) { + (0 until config.rnnFeaturesCount).forEach { + add("rnnFeature$it") + } + } + } + add("chosenStateId") + add("reward") + if (config.logGraphFeatures) { + add("graphId") + add("blockIds") + } + } + val jsonTrajectoryScheme = buildJsonArray { + add("hash") + add("trajectory") + add("name") + add("statementsCount") + if (config.logGraphFeatures) { + add("graphFeatures") + add("graphEdges") + } + add("probabilities") + } + return Pair(jsonStateScheme, jsonTrajectoryScheme) +} + +@OptIn(ExperimentalSerializationApi::class) +fun aggregate(config: MLConfig) { + val resultDirname = config.dataPath + val resultFilename = "current_dataset.json" + val schemesFilename = "schemes.json" + val jsons = mutableListOf() + + val (jsonStateScheme, jsonTrajectoryScheme) = getJsonSchemes(config) + val schemesJson = buildJsonObject { + put("stateScheme", jsonStateScheme) + put("trajectoryScheme", jsonTrajectoryScheme) + } + val schemesFile = Path(resultDirname, schemesFilename).toFile() + schemesFile.parentFile.mkdirs() + prettyJson.encodeToStream(schemesJson, schemesFile.outputStream()) + + Path(resultDirname, "jsons").toFile().listFiles()?.forEach { file -> + if (!file.isFile || file.extension != "json") { + return@forEach + } + val json = Json.decodeFromString(file.readText()) + jsons.add(buildJsonObject { + put("json", json) + put("methodName", file.nameWithoutExtension) + put("methodHash", file.nameWithoutExtension.hashCode()) + }) + file.delete() + } + jsons.sortBy { it.jsonObject["methodName"].toString() } + + if (jsons.isEmpty()) { + println("NO JSONS FOUND") + return + } + + val bigJson = buildJsonObject { + put( + "stateScheme", jsonStateScheme + ) + put( + "trajectoryScheme", jsonTrajectoryScheme + ) + putJsonArray("paths") { + jsons.forEach { + addJsonArray { + add(it.jsonObject.getValue("methodHash")) + add(it.jsonObject.getValue("json").jsonObject.getValue("path")) + add(it.jsonObject.getValue("methodName")) + add(it.jsonObject.getValue("json").jsonObject.getValue("statementsCount")) + if (config.logGraphFeatures) { + add(it.jsonObject.getValue("json").jsonObject.getValue("graphFeatures")) + add(it.jsonObject.getValue("json").jsonObject.getValue("graphEdges")) + } + add(it.jsonObject.getValue("json").jsonObject.getValue("probabilities")) + } + } + } + } + + val resultFile = Path(resultDirname, resultFilename).toFile() + resultFile.parentFile.mkdirs() + Json.encodeToStream(bigJson, resultFile.outputStream()) + + println("\nAGGREGATION FINISHED IN DIRECTORY $resultDirname\n") +} + +fun createConfig(options: JsonObject): MLConfig { + val defaultConfig = MLConfig() + val config = MLConfig( + gameEnvPath = (options.getOrDefault( + "gameEnvPath", + JsonPrimitive(defaultConfig.gameEnvPath) + ) as JsonPrimitive).content, + dataPath = (options.getOrDefault( + "dataPath", + JsonPrimitive(defaultConfig.dataPath) + ) as JsonPrimitive).content, + defaultAlgorithm = Algorithm.valueOf( + (options.getOrDefault( + "defaultAlgorithm", + JsonPrimitive(defaultConfig.defaultAlgorithm.name) + ) as JsonPrimitive).content + ), + postprocessing = Postprocessing.valueOf( + (options.getOrDefault( + "postprocessing", + JsonPrimitive(defaultConfig.postprocessing.name) + ) as JsonPrimitive).content + ), + mode = Mode.valueOf( + (options.getOrDefault( + "mode", + JsonPrimitive(defaultConfig.mode.name) + ) as JsonPrimitive).content + ), + logFeatures = (options.getOrDefault( + "logFeatures", + JsonPrimitive(defaultConfig.logFeatures) + ) as JsonPrimitive).content.toBoolean(), + shuffleTests = (options.getOrDefault( + "shuffleTests", + JsonPrimitive(defaultConfig.shuffleTests) + ) as JsonPrimitive).content.toBoolean(), + discounts = (options.getOrDefault( + "discounts", JsonArray( + defaultConfig.discounts + .map { JsonPrimitive(it) }) + ) as JsonArray).map { (it as JsonPrimitive).content.toFloat() }, + inputShape = (options.getOrDefault( + "inputShape", JsonArray( + defaultConfig.inputShape + .map { JsonPrimitive(it) }) + ) as JsonArray).map { (it as JsonPrimitive).content.toLong() }, + maxAttentionLength = (options.getOrDefault( + "maxAttentionLength", + JsonPrimitive(defaultConfig.maxAttentionLength) + ) as JsonPrimitive).content.toInt(), + useGnn = (options.getOrDefault( + "useGnn", + JsonPrimitive(defaultConfig.useGnn) + ) as JsonPrimitive).content.toBoolean(), + dataConsumption = (options.getOrDefault( + "dataConsumption", + JsonPrimitive(defaultConfig.dataConsumption) + ) as JsonPrimitive).content.toFloat(), + hardTimeLimit = (options.getOrDefault( + "hardTimeLimit", + JsonPrimitive(defaultConfig.hardTimeLimit) + ) as JsonPrimitive).content.toInt(), + solverTimeLimit = (options.getOrDefault( + "solverTimeLimit", + JsonPrimitive(defaultConfig.solverTimeLimit) + ) as JsonPrimitive).content.toInt(), + maxConcurrency = (options.getOrDefault( + "maxConcurrency", + JsonPrimitive(defaultConfig.maxConcurrency) + ) as JsonPrimitive).content.toInt(), + graphUpdate = GraphUpdate.valueOf( + (options.getOrDefault( + "graphUpdate", + JsonPrimitive(defaultConfig.graphUpdate.name) + ) as JsonPrimitive).content + ), + logGraphFeatures = (options.getOrDefault( + "logGraphFeatures", + JsonPrimitive(defaultConfig.logGraphFeatures) + ) as JsonPrimitive).content.toBoolean(), + gnnFeaturesCount = (options.getOrDefault( + "gnnFeaturesCount", + JsonPrimitive(defaultConfig.gnnFeaturesCount) + ) as JsonPrimitive).content.toInt(), + useRnn = (options.getOrDefault( + "useRnn", + JsonPrimitive(defaultConfig.useRnn) + ) as JsonPrimitive).content.toBoolean(), + rnnStateShape = (options.getOrDefault( + "rnnStateShape", JsonArray( + defaultConfig.rnnStateShape + .map { JsonPrimitive(it) }) + ) as JsonArray).map { (it as JsonPrimitive).content.toLong() }, + rnnFeaturesCount = (options.getOrDefault( + "rnnFeaturesCount", + JsonPrimitive(defaultConfig.rnnFeaturesCount) + ) as JsonPrimitive).content.toInt(), + inputJars = (options.getOrDefault( + "inputJars", + JsonObject(defaultConfig.inputJars.mapValues { (_, value) -> + JsonArray(value.map { JsonPrimitive(it) }) + }) + ) as JsonObject).mapValues { (_, value) -> + (value as JsonArray).toList().map { (it as JsonPrimitive).content } + } + ) + + println("OPTIONS:") + println(" GAME ENV PATH: ${config.gameEnvPath}") + println(" DATA PATH: ${config.dataPath}") + println(" DEFAULT ALGORITHM: ${config.defaultAlgorithm}") + println(" POSTPROCESSING: ${config.postprocessing}") + println(" MODE: ${config.mode}") + println(" LOG FEATURES: ${config.logFeatures}") + println(" SHUFFLE TESTS: ${config.shuffleTests}") + println(" INPUT SHAPE: ${config.inputShape}") + println(" MAX ATTENTION LENGTH: ${config.maxAttentionLength}") + println(" USE GNN: ${config.useGnn}") + println(" DATA CONSUMPTION: ${config.dataConsumption}%") + println(" HARD TIME LIMIT: ${config.hardTimeLimit}ms") + println(" SOLVER TIME LIMIT: ${config.solverTimeLimit}ms") + println(" MAX CONCURRENCY: ${config.maxConcurrency}") + println(" GRAPH UPDATE: ${config.graphUpdate}") + println(" LOG GRAPH FEATURES: ${config.logGraphFeatures}") + println(" GNN FEATURES COUNT: ${config.gnnFeaturesCount}") + println(" USE RNN: ${config.useRnn}") + println(" RNN STATE SHAPE: ${config.rnnStateShape}") + println(" RNN FEATURES COUNT: ${config.rnnFeaturesCount}") + println(" INPUT JARS: ${config.inputJars}") + println() + + return config +} + +fun clear(dataPath: String) { + Path(dataPath, "jsons").toFile().listFiles()?.forEach { file -> + file.delete() + } +} + +fun main(args: Array) { + val options = args.getOrNull(0)?.let { File(it) }?.readText()?.let { + Json.decodeFromString(it) + } + val config = if (options != null) { + createConfig(options) + } else { + MLConfig() + } + + if (config.mode != Mode.Aggregation) { + clear(config.dataPath) + } + + if (config.mode in listOf(Mode.Calculation, Mode.Both, Mode.Test)) { + try { + calculate(config) + } catch (e: Throwable) { + e.printStackTrace() + clear(config.dataPath) + } + } + + if (config.mode in listOf(Mode.Aggregation, Mode.Both)) { + try { + aggregate(config) + } catch (e: Throwable) { + e.printStackTrace() + clear(config.dataPath) + } + } +} diff --git a/usvm-ml-path-selection/src/test/kotlin/org/usvm/samples/JacoDBContainer.kt b/usvm-ml-path-selection/src/test/kotlin/org/usvm/samples/JacoDBContainer.kt new file mode 100644 index 0000000000..b64f2983bc --- /dev/null +++ b/usvm-ml-path-selection/src/test/kotlin/org/usvm/samples/JacoDBContainer.kt @@ -0,0 +1,49 @@ +package org.usvm.samples + +import kotlinx.coroutines.runBlocking +import org.jacodb.api.JcClasspath +import org.jacodb.api.JcDatabase +import org.jacodb.impl.JcSettings +import org.jacodb.impl.features.InMemoryHierarchy +import org.jacodb.impl.jacodb +import org.usvm.util.allClasspath +import java.io.File + +class JacoDBContainer( + classpath: List, + builder: JcSettings.() -> Unit, +) { + val db: JcDatabase + val cp: JcClasspath + + init { + val (db, cp) = runBlocking { + val db = jacodb(builder) + db to db.classpath(classpath) + } + this.db = db + this.cp = cp + runBlocking { + db.awaitBackgroundJobs() + } + } + + companion object { + private val keyToJacoDBContainer = HashMap() + + operator fun invoke( + key: Any?, + classpath: List = samplesClasspath, + builder: JcSettings.() -> Unit = defaultBuilder, + ): JacoDBContainer = + keyToJacoDBContainer.getOrPut(key) { JacoDBContainer(classpath, builder) } + + private val samplesClasspath = allClasspath.filter { it.name.contains("samples") } + + private val defaultBuilder: JcSettings.() -> Unit = { + useProcessJavaRuntime() + installFeatures(InMemoryHierarchy) + loadByteCode(samplesClasspath) + } + } +} diff --git a/usvm-ml-path-selection/src/test/kotlin/org/usvm/util/Util.kt b/usvm-ml-path-selection/src/test/kotlin/org/usvm/util/Util.kt new file mode 100644 index 0000000000..9780e55ee5 --- /dev/null +++ b/usvm-ml-path-selection/src/test/kotlin/org/usvm/util/Util.kt @@ -0,0 +1,15 @@ +package org.usvm.util + +import java.io.File + +val allClasspath: List + get() { + return classpath.map { File(it) } + } + +private val classpath: List + get() { + val classpath = System.getProperty("java.class.path") + return classpath.split(File.pathSeparatorChar) + .toList() + }