diff --git a/tools/hls-fuzzer/ConjunctionTypeSystem.h b/tools/hls-fuzzer/ConjunctionTypeSystem.h index 3f5b50745..5c4a8674f 100644 --- a/tools/hls-fuzzer/ConjunctionTypeSystem.h +++ b/tools/hls-fuzzer/ConjunctionTypeSystem.h @@ -31,6 +31,8 @@ class ConjunctionTypeSystemBase TypeSystem, Self>; using Context = typename Base::Context; + ConjunctionTypeSystemBase() = default; + /// Constructs a conjunctive typesystem from the instances of the /// sub-typesystems. explicit ConjunctionTypeSystemBase(SubTypeSystems &&...subTypeSystems) diff --git a/tools/hls-fuzzer/LimitTypeSystem.cpp b/tools/hls-fuzzer/LimitTypeSystem.cpp index cd9321544..8824cf4f2 100644 --- a/tools/hls-fuzzer/LimitTypeSystem.cpp +++ b/tools/hls-fuzzer/LimitTypeSystem.cpp @@ -1,8 +1,8 @@ #include "LimitTypeSystem.h" dynamatic::ProbabilityTable -dynamatic::gen::LimitTypeSystem::getExpressionProbabilityTable( - const LimitTypingContext &context) { +dynamatic::gen::details::DepthTypeSystem::getExpressionProbabilityTable( + const DepthTypingContext &context) { // Default probabilities for expressions. // Most expressions are 100 times more likely to be generated than a // constant. diff --git a/tools/hls-fuzzer/LimitTypeSystem.h b/tools/hls-fuzzer/LimitTypeSystem.h index 34f78157c..06dde390f 100644 --- a/tools/hls-fuzzer/LimitTypeSystem.h +++ b/tools/hls-fuzzer/LimitTypeSystem.h @@ -1,39 +1,42 @@ #ifndef DYNAMATIC_HLS_FUZZER_LIMITTYPESYSTEM #define DYNAMATIC_HLS_FUZZER_LIMITTYPESYSTEM +#include "ConjunctionTypeSystem.h" #include "TypeSystem.h" +#include "VisitorTypeSystem.h" #include namespace dynamatic::gen { -struct LimitTypingContext { + +namespace details { + +struct DepthTypingContext { std::size_t expressionDepth{}; std::size_t totalNumberOfStatements{}; }; -/// Typesystem used to enforce limits such as number of a specific AST nodes, -/// parameters, depth of expressions and so on. -class LimitTypeSystem : public TypeSystem { +class DepthTypeSystem : public TypeSystem { /// Returns a transfer function which increments 'field' of the context of /// 'from' before returning it. - template static auto incrementDepth() { return TransferFn( - [](LimitTypingContext context, auto &&...) { + [](DepthTypingContext context, auto &&...) { ++(context.*field); return context; }); } public: - explicit LimitTypeSystem(std::size_t maxExpressionDepth = 4, + explicit DepthTypeSystem(std::size_t maxExpressionDepth = 4, std::size_t maxTotalStatements = 10) : maxExpressionDepth(maxExpressionDepth), maxTotalStatements(maxTotalStatements) {} bool discardBinaryExpression(ast::BinaryExpression::Op, - const LimitTypingContext &context) const { + const DepthTypingContext &context) const { return context.expressionDepth >= maxExpressionDepth; } @@ -41,16 +44,16 @@ class LimitTypeSystem : public TypeSystem { getBinaryExpressionTransferFns(ast::BinaryExpression::Op op) override { return { /*lhs=*/incrementDepth(), + &DepthTypingContext::expressionDepth>(), /*rhs=*/ incrementDepth(), + &DepthTypingContext::expressionDepth>(), /*output=*/copyInputToOutput(), }; } bool discardUnaryExpression(ast::UnaryExpression::Op, - const LimitTypingContext &context) const { + const DepthTypingContext &context) const { return context.expressionDepth >= maxExpressionDepth; } @@ -58,12 +61,12 @@ class LimitTypeSystem : public TypeSystem { getUnaryExpressionTransferFns(ast::UnaryExpression::Op op) override { return { /*operand=*/incrementDepth(), + &DepthTypingContext::expressionDepth>(), /*output=*/copyInputToOutput(), }; } - bool discardCastExpression(const LimitTypingContext &context) const { + bool discardCastExpression(const DepthTypingContext &context) const { return context.expressionDepth >= maxExpressionDepth; } @@ -72,12 +75,12 @@ class LimitTypeSystem : public TypeSystem { /*target type=*/copyFromInput(), /*operand=*/ incrementDepth(), + &DepthTypingContext::expressionDepth>(), /*output=*/copyInputToOutput(), }; } - bool discardConditionalExpression(const LimitTypingContext &context) const { + bool discardConditionalExpression(const DepthTypingContext &context) const { return context.expressionDepth >= maxExpressionDepth; } @@ -87,18 +90,18 @@ class LimitTypeSystem : public TypeSystem { // subelements. return { /*condition=*/incrementDepth(), + &DepthTypingContext::expressionDepth>(), /*true value=*/ incrementDepth(), + &DepthTypingContext::expressionDepth>(), /*false value=*/ incrementDepth(), + &DepthTypingContext::expressionDepth>(), /*output=*/copyInputToOutput(), }; } - bool discardArrayReadExpression(const LimitTypingContext &context) const { + bool discardArrayReadExpression(const DepthTypingContext &context) const { return context.expressionDepth >= maxExpressionDepth; } @@ -108,7 +111,7 @@ class LimitTypeSystem : public TypeSystem { /*array parameter=*/copyFromInput(), /*index=*/ incrementDepth(), + &DepthTypingContext::expressionDepth>(), /*output=*/copyInputToOutput(), }; } @@ -123,14 +126,14 @@ class LimitTypeSystem : public TypeSystem { OutputTransferFn( std::index_sequence{}, [](const ast::ArrayAssignmentStatement &, - LimitTypingContext context) { + DepthTypingContext context) { context.totalNumberOfStatements++; return context; }), }; } - bool discardStatementList(const LimitTypingContext &context) const { + bool discardStatementList(const DepthTypingContext &context) const { return context.totalNumberOfStatements >= maxTotalStatements; } @@ -147,8 +150,8 @@ class LimitTypeSystem : public TypeSystem { OutputTransferFn( std::index_sequence{}, - [](const ast::StatementList &, LimitTypingContext statement, - const LimitTypingContext &statementList) { + [](const ast::StatementList &, DepthTypingContext statement, + const DepthTypingContext &statementList) { // Regardless of which of the two was generated first, we can // extract the total number of statements by taking their maximum. statement.totalNumberOfStatements = @@ -167,7 +170,7 @@ class LimitTypeSystem : public TypeSystem { /*step=*/copyFromInput(), /*statements=*/ incrementDepth(), + &DepthTypingContext::totalNumberOfStatements>(), /*output=*/ copyToOutput(), @@ -175,11 +178,83 @@ class LimitTypeSystem : public TypeSystem { } static ProbabilityTable - getExpressionProbabilityTable(const LimitTypingContext &context); + getExpressionProbabilityTable(const DepthTypingContext &context); std::size_t maxExpressionDepth{}; std::size_t maxTotalStatements{}; }; + +struct ParamTypingContext { + std::size_t numScalarParam{}; + std::size_t numArrayParam{}; + + ParamTypingContext merge(const ParamTypingContext &rhs) const { + return {std::max(numScalarParam, rhs.numScalarParam), + std::max(numArrayParam, rhs.numArrayParam)}; + } +}; + +/// Type system that caps the maximum amount of scalar parameters, array +/// parameters and parameters in general. +class ParamTypeSystem + : public VisitorTypeSystem { +public: + explicit ParamTypeSystem(std::size_t maxScalarParam = 16, + std::size_t maxArrayParam = 8, + std::size_t maxParams = 256) + : maxScalarParam(maxScalarParam), maxArrayParam(maxArrayParam), + maxParams(maxParams) {} + + bool discardFreshScalarParameter(const ParamTypingContext &context) const { + return context.numScalarParam >= maxScalarParam || + context.numScalarParam + context.numArrayParam >= maxParams; + } + + TransferFnArray + getFreshScalarParameterTransferFns() override { + return { + copyFromInput(), + OutputTransferFn( + std::index_sequence{}, + [](const ast::ScalarParameter &, ParamTypingContext context) { + context.numScalarParam++; + return context; + }), + }; + } + + bool discardFreshArrayParameter(const ParamTypingContext &context) const { + return context.numArrayParam >= maxArrayParam || + context.numScalarParam + context.numArrayParam >= maxParams; + } + + TransferFnArray + getFreshArrayParameterTransferFns() override { + return { + copyFromInput(), + OutputTransferFn( + std::index_sequence{}, + [](const ast::ArrayParameter &, ParamTypingContext context) { + context.numArrayParam++; + return context; + }), + }; + } + +private: + std::size_t maxScalarParam{}; + std::size_t maxArrayParam{}; + std::size_t maxParams{}; +}; + +} // namespace details + +/// Typesystem used to enforce limits such as number of a specific AST nodes, +/// parameters, depth of expressions and so on. +class LimitTypeSystem final + : public ConjunctionTypeSystemBase< + LimitTypeSystem, details::DepthTypeSystem, details::ParamTypeSystem> { +}; } // namespace dynamatic::gen #endif diff --git a/tools/hls-fuzzer/VisitorTypeSystem.h b/tools/hls-fuzzer/VisitorTypeSystem.h new file mode 100644 index 000000000..875cb159e --- /dev/null +++ b/tools/hls-fuzzer/VisitorTypeSystem.h @@ -0,0 +1,188 @@ +#ifndef DYNAMATIC_HLS_FUZZER_VISITOR_TYPE_SYSTEM +#define DYNAMATIC_HLS_FUZZER_VISITOR_TYPE_SYSTEM + +#include "TypeSystem.h" + +namespace dynamatic::gen { + +/// Convenient base class for any type systems that are "pure" visitors. +/// These type systems: +/// * Do not care about the order that AST nodes are generated. +/// * Have a monotonic 'TypingContext' which only ever increases/decreases. +/// +/// One property of this type system is that the most recent transfer function +/// called is guaranteed to receive the maximum/minimum 'TypingContext' +/// instance. +/// This makes the type system especially useful to implement counters and +/// similar. +/// +/// The 'TypingContext' is required to contain a method with the signature: +/// 'TypingContext merge(const TypingContext& rhs) const' which can be used +/// to calculate the current maximum/minimum of all contexts generated so far. +template +class VisitorTypeSystem : public TypeSystem { + +public: + using Base = TypeSystem; + +protected: + /// Returns a 'TransferFn' which merges all present contexts of 'indices' and + /// the input context. + template + static auto getMergingTransferFn(std::index_sequence) { + return TransferFn( + [](TypingContext result, auto &&...args) -> TypingContext { + foreachInTuples( + [&](const auto &element) { + if constexpr (std::is_same_v, + const TypingContext *>) { + if (!element) + return; + result = result.merge(*element); + } + }, + std::forward_as_tuple(std::forward(args)...)); + return result; + }); + } + + /// Returns a 'TransferFn' which merges all present contexts including + /// the input context. + template + static auto getMergingTransferFn() { + return getMergingTransferFn( + std::make_index_sequence< + std::tuple_size_v>{}); + } + + /// Returns a 'OutputTransferFn' which merges all present contexts of + /// 'indices' and the input context. + template + static auto getMergingOutputTransferFn(std::index_sequence) { + return OutputTransferFn( + std::index_sequence{}, + [](const ASTNode &, TypingContext result, + const std::conditional_t(indices), TypingContext, + TypingContext> &...args) -> TypingContext { + foreachInTuples( + [&](const auto &element) { + if constexpr (std::is_same_v, + TypingContext>) { + result = result.merge(element); + } + }, + std::forward_as_tuple(std::forward(args)...)); + return result; + }); + } + + /// Returns a 'OutputTransferFn' which merges all present contexts including + /// the input context. + template + static auto getMergingOutputTransferFn() { + return getMergingOutputTransferFn( + std::make_index_sequence< + std::tuple_size_v>{}); + } + + /// Returns a transfer array only consisting of merging transfer functions + /// and output functions. + template + static TransferFnArray getMergingTransferFnArray() { + return std::tuple_cat( + mapTuples([](auto &&) { return getMergingTransferFn(); }, + getTupleOfIndices( + std::make_index_sequence< + std::tuple_size_v>{})), + std::tuple(getMergingOutputTransferFn())); + } + +public: + TransferFnArray getFunctionTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray getScalarTypeTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray getReturnTypeTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getReturnStatementTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getBinaryExpressionTransferFns(ast::BinaryExpression::Op op) override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getUnaryExpressionTransferFns(ast::UnaryExpression::Op op) override { + return getMergingTransferFnArray(); + } + + TransferFnArray getVariableTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray getCastExpressionTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getConditionalExpressionTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray getConstantTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getFreshScalarParameterTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getExistingScalarParameterTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getArrayReadExpressionTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getFreshArrayParameterTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getExistingArrayParameterTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getArrayAssignmentStatementTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray getStatementListTransferFns() override { + return getMergingTransferFnArray(); + } + + TransferFnArray + getStructuredForStatementTransferFns() override { + return getMergingTransferFnArray(); + } +}; + +} // namespace dynamatic::gen + +#endif