From a8478ac1434aacf7784ea6a5ba879c297f5f94b5 Mon Sep 17 00:00:00 2001 From: Jincheng Chen Date: Sat, 23 Oct 2021 15:55:03 +0000 Subject: [PATCH 1/3] add transform support for LightGBM by open source FreeForm2 library --- src/transform/CMakeLists.txt | 96 + .../CMakeLists.txt | 18 + .../inc/FreeForm2.h | 15 + .../inc/FreeForm2Assert.h | 41 + .../inc/FreeForm2Compiler.h | 60 + .../inc/FreeForm2CompilerFactory.h | 78 + .../inc/FreeForm2Executable.h | 93 + .../inc/FreeForm2ExternalData.h | 90 + .../inc/FreeForm2Features.h | 90 + .../inc/FreeForm2Program.h | 142 + .../inc/FreeForm2Result.h | 122 + .../inc/FreeForm2Type.h | 128 + .../inc/NeuralInputFreeForm2.h | 421 + .../libs/Backend/llvm/ArrayCodeGen.cpp | 275 + .../libs/Backend/llvm/ArrayCodeGen.h | 115 + .../libs/Backend/llvm/CMakeLists.txt | 36 + .../libs/Backend/llvm/CompilationState.cpp | 466 + .../libs/Backend/llvm/CompilationState.h | 223 + .../Backend/llvm/Extend/FreeForm2Support.cpp | 30 + .../Backend/llvm/Extend/FreeForm2Support.h | 14 + .../libs/Backend/llvm/Extend/JITEmitter.cpp | 1264 + .../libs/Backend/llvm/Extend/JITExtend.cpp | 695 + .../libs/Backend/llvm/Extend/JITExtend.h | 232 + .../libs/Backend/llvm/LlvmCodeGenUtils.cpp | 343 + .../libs/Backend/llvm/LlvmCodeGenUtils.h | 103 + .../libs/Backend/llvm/LlvmCodeGenerator.cpp | 2556 ++ .../libs/Backend/llvm/LlvmCodeGenerator.h | 203 + .../libs/Backend/llvm/LlvmCompiler.cpp | 783 + .../libs/Backend/llvm/LlvmCompiler.h | 220 + .../libs/Backend/llvm/LlvmRuntimeLibrary.cpp | 252 + .../libs/Backend/llvm/LlvmRuntimeLibrary.h | 65 + .../libs/Expression/Allocation.cpp | 67 + .../libs/Expression/Allocation.h | 64 + .../Expression/ArrayDereferenceExpression.cpp | 147 + .../Expression/ArrayDereferenceExpression.h | 53 + .../libs/Expression/ArrayLength.cpp | 54 + .../libs/Expression/ArrayLength.h | 38 + .../Expression/ArrayLiteralExpression.cpp | 450 + .../libs/Expression/ArrayLiteralExpression.h | 105 + .../libs/Expression/BinaryOperator.cpp | 130 + .../libs/Expression/BinaryOperator.h | 59 + .../libs/Expression/BlockExpression.cpp | 104 + .../libs/Expression/BlockExpression.h | 61 + .../libs/Expression/CMakeLists.txt | 53 + .../libs/Expression/Conditional.cpp | 103 + .../libs/Expression/Conditional.h | 43 + .../libs/Expression/ConvertExpression.cpp | 239 + .../libs/Expression/ConvertExpression.h | 137 + .../libs/Expression/DebugExpression.cpp | 78 + .../libs/Expression/DebugExpression.h | 47 + .../libs/Expression/Declaration.cpp | 104 + .../libs/Expression/Declaration.h | 64 + .../libs/Expression/Expression.cpp | 233 + .../libs/Expression/Expression.h | 143 + .../libs/Expression/Extern.cpp | 213 + .../libs/Expression/Extern.h | 80 + .../libs/Expression/FeatureSpec.cpp | 592 + .../libs/Expression/FeatureSpec.h | 268 + .../libs/Expression/Function.cpp | 255 + .../libs/Expression/Function.h | 130 + .../libs/Expression/LetExpression.cpp | 107 + .../libs/Expression/LetExpression.h | 53 + .../libs/Expression/LiteralExpression.cpp | 545 + .../libs/Expression/LiteralExpression.h | 221 + .../libs/Expression/Match.cpp | 356 + .../libs/Expression/Match.h | 164 + .../libs/Expression/MatchSub.h | 54 + .../Expression/MemberAccessExpression.cpp | 204 + .../libs/Expression/MemberAccessExpression.h | 75 + .../libs/Expression/Mutation.cpp | 248 + .../libs/Expression/Mutation.h | 87 + .../libs/Expression/OperatorExpression.cpp | 265 + .../libs/Expression/OperatorExpression.h | 120 + .../libs/Expression/PhiNode.cpp | 103 + .../libs/Expression/PhiNode.h | 48 + .../libs/Expression/Publish.cpp | 171 + .../libs/Expression/Publish.h | 93 + .../libs/Expression/RandExpression.cpp | 86 + .../libs/Expression/RandExpression.h | 54 + .../libs/Expression/RangeReduceExpression.cpp | 615 + .../libs/Expression/RangeReduceExpression.h | 229 + .../libs/Expression/RefExpression.cpp | 156 + .../libs/Expression/RefExpression.h | 85 + .../libs/Expression/SelectNth.cpp | 210 + .../libs/Expression/SelectNth.h | 107 + .../libs/Expression/SimpleExpressionOwner.h | 29 + .../libs/Expression/StateMachine.cpp | 717 + .../libs/Expression/StateMachine.h | 418 + .../libs/Expression/StreamData.cpp | 82 + .../libs/Expression/StreamData.h | 53 + .../libs/Expression/SymbolTable.cpp | 241 + .../libs/Expression/SymbolTable.h | 126 + .../libs/Expression/TypeUtil.cpp | 360 + .../libs/Expression/TypeUtil.h | 55 + .../libs/Expression/UnaryOperator.cpp | 105 + .../libs/Expression/UnaryOperator.h | 42 + .../libs/Expression/Visitor.h | 565 + .../libs/External/ArrayResult.h | 396 + .../libs/External/CMakeLists.txt | 34 + .../libs/External/Compiler.cpp | 33 + .../libs/External/Compiler.h | 25 + .../libs/External/Executable.cpp | 99 + .../libs/External/Executable.h | 48 + .../libs/External/FreeForm2ExternalData.cpp | 70 + .../libs/External/FreeForm2Result.cpp | 365 + .../libs/External/FreeForm2Type.cpp | 226 + .../libs/External/NeuralInputFreeForm2.cpp | 459 + .../libs/External/Program.cpp | 262 + .../libs/External/Program.h | 72 + .../libs/External/ResultIteratorImpl.h | 50 + .../libs/External/ValueResult.cpp | 448 + .../libs/External/ValueResult.h | 94 + .../SExpression/inc/FreeForm2Tokenizer.h | 201 + .../Parse/SExpression/inc/SExpressionParse.h | 67 + .../Parse/SExpression/libs/Arithmetic.cpp | 223 + .../libs/Parse/SExpression/libs/Arithmetic.h | 33 + .../libs/Parse/SExpression/libs/Bitwise.cpp | 37 + .../libs/Parse/SExpression/libs/Bitwise.h | 20 + .../Parse/SExpression/libs/CMakeLists.txt | 31 + .../SExpression/libs/ExpressionFactory.cpp | 26 + .../SExpression/libs/ExpressionFactory.h | 51 + .../SExpression/libs/FreeForm2Tokenizer.cpp | 489 + .../libs/Parse/SExpression/libs/Logic.cpp | 86 + .../libs/Parse/SExpression/libs/Logic.h | 27 + .../Parse/SExpression/libs/MiscFactory.cpp | 445 + .../libs/Parse/SExpression/libs/MiscFactory.h | 43 + .../libs/OperatorExpressionFactory.h | 86 + .../SExpression/libs/ProgramParseState.cpp | 152 + .../SExpression/libs/ProgramParseState.h | 129 + .../SExpression/libs/SExpressionParse.cpp | 1382 + .../libs/Shared/ArrayType.cpp | 282 + .../libs/Shared/ArrayType.h | 105 + .../libs/Shared/Attributes.h | 34 + .../libs/Shared/BFBFile.h | 202 + .../libs/Shared/CMakeLists.txt | 34 + .../libs/Shared/CompoundType.cpp | 30 + .../libs/Shared/CompoundType.h | 52 + .../libs/Shared/FreeForm2Assert.cpp | 38 + .../libs/Shared/FreeForm2Utils.cpp | 150 + .../libs/Shared/FreeForm2Utils.h | 88 + .../libs/Shared/FunctionType.cpp | 134 + .../libs/Shared/FunctionType.h | 69 + .../libs/Shared/ObjectType.cpp | 119 + .../libs/Shared/ObjectType.h | 70 + .../libs/Shared/QuietMetaStreams.h | 97 + .../libs/Shared/StateMachineType.cpp | 119 + .../libs/Shared/StateMachineType.h | 80 + .../libs/Shared/StructType.cpp | 186 + .../libs/Shared/StructType.h | 96 + .../libs/Shared/TypeImpl.cpp | 373 + .../libs/Shared/TypeImpl.h | 111 + .../libs/Shared/TypeManager.cpp | 1097 + .../libs/Shared/TypeManager.h | 315 + .../libs/Transform/AllocationVisitor.cpp | 165 + .../libs/Transform/AllocationVisitor.h | 73 + .../libs/Transform/CMakeLists.txt | 34 + .../libs/Transform/CopyingVisitor.cpp | 1581 + .../libs/Transform/CopyingVisitor.h | 159 + .../libs/Transform/FunctionInlineVisitor.cpp | 133 + .../libs/Transform/FunctionInlineVisitor.h | 38 + .../libs/Transform/NoOpVisitor.h | 90 + .../Transform/ObjectResolutionVisitor.cpp | 128 + .../libs/Transform/ObjectResolutionVisitor.h | 36 + .../Transform/OperandPromotionVisitor.cpp | 228 + .../libs/Transform/OperandPromotionVisitor.h | 40 + .../libs/Transform/ProcessFeaturesUsed.cpp | 17 + .../libs/Transform/ProcessFeaturesUsed.h | 30 + .../libs/Transform/TypeCheckingVisitor.cpp | 599 + .../libs/Transform/TypeCheckingVisitor.h | 92 + .../Transform/UniformExpressionVisitor.cpp | 581 + .../libs/Transform/UniformExpressionVisitor.h | 90 + .../test/CMakeLists.txt | 51 + .../test/FreeFormLibTest.cpp | 15 + .../test/FreeFormLibTestSet.h | 142 + .../test/SimpleFeatureMap.cpp | 85 + .../test/SimpleFeatureMap.h | 39 + .../NeuralTree.Library/CMakeLists.txt | 7 + src/transform/NeuralTree.Library/inc/CsHash.h | 157 + .../NeuralTree.Library/inc/FeaSpecConfig.h | 29 + .../NeuralTree.Library/inc/IFeatureMap.h | 34 + .../inc/INeuralNetFeatures.h | 21 + .../NeuralTree.Library/inc/MigratedApi.h | 19 + .../NeuralTree.Library/inc/NeuralInput.h | 633 + .../inc/NeuralInputFactory.h | 121 + .../inc/NeuralInputFreeForm2_types.h | 232 + .../inc/NeuralInput_types.h | 1261 + .../inc/UnionBondInput_types.h | 245 + .../NeuralTree.Library/inc/basic_types.h | 96 + .../NeuralTree.Library/src/CMakeLists.txt | 35 + .../NeuralTree.Library/src/CsHash.cpp | 135 + .../NeuralTree.Library/src/FeaSpecConfig.cpp | 144 + .../NeuralTree.Library/src/NeuralInput.cpp | 1944 + .../src/NeuralInputFactory.cpp | 353 + .../CMakeLists.txt | 13 + .../inc/FeatureMap.h | 18 + .../inc/InputComputation.h | 8 + .../inc/InputExtraction.h | 29 + .../src/CMakeLists.txt | 46 + .../src/FeatureMap.cpp | 64 + .../src/InputComputation.cpp | 50 + .../src/InputExtraction.cpp | 72 + .../src/InputExtractor.cpp | 174 + .../src/InputExtractor.h | 46 + .../src/LocalFactoryHolder.cpp | 23 + .../src/LocalFactoryHolder.h | 19 + .../src/MinimalFeatureMap.cpp | 84 + .../src/MinimalFeatureMap.h | 34 + .../test/CMakeLists.txt | 52 + .../test/NeuralTreeEvaluatorLibTest.cpp | 50 + .../test/data/TrainInputIni | 18 + .../TransformProcessor/CMakeLists.txt | 8 + .../TransformProcessor/CMakeLists.txt | 43 + .../TransformProcessor/FeatureEvaluator.cpp | 28 + .../TransformProcessor/FeatureEvaluator.h | 19 + .../FeatureEvaluatorExtendedInfo.h | 17 + .../IniFileParserInterface.cpp | 189 + .../IniFileParserInterface.h | 50 + .../TransformProcessor/Parser.h | 23 + .../TransformProcessor/TransformProcessor.cpp | 226 + .../TransformProcessor/TransformProcessor.h | 45 + .../TransformProcessorFeatureMap.cpp | 10 + .../TransformProcessorFeatureMap.h | 17 + .../TransformProcessor/test/CMakeLists.txt | 49 + .../TransformProcessor/test/Tests.cpp | 57 + src/transform/TransformProcessor/test/Tests.h | 7 + .../test/data/integration/ExpectedOutput.txt | 3 + .../test/data/integration/Header.tsv | 1 + .../test/data/integration/Input.tsv | 3 + .../data/integration/SmoothedTrainInputIni | 29642 ++++++++++++++++ 229 files changed, 71871 insertions(+) create mode 100644 src/transform/CMakeLists.txt create mode 100644 src/transform/DynamicRank.FreeForm.Library/CMakeLists.txt create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Assert.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Compiler.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2CompilerFactory.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Executable.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2ExternalData.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Features.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Program.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Result.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Type.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/inc/NeuralInputFreeForm2.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/ArrayCodeGen.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/ArrayCodeGen.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CMakeLists.txt create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CompilationState.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CompilationState.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/FreeForm2Support.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/FreeForm2Support.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITEmitter.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITExtend.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITExtend.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenUtils.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenUtils.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenerator.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenerator.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCompiler.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCompiler.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmRuntimeLibrary.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmRuntimeLibrary.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Allocation.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Allocation.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayDereferenceExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayDereferenceExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLength.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLength.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLiteralExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLiteralExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/BinaryOperator.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/BinaryOperator.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/BlockExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/BlockExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/CMakeLists.txt create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Conditional.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Conditional.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/ConvertExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/ConvertExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/DebugExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/DebugExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Declaration.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Declaration.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Expression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Expression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Extern.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Extern.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/FeatureSpec.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/FeatureSpec.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Function.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Function.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/LetExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/LetExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/LiteralExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/LiteralExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Match.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Match.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/MatchSub.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/MemberAccessExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/MemberAccessExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Mutation.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Mutation.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/OperatorExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/OperatorExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/PhiNode.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/PhiNode.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Publish.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Publish.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/RandExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/RandExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/RangeReduceExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/RangeReduceExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/RefExpression.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/RefExpression.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/SelectNth.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/SelectNth.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/SimpleExpressionOwner.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/StateMachine.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/StateMachine.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/StreamData.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/StreamData.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/SymbolTable.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/SymbolTable.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/TypeUtil.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/TypeUtil.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/UnaryOperator.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/UnaryOperator.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Expression/Visitor.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/ArrayResult.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/CMakeLists.txt create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/Compiler.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/Compiler.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/Executable.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/Executable.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2ExternalData.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2Result.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2Type.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/NeuralInputFreeForm2.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/Program.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/Program.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/ResultIteratorImpl.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/ValueResult.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/External/ValueResult.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/inc/FreeForm2Tokenizer.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/inc/SExpressionParse.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Arithmetic.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Arithmetic.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Bitwise.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Bitwise.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/CMakeLists.txt create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ExpressionFactory.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ExpressionFactory.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/FreeForm2Tokenizer.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Logic.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Logic.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/MiscFactory.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/MiscFactory.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/OperatorExpressionFactory.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ProgramParseState.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ProgramParseState.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/SExpressionParse.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/ArrayType.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/ArrayType.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/Attributes.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/BFBFile.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/CMakeLists.txt create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/CompoundType.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/CompoundType.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Assert.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Utils.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Utils.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/FunctionType.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/FunctionType.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/ObjectType.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/ObjectType.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/QuietMetaStreams.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/StateMachineType.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/StateMachineType.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/StructType.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/StructType.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeImpl.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeImpl.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeManager.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeManager.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/AllocationVisitor.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/AllocationVisitor.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/CMakeLists.txt create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/CopyingVisitor.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/CopyingVisitor.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/FunctionInlineVisitor.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/FunctionInlineVisitor.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/NoOpVisitor.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/ObjectResolutionVisitor.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/ObjectResolutionVisitor.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/OperandPromotionVisitor.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/OperandPromotionVisitor.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/ProcessFeaturesUsed.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/ProcessFeaturesUsed.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/TypeCheckingVisitor.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/TypeCheckingVisitor.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/UniformExpressionVisitor.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/libs/Transform/UniformExpressionVisitor.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/test/CMakeLists.txt create mode 100644 src/transform/DynamicRank.FreeForm.Library/test/FreeFormLibTest.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/test/FreeFormLibTestSet.h create mode 100644 src/transform/DynamicRank.FreeForm.Library/test/SimpleFeatureMap.cpp create mode 100644 src/transform/DynamicRank.FreeForm.Library/test/SimpleFeatureMap.h create mode 100644 src/transform/NeuralTree.Library/CMakeLists.txt create mode 100644 src/transform/NeuralTree.Library/inc/CsHash.h create mode 100644 src/transform/NeuralTree.Library/inc/FeaSpecConfig.h create mode 100644 src/transform/NeuralTree.Library/inc/IFeatureMap.h create mode 100644 src/transform/NeuralTree.Library/inc/INeuralNetFeatures.h create mode 100644 src/transform/NeuralTree.Library/inc/MigratedApi.h create mode 100644 src/transform/NeuralTree.Library/inc/NeuralInput.h create mode 100644 src/transform/NeuralTree.Library/inc/NeuralInputFactory.h create mode 100644 src/transform/NeuralTree.Library/inc/NeuralInputFreeForm2_types.h create mode 100644 src/transform/NeuralTree.Library/inc/NeuralInput_types.h create mode 100644 src/transform/NeuralTree.Library/inc/UnionBondInput_types.h create mode 100644 src/transform/NeuralTree.Library/inc/basic_types.h create mode 100644 src/transform/NeuralTree.Library/src/CMakeLists.txt create mode 100644 src/transform/NeuralTree.Library/src/CsHash.cpp create mode 100644 src/transform/NeuralTree.Library/src/FeaSpecConfig.cpp create mode 100644 src/transform/NeuralTree.Library/src/NeuralInput.cpp create mode 100644 src/transform/NeuralTree.Library/src/NeuralInputFactory.cpp create mode 100644 src/transform/NeuralTreeEvaluator.Library/CMakeLists.txt create mode 100644 src/transform/NeuralTreeEvaluator.Library/inc/FeatureMap.h create mode 100644 src/transform/NeuralTreeEvaluator.Library/inc/InputComputation.h create mode 100644 src/transform/NeuralTreeEvaluator.Library/inc/InputExtraction.h create mode 100644 src/transform/NeuralTreeEvaluator.Library/src/CMakeLists.txt create mode 100644 src/transform/NeuralTreeEvaluator.Library/src/FeatureMap.cpp create mode 100644 src/transform/NeuralTreeEvaluator.Library/src/InputComputation.cpp create mode 100644 src/transform/NeuralTreeEvaluator.Library/src/InputExtraction.cpp create mode 100644 src/transform/NeuralTreeEvaluator.Library/src/InputExtractor.cpp create mode 100644 src/transform/NeuralTreeEvaluator.Library/src/InputExtractor.h create mode 100644 src/transform/NeuralTreeEvaluator.Library/src/LocalFactoryHolder.cpp create mode 100644 src/transform/NeuralTreeEvaluator.Library/src/LocalFactoryHolder.h create mode 100644 src/transform/NeuralTreeEvaluator.Library/src/MinimalFeatureMap.cpp create mode 100644 src/transform/NeuralTreeEvaluator.Library/src/MinimalFeatureMap.h create mode 100644 src/transform/NeuralTreeEvaluator.Library/test/CMakeLists.txt create mode 100644 src/transform/NeuralTreeEvaluator.Library/test/NeuralTreeEvaluatorLibTest.cpp create mode 100644 src/transform/NeuralTreeEvaluator.Library/test/data/TrainInputIni create mode 100644 src/transform/TransformProcessor/CMakeLists.txt create mode 100644 src/transform/TransformProcessor/TransformProcessor/CMakeLists.txt create mode 100644 src/transform/TransformProcessor/TransformProcessor/FeatureEvaluator.cpp create mode 100644 src/transform/TransformProcessor/TransformProcessor/FeatureEvaluator.h create mode 100644 src/transform/TransformProcessor/TransformProcessor/FeatureEvaluatorExtendedInfo.h create mode 100644 src/transform/TransformProcessor/TransformProcessor/IniFileParserInterface.cpp create mode 100644 src/transform/TransformProcessor/TransformProcessor/IniFileParserInterface.h create mode 100644 src/transform/TransformProcessor/TransformProcessor/Parser.h create mode 100644 src/transform/TransformProcessor/TransformProcessor/TransformProcessor.cpp create mode 100644 src/transform/TransformProcessor/TransformProcessor/TransformProcessor.h create mode 100644 src/transform/TransformProcessor/TransformProcessor/TransformProcessorFeatureMap.cpp create mode 100644 src/transform/TransformProcessor/TransformProcessor/TransformProcessorFeatureMap.h create mode 100644 src/transform/TransformProcessor/test/CMakeLists.txt create mode 100644 src/transform/TransformProcessor/test/Tests.cpp create mode 100644 src/transform/TransformProcessor/test/Tests.h create mode 100644 src/transform/TransformProcessor/test/data/integration/ExpectedOutput.txt create mode 100644 src/transform/TransformProcessor/test/data/integration/Header.tsv create mode 100644 src/transform/TransformProcessor/test/data/integration/Input.tsv create mode 100644 src/transform/TransformProcessor/test/data/integration/SmoothedTrainInputIni diff --git a/src/transform/CMakeLists.txt b/src/transform/CMakeLists.txt new file mode 100644 index 000000000000..3269b7059bde --- /dev/null +++ b/src/transform/CMakeLists.txt @@ -0,0 +1,96 @@ +cmake_minimum_required(VERSION 3.15.0 FATAL_ERROR) + +project(transform) + +# following lib paths are required by linking directories not compiling. +set(LLVM_LIB + LLVMAnalysis + LLVMAsmParser + LLVMAsmPrinter + LLVMBitReader + LLVMBitWriter + LLVMCodeGen + LLVMCore + LLVMDebugInfo + LLVMExecutionEngine + LLVMInstCombine + LLVMInstrumentation + LLVMInterpreter + LLVMipa + LLVMipo + LLVMIRReader + LLVMJIT + LLVMLineEditor + LLVMLinker + LLVMLTO + LLVMMC + LLVMMCAnalysis + LLVMMCDisassembler + LLVMMCJIT + LLVMMCParser + LLVMObjCARCOpts + LLVMObject + LLVMOption + LLVMProfileData + LLVMRuntimeDyld + LLVMScalarOpts + LLVMSelectionDAG + LLVMSupport + LLVMTarget + LLVMTransformUtils + LLVMVectorize + LLVMX86AsmParser + LLVMX86AsmPrinter + LLVMX86CodeGen + LLVMX86Desc + LLVMX86Disassembler + LLVMX86Info + LLVMX86Utils + ) + +set(BOOST_LIB + boost_atomic + boost_chrono + boost_context + boost_coroutine + boost_date_time + boost_exception + boost_filesystem + boost_graph + boost_graph_parallel + boost_iostreams + boost_locale + boost_log + boost_log_setup + boost_math_c99 + boost_math_c99f + boost_math_c99l + boost_math_tr1 + boost_math_tr1f + boost_math_tr1l + boost_mpi + boost_prg_exec_monitor + boost_program_options + boost_random + boost_regex + boost_serialization + boost_system + boost_test_exec_monitor + boost_thread + boost_timer + boost_unit_test_framework + boost_wave + boost_wserialization + ) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +if(USE_DEBUG) + SET(CMAKE_BUILD_TYPE "Debug") + SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g -ggdb") + SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall") +endif() + +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/NeuralTree.Library) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/DynamicRank.FreeForm.Library) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/NeuralTreeEvaluator.Library) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/TransformProcessor) diff --git a/src/transform/DynamicRank.FreeForm.Library/CMakeLists.txt b/src/transform/DynamicRank.FreeForm.Library/CMakeLists.txt new file mode 100644 index 000000000000..b20c80974749 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 3.15.0 FATAL_ERROR) + +project(DynamicRank.FreeForm.Library) + +set(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS} -fpermissive") + +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libs/Shared) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libs/Expression) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libs/Parse/SExpression/libs) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libs/Transform) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libs/Backend/llvm) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/libs/External) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/test) + +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/inc/ + DESTINATION include + FILES_MATCHING PATTERN "*.h" + ) diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2.h b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2.h new file mode 100644 index 000000000000..632f7dd3bbfc --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2.h @@ -0,0 +1,15 @@ +// FreeForm2.h declares the main interface to the FreeForm2 expression +// evaluator, used for feature evaluation in the neural net/tree ensemble +// evaluator. + +#pragma once + +#ifndef FREEFORM2_FREEFORM2_H +#define FREEFORM2_FREEFORM2_H + +#include "FreeForm2Executable.h" +#include "FreeForm2Compiler.h" +#include "FreeForm2Program.h" + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Assert.h b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Assert.h new file mode 100644 index 000000000000..6a89afcfccbd --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Assert.h @@ -0,0 +1,41 @@ +#pragma once + +#ifndef FREEFORM2_ASSERT_H +#define FREEFORM2_ASSERT_H + +namespace FreeForm2 +{ + // Assert a condition, providing file and line of the assertion call. + // Note that this function throws on failure, rather than aborting (like + // standard c assert). + void ThrowAssert(bool p_condition, const char* p_file, unsigned int p_line); + + // Assert a condition, providing the expression, file and line of the + // assertion call. + // Note that this function throws on failure, rather than aborting (like + // standard c assert). + void ThrowAssert(bool p_condition, const char* p_expression, const char* p_file, unsigned int p_line); + + // Assert that this function call should not be reached during normal + // program execution. Note that this function throws on failure, rather + // than aborting. + // __declspec(noreturn) + void Unreachable(const char* p_file, unsigned int p_line); +}; + +// Macros for asserting. +#define FF2_ASSERT(cond) \ + /* Call the regular FreeForm2::ThrowAssert function with macro-generated \ + * parameters */ \ + FreeForm2::ThrowAssert((cond), \ + #cond, \ + __FILE__, \ + __LINE__) + +#define FF2_UNREACHABLE() \ + /* Call the regular FreeForm2::Unreachable function with macro-generated \ + * parameters */ \ + FreeForm2::Unreachable(__FILE__, \ + __LINE__) + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Compiler.h b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Compiler.h new file mode 100644 index 000000000000..c8f416c5f664 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Compiler.h @@ -0,0 +1,60 @@ +#pragma once + +#ifndef FREEFORM2_INC_COMPILER_H +#define FREEFORM2_INC_COMPILER_H + +#include +#include +#include +#include + +namespace DynamicRank +{ + class IFeatureMap; +} + +namespace FreeForm2 +{ + class CompilerImpl; + class ExternalDataManager; + class Program; + + // This class contains the results of the compilation process. Since the + // result of the compilation process is backend-dependent, this class + // should be downcast to the backend result type according to the compiler + // used. + class CompilerResults : boost::noncopyable + { + public: + virtual ~CompilerResults(); + }; + + // A compiler compiles multiple programs using the same set of compiler + // resources, which amortises costs across different programs. + class Compiler : boost::noncopyable + { + public: + Compiler(std::auto_ptr p_impl); + + ~Compiler(); + + // Compile the program producing a backend-dependent results obejct. + // This method, optionally producing debug output on stderr. + std::unique_ptr + Compile(const Program& p_program, bool p_debugOutput); + + // Default optimization level, used whenever the level is not explicitly specified. + static const unsigned int c_defaultOptimizationLevel = 0; + + private: + // Pointer to implementation class (pimpl idiom). + boost::scoped_ptr m_impl; + }; +} + +// TODO: Remove the following include (TFS# 453473). +// This is used to prevent a breaking API change (moving compiler factory into +// a separate header). +#include "FreeForm2CompilerFactory.h" + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2CompilerFactory.h b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2CompilerFactory.h new file mode 100644 index 000000000000..0657128c0807 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2CompilerFactory.h @@ -0,0 +1,78 @@ +#pragma once + +#ifndef FREEFORM2_COMPILER_FACTORY_H +#define FREEFORM2_COMPILER_FACTORY_H + +#include +#include + +namespace FreeForm2 +{ + namespace Cpp + { + class ExternResolver; + } + + class Compiler; + + // This class contains methods to instantiate the various types of + // compilers. + class CompilerFactory + { + public: + enum DestinationFunctionType + { + // typedef float (*DirectEvalFun)(StreamFeatureInput*, const FeatureType[], OutputType[]). + SingleDocumentEvaluation, + + // typedef float (*AggregatedEvalFun)(StreamFeatureInput*, const FeatureType[][], UInt32, UInt32, OutputType[]). + DocumentSetEvaluation, + }; + + // Create a compiler which takes a program and compiles it into an + // ExecutableCompilerResults object. p_optimizationLevel describes the + // degree of optimization to perform on the code with an integer, + // analagous to 'gcc -O p_optimizationLevel'. + static + std::unique_ptr + CreateExecutableCompiler(unsigned int p_optimizationLevel, + DestinationFunctionType p_destinationFunctionType = SingleDocumentEvaluation); + + // This method creates a Compiler object to compile a program to a + // C++/IFM target. The results of this compiler are defined in + // FreeForm2CppCompiler.h. + static + std::unique_ptr + CreateCppIFMCompiler(const Cpp::ExternResolver& p_resolver); + + // This method creates a Compiler object to compile a program to a + // C++/Barramundi target. The results of this compiler are defined + // in FreeForm2CppCompiler.h. + static + std::unique_ptr + CreateCppBarramundiCompiler(const Cpp::ExternResolver& p_resolver, + const std::string& p_metadataPath); + + // This method creates a Compiler object to compile a program to a + // C++/Barramundi target with debugging instrumentation present in the + // program. The printf command is the name of a printf-style function + // for use with debugging statements. The results of this compiler are + // defined in FreeForm2CppCompiler.h. + static + std::unique_ptr + CreateDebuggingCppBarramundiCompiler(const Cpp::ExternResolver& p_resolver, + const std::string& p_metadataPath, + const std::string& p_printfCommand); + + // This method creates a Compiler object to compile a program to a + // C++/FPGA target. The results of this compiler are defined + // in FreeForm2CppCompiler.h. + static + std::unique_ptr + CreateCppFpgaCompiler(const Cpp::ExternResolver& p_resolver, + const std::string& p_outputMappingPath, + const std::string& p_msdlPath); + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Executable.h b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Executable.h new file mode 100644 index 000000000000..843be651359d --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Executable.h @@ -0,0 +1,93 @@ +#pragma once + +#ifndef FREEFORM2_INC_EXECUTABLE_H +#define FREEFORM2_INC_EXECUTABLE_H + +#include +#include +#include +#include +// Included for the CompilerResults object. +#include "FreeForm2Compiler.h" +#include + +class StreamFeatureInput; + +namespace FreeForm2 +{ + class ExecutableImpl; + class Result; + class Type; + + // An executable is something that can simply be executed, and not much + // more (knows its type, because that's unavoidable). + class Executable : boost::noncopyable + { + public: + // Features are 32-bit unsigned integer quantities generated by a + // variety of feature generators. + typedef UInt32 FeatureType; + + // Inputs take features and turn them into floating point numbers. + typedef double InputType; + typedef float OutputType; + + // Constructor, taking the implementation object. + Executable(std::auto_ptr p_impl); + + // Evaluate a program, returning a result. The StreamFeatureInput may + // be NULL, in which case default values are used for its members. + boost::shared_ptr Evaluate(StreamFeatureInput* p_input, + const FeatureType p_features[]) const; + + + // Evaluate a program of aggregated freeform syntax. + boost::shared_ptr + Evaluate(const Executable::FeatureType* const* p_features, + UInt32 p_currentDocument, + UInt32 p_documentCount, + Int64* p_cache) const; + + // Attempt to get a function returning a pointer to a function that can + // be called directly. This will only work if the executable produces a + // floating point number, and various other conditions. + typedef float (*DirectEvalFun)(StreamFeatureInput*, const FeatureType[], OutputType[]); + DirectEvalFun EvaluationFunction() const; + + // Attempt to get a function returning a pointer to a function that can + // be called to evaluate aggregated freeform. + typedef float (*AggregatedEvalFun)(const FeatureType* const*, UInt32, UInt32, Int64*, OutputType[]); + AggregatedEvalFun AggregatedEvaluationFunction() const; + + // Get the output type of the executable. + const Type& GetType() const; + + // Get the size of external memory. + size_t GetExternalSize() const; + + // Implementation accessor. + const ExecutableImpl& GetImplementation() const; + + private: + // Pointer to implementation class (pimpl idiom). + boost::scoped_ptr m_impl; + }; + + // The class holds the results of the Compiler::Compile method for an + // Executable compiler. + class ExecutableCompilerResults : public CompilerResults + { + public: + // Construct a results object for an executable. + explicit ExecutableCompilerResults(const boost::shared_ptr& p_executable); + + // Get the executable held by this object. + const boost::shared_ptr& GetExecutable() const; + + private: + // The executable pointer passed to the constructor. + const boost::shared_ptr m_executable; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2ExternalData.h b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2ExternalData.h new file mode 100644 index 000000000000..f4f64865b6e6 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2ExternalData.h @@ -0,0 +1,90 @@ +#pragma once +#ifndef FREEFORM2_INC_EXTERNAL_DATA_H +#define FREEFORM2_INC_EXTERNAL_DATA_H + +#include +#include "FreeForm2Features.h" +#include "FreeForm2Type.h" +#include "FreeForm2Result.h" +#include +#include + +namespace FreeForm2 +{ + class TypeManager; + + // This class represents a piece of data, which is defined externally from + // the perspective of a program compiled by this library. + class ExternalData + { + public: + // Creates a run-time constant external data member. + ExternalData(const std::string& p_name, const TypeImpl& p_typeImpl); + + // Creates a compile-time constant external data member. + ExternalData(const std::string& p_name, const TypeImpl& p_typeImpl, ConstantValue p_value); + + // Force subclassing of this class. This serves two purposes: it + // encounrages implementors to add backend-specific data to the + // ExternalData class to group resource storage, and it disallows the + // library to copy ExternalData in case the client has subclassed + // ExternalData. + virtual ~ExternalData() = 0; + + // Get the name of the external data. + const std::string& GetName() const; + + // Get the type of the external data. + const TypeImpl& GetType() const; + + // Return a flag that determines if this external data is a + // compile-time constant. + bool IsCompileTimeConstant() const; + + // Return the compile-time constant value of this external data. If + // this method is not a compile-time constant, this method throws an + // exception. + ConstantValue GetCompileTimeValue() const; + + private: + // The name of the data. + std::string m_name; + + // The type of the data. + const TypeImpl* m_type; + + // This flag indicates whether or not this object is a compile-time + // constant value. + bool m_isCompileTimeConst; + + // This value is used only if this object is a compile-time constant + // and contains the constant value. + ConstantValue m_constantValue; + }; + + // This class provides a name-to-data mapping for external data members. + class ExternalDataManager + { + public: + // The default constructor initializes the type factory. + ExternalDataManager(); + virtual ~ExternalDataManager(); + + // This method returns an ExternalData object based on a name. If no + // data exists for the name, this method returns nullptr. + virtual const ExternalData* FindData(const std::string& p_name) const = 0; + + protected: + // Get the type factory for managing the types associated with this + // data manager. + TypeFactory& GetTypeFactory(); + + private: + // The type manager owns the types created with the above methods. + std::unique_ptr m_typeFactory; + + friend class TypeManager; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Features.h b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Features.h new file mode 100644 index 000000000000..12791020b083 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Features.h @@ -0,0 +1,90 @@ +#pragma once + +#ifndef FREEFORM2_INC_FEATURES_H +#define FREEFORM2_INC_FEATURES_H + +#include + +namespace FreeForm2 +{ + class TypeImpl; + + // The feature information struct provides a namespace for feature-related + // data declarations. + struct FeatureInformation + { + // This enum lists the type of features supported by the feature + // compiler. + enum FeatureType + { + MetaStreamFeature, + DerivedFeature, + AggregatedDerivedFeature, + AbInitioFeature + }; + }; + + // This namespace declares the names of required external data members for + // metastream features. + namespace RequiredMetaStreamData + { + // The number of query paths in the current query. + struct NumQueryPaths + { + static const std::string& GetName(); + static const TypeImpl& GetType(); + }; + + // The number of words in the query. + struct QueryLength + { + static const std::string& GetName(); + static const TypeImpl& GetType(); + }; + + // The index of word candidates per term in a specific query path. + struct QueryPathCandidates + { + static const std::string& GetName(); + }; + + // The stream data over which a metastream feature operates. + struct Stream + { + static const std::string& GetName(); + static const TypeImpl& GetType(); + }; + + // The number of tuples of interest per type. + struct TupleOfInterestCount + { + static const std::string& GetName(); + }; + + // The tuples of interest. + struct TuplesOfInterest + { + static const std::string& GetName(); + }; + + // Duplicate term information. + struct UnsafeDuplicateTermInformation + { + static const std::string& GetName(); + }; + } + + // This namespace declares the external data members required for + // compilation of derived features. + namespace RequiredDerivedFeatureData + { + // The data member representing the stream ID. + struct StreamID + { + static const std::string& GetName(); + static const TypeImpl& GetType(); + }; + } +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Program.h b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Program.h new file mode 100644 index 000000000000..0e0b7fe983aa --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Program.h @@ -0,0 +1,142 @@ +#pragma once + +#ifndef FREEFORM2_INC_PROGRAM_H +#define FREEFORM2_INC_PROGRAM_H + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace DynamicRank +{ + class IFeatureMap; + class INeuralNetFeatures; +} + +namespace FreeForm2 +{ + class Allocation; + class ProgramImpl; + class Expression; + class ExternalDataManager; + class Type; + + // A struct that holds information about the source location in a program. + struct SourceLocation + { + SourceLocation(); + SourceLocation(unsigned int p_lineNo, unsigned int p_lineOffset); + + unsigned int m_lineNo; + unsigned int m_lineOffset; + }; + + // This class is used to throw errors during parsing. + class ParseError : public std::runtime_error + { + public: + // Construct an exception with a message and source location. + ParseError(const std::string& p_message, const SourceLocation& p_sourceLocation); + + // Construct an exception based off an inner exception, with + // a source location. + ParseError(const std::exception& p_inner, const SourceLocation& p_sourceLocation); + + // Get the error message, without line information. + const std::string& GetMessage() const; + + // Get the line number where the parsing error occurred. + const SourceLocation& GetSourceLocation() const; + + private: + // The error message, without line information. + std::string m_message; + + // The source location where the parsing error occurred. + SourceLocation m_sourceLocation; + }; + + // A program is a representation of a program, that offers functionality in + // addition to being able to simply execute. + class Program : boost::noncopyable + { + public: + Program(std::auto_ptr p_impl); + + // Enumeration of available syntaxes. + enum Syntax + { + sexpression, + /* visage, */ + aggregatedSExpression, + }; + + // Parse a program from a string. The feature map defines the mapping + // between names and feature value slots in the p_features array passed + // to Executable::Evaluate. p_syntax dictates the syntax used for + // parsing. p_mustProduceFloat forces the program to + // produce a float as final result. If the result can be sensibly + // converted to a float (i.e. from an integer) it will be, otherwise an + // exception will be thrown. If p_debugOutput is not NULL, debugging + // information is written to this stream during parsing. The external + // data manager is an optional argument. If a manager is provided, + // extern lookups will be done via this object; otherwise, externs will + // be disallowed in the program. + // + // Note that a reference to the feature map is saved, and that object + // must persist through the lifetime of the program. + template + static boost::shared_ptr Parse(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + bool p_mustProduceFloat, + const ExternalDataManager* p_externalData, + std::ostream* p_debugOutput); + + // Get the output type of the program. + const Type& GetType() const; + + // Get the expression of the program. + const Expression& GetExpression() const; + + // Provide an interface to extract features used by this program, by + // calling Process on the provided INeuralNetFeatures object. + void ProcessFeaturesUsed(DynamicRank::INeuralNetFeatures& p_features) const; + + // Gets the list of all the allocations required by the program. + const std::vector>& GetAllocations() const; + + // Implementation accessors. + ProgramImpl& GetImplementation(); + const ProgramImpl& GetImplementation() const; + + private: + // Pointer to implementation class (pimpl idiom). + boost::scoped_ptr m_impl; + }; + + // Explicit instantiation of Program::Parse for all available syntaxes. + template + boost::shared_ptr + Program::Parse(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + bool p_mustProduceFloat, + const ExternalDataManager* p_externalData, + std::ostream* p_debugOutput); + + template + boost::shared_ptr + Program::Parse(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + bool p_mustProduceFloat, + const ExternalDataManager* p_externalData, + std::ostream* p_debugOutput); +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Result.h b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Result.h new file mode 100644 index 000000000000..8fa3267e1d9b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Result.h @@ -0,0 +1,122 @@ +#pragma once + +#ifndef FREEFORM2_RESULT_H +#define FREEFORM2_RESULT_H + +#include +#include +#include "FreeForm2Type.h" +#include +#include +#include + +namespace FreeForm2 +{ + // Class to encapsulate value and type into a result. + class ResultIterator; + class Result : boost::noncopyable + { + public: + typedef bool BoolType; + typedef Int64 IntType; + typedef UInt64 UInt64Type; + typedef Int32 Int32Type; + typedef UInt32 UInt32Type; + + // Generate minimum int for this size via bit twiddling. + const static IntType c_minInt + = (static_cast(1) << ((sizeof(IntType) * 8) - 1)); + + // Generate maximum int for this size by underflowing from c_minInt. + // (disable warning about integer overflow). +#pragma warning(push) +#pragma warning(disable : 4307) + const static IntType c_maxInt = c_minInt - 1; +#pragma warning(pop) + + typedef float FloatType; + + virtual ~Result(); + + // Compare two results (which must be of the same type, else an + // exception is thrown), returning the usual c convention of < 0, 0, + // > 0 to indicate less than, equal to or greater than the other result. + int Compare(const Result& p_other) const; + + // Print the result to the given output stream. + void Print(std::ostream& p_out) const; + + virtual const Type& GetType() const = 0; + virtual IntType GetInt() const = 0; + virtual UInt64Type GetUInt64() const = 0; + virtual Int32Type GetInt32() const = 0; + virtual UInt32Type GetUInt32() const = 0; + virtual FloatType GetFloat() const = 0; + virtual BoolType GetBool() const = 0; + virtual ResultIterator BeginArray() const = 0; + virtual ResultIterator EndArray() const = 0; + + // Compare two floats for equality, using the freeform2 standard + // tolerance. + static int CompareFloat(FloatType p_left, FloatType p_right); + }; + + // Facade class to help iterate over array elements in an abstract way. + // Note that boost::iterator_facade and virtual functions don't place nicely + // together, so we need a separate virtual class, and an iterator_facade class. + class ResultIteratorImpl; + class ResultIterator : public boost::iterator_facade + { + friend class boost::iterator_core_access; + + public: + ResultIterator(std::auto_ptr p_impl); + + ResultIterator(const ResultIterator& p_other); + + ~ResultIterator(); + + private: + // iterator_facade function to increment the iterator. + void increment(); + + // iterator_facade function to decrement the iterator. + void decrement(); + + // iterator_facade function to compare iterators. + bool equal(const ResultIterator& p_other) const; + + // iterator_facade function to get the current element. + const Result& dereference() const; + + // iterator_facade function to get the current element. + void advance(std::ptrdiff_t p_distance); + + // iterator_facade function to calculate distance to another iterator. + std::ptrdiff_t distance_to(const ResultIterator& p_other) const; + + // Pointer to virtual iterator implementation. + std::auto_ptr m_impl; + }; + + std::ostream& operator<<(std::ostream& p_out, const Result& p_result); + + // This union defines the various compile-time constant values in the + // compiler. + union ConstantValue + { + Result::IntType m_int; + Result::UInt64Type m_uint64; + Result::Int32Type m_int32; + Result::UInt32Type m_uint32; + Result::FloatType m_float; + Result::BoolType m_bool; + + Result::IntType GetInt(const TypeImpl& p_type) const; + }; + +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Type.h b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Type.h new file mode 100644 index 000000000000..3be9ac9fdc77 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/FreeForm2Type.h @@ -0,0 +1,128 @@ +#pragma once + +#ifndef FREEFORM2_TYPE_H +#define FREEFORM2_TYPE_H + +#include +#include +#include +#include +#include +#include + +namespace FreeForm2 +{ + class TypeImpl; + class TypeManager; + + class Type : public boost::equality_comparable + { + public: + // List of possible value types. Note that i'm not following the coding + // guidelines here (by capitalising each), because otherwise we get + // conflicts with the C++ type 'float', etc. + enum TypePrimitive + { + Float, + Int, + UInt64, + Int32, + UInt32, + Bool, + Array, + Struct, + Void, + Stream, + Word, + InstanceHeader, + BodyBlockHeader, + StateMachine, + Function, + Object, + + Unknown, + Invalid + }; + + // Construct a type from a TypeImpl. Note that no ownership transfer is + // implied by this, and the TypeImpl object must remain in scope for the + // lifetime of the Type. + explicit Type(const TypeImpl& p_impl); + + // Get the type primitive of this type. + TypePrimitive Primitive() const; + + // Returns a name for each type in the Type enumeration, or NULL for + // unrecognised types. + static const char* Name(TypePrimitive p_type); + + // Returns the type primitive corresponding to a name, the inverse of + // the Name() function. Returns Type::Invalid for unrecognised names, + // and is case sensitive. + static TypePrimitive ParsePrimitive(SIZED_STRING p_string); + + // Equality operator. + bool operator==(const Type& p_other) const; + + // Get implementation class. + const TypeImpl& GetImplementation() const; + + private: + Type(const Type& p_type); + void operator=(const Type& p_type); + + // Pointer to implementation (pimpl idiom). + const TypeImpl& m_impl; + }; + + // Output to std::ostream. + std::ostream& operator<<(std::ostream& p_out, const Type& p_type); + + // A type factory allows creation of arbitrary TypeImpl objects without + // exposing the type implementations. + class TypeFactory : boost::noncopyable + { + public: + // Construct a TypeFactory for an implementation. + TypeFactory(std::auto_ptr p_typeManager); + + // The members of a structure consists of a name-type pair. + typedef std::pair StructMember; + + // These methods produce a reference to a TypeImpl corresponding to a + // feature compiler type. These can be used to construct Type objects. + static const TypeImpl& GetFloatType(); + static const TypeImpl& GetIntType(); + static const TypeImpl& GetUInt64Type(); + static const TypeImpl& GetInt32Type(); + static const TypeImpl& GetUInt32Type(); + static const TypeImpl& GetBoolType(); + static const TypeImpl& GetVoidType(); + static const TypeImpl& GetStreamType(); + static const TypeImpl& GetWordType(); + static const TypeImpl& GetInstanceHeaderType(); + static const TypeImpl& GetBodyBlockHeaderType(); + const TypeImpl& GetArrayType(const TypeImpl& p_child, + const UInt32* p_dimensions, + UInt32 p_dimensionCount); + const TypeImpl& GetStructType(const std::string& p_name, + const StructMember* p_members, + UInt32 p_memberCount); + const TypeImpl& GetFunctionType(const TypeImpl& p_returnType, + const TypeImpl* const* p_parameters, + UInt32 p_parameterCount); + + // Find a type by type name. This should be the same string as when the + // type is written to a stream using the stream insertion operator. + const TypeImpl* FindType(const std::string& p_name) const; + + // Get the type manager associated with this factory. + const TypeManager& GetTypeManager() const; + + private: + // The type manager owns the types created with the above methods. + boost::scoped_ptr m_typeManager; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/inc/NeuralInputFreeForm2.h b/src/transform/DynamicRank.FreeForm.Library/inc/NeuralInputFreeForm2.h new file mode 100644 index 000000000000..ec408269b5de --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/inc/NeuralInputFreeForm2.h @@ -0,0 +1,421 @@ +#pragma once + +#ifndef FREEFORM2_NEURAL_INPUT_FREEFORM2_H +#define FREEFORM2_NEURAL_INPUT_FREEFORM2_H + +#include +#include +#include +#include +#include "FreeForm2.h" +#include "FreeForm2CompilerFactory.h" +#include "FreeForm2Executable.h" +#include +#include +#include "UnionBondInput_types.h" +#include "NeuralInputFreeForm2_types.h" + +namespace FreeForm2 +{ + class Program; + + // Match with FillBond, Loader reguster. Also match the model .ini. + static const char* c_freeform2_tranform = "freeform2"; + static const char* c_bulkinput_tranform = "bulkinput"; + + // a NeuralInput that compiles several FreeFormv2 expressions together into fewer functions that + // can be bulk evaluated. This is done to increase performance. + class BulkCompiledNeuralInput : public DynamicRank::BulkNeuralInput + { + public: + BulkCompiledNeuralInput(); + + BulkCompiledNeuralInput(std::vector> p_programs, + size_t p_numInputs, + const std::vector& p_features); + + virtual void Evaluate(const UInt32 p_input[], float p_output[]) const override; + + virtual void + GetAllAssociatedFeatures(std::vector& p_associatedFeaturesList) const override; + + virtual size_t GetNumberOfNeuralInputs() const; + + // Get the size of this object, including internal and external memory. + virtual size_t GetSize() const override; + + // Bond Serialization. + virtual void FillBond(DynamicRank::UnionBondInput& p_data) const override; + + virtual bool Equal(const DynamicRank::BulkNeuralInput* p_other) const override; + + BulkCompiledNeuralInput(const DynamicRank::UnionBondInput& p_data); + + // BulkCompiledNeuralInput never load from config.ini since it's constructed by ff2 compiler + // or reconstructed from UnionBondInput. + static BulkCompiledNeuralInput* + Load(DynamicRank::Config& p_config, + const char* p_section, + DynamicRank::IFeatureMap* p_featureMap, + const char* p_transform); + + protected: + // A std::vector of features used by the FreeForm expression. + std::vector m_features; + + // Executable program. + std::vector> m_execs; + + // Directly executable function. + std::vector m_funs; + + // Number of neural inputs. + size_t m_numInputs; + + private: + friend class NeuralInputFreeForm2; + + // Functions to do low-level serialization and deserialization of executable code. + static std::pair, size_t> SerializeBlob(const std::vector>& p_execs); + static std::pair, size_t> SerializePhoenixBlob(const std::vector>& p_execs); + static std::pair, size_t> SerializeLlvmBlob(const std::vector>& p_execs); + + static void DeserializeBlob(const unsigned char* p_buffer, + size_t p_length, + std::vector>& p_execs, + std::vector& p_funs); + }; + + + // NeuralInputFreeForm2 wraps freeform2 expression evaluation in a neural + // input, which makes it suitable for use in neural nets and tree ensembles. + class NeuralInputFreeForm2 : public DynamicRank::NeuralInput + { + public: + // Constructor, taking the input program, the name of the transform used + // to identify this neural input, the feature map to use when compiling + // the program. + NeuralInputFreeForm2(const std::string& p_input, + const char* p_transform, + DynamicRank::IFeatureMap& p_map); + + NeuralInputFreeForm2(const std::string& p_input, + const char* p_transform, + DynamicRank::IFeatureMap& p_map, + boost::shared_ptr p_program); + + // Evaluate. + virtual double Evaluate(UInt32 p_input[]) const override; + + // Train. + virtual bool Train(double p_learningRate, + double p_outputHigh, + double p_outputLow, + double p_outputDelta, + UInt32 p_inputHigh[], + UInt32 p_inputLow[]) override; + + // Get all associated features. + virtual void GetAllAssociatedFeatures(std::vector& p_associatedFeaturesList) const override; + + // Compile the freeform expression, making this input available for use. + // The provided compiler will be used if not NULL, a compiler will be + // created if not. + virtual void Compile(Compiler* p_compiler); + + // Returns the Program referenced by this NeuralInput. Used for serialization. + const Program& GetProgram() const; + + // Bond Serialization. + virtual void FillBond(DynamicRank::UnionBondInput& p_data) const override; + + virtual bool Equal(const DynamicRank::NeuralInput* p_other) const override; + + NeuralInputFreeForm2(const DynamicRank::UnionBondInput& p_data); + + // Interface to support batch serialization. + // Just for special batch serializaton of freeform2. + virtual bool IsFreeForm2() const override; + + virtual double GetMin() const override; + + virtual double GetMax() const override; + + virtual bool Save(FILE *p_out, size_t p_input, const DynamicRank::IFeatureMap& p_map) const override; + + // Function that loads a freeform2 neural input from a config + // file. Returns a pointer allocated via new, transferring + // ownership (as per NeuralInputFactory requirements). + static NeuralInputFreeForm2* + Load(//const apsdk::configuration::IConfiguration& p_config, + DynamicRank::Config& p_config, + const char* p_section, + DynamicRank::IFeatureMap* p_featureMap, + const char* p_transform); + + // Get the size of this object, including internal and external memory. + virtual size_t GetSize() const override; + + protected: + + // Gets a string representation of the input. This is used to log if an error occurs + // while compiling. + virtual std::string GetStringRepresentation() const; + + // Get the size of external memory. + size_t GetExternalSize() const; + + // Load a freeform2 neural input string from a config file. + static std::string + LoadProgram(DynamicRank::Config& p_config, + const char* p_section, + const DynamicRank::IFeatureMap* p_featureMap, + const char* p_transform); + + // Which transform was used to generate this input. + const char* m_transform; + + // Compiled program. + boost::shared_ptr m_program; + + // A std::vector of features used by the FreeForm expression. + std::vector m_features; + + // Executable program. + boost::shared_ptr m_exec; + + // Directly executable function. + Executable::DirectEvalFun m_fun; + + // Feature map used to compile feature. This should be a shared + // pointer, as we need it to live for the lifetime of this object. TFS + // task 55302 is open to fixing this issue. + DynamicRank::IFeatureMap* m_map; + + NeuralInputFreeForm2(); + private: + + // Init all members. + void Init(); + + // Program input, concatenated onto a single line. + std::string m_input; + }; + + + // Class to hide Compile's implementation. + class NeuralInputCompiler : boost::noncopyable + { + public: + static void Compile(const std::vector& p_inputs, + Compiler& p_compiler); + + private: + // Make sure nobody can create instances of this class. + NeuralInputCompiler(); + }; + + + // Class to load neural inputs from neural net/tree ensemble .ini files and then + // compile them. + template + class CompiledNeuralInputLoader : public DynamicRank::NeuralInputFactory::Loader + { + public: + // Constructor, taking the name of the transform which this loader will + // be responsible for. + CompiledNeuralInputLoader(const char* p_transform) + : m_transform(p_transform) + { + } + + // Functor to create a NeuralInput given appropriate inputs. + virtual DynamicRank::NeuralInput* + operator()(DynamicRank::Config& p_config, + const char* p_section, + DynamicRank::IFeatureMap& p_featureMap) const + { + T* t = T::Load(p_config, + p_section, + &p_featureMap, + m_transform); + NeuralInputFreeForm2* input = dynamic_cast(t); + + if (input != NULL) + { + // A stateful factory. + m_loaded.push_back(input); + } + + return input; + } + + virtual DynamicRank::NeuralInput* FromBond(const DynamicRank::UnionBondInput& p_data) const + { + // To handle all the cases we let all the cases' constructor from bond + // Take the DynamicRank::UnionBondInput as parameter. + T* t = new T(p_data); + DynamicRank::NeuralInput* t1 = dynamic_cast(t); + if (!t1) + { + delete t; + t = nullptr; + } + return t1; + } + + virtual DynamicRank::BulkNeuralInput* FromBulkBond(const DynamicRank::UnionBondInput& p_data) const + { + // To handle all the cases we let all the cases' constructor from bond + // Take the DynamicRank::UnionBondInput as parameter. + T* t = new T(p_data); + DynamicRank::BulkNeuralInput* t1 = dynamic_cast(t); + if (!t1) + { + delete t; + t = nullptr; + } + return t1; + } + + // Compile all freeform2 programs loaded through this factory. + void Compile(Compiler& p_compiler) const + { + Compile(m_loaded, p_compiler); + + m_loaded.clear(); + } + + // Static function to compile a given set of inputs. + static void Compile(const std::vector& p_inputs, + Compiler& p_compiler) + { + NeuralInputCompiler::Compile(p_inputs, p_compiler); + } + + + private: + // The transform this loader is loading. + const char* m_transform; + + // Vector of loaded compiled neural inputs, kept in order to allow compilation + // after loading. This is not great, but allows us to avoid individual + // compilation of each input, which takes excessive amounts of + // time. + mutable std::vector m_loaded; + }; + + + // A base class that groups a large amount of NeuralInputs and combines them into + // a few executable functions. They should have the same effect as running all the + // individual functions. + class CompiledBulkNeuralInputLoaderBase : public DynamicRank::BulkNeuralInputFactory + { + public: + // The reduction factor is how many functions should be compiled together at + // a time. A factor of 100 means that it will bulk compile in batches of 100. + static const UInt32 c_DefaultReductionFactor = 100; + + // p_reductionFactor must be greater or equal to 1. + CompiledBulkNeuralInputLoaderBase(UInt32 p_reductionFactor); + + virtual std::unique_ptr + ConvertToBulkInput(const std::vector& p_inputs, + DynamicRank::IFeatureMap& p_featureMap) const = 0; + + protected: + + // Load all neuralinput to a vector, extracted from ConvertToBulkInput for BulkAggregatedNeuralInput. + size_t + LoadProgramsForBulkInput(std::vector>& p_programs, + std::vector& p_features, + const std::vector& p_inputs, + DynamicRank::IFeatureMap& p_featureMap) const; + + private: + const UInt32 m_reductionFactor; + }; + + template + class CompiledBulkNeuralInputLoader : public CompiledBulkNeuralInputLoaderBase + { + public: + // p_reductionFactor must be greater or equal to 1. + CompiledBulkNeuralInputLoader(UInt32 p_reductionFactor) + : CompiledBulkNeuralInputLoaderBase(p_reductionFactor) + { + } + + virtual std::unique_ptr + ConvertToBulkInput(const std::vector& p_inputs, + DynamicRank::IFeatureMap& p_featureMap) const + { + std::vector features; + std::vector> programs; + size_t convertedInputsSize = LoadProgramsForBulkInput(programs, features, p_inputs, p_featureMap); + + return std::unique_ptr(new T(programs, + convertedInputsSize, + features)); + } + + }; + + // This loader is just for FromBulkBond for BulkNeuralInput. + template + class FromBulkBondLoader : public DynamicRank::NeuralInputFactory::Loader + { + public: + FromBulkBondLoader() {} + + virtual DynamicRank::NeuralInput* + operator()(DynamicRank::Config& p_config, + const char* p_section, + DynamicRank::IFeatureMap& p_featureMap) const + { + // Not implemented. + return nullptr; + } + + + virtual DynamicRank::NeuralInput* FromBond(const DynamicRank::UnionBondInput& p_data) const + { + // Not implemented. + return nullptr; + } + + virtual DynamicRank::BulkNeuralInput* FromBulkBond(const DynamicRank::UnionBondInput& p_data) const + { + // To handle all the cases we let all the cases' constructor from bond + // Take the DynamicRank::UnionBondInput as parameter. + T* t = new T(p_data); + DynamicRank::BulkNeuralInput* t1 = dynamic_cast(t); + if (!t1) + { + delete t; + t = nullptr; + } + return t1; + } + + }; + + // A DynamicRank::NeuralInputFactory with the NeuralInputFreeForm2 loader + // and CompiledBulkNeuralInputLoader pre-registered. + // CompiledBulkNeuralInputLoader registered mainly for FromBond logic. + class CompiledNeuralInputFactory : public DynamicRank::NeuralInputFactory + { + public: + CompiledNeuralInputFactory(); + + const CompiledNeuralInputLoader& GetFreeForm2Loader() const; + + private: + boost::shared_ptr> m_ff2Loader; + + // Register both loaders to help reconstruct tree from Bond. + boost::shared_ptr> m_bulkLoader; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/ArrayCodeGen.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/ArrayCodeGen.cpp new file mode 100644 index 000000000000..3ccfdc32e891 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/ArrayCodeGen.cpp @@ -0,0 +1,275 @@ +#include "ArrayCodeGen.h" + +#include "ArrayResult.h" +#include "CompilationState.h" +#include "FreeForm2Assert.h" +#include "LlvmCodeGenUtils.h" +#include "ResultIteratorImpl.h" +#include +#include "ValueResult.h" + +using namespace FreeForm2; + + +std::pair +FreeForm2::ArrayCodeGen::EncodeDimensions(const unsigned int* p_dimensions, + const unsigned int p_dimensionCount) +{ + ArrayBoundsType encoded = 0; + ArrayCountType count = 1; + size_t bitsNeeded = c_bitsPerFlatDimension * p_dimensionCount; + FF2_ASSERT(bitsNeeded <= sizeof(encoded) * 8); + + // Encode array dimensions into an integer. Note that we proceed from + // highest dimension to lowest, so that the first dimension is in the least + // significant bits. As we always index into that dimension first, we can + // then operate by a simple bitshift down on dereference. + for (unsigned int i = p_dimensionCount; i > 0; --i) + { + if (p_dimensions[i - 1] >= (1 << c_bitsPerFlatDimension)) + { + std::ostringstream err; + err << "Array (dimensions "; + for (unsigned int j = p_dimensionCount; j > 0; --j) + { + err << (j != p_dimensionCount ? ", " : "") << p_dimensions[j - 1]; + } + err << ") exceed the maximum single dimension size of " + << (1 << c_bitsPerFlatDimension); + throw std::runtime_error(err.str()); + } + + encoded = encoded << c_bitsPerFlatDimension; + encoded |= p_dimensions[i - 1]; + count *= p_dimensions[i - 1]; + } + + return std::make_pair(encoded, count); +} + + +std::pair +FreeForm2::ArrayCodeGen::EncodeDimensions(const ArrayType& p_type) +{ + return EncodeDimensions(p_type.GetDimensions(), p_type.GetDimensionCount()); +} + + +ArrayCodeGen::ArrayCountType +FreeForm2::ArrayCodeGen::DecodeDimensions(ArrayBoundsType p_bounds, + unsigned int p_numDimensions, + std::vector& p_dimensions) +{ + p_dimensions.clear(); + ArrayCountType numElements = 1; + p_dimensions.assign(p_numDimensions, 0); + const unsigned int mask = (1 << c_bitsPerFlatDimension) - 1; + + // Decode bounds. + for (unsigned int i = 0; i < p_numDimensions; ++i, p_bounds >>= c_bitsPerFlatDimension) + { + p_dimensions[i] = p_bounds & mask; + numElements *= p_dimensions[i]; + } + + return numElements; +} + + +LlvmCodeGenerator::CompiledValue& +FreeForm2::ArrayCodeGen::IssueReturn(CompilationState& p_state, + LlvmCodeGenerator::CompiledValue& p_array, + const ArrayType& p_arrayType, + LlvmCodeGenerator::CompiledValue& p_arraySpace) +{ + LlvmCodeGenerator::CompiledValue* bounds + = p_state.GetBuilder().CreateExtractValue(&p_array, boundsPosition); + CHECK_LLVM_RET(bounds); + LlvmCodeGenerator::CompiledValue* count + = p_state.GetBuilder().CreateExtractValue(&p_array, countPosition); + CHECK_LLVM_RET(count); + LlvmCodeGenerator::CompiledValue* pointer + = p_state.GetBuilder().CreateExtractValue(&p_array, pointerPosition); + CHECK_LLVM_RET(pointer); + + // Calculate number of bytes per element. + const TypeImpl& childType = p_arrayType.GetChildType().AsConstType(); + unsigned int bytes = p_state.GetSizeInBytes(&p_state.GetType(childType)); + + LlvmCodeGenerator::CompiledValue* elementSize + = llvm::ConstantInt::get(p_state.GetContext(), + llvm::APInt(sizeof(ArrayBoundsType) * 8, bytes)); + CHECK_LLVM_RET(elementSize); + + LlvmCodeGenerator::CompiledValue* copyBytes + = p_state.GetBuilder().CreateMul(elementSize, count); + CHECK_LLVM_RET(copyBytes); + + // Copy array into provided space. Note that first arg is + // destination, second is source, third is number of bytes to copy, + // final is guaranteed alignment (which we could optimise by + // guaranteeing and providing a higher value). + p_state.GetBuilder().CreateMemCpy(&p_arraySpace, pointer, copyBytes, 0); + + return *bounds; +} + + +template +boost::shared_ptr +FreeForm2::ArrayCodeGen::CreateArrayResult(const ArrayType& p_arrayType, + ArrayCodeGen::ArrayBoundsType p_bounds, + const boost::shared_array& p_space) +{ + // Decode bounds. + SharedDimensions dimensions(new std::vector()); + ArrayCodeGen::DecodeDimensions(p_bounds, + p_arrayType.GetDimensionCount(), + *dimensions); + return boost::shared_ptr(new ArrayResult(p_arrayType, + 0, + dimensions, + p_space.get(), + p_space)); +} + + +// Instantiate CreateArrayResult for needed types. +template +boost::shared_ptr +FreeForm2::ArrayCodeGen::CreateArrayResult( + const ArrayType& p_arrayType, + ArrayCodeGen::ArrayBoundsType p_bounds, + const boost::shared_array& p_space); + +template +boost::shared_ptr +FreeForm2::ArrayCodeGen::CreateArrayResult( + const ArrayType& p_arrayType, + ArrayCodeGen::ArrayBoundsType p_bounds, + const boost::shared_array& p_space); + +template +boost::shared_ptr +FreeForm2::ArrayCodeGen::CreateArrayResult( + const ArrayType& p_arrayType, + ArrayCodeGen::ArrayBoundsType p_bounds, + const boost::shared_array& p_space); + + +LlvmCodeGenerator::CompiledValue& +FreeForm2::ArrayCodeGen::CreateArray(CompilationState& p_state, + const ArrayType& p_arrayType, + LlvmCodeGenerator::CompiledValue& p_bounds, + LlvmCodeGenerator::CompiledValue& p_count, + LlvmCodeGenerator::CompiledValue& p_pointer) +{ + // Calculate array bounds. + llvm::Type& arrayType = p_state.GetType(p_arrayType); + + // Create structure with calculated bounds. + LlvmCodeGenerator::CompiledValue* undef = llvm::UndefValue::get(&arrayType); + CHECK_LLVM_RET(undef); + LlvmCodeGenerator::CompiledValue* structure + = p_state.GetBuilder().CreateInsertValue(undef, &p_bounds, boundsPosition); + CHECK_LLVM_RET(structure); + + // Add calculated element count. + structure + = p_state.GetBuilder().CreateInsertValue(structure, &p_count, countPosition); + CHECK_LLVM_RET(structure); + + structure = p_state.GetBuilder().CreateInsertValue(structure, &p_pointer, pointerPosition); + CHECK_LLVM_RET(structure); + + return *structure; +} + + +LlvmCodeGenerator::CompiledValue& +FreeForm2::ArrayCodeGen::CreateEmptyArray(CompilationState& p_state, + const ArrayType& p_arrayType) +{ + LlvmCodeGenerator::CompiledValue* boundsZero + = llvm::ConstantInt::get(p_state.GetContext(), + llvm::APInt(sizeof(ArrayBoundsType) * 8, 0)); + CHECK_LLVM_RET(boundsZero); + + LlvmCodeGenerator::CompiledValue* countZero + = llvm::ConstantInt::get(p_state.GetContext(), + llvm::APInt(sizeof(ArrayCountType) * 8, 0)); + CHECK_LLVM_RET(countZero); + + llvm::Type& childType = p_state.GetType(p_arrayType.GetChildType().AsConstType()); + llvm::Type* nullType = llvm::PointerType::get(&childType, 0); + CHECK_LLVM_RET(nullType); + LlvmCodeGenerator::CompiledValue* pointer = llvm::Constant::getNullValue(nullType); + CHECK_LLVM_RET(pointer); + + return CreateArray(p_state, p_arrayType, *boundsZero, *countZero, *pointer); +} + + +LlvmCodeGenerator::CompiledValue& +FreeForm2::ArrayCodeGen::MaskBounds(CompilationState& p_state, + LlvmCodeGenerator::CompiledValue& p_bounds) +{ + LlvmCodeGenerator::CompiledValue* mask + = llvm::ConstantInt::get(p_state.GetContext(), + llvm::APInt(sizeof(ArrayBoundsType) * 8, (1 << c_bitsPerFlatDimension) - 1)); + CHECK_LLVM_RET(mask); + LlvmCodeGenerator::CompiledValue* masked + = p_state.GetBuilder().CreateAnd(&p_bounds, mask); + CHECK_LLVM_RET(masked); + return *masked; +} + + + +LlvmCodeGenerator::CompiledValue& +FreeForm2::ArrayCodeGen::ShiftBounds(CompilationState& p_state, + LlvmCodeGenerator::CompiledValue& p_bounds, + unsigned int p_dimensions) +{ + LlvmCodeGenerator::CompiledValue* shift + = llvm::ConstantInt::get(p_state.GetContext(), + llvm::APInt(sizeof(ArrayBoundsType) * 8, + c_bitsPerFlatDimension * p_dimensions)); + CHECK_LLVM_RET(shift); + LlvmCodeGenerator::CompiledValue* shifted + = p_state.GetBuilder().CreateAShr(&p_bounds, shift); + CHECK_LLVM_RET(shifted); + return *shifted; +} + + +LlvmCodeGenerator::CompiledValue& +FreeForm2::ArrayCodeGen::UnshiftBound(CompilationState& p_state, + LlvmCodeGenerator::CompiledValue& p_bounds, + LlvmCodeGenerator::CompiledValue& p_newBound) +{ + llvm::Value* bitCount = llvm::ConstantInt::get(p_bounds.getType(), c_bitsPerFlatDimension); + CHECK_LLVM_RET(bitCount); + + llvm::Value* leftShift = p_state.GetBuilder().CreateShl(&p_bounds, bitCount); + CHECK_LLVM_RET(leftShift); + + llvm::Value* maxBound = llvm::ConstantInt::get(p_newBound.getType(), (1 << c_bitsPerFlatDimension) - 1); + CHECK_LLVM_RET(maxBound); + + llvm::Value* checkBound = p_state.GetBuilder().CreateICmpUGT(&p_newBound, maxBound); + CHECK_LLVM_RET(checkBound); + + llvm::Value* realBound = p_state.GetBuilder().CreateSelect(checkBound, maxBound, &p_newBound); + CHECK_LLVM_RET(realBound); + + if (realBound->getType()->getPrimitiveSizeInBits() < p_bounds.getType()->getPrimitiveSizeInBits()) + { + realBound = p_state.GetBuilder().CreateZExt(realBound, p_bounds.getType()); + CHECK_LLVM_RET(realBound); + } + + llvm::Value* finalBounds = p_state.GetBuilder().CreateOr(leftShift, realBound); + CHECK_LLVM_RET(finalBounds); + return *finalBounds; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/ArrayCodeGen.h b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/ArrayCodeGen.h new file mode 100644 index 000000000000..1716fde644ee --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/ArrayCodeGen.h @@ -0,0 +1,115 @@ +#pragma once + +#ifndef FREEFORM2_ARRAYCODEGEN_H +#define FREEFORM2_ARRAYCODEGEN_H + +#include "ArrayType.h" +#include +#include +#include +#include "LlvmCodeGenerator.h" +#include +#include + +namespace FreeForm2 +{ + class CompilationState; + + class ArrayCodeGen + { + public: + // Number of bits used to count members per flattened dimension. + // This limits the number of elements in an array dimension. + static const unsigned int c_bitsPerFlatDimension = 8; + + // Type used to represent encoded array bounds. + typedef UInt64 ArrayBoundsType; + + // Ensure that our representation can handle all arrays. + static_assert(sizeof(ArrayBoundsType) * 8 + >= c_bitsPerFlatDimension * ArrayType::c_maxDimensions, + "Please update ArrayBoundsType to reflect the max number of dimensions."); + + // Type used to calculate the total number of elements in an array. + // Note that we assume that counting the number of elements in an array + // will inherently occupy equal of fewer bits than the encoded bounds, + // which is safe unless we start playing tricks with the bounds + // representation. + typedef ArrayBoundsType ArrayCountType; + + // Enumeration that declares where the encoded bounds and space pointer + // are in the LLVM structure used to represent an array. + enum ArrayStructPosition + { + // Position of encoded array bounds. + boundsPosition, + + // Position of total array element count (in flattened elements). + countPosition, + + // Position of pointer to array space. + pointerPosition + }; + + // Encode array dimensions in an unsigned integer, returning + // the integer and the total number of elements in the array. + static std::pair + EncodeDimensions(const unsigned int* p_dimensions, + const unsigned int p_dimensionCount); + + // Encode array dimensions in an unsigned integer, returning the + // integer and the total number of elements in the array. + static std::pair + EncodeDimensions(const ArrayType& p_type); + + // Decode given number of array dimensions (p_numDimensions) from an + // unsigned integer (p_bounds), returning the the total array element count, + // and populating p_dimensions (which will be cleared first) with the arrays dimensions. + static ArrayCountType DecodeDimensions(ArrayBoundsType p_bounds, + unsigned int p_numDimensions, + std::vector& p_dimensions); + + // Issue LLVM code to return an array from a function. + static LlvmCodeGenerator::CompiledValue& + IssueReturn(CompilationState& p_state, + LlvmCodeGenerator::CompiledValue& p_array, + const ArrayType& p_arrayType, + LlvmCodeGenerator::CompiledValue& p_arraySpace); + + // Create an array result from a flattened array and calculated bounds. + template + static boost::shared_ptr + CreateArrayResult(const ArrayType& p_arrayType, + ArrayCodeGen::ArrayBoundsType p_bounds, + const boost::shared_array& p_space); + + // Create an empty array of the given type. + static LlvmCodeGenerator::CompiledValue& CreateArray(CompilationState& p_state, + const ArrayType& p_type, + LlvmCodeGenerator::CompiledValue& p_bounds, + LlvmCodeGenerator::CompiledValue& p_count, + LlvmCodeGenerator::CompiledValue& p_pointer); + + // Create an empty array of the given type. + static LlvmCodeGenerator::CompiledValue& CreateEmptyArray(CompilationState& p_state, + const ArrayType& p_arrayType); + + // Mask the given bounds to extract the top dimension. + static LlvmCodeGenerator::CompiledValue& MaskBounds(CompilationState& p_state, + LlvmCodeGenerator::CompiledValue& p_bounds); + + // Shift the given bounds down to remove one dimension. + static LlvmCodeGenerator::CompiledValue& ShiftBounds(CompilationState& p_state, + LlvmCodeGenerator::CompiledValue& p_bounds, + unsigned int p_dimensions); + + // Push a dimension to the beginning of the bit vector. This is + // effectively the opposite of ShiftBounds. + static LlvmCodeGenerator::CompiledValue& UnshiftBound(CompilationState& p_state, + LlvmCodeGenerator::CompiledValue& p_bounds, + LlvmCodeGenerator::CompiledValue& p_newBound); + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CMakeLists.txt b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CMakeLists.txt new file mode 100644 index 000000000000..a20447bd9df8 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CMakeLists.txt @@ -0,0 +1,36 @@ +cmake_minimum_required(VERSION 3.15) + +set(PROJECT_NAME DRFreeFormLlvmBackendLibrary) + +project(${PROJECT_NAME}) + +set(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS} -fpermissive") + +add_library(${PROJECT_NAME} STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/Extend/FreeForm2Support.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Extend/JITEmitter.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Extend/JITExtend.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/CompilationState.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ArrayCodeGen.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/LlvmCodeGenerator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/LlvmCodeGenUtils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/LlvmCompiler.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/LlvmRuntimeLibrary.cpp +) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/Extend + ${CMAKE_CURRENT_SOURCE_DIR}/../../../inc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../../NeuralTree.Library/inc + ${CMAKE_CURRENT_SOURCE_DIR}/../../Shared + ${CMAKE_CURRENT_SOURCE_DIR}/../../Expression + ${CMAKE_CURRENT_SOURCE_DIR}/../../External + ${CMAKE_CURRENT_SOURCE_DIR}/../../Transform + ${CMAKE_CURRENT_SOURCE_DIR}/../../Parse/SExpression/inc + ) + +install(TARGETS ${PROJECT_NAME} + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + ) \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CompilationState.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CompilationState.cpp new file mode 100644 index 000000000000..f59a94651186 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CompilationState.cpp @@ -0,0 +1,466 @@ +#include "CompilationState.h" + +#include "ArrayCodeGen.h" +#include "ArrayType.h" +#include +#include +#include "Expression.h" +#include "FreeForm2Assert.h" +#include "FreeForm2Result.h" +#include +#include +#include +#include +#include +#include +#include +#include "LlvmCodeGenUtils.h" +#include + +namespace +{ + std::auto_ptr + CreateEngine(llvm::Module& p_module, llvm::JITMemoryManager& p_memoryManager) + { + std::string builderErr; + llvm::EngineBuilder engineBuilder(&p_module); + engineBuilder.setJITMemoryManager(&p_memoryManager); + engineBuilder.setEngineKind(llvm::EngineKind::JIT); + engineBuilder.setErrorStr(&builderErr); + std::auto_ptr engine(engineBuilder.create()); + if (engine.get() == NULL) + { + std::ostringstream err; + err << "JIT builder error: " << builderErr; + throw std::runtime_error(err.str()); + } + return engine; + } +} + + +FreeForm2::CompilationState::CompilationState(llvm::JITMemoryManager& p_memoryManager) + : m_context(new llvm::LLVMContext()), + m_runtimeLibrary(new LlvmRuntimeLibrary(*m_context)), + m_builder(new llvm::IRBuilder<>(*m_context)), + m_intBits(sizeof(Result::IntType) * 8), + m_arrayBoundsType(llvm::IntegerType::get(*m_context, + sizeof(ArrayCodeGen::ArrayBoundsType) * 8)), + m_arrayCountType(llvm::IntegerType::get(*m_context, + sizeof(ArrayCodeGen::ArrayCountType) * 8)), + m_boolType(llvm::IntegerType::get(*m_context, 1)), + m_arrayBoolType(&CreateArrayType(m_boolType)), + m_intType(llvm::IntegerType::get(*m_context, m_intBits)), + m_arrayIntType(&CreateArrayType(m_intType)), + m_int32Type(llvm::IntegerType::get(*m_context, sizeof(Result::Int32Type) * 8)), + m_arrayInt32Type(&CreateArrayType(m_int32Type)), + m_floatType(llvm::Type::getFloatTy(*m_context)), + m_floatPtrType(llvm::Type::getFloatPtrTy(*m_context)), + m_arrayFloatType(&CreateArrayType(m_floatType)), + + // Use a two-bit integer to represent the freeform void type in LLVM. We + // can't use the LLVM void type because LLVM doesn't have a value that + // corresponds to that (is natively imperative), whereas we're taking the + // approach that void exists to remove the distinction between expressions + // and statements (convert-to-imperative approach). As a result, we rely + // on being able to instantiate void values, though they should be + // immediately discarded by the underlying backend (and we check this). + m_voidType(llvm::IntegerType::get(*m_context, 2)), + + m_featureType(llvm::IntegerType::get(*m_context, sizeof(Expression::FeatureType) * 8)), + m_featureArgument(nullptr), + m_arraySpace(nullptr), + m_previousOffset(nullptr), + m_aggregatedDocumentCount(nullptr), + m_aggregatedDocumentIndex(nullptr), + m_aggregatedCache(nullptr) +{ + BOOST_STATIC_ASSERT(sizeof(Result::FloatType) == sizeof(float)); + + CHECK_LLVM_RET(m_arrayBoundsType); + CHECK_LLVM_RET(m_boolType); + CHECK_LLVM_RET(m_intType); + CHECK_LLVM_RET(m_floatType); + CHECK_LLVM_RET(m_featureType); + + std::auto_ptr ownedModule(new llvm::Module("FreeForm2", *m_context)); + m_engine = CreateEngine(*ownedModule, p_memoryManager); + m_module = ownedModule.release(); + m_targetData = m_engine->getDataLayout(); + InitializeRuntimeLibrary(); +} + + +void +FreeForm2::CompilationState::InitializeRuntimeLibrary() +{ + m_runtimeLibrary->AddLibraryToModule(*m_module); + m_runtimeLibrary->AddExecutionMappings(*m_engine); +} + + +llvm::IRBuilder<>& +FreeForm2::CompilationState::GetBuilder() +{ + return *m_builder; +} + + +llvm::LLVMContext& +FreeForm2::CompilationState::GetContext() +{ + return *m_context; +} + + +llvm::Module& +FreeForm2::CompilationState::GetModule() +{ + return *m_module; +} + + +llvm::ExecutionEngine& +FreeForm2::CompilationState::GetExecutionEngine() +{ + return *m_engine; +} + + +const FreeForm2::LlvmRuntimeLibrary& +FreeForm2::CompilationState::GetRuntimeLibrary() const +{ + return *m_runtimeLibrary; +} + + +void +FreeForm2::CompilationState::SetVariableValue(VariableID p_id, + llvm::Value& p_value) +{ + m_variables[p_id] = &p_value; +} + + +llvm::Value* +FreeForm2::CompilationState::GetVariableValue(VariableID p_id) const +{ + auto find = m_variables.find(p_id); + FF2_ASSERT(find != m_variables.end()); + return find->second; +} + + +llvm::Value* +FreeForm2::CompilationState::GetFeatureArgument() const +{ + FF2_ASSERT(m_featureArgument != NULL); + return m_featureArgument; +} + + +void +FreeForm2::CompilationState::SetFeatureArgument(llvm::Value& p_val) +{ + m_featureArgument = &p_val; +} + + +llvm::Value& +FreeForm2::CompilationState::GetArrayReturnSpace() const +{ + FF2_ASSERT(m_arraySpace != NULL); + return *m_arraySpace; +} + + +void +FreeForm2::CompilationState::SetArrayReturnSpace(llvm::Value& p_value) +{ + m_arraySpace = &p_value; +} + + +llvm::Value& +FreeForm2::CompilationState::GetAggregatedDocumentCount() const +{ + return *m_aggregatedDocumentCount; +} + + +void +FreeForm2::CompilationState::SetAggregatedDocumentCount(llvm::Value& p_value) +{ + m_aggregatedDocumentCount = &p_value; +} + + +llvm::Value& +FreeForm2::CompilationState::GetAggregatedDocumentIndex() const +{ + return *m_aggregatedDocumentIndex; +} + + +void +FreeForm2::CompilationState::SetAggregatedDocumentIndex(llvm::Value& p_value) +{ + m_aggregatedDocumentIndex = &p_value; +} + + +llvm::Value& +FreeForm2::CompilationState::GetAggregatedCache() const +{ + return *m_aggregatedCache; +} + + +void +FreeForm2::CompilationState::SetAggregatedCache(llvm::Value& p_value) +{ + m_aggregatedCache = &p_value; +} + + +llvm::Value& +FreeForm2::CompilationState::GetFeatureArrayPointer() const +{ + return *m_featureArrayPointer; +} + + +void +FreeForm2::CompilationState::SetFeatureArrayPointer(llvm::Value& p_value) +{ + m_featureArrayPointer = &p_value; +} + + +llvm::IntegerType& +FreeForm2::CompilationState::GetArrayBoundsType() const +{ + return *m_arrayBoundsType; +} + + +llvm::IntegerType& +FreeForm2::CompilationState::GetArrayCountType() const +{ + return *m_arrayCountType; +} + + +llvm::Type& +FreeForm2::CompilationState::GetBoolType() const +{ + return *m_boolType; +} + + +llvm::IntegerType& +FreeForm2::CompilationState::GetIntType() const +{ + return *m_intType; +} + + +unsigned int +FreeForm2::CompilationState::GetIntBits() const +{ + return m_intBits; +} + + +llvm::IntegerType& +FreeForm2::CompilationState::GetInt32Type() const +{ + return *m_int32Type; +} + + +llvm::Type& +FreeForm2::CompilationState::GetFloatType() const +{ + return *m_floatType; +} + + +llvm::Type& +FreeForm2::CompilationState::GetFloatPtrType() const +{ + return *m_floatPtrType; +} + + +llvm::Type& +FreeForm2::CompilationState::GetVoidType() const +{ + return *m_voidType; +} + + +llvm::Type& +FreeForm2::CompilationState::GetFeatureType() const +{ + return *m_featureType; +} + + +unsigned int +FreeForm2::CompilationState::GetSizeInBytes(llvm::Type* p_type) const +{ + return static_cast(m_targetData->getTypeAllocSize(p_type)); +} + + +llvm::Type& +FreeForm2::CompilationState::GetType(const TypeImpl& p_type) const +{ + switch (p_type.Primitive()) + { + case Type::Int: + { + return GetIntType(); + } + + case Type::UInt32: __attribute__((__fallthrough__)); + case Type::Int32: + { + return GetInt32Type(); + } + + case Type::Float: + { + return GetFloatType(); + } + + case Type::Bool: + { + return GetBoolType(); + } + + case Type::Array: + { + const ArrayType& arrayType = static_cast(p_type); + switch (arrayType.GetChildType().Primitive()) + { + case Type::Int: + { + return *m_arrayIntType; + } + + case Type::UInt32: __attribute__((__fallthrough__)); + case Type::Int32: + { + return *m_arrayInt32Type; + } + + case Type::Float: + { + return *m_arrayFloatType; + } + + case Type::Bool: + { + return *m_arrayBoolType; + } + + default: + { + Unreachable(__FILE__, __LINE__); + } + } + } + + case Type::Void: + { + return GetVoidType(); + } + + default: + { + Unreachable(__FILE__, __LINE__); + } + } +} + + +llvm::Value& +FreeForm2::CompilationState::CreateZeroValue(const TypeImpl& p_type) +{ + switch (p_type.Primitive()) + { + case Type::Int: + { + llvm::Value* val = llvm::ConstantInt::get(m_intType, 0); + CHECK_LLVM_RET(val); + return *val; + } + + case Type::UInt32: __attribute__((__fallthrough__)); + case Type::Int32: + { + llvm::Value* val = llvm::ConstantInt::get(m_int32Type, 0); + CHECK_LLVM_RET(val); + return *val; + } + + case Type::Float: + { + llvm::Value* val = llvm::ConstantFP::get(m_floatType, 0); + CHECK_LLVM_RET(val); + return *val; + } + + case Type::Bool: + { + llvm::Value* val = llvm::ConstantInt::get(m_boolType, 0); + CHECK_LLVM_RET(val); + return *val; + } + + case Type::Array: + { + return ArrayCodeGen::CreateEmptyArray(*this, static_cast(p_type)); + } + + default: + { + Unreachable(__FILE__, __LINE__); + } + } +} + + +llvm::Value& +FreeForm2::CompilationState::CreateVoidValue() const +{ + LlvmCodeGenerator::CompiledValue* ret + = llvm::UndefValue::get(m_voidType); + CHECK_LLVM_RET(ret); + return *ret; +} + + +llvm::Type& +FreeForm2::CompilationState::CreateArrayType(llvm::Type* p_base) +{ + CHECK_LLVM_RET(m_arrayBoundsType); + CHECK_LLVM_RET(p_base); + llvm::Type* pointerType = llvm::PointerType::get(p_base, 0); + CHECK_LLVM_RET(pointerType); + + std::vector structure(1, m_arrayBoundsType); + FF2_ASSERT(structure.size() - 1 == ArrayCodeGen::boundsPosition); + + structure.push_back(m_arrayCountType); + FF2_ASSERT(structure.size() - 1 == ArrayCodeGen::countPosition); + + structure.push_back(pointerType); + FF2_ASSERT(structure.size() - 1 == ArrayCodeGen::pointerPosition); + + llvm::Type* result = llvm::StructType::get(*m_context, structure); + CHECK_LLVM_RET(result); + return *result; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CompilationState.h b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CompilationState.h new file mode 100644 index 000000000000..1d47df1b3a20 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/CompilationState.h @@ -0,0 +1,223 @@ +#pragma once + +#ifndef FREEFORM2_COMPILATION_STATE_H +#define FREEFORM2_COMPILATION_STATE_H + +#include +#include +#include +#include "LlvmRuntimeLibrary.h" +#include +#include +#include "Expression.h" +#include "FreeForm2Type.h" + +namespace llvm +{ + class ExecutionEngine; + class JITMemoryManager; + class LLVMContext; + class Module; + class DataLayout; +} + +namespace FreeForm2 +{ + // CompilationState tracks state used during compilation, notably + // the underlying LLVM state objects, and the symbol table. + class CompilationState : boost::noncopyable + { + public: + // Create a compilation state using a memory manager to create the + // ExecutionEngine. + CompilationState(llvm::JITMemoryManager& p_memoryManager); + + // Get the intermediate representation builder that we use to + // issue instructions. + llvm::IRBuilder<>& GetBuilder(); + + // Get the LLVMContext object, which basically acts as a big container + // object during code generation. + llvm::LLVMContext& GetContext(); + + // Get the LLVM module we're issuing code into (Module in the + // programming sense, of a related group of code). + llvm::Module& GetModule(); + + // Get the LLVM ExecutionEngine used to compile execute a program. + llvm::ExecutionEngine& GetExecutionEngine(); + + // Get the LlvmRuntimeLibrary for this state. + const LlvmRuntimeLibrary& GetRuntimeLibrary() const; + + // Get and set feature array value, which is passed as an argument to + // the top-level execution and we use in feature reference expressions. + // This function returns NULL if the feature argument has not been set. + llvm::Value* GetFeatureArgument() const; + void SetFeatureArgument(llvm::Value& p_val); + + // Get and set the value containing the array return space for the + // function. This is pre-allocated space into which a program copies + // an array return value. + llvm::Value& GetArrayReturnSpace() const; + void SetArrayReturnSpace(llvm::Value& p_value); + + // Get and set document count for aggregated freeforms. + llvm::Value& GetAggregatedDocumentCount() const; + void SetAggregatedDocumentCount(llvm::Value& p_value); + + // Get and set document index for aggregated freeforms. + llvm::Value& GetAggregatedDocumentIndex() const; + void SetAggregatedDocumentIndex(llvm::Value& p_value); + + // Get and set cache pointer for aggregated freeforms. + llvm::Value& GetAggregatedCache() const; + void SetAggregatedCache(llvm::Value& p_value); + + // Get and set array of pointers to feature arrays for aggregated freeforms. + llvm::Value& GetFeatureArrayPointer() const; + void SetFeatureArrayPointer(llvm::Value& p_value); + + // Get the integer type we're using for the array bounds type. + llvm::IntegerType& GetArrayBoundsType() const; + + // Get the integer type we're using for the array element count. + llvm::IntegerType& GetArrayCountType() const; + + // Get the boolean type we're using for the freeform bool type. + llvm::Type& GetBoolType() const; + + // Get the integer type we're using for the freeform int type. + llvm::IntegerType& GetIntType() const; + + // Get number of bits in Int type. + unsigned int GetIntBits() const; + + // Get the type used to represent signed and unsigned int32 types. + llvm::IntegerType& GetInt32Type() const; + + // Get the type we're using for the freeform float type. + llvm::Type& GetFloatType() const; + + // Get the type for a pointer to the freeform float type. + llvm::Type& GetFloatPtrType() const; + + // Get the integer type we're using for the freeform void type. + llvm::Type& GetVoidType() const; + + // Get number of bytes that a certain type needs to allocate. + unsigned int GetSizeInBytes(llvm::Type* p_type) const; + + // Get the LLVM type corresponding to a given freeform type. + llvm::Type& GetType(const TypeImpl& p_type) const; + + // LLVM type corresponding to the feature input type (doesn't have an + // exactly corresponding type in the freeform type system). + llvm::Type& GetFeatureType() const; + + // Get the LLVM type for the StreamFeatureInput object. + llvm::Type& GetStreamFeatureInputType() const; + + // Get the zero value (0, 0.0, false, []) for given type. + llvm::Value& CreateZeroValue(const TypeImpl& p_type); + + // Create a void value. + llvm::Value& CreateVoidValue() const; + + // Push a value onto the stack, returning the slot that this value was + // stored in. Note that the stack involved here is the compile-time + // dual of the stack of values assigned during parsing. + void SetVariableValue(VariableID p_id, llvm::Value& p_value); + + // Get a value from the stack. Note that the stack involved here is + // the compile-time dual of the stack of values assigned during parsing. + llvm::Value* GetVariableValue(VariableID p_id) const; + + // Number of bits in each field in a word. + static const unsigned int c_wordFieldBits = 32; + + private: + llvm::Type& CreateArrayType(llvm::Type* p_base); + + // Initialize the execution engine to contain references to all + // necessary runtime functions. + void InitializeRuntimeLibrary(); + + // LLVMContext, which tracks global variables and other + // 'program'-level constructs. + boost::shared_ptr m_context; + + // LLVM module, which roughly corresponds to a module in the + // programming sense, being a collection of functions. + llvm::Module* m_module; + + // This object is used to add and lookup runtime functions. + std::unique_ptr m_runtimeLibrary; + + // Object to help with building intermediate representation. + boost::shared_ptr> m_builder; + + // Execution engine. + std::auto_ptr m_engine; + + // Bit-counts of LLVM types. + unsigned int m_intBits; + + // Type that represents array bounds. + llvm::IntegerType* m_arrayBoundsType; + + // Type that represents array count. + llvm::IntegerType* m_arrayCountType; + + // LLVM types corresponding to freeform types. + llvm::Type* m_boolType; + llvm::Type* m_arrayBoolType; + llvm::IntegerType* m_intType; + llvm::Type* m_arrayIntType; + llvm::IntegerType* m_int32Type; + llvm::Type* m_arrayInt32Type; + llvm::Type* m_floatType; + llvm::Type* m_floatPtrType; + llvm::Type* m_arrayFloatType; + llvm::Type* m_voidType; + + // LLVM type corresponding to the feature input type (doesn't have an + // exactly corresponding type in the freeform type system). + llvm::IntegerType* m_featureType; + + // Feature array value, passed as a top-level arg to execution. + llvm::Value* m_featureArgument; + + // The LLVM value for a pre-allocated array return value. + llvm::Value* m_arraySpace; + + // The value of the current word of the match. + //llvm::Value* m_currentWord; + + // The value of the offset of the previous query word. + llvm::Value* m_previousOffset; + + // Document count for aggregated freeforms. + llvm::Value* m_aggregatedDocumentCount; + + // Document index for aggregated freeforms. + llvm::Value* m_aggregatedDocumentIndex; + + // Cache pointer for aggregated freeforms. + llvm::Value* m_aggregatedCache; + + // Array of pointers to feature arrays. + llvm::Value* m_featureArrayPointer; + + // Map of precalculated quantities, referred to by variable ID, as well + // as a boolean indication of whether the quantity is a reference (a + // pointer) or not. + std::map m_variables; + + // Holds information about the LLVM target, like sizes of structures in memory. + const llvm::DataLayout* m_targetData; + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/FreeForm2Support.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/FreeForm2Support.cpp new file mode 100644 index 000000000000..bd8c636cc9a2 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/FreeForm2Support.cpp @@ -0,0 +1,30 @@ +#include +#include "FreeForm2Support.h" +#include "JITExtend.h" + +namespace llvm +{ + struct FakeJITEmitter : public llvm::JITCodeEmitter { + void *MemMgr; + + // When outputting a function stub in the context of some other function, we + // save BufferBegin/BufferEnd/CurBufferPtr here. + uint8_t *SavedBufferBegin, *SavedBufferEnd, *SavedCurBufferPtr; + + // When reattempting to JIT a function after running out of space, we store + // the estimated size of the function we're trying to JIT here, so we can + // ask the memory manager for at least this much space. When we + // successfully emit the function, we reset this back to zero. + uintptr_t SizeEstimate; + + /// Relocations - These are the relocations that the function needs, as + /// emitted. + std::vector m_relocations; + }; + + const std::vector& GetMachineRelocations(const ExecutionEngine* p_engine) + { + const JIT* jit = static_cast(p_engine); + return GetJitMachineRelocations(jit); + } +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/FreeForm2Support.h b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/FreeForm2Support.h new file mode 100644 index 000000000000..ce9b7857ba4b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/FreeForm2Support.h @@ -0,0 +1,14 @@ + +#ifndef LLVM_FREEFORM2_SUPPORT_H +#define LLVM_FREEFORM2_SUPPORT_H + +#include +#include + +namespace llvm +{ + class ExecutionEngine; + + const std::vector& GetMachineRelocations(const ExecutionEngine* p_engine); +} +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITEmitter.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITEmitter.cpp new file mode 100644 index 000000000000..268cc2724d72 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITEmitter.cpp @@ -0,0 +1,1264 @@ +//===-- JITEmitter.cpp - Write machine code to executable memory ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines a MachineCodeEmitter object that is used by the JIT to +// write machine code to memory and remember where relocatable values are. +// +//===----------------------------------------------------------------------===// + +#include "JITExtend.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef NDEBUG +#include +#endif +using namespace llvm; + +#define DEBUG_TYPE "jit" + +STATISTIC(NumBytes, "Number of bytes of machine code compiled"); +STATISTIC(NumRelos, "Number of relocations applied"); +STATISTIC(NumRetries, "Number of retries with more memory"); + + +// A declaration may stop being a declaration once it's fully read from bitcode. +// This function returns true if F is fully read and is still a declaration. +static bool isNonGhostDeclaration(const Function *F) { + return F->isDeclaration() && !F->isMaterializable(); +} + +//===----------------------------------------------------------------------===// +// JIT lazy compilation code. +// +namespace { + class JITEmitter; + class JITResolverState; + + template + struct NoRAUWValueMapConfig : public ValueMapConfig { + typedef JITResolverState *ExtraData; + static void onRAUW(JITResolverState *, Value *Old, Value *New) { + llvm_unreachable("The JIT doesn't know how to handle a" + " RAUW on a value it has emitted."); + } + }; + + struct CallSiteValueMapConfig : public NoRAUWValueMapConfig { + typedef JITResolverState *ExtraData; + static void onDelete(JITResolverState *JRS, Function *F); + }; + + class JITResolverState { + public: + typedef ValueMap > + FunctionToLazyStubMapTy; + typedef std::map > CallSiteToFunctionMapTy; + typedef ValueMap, + CallSiteValueMapConfig> FunctionToCallSitesMapTy; + typedef std::map, void*> GlobalToIndirectSymMapTy; + private: + /// FunctionToLazyStubMap - Keep track of the lazy stub created for a + /// particular function so that we can reuse them if necessary. + FunctionToLazyStubMapTy FunctionToLazyStubMap; + + /// CallSiteToFunctionMap - Keep track of the function that each lazy call + /// site corresponds to, and vice versa. + CallSiteToFunctionMapTy CallSiteToFunctionMap; + FunctionToCallSitesMapTy FunctionToCallSitesMap; + + /// GlobalToIndirectSymMap - Keep track of the indirect symbol created for a + /// particular GlobalVariable so that we can reuse them if necessary. + GlobalToIndirectSymMapTy GlobalToIndirectSymMap; + +#ifndef NDEBUG + /// Instance of the JIT this ResolverState serves. + JIT *TheJIT; +#endif + + public: + JITResolverState(JIT *jit) : FunctionToLazyStubMap(this), + FunctionToCallSitesMap(this) { +#ifndef NDEBUG + TheJIT = jit; +#endif + } + + FunctionToLazyStubMapTy& getFunctionToLazyStubMap() { + return FunctionToLazyStubMap; + } + + GlobalToIndirectSymMapTy& getGlobalToIndirectSymMap() { + return GlobalToIndirectSymMap; + } + + std::pair LookupFunctionFromCallSite( + void *CallSite) const { + // The address given to us for the stub may not be exactly right, it + // might be a little bit after the stub. As such, use upper_bound to + // find it. + CallSiteToFunctionMapTy::const_iterator I = + CallSiteToFunctionMap.upper_bound(CallSite); + assert(I != CallSiteToFunctionMap.begin() && + "This is not a known call site!"); + --I; + return *I; + } + + void AddCallSite(void *CallSite, Function *F) { + bool Inserted = CallSiteToFunctionMap.insert( + std::make_pair(CallSite, F)).second; + (void)Inserted; + assert(Inserted && "Pair was already in CallSiteToFunctionMap"); + FunctionToCallSitesMap[F].insert(CallSite); + } + + void EraseAllCallSitesForPrelocked(Function *F); + + // Erases _all_ call sites regardless of their function. This is used to + // unregister the stub addresses from the StubToResolverMap in + // ~JITResolver(). + void EraseAllCallSitesPrelocked(); + }; + + /// JITResolver - Keep track of, and resolve, call sites for functions that + /// have not yet been compiled. + class JITResolver { + typedef JITResolverState::FunctionToLazyStubMapTy FunctionToLazyStubMapTy; + typedef JITResolverState::CallSiteToFunctionMapTy CallSiteToFunctionMapTy; + typedef JITResolverState::GlobalToIndirectSymMapTy GlobalToIndirectSymMapTy; + + /// LazyResolverFn - The target lazy resolver function that we actually + /// rewrite instructions to use. + TargetJITInfo::LazyResolverFn LazyResolverFn; + + JITResolverState state; + + /// ExternalFnToStubMap - This is the equivalent of FunctionToLazyStubMap + /// for external functions. TODO: Of course, external functions don't need + /// a lazy stub. It's actually here to make it more likely that far calls + /// succeed, but no single stub can guarantee that. I'll remove this in a + /// subsequent checkin when I actually fix far calls. + std::map ExternalFnToStubMap; + + /// revGOTMap - map addresses to indexes in the GOT + std::map revGOTMap; + unsigned nextGOTIndex; + + JITEmitter &JE; + + /// Instance of JIT corresponding to this Resolver. + JIT *TheJIT; + + public: + explicit JITResolver(JIT &jit, JITEmitter &je) + : state(&jit), nextGOTIndex(0), JE(je), TheJIT(&jit) { + LazyResolverFn = jit.getJITInfo().getLazyResolverFunction(JITCompilerFn); + } + + ~JITResolver(); + + /// getLazyFunctionStubIfAvailable - This returns a pointer to a function's + /// lazy-compilation stub if it has already been created. + void *getLazyFunctionStubIfAvailable(Function *F); + + /// getLazyFunctionStub - This returns a pointer to a function's + /// lazy-compilation stub, creating one on demand as needed. + void *getLazyFunctionStub(Function *F); + + /// getExternalFunctionStub - Return a stub for the function at the + /// specified address, created lazily on demand. + void *getExternalFunctionStub(void *FnAddr); + + /// getGlobalValueIndirectSym - Return an indirect symbol containing the + /// specified GV address. + void *getGlobalValueIndirectSym(GlobalValue *V, void *GVAddress); + + /// getGOTIndexForAddress - Return a new or existing index in the GOT for + /// an address. This function only manages slots, it does not manage the + /// contents of the slots or the memory associated with the GOT. + unsigned getGOTIndexForAddr(void *addr); + + /// JITCompilerFn - This function is called to resolve a stub to a compiled + /// address. If the LLVM Function corresponding to the stub has not yet + /// been compiled, this function compiles it first. + static void *JITCompilerFn(void *Stub); + }; + + class StubToResolverMapTy { + /// Map a stub address to a specific instance of a JITResolver so that + /// lazily-compiled functions can find the right resolver to use. + /// + /// Guarded by Lock. + std::map Map; + + /// Guards Map from concurrent accesses. + mutable sys::Mutex Lock; + + public: + /// Registers a Stub to be resolved by Resolver. + void RegisterStubResolver(void *Stub, JITResolver *Resolver) { + MutexGuard guard(Lock); + Map.insert(std::make_pair(Stub, Resolver)); + } + /// Unregisters the Stub when it's invalidated. + void UnregisterStubResolver(void *Stub) { + MutexGuard guard(Lock); + Map.erase(Stub); + } + /// Returns the JITResolver instance that owns the Stub. + JITResolver *getResolverFromStub(void *Stub) const { + MutexGuard guard(Lock); + // The address given to us for the stub may not be exactly right, it might + // be a little bit after the stub. As such, use upper_bound to find it. + // This is the same trick as in LookupFunctionFromCallSite from + // JITResolverState. + std::map::const_iterator I = Map.upper_bound(Stub); + assert(I != Map.begin() && "This is not a known stub!"); + --I; + return I->second; + } + /// True if any stubs refer to the given resolver. Only used in an assert(). + /// O(N) + bool ResolverHasStubs(JITResolver* Resolver) const { + MutexGuard guard(Lock); + for (std::map::const_iterator I = Map.begin(), + E = Map.end(); I != E; ++I) { + if (I->second == Resolver) + return true; + } + return false; + } + }; + /// This needs to be static so that a lazy call stub can access it with no + /// context except the address of the stub. + ManagedStatic StubToResolverMap; + + /// JITEmitter - The JIT implementation of the MachineCodeEmitter, which is + /// used to output functions to memory for execution. + class JITEmitter : public JITCodeEmitter { + JITMemoryManager *MemMgr; + + // When outputting a function stub in the context of some other function, we + // save BufferBegin/BufferEnd/CurBufferPtr here. + uint8_t *SavedBufferBegin, *SavedBufferEnd, *SavedCurBufferPtr; + + // When reattempting to JIT a function after running out of space, we store + // the estimated size of the function we're trying to JIT here, so we can + // ask the memory manager for at least this much space. When we + // successfully emit the function, we reset this back to zero. + uintptr_t SizeEstimate; + + /// Relocations - These are the relocations that the function needs, as + /// emitted. + std::vector Relocations; + + std::vector m_copyRelocations; + + /// MBBLocations - This vector is a mapping from MBB ID's to their address. + /// It is filled in by the StartMachineBasicBlock callback and queried by + /// the getMachineBasicBlockAddress callback. + std::vector MBBLocations; + + /// ConstantPool - The constant pool for the current function. + /// + MachineConstantPool *ConstantPool; + + /// ConstantPoolBase - A pointer to the first entry in the constant pool. + /// + void *ConstantPoolBase; + + /// ConstPoolAddresses - Addresses of individual constant pool entries. + /// + SmallVector ConstPoolAddresses; + + /// JumpTable - The jump tables for the current function. + /// + MachineJumpTableInfo *JumpTable; + + /// JumpTableBase - A pointer to the first entry in the jump table. + /// + void *JumpTableBase; + + /// Resolver - This contains info about the currently resolved functions. + JITResolver Resolver; + + /// LabelLocations - This vector is a mapping from Label ID's to their + /// address. + DenseMap LabelLocations; + + /// MMI - Machine module info for exception informations + MachineModuleInfo* MMI; + + // CurFn - The llvm function being emitted. Only valid during + // finishFunction(). + const Function *CurFn; + + /// Information about emitted code, which is passed to the + /// JITEventListeners. This is reset in startFunction and used in + /// finishFunction. + JITEvent_EmittedFunctionDetails EmissionDetails; + + struct EmittedCode { + void *FunctionBody; // Beginning of the function's allocation. + void *Code; // The address the function's code actually starts at. + void *ExceptionTable; + EmittedCode() : FunctionBody(nullptr), Code(nullptr), + ExceptionTable(nullptr) {} + }; + struct EmittedFunctionConfig : public ValueMapConfig { + typedef JITEmitter *ExtraData; + static void onDelete(JITEmitter *, const Function*); + static void onRAUW(JITEmitter *, const Function*, const Function*); + }; + ValueMap EmittedFunctions; + + DebugLoc PrevDL; + + /// Instance of the JIT + JIT *TheJIT; + + public: + JITEmitter(JIT &jit, JITMemoryManager *JMM, TargetMachine &TM) + : SizeEstimate(0), Resolver(jit, *this), MMI(nullptr), CurFn(nullptr), + EmittedFunctions(this), TheJIT(&jit) { + MemMgr = JMM ? JMM : JITMemoryManager::CreateDefaultMemManager(); + if (jit.getJITInfo().needsGOT()) { + MemMgr->AllocateGOT(); + // DEBUG_PRINT(dbgs() << "JIT is managing a GOT\n"); + } + + } + ~JITEmitter() { + delete MemMgr; + } + + JITResolver &getJITResolver() { return Resolver; } + + void startFunction(MachineFunction &F) override; + bool finishFunction(MachineFunction &F) override; + + void emitConstantPool(MachineConstantPool *MCP); + void initJumpTableInfo(MachineJumpTableInfo *MJTI); + void emitJumpTableInfo(MachineJumpTableInfo *MJTI); + + void startGVStub(const GlobalValue* GV, + unsigned StubSize, unsigned Alignment = 1); + void startGVStub(void *Buffer, unsigned StubSize); + void finishGVStub(); + void *allocIndirectGV(const GlobalValue *GV, const uint8_t *Buffer, + size_t Size, unsigned Alignment) override; + + /// allocateSpace - Reserves space in the current block if any, or + /// allocate a new one of the given size. + void *allocateSpace(uintptr_t Size, unsigned Alignment) override; + + /// allocateGlobal - Allocate memory for a global. Unlike allocateSpace, + /// this method does not allocate memory in the current output buffer, + /// because a global may live longer than the current function. + void *allocateGlobal(uintptr_t Size, unsigned Alignment) override; + + void addRelocation(const MachineRelocation &MR) override { + Relocations.push_back(MR); + } + + void StartMachineBasicBlock(MachineBasicBlock *MBB) override { + if (MBBLocations.size() <= (unsigned)MBB->getNumber()) + MBBLocations.resize((MBB->getNumber()+1)*2); + MBBLocations[MBB->getNumber()] = getCurrentPCValue(); + if (MBB->hasAddressTaken()) + TheJIT->addPointerToBasicBlock(MBB->getBasicBlock(), + (void*)getCurrentPCValue()); + // DEBUG_PRINT(dbgs() << "JIT: Emitting BB" << MBB->getNumber() << " at [" + // << (void*) getCurrentPCValue() << "]\n"); + } + + uintptr_t getConstantPoolEntryAddress(unsigned Entry) const override; + uintptr_t getJumpTableEntryAddress(unsigned Entry) const override; + + uintptr_t + getMachineBasicBlockAddress(MachineBasicBlock *MBB) const override { + assert(MBBLocations.size() > (unsigned)MBB->getNumber() && + MBBLocations[MBB->getNumber()] && "MBB not emitted!"); + return MBBLocations[MBB->getNumber()]; + } + + /// retryWithMoreMemory - Log a retry and deallocate all memory for the + /// given function. Increase the minimum allocation size so that we get + /// more memory next time. + void retryWithMoreMemory(MachineFunction &F); + + /// deallocateMemForFunction - Deallocate all memory for the specified + /// function body. + void deallocateMemForFunction(const Function *F); + + void processDebugLoc(DebugLoc DL, bool BeforePrintingInsn) override; + + void emitLabel(MCSymbol *Label) override { + LabelLocations[Label] = getCurrentPCValue(); + } + + DenseMap *getLabelLocations() override { + return &LabelLocations; + } + + uintptr_t getLabelAddress(MCSymbol *Label) const override { + assert(LabelLocations.count(Label) && "Label not emitted!"); + return LabelLocations.find(Label)->second; + } + + void setModuleInfo(MachineModuleInfo* Info) override { + MMI = Info; + } + + const std::vector& GetMachineRelocations() const + { + return m_copyRelocations; + } + + private: + void *getPointerToGlobal(GlobalValue *GV, void *Reference, + bool MayNeedFarStub); + void *getPointerToGVIndirectSym(GlobalValue *V, void *Reference); + }; +} + +void CallSiteValueMapConfig::onDelete(JITResolverState *JRS, Function *F) { + JRS->EraseAllCallSitesForPrelocked(F); +} + +void JITResolverState::EraseAllCallSitesForPrelocked(Function *F) { + FunctionToCallSitesMapTy::iterator F2C = FunctionToCallSitesMap.find(F); + if (F2C == FunctionToCallSitesMap.end()) + return; + StubToResolverMapTy &S2RMap = *StubToResolverMap; + for (SmallPtrSet::const_iterator I = F2C->second.begin(), + E = F2C->second.end(); I != E; ++I) { + S2RMap.UnregisterStubResolver(*I); + bool Erased = CallSiteToFunctionMap.erase(*I); + (void)Erased; + assert(Erased && "Missing call site->function mapping"); + } + FunctionToCallSitesMap.erase(F2C); +} + +void JITResolverState::EraseAllCallSitesPrelocked() { + StubToResolverMapTy &S2RMap = *StubToResolverMap; + for (CallSiteToFunctionMapTy::const_iterator + I = CallSiteToFunctionMap.begin(), + E = CallSiteToFunctionMap.end(); I != E; ++I) { + S2RMap.UnregisterStubResolver(I->first); + } + CallSiteToFunctionMap.clear(); + FunctionToCallSitesMap.clear(); +} + +JITResolver::~JITResolver() { + // No need to lock because we're in the destructor, and state isn't shared. + state.EraseAllCallSitesPrelocked(); + assert(!StubToResolverMap->ResolverHasStubs(this) && + "Resolver destroyed with stubs still alive."); +} + +/// getLazyFunctionStubIfAvailable - This returns a pointer to a function stub +/// if it has already been created. +void *JITResolver::getLazyFunctionStubIfAvailable(Function *F) { + MutexGuard locked(TheJIT->lock); + + // If we already have a stub for this function, recycle it. + return state.getFunctionToLazyStubMap().lookup(F); +} + +/// getFunctionStub - This returns a pointer to a function stub, creating +/// one on demand as needed. +void *JITResolver::getLazyFunctionStub(Function *F) { + MutexGuard locked(TheJIT->lock); + + // If we already have a lazy stub for this function, recycle it. + void *&Stub = state.getFunctionToLazyStubMap()[F]; + if (Stub) return Stub; + + // Call the lazy resolver function if we are JIT'ing lazily. Otherwise we + // must resolve the symbol now. + void *Actual = TheJIT->isCompilingLazily() + ? (void *)(intptr_t)LazyResolverFn : (void *)nullptr; + + // If this is an external declaration, attempt to resolve the address now + // to place in the stub. + if (isNonGhostDeclaration(F) || F->hasAvailableExternallyLinkage()) { + Actual = TheJIT->getPointerToFunction(F); + + // If we resolved the symbol to a null address (eg. a weak external) + // don't emit a stub. Return a null pointer to the application. + if (!Actual) return nullptr; + } + + TargetJITInfo::StubLayout SL = TheJIT->getJITInfo().getStubLayout(); + JE.startGVStub(F, SL.Size, SL.Alignment); + // Codegen a new stub, calling the lazy resolver or the actual address of the + // external function, if it was resolved. + Stub = TheJIT->getJITInfo().emitFunctionStub(F, Actual, JE); + JE.finishGVStub(); + + if (Actual != (void*)(intptr_t)LazyResolverFn) { + // If we are getting the stub for an external function, we really want the + // address of the stub in the GlobalAddressMap for the JIT, not the address + // of the external function. + TheJIT->updateGlobalMapping(F, Stub); + } + +// DEBUG_PRINT(dbgs() << "JIT: Lazy stub emitted at [" << Stub << "] for function '" +// << F->getName() << "'\n"); + + if (TheJIT->isCompilingLazily()) { + // Register this JITResolver as the one corresponding to this call site so + // JITCompilerFn will be able to find it. + StubToResolverMap->RegisterStubResolver(Stub, this); + + // Finally, keep track of the stub-to-Function mapping so that the + // JITCompilerFn knows which function to compile! + state.AddCallSite(Stub, F); + } else if (!Actual) { + // If we are JIT'ing non-lazily but need to call a function that does not + // exist yet, add it to the JIT's work list so that we can fill in the + // stub address later. + assert(!isNonGhostDeclaration(F) && !F->hasAvailableExternallyLinkage() && + "'Actual' should have been set above."); + TheJIT->addPendingFunction(F); + } + + return Stub; +} + +/// getGlobalValueIndirectSym - Return a lazy pointer containing the specified +/// GV address. +void *JITResolver::getGlobalValueIndirectSym(GlobalValue *GV, void *GVAddress) { + MutexGuard locked(TheJIT->lock); + + // If we already have a stub for this global variable, recycle it. + void *&IndirectSym = state.getGlobalToIndirectSymMap()[GV]; + if (IndirectSym) return IndirectSym; + + // Otherwise, codegen a new indirect symbol. + IndirectSym = TheJIT->getJITInfo().emitGlobalValueIndirectSym(GV, GVAddress, + JE); + +// DEBUG_PRINT(dbgs() << "JIT: Indirect symbol emitted at [" << IndirectSym +// << "] for GV '" << GV->getName() << "'\n"); + + return IndirectSym; +} + +/// getExternalFunctionStub - Return a stub for the function at the +/// specified address, created lazily on demand. +void *JITResolver::getExternalFunctionStub(void *FnAddr) { + // If we already have a stub for this function, recycle it. + void *&Stub = ExternalFnToStubMap[FnAddr]; + if (Stub) return Stub; + + TargetJITInfo::StubLayout SL = TheJIT->getJITInfo().getStubLayout(); + JE.startGVStub(nullptr, SL.Size, SL.Alignment); + Stub = TheJIT->getJITInfo().emitFunctionStub(nullptr, FnAddr, JE); + JE.finishGVStub(); + +// DEBUG_PRINT(dbgs() << "JIT: Stub emitted at [" << Stub +// << "] for external function at '" << FnAddr << "'\n"); + return Stub; +} + +unsigned JITResolver::getGOTIndexForAddr(void* addr) { + unsigned idx = revGOTMap[addr]; + if (!idx) { + idx = ++nextGOTIndex; + revGOTMap[addr] = idx; + // DEBUG_PRINT(dbgs() << "JIT: Adding GOT entry " << idx << " for addr [" + // << addr << "]\n"); + } + return idx; +} + +/// JITCompilerFn - This function is called when a lazy compilation stub has +/// been entered. It looks up which function this stub corresponds to, compiles +/// it if necessary, then returns the resultant function pointer. +void *JITResolver::JITCompilerFn(void *Stub) { + JITResolver *JR = StubToResolverMap->getResolverFromStub(Stub); + assert(JR && "Unable to find the corresponding JITResolver to the call site"); + + Function* F = nullptr; + void* ActualPtr = nullptr; + + { + // Only lock for getting the Function. The call getPointerToFunction made + // in this function might trigger function materializing, which requires + // JIT lock to be unlocked. + MutexGuard locked(JR->TheJIT->lock); + + // The address given to us for the stub may not be exactly right, it might + // be a little bit after the stub. As such, use upper_bound to find it. + std::pair I = + JR->state.LookupFunctionFromCallSite(Stub); + F = I.second; + ActualPtr = I.first; + } + + // If we have already code generated the function, just return the address. + void *Result = JR->TheJIT->getPointerToGlobalIfAvailable(F); + + if (!Result) { + // Otherwise we don't have it, do lazy compilation now. + + // If lazy compilation is disabled, emit a useful error message and abort. + if (!JR->TheJIT->isCompilingLazily()) { + report_fatal_error("LLVM JIT requested to do lazy compilation of" + " function '" + + F->getName() + "' when lazy compiles are disabled!"); + } + + // DEBUG_PRINT(dbgs() << "JIT: Lazily resolving function '" << F->getName() + // << "' In stub ptr = " << Stub << " actual ptr = " + // << ActualPtr << "\n"); + (void)ActualPtr; + + Result = JR->TheJIT->getPointerToFunction(F); + } + + // Reacquire the lock to update the GOT map. + MutexGuard locked(JR->TheJIT->lock); + + // We might like to remove the call site from the CallSiteToFunction map, but + // we can't do that! Multiple threads could be stuck, waiting to acquire the + // lock above. As soon as the 1st function finishes compiling the function, + // the next one will be released, and needs to be able to find the function it + // needs to call. + + // FIXME: We could rewrite all references to this stub if we knew them. + + // What we will do is set the compiled function address to map to the + // same GOT entry as the stub so that later clients may update the GOT + // if they see it still using the stub address. + // Note: this is done so the Resolver doesn't have to manage GOT memory + // Do this without allocating map space if the target isn't using a GOT + if(JR->revGOTMap.find(Stub) != JR->revGOTMap.end()) + JR->revGOTMap[Result] = JR->revGOTMap[Stub]; + + return Result; +} + +//===----------------------------------------------------------------------===// +// JITEmitter code. +// + +static GlobalObject *getSimpleAliasee(Constant *C) { + C = C->stripPointerCasts(); + return dyn_cast(C); +} + +void *JITEmitter::getPointerToGlobal(GlobalValue *V, void *Reference, + bool MayNeedFarStub) { + if (GlobalVariable *GV = dyn_cast(V)) + return TheJIT->getOrEmitGlobalVariable(GV); + + if (GlobalAlias *GA = dyn_cast(V)) { + // We can only handle simple cases. + if (GlobalValue *GV = getSimpleAliasee(GA->getAliasee())) + return TheJIT->getPointerToGlobal(GV); + return nullptr; + } + + // If we have already compiled the function, return a pointer to its body. + Function *F = cast(V); + + void *FnStub = Resolver.getLazyFunctionStubIfAvailable(F); + if (FnStub) { + // Return the function stub if it's already created. We do this first so + // that we're returning the same address for the function as any previous + // call. TODO: Yes, this is wrong. The lazy stub isn't guaranteed to be + // close enough to call. + return FnStub; + } + + // If we know the target can handle arbitrary-distance calls, try to + // return a direct pointer. + if (!MayNeedFarStub) { + // If we have code, go ahead and return that. + void *ResultPtr = TheJIT->getPointerToGlobalIfAvailable(F); + if (ResultPtr) return ResultPtr; + + // If this is an external function pointer, we can force the JIT to + // 'compile' it, which really just adds it to the map. + if (isNonGhostDeclaration(F) || F->hasAvailableExternallyLinkage()) + return TheJIT->getPointerToFunction(F); + } + + // Otherwise, we may need a to emit a stub, and, conservatively, we always do + // so. Note that it's possible to return null from getLazyFunctionStub in the + // case of a weak extern that fails to resolve. + return Resolver.getLazyFunctionStub(F); +} + +void *JITEmitter::getPointerToGVIndirectSym(GlobalValue *V, void *Reference) { + // Make sure GV is emitted first, and create a stub containing the fully + // resolved address. + void *GVAddress = getPointerToGlobal(V, Reference, false); + void *StubAddr = Resolver.getGlobalValueIndirectSym(V, GVAddress); + return StubAddr; +} + +void JITEmitter::processDebugLoc(DebugLoc DL, bool BeforePrintingInsn) { + if (DL.isUnknown()) return; + if (!BeforePrintingInsn) return; + + const LLVMContext &Context = EmissionDetails.MF->getFunction()->getContext(); + + if (DL.getScope(Context) != nullptr && PrevDL != DL) { + JITEvent_EmittedFunctionDetails::LineStart NextLine; + NextLine.Address = getCurrentPCValue(); + NextLine.Loc = DL; + EmissionDetails.LineStarts.push_back(NextLine); + } + + PrevDL = DL; +} + +static unsigned GetConstantPoolSizeInBytes(MachineConstantPool *MCP, + const DataLayout *TD) { + const std::vector &Constants = MCP->getConstants(); + if (Constants.empty()) return 0; + + unsigned Size = 0; + for (unsigned i = 0, e = Constants.size(); i != e; ++i) { + MachineConstantPoolEntry CPE = Constants[i]; + unsigned AlignMask = CPE.getAlignment() - 1; + Size = (Size + AlignMask) & ~AlignMask; + Type *Ty = CPE.getType(); + Size += TD->getTypeAllocSize(Ty); + } + return Size; +} + +void JITEmitter::startFunction(MachineFunction &F) { +// DEBUG_PRINT(dbgs() << "JIT: Starting CodeGen of Function " +// << F.getName() << "\n"); + + uintptr_t ActualSize = 0; + // Set the memory writable, if it's not already + MemMgr->setMemoryWritable(); + + if (SizeEstimate > 0) { + // SizeEstimate will be non-zero on reallocation attempts. + ActualSize = SizeEstimate; + } + + BufferBegin = CurBufferPtr = MemMgr->startFunctionBody(F.getFunction(), + ActualSize); + BufferEnd = BufferBegin+ActualSize; + EmittedFunctions[F.getFunction()].FunctionBody = BufferBegin; + + // Ensure the constant pool/jump table info is at least 4-byte aligned. + emitAlignment(16); + + emitConstantPool(F.getConstantPool()); + if (MachineJumpTableInfo *MJTI = F.getJumpTableInfo()) + initJumpTableInfo(MJTI); + + // About to start emitting the machine code for the function. + emitAlignment(std::max(F.getFunction()->getAlignment(), 8U)); + TheJIT->updateGlobalMapping(F.getFunction(), CurBufferPtr); + EmittedFunctions[F.getFunction()].Code = CurBufferPtr; + + MBBLocations.clear(); + + EmissionDetails.MF = &F; + EmissionDetails.LineStarts.clear(); +} + +bool JITEmitter::finishFunction(MachineFunction &F) { + if (CurBufferPtr == BufferEnd) { + // We must call endFunctionBody before retrying, because + // deallocateMemForFunction requires it. + MemMgr->endFunctionBody(F.getFunction(), BufferBegin, CurBufferPtr); + retryWithMoreMemory(F); + return true; + } + + if (MachineJumpTableInfo *MJTI = F.getJumpTableInfo()) + emitJumpTableInfo(MJTI); + + // FnStart is the start of the text, not the start of the constant pool and + // other per-function data. + uint8_t *FnStart = + (uint8_t *)TheJIT->getPointerToGlobalIfAvailable(F.getFunction()); + + // FnEnd is the end of the function's machine code. + uint8_t *FnEnd = CurBufferPtr; + + if (!Relocations.empty()) { + CurFn = F.getFunction(); + NumRelos += Relocations.size(); + + // Resolve the relocations to concrete pointers. + for (unsigned i = 0, e = Relocations.size(); i != e; ++i) { + MachineRelocation &MR = Relocations[i]; + void *ResultPtr = nullptr; + if (!MR.letTargetResolve()) { + if (MR.isExternalSymbol()) { + ResultPtr = TheJIT->getPointerToNamedFunction(MR.getExternalSymbol(), + false); + // DEBUG_PRINT(dbgs() << "JIT: Map \'" << MR.getExternalSymbol() << "\' to [" + // << ResultPtr << "]\n"); + + // If the target REALLY wants a stub for this function, emit it now. + if (MR.mayNeedFarStub()) { + ResultPtr = Resolver.getExternalFunctionStub(ResultPtr); + } + } else if (MR.isGlobalValue()) { + ResultPtr = getPointerToGlobal(MR.getGlobalValue(), + BufferBegin+MR.getMachineCodeOffset(), + MR.mayNeedFarStub()); + } else if (MR.isIndirectSymbol()) { + ResultPtr = getPointerToGVIndirectSym( + MR.getGlobalValue(), BufferBegin+MR.getMachineCodeOffset()); + } else if (MR.isBasicBlock()) { + ResultPtr = (void*)getMachineBasicBlockAddress(MR.getBasicBlock()); + } else if (MR.isConstantPoolIndex()) { + ResultPtr = + (void*)getConstantPoolEntryAddress(MR.getConstantPoolIndex()); + } else { + assert(MR.isJumpTableIndex()); + ResultPtr=(void*)getJumpTableEntryAddress(MR.getJumpTableIndex()); + } + + MR.setResultPointer(ResultPtr); + } + + // if we are managing the GOT and the relocation wants an index, + // give it one + if (MR.isGOTRelative() && MemMgr->isManagingGOT()) { + unsigned idx = Resolver.getGOTIndexForAddr(ResultPtr); + MR.setGOTIndex(idx); + if (((void**)MemMgr->getGOTBase())[idx] != ResultPtr) { + // DEBUG_PRINT(dbgs() << "JIT: GOT was out of date for " << ResultPtr + // << " pointing at " << ((void**)MemMgr->getGOTBase())[idx] + // << "\n"); + ((void**)MemMgr->getGOTBase())[idx] = ResultPtr; + } + } + } + + CurFn = nullptr; + TheJIT->getJITInfo().relocate(BufferBegin, &Relocations[0], + Relocations.size(), MemMgr->getGOTBase()); + } + + // Update the GOT entry for F to point to the new code. + if (MemMgr->isManagingGOT()) { + unsigned idx = Resolver.getGOTIndexForAddr((void*)BufferBegin); + if (((void**)MemMgr->getGOTBase())[idx] != (void*)BufferBegin) { + // DEBUG_PRINT(dbgs() << "JIT: GOT was out of date for " << (void*)BufferBegin + // << " pointing at " << ((void**)MemMgr->getGOTBase())[idx] + // << "\n"); + ((void**)MemMgr->getGOTBase())[idx] = (void*)BufferBegin; + } + } + + // CurBufferPtr may have moved beyond FnEnd, due to memory allocation for + // global variables that were referenced in the relocations. + MemMgr->endFunctionBody(F.getFunction(), BufferBegin, CurBufferPtr); + + if (CurBufferPtr == BufferEnd) { + retryWithMoreMemory(F); + return true; + } else { + // Now that we've succeeded in emitting the function, reset the + // SizeEstimate back down to zero. + SizeEstimate = 0; + } + + BufferBegin = CurBufferPtr = nullptr; + NumBytes += FnEnd-FnStart; + + // Invalidate the icache if necessary. + sys::Memory::InvalidateInstructionCache(FnStart, FnEnd-FnStart); + + TheJIT->NotifyFunctionEmitted(*F.getFunction(), FnStart, FnEnd-FnStart, + EmissionDetails); + + // Reset the previous debug location. + PrevDL = DebugLoc(); + +// DEBUG_PRINT(dbgs() << "JIT: Finished CodeGen of [" << (void*)FnStart +// << "] Function: " << F.getName() +// << ": " << (FnEnd-FnStart) << " bytes of text, " +// << Relocations.size() << " relocations\n"); + + m_copyRelocations = Relocations; + Relocations.clear(); + ConstPoolAddresses.clear(); + + // Mark code region readable and executable if it's not so already. + MemMgr->setMemoryExecutable(); + +// DEBUG_PRINT({ +// dbgs() << "JIT: Binary code:\n"; +// uint8_t* q = FnStart; +// for (int i = 0; q < FnEnd; q += 4, ++i) { +// if (i == 4) +// i = 0; +// if (i == 0) +// dbgs() << "JIT: " << (long)(q - FnStart) << ": "; +// bool Done = false; +// for (int j = 3; j >= 0; --j) { +// if (q + j >= FnEnd) +// Done = true; +// else +// dbgs() << (unsigned short)q[j]; +// } +// if (Done) +// break; +// dbgs() << ' '; +// if (i == 3) +// dbgs() << '\n'; +// } +// dbgs()<< '\n'; +// }); + + if (MMI) + MMI->EndFunction(); + + return false; +} + +void JITEmitter::retryWithMoreMemory(MachineFunction &F) { +// DEBUG_PRINT(dbgs() << "JIT: Ran out of space for native code. Reattempting.\n"); + Relocations.clear(); // Clear the old relocations or we'll reapply them. + ConstPoolAddresses.clear(); + ++NumRetries; + deallocateMemForFunction(F.getFunction()); + // Try again with at least twice as much free space. + SizeEstimate = (uintptr_t)(2 * (BufferEnd - BufferBegin)); + + for (MachineFunction::iterator MBB = F.begin(), E = F.end(); MBB != E; ++MBB){ + if (MBB->hasAddressTaken()) + TheJIT->clearPointerToBasicBlock(MBB->getBasicBlock()); + } +} + +/// deallocateMemForFunction - Deallocate all memory for the specified +/// function body. Also drop any references the function has to stubs. +/// May be called while the Function is being destroyed inside ~Value(). +void JITEmitter::deallocateMemForFunction(const Function *F) { + ValueMap::iterator + Emitted = EmittedFunctions.find(F); + if (Emitted != EmittedFunctions.end()) { + MemMgr->deallocateFunctionBody(Emitted->second.FunctionBody); + TheJIT->NotifyFreeingMachineCode(Emitted->second.Code); + + EmittedFunctions.erase(Emitted); + } +} + + +void *JITEmitter::allocateSpace(uintptr_t Size, unsigned Alignment) { + if (BufferBegin) + return JITCodeEmitter::allocateSpace(Size, Alignment); + + // create a new memory block if there is no active one. + // care must be taken so that BufferBegin is invalidated when a + // block is trimmed + BufferBegin = CurBufferPtr = MemMgr->allocateSpace(Size, Alignment); + BufferEnd = BufferBegin+Size; + return CurBufferPtr; +} + +void *JITEmitter::allocateGlobal(uintptr_t Size, unsigned Alignment) { + // Delegate this call through the memory manager. + return MemMgr->allocateGlobal(Size, Alignment); +} + +void JITEmitter::emitConstantPool(MachineConstantPool *MCP) { + if (TheJIT->getJITInfo().hasCustomConstantPool()) + return; + + const std::vector &Constants = MCP->getConstants(); + if (Constants.empty()) return; + + unsigned Size = GetConstantPoolSizeInBytes(MCP, TheJIT->getDataLayout()); + unsigned Align = MCP->getConstantPoolAlignment(); + ConstantPoolBase = allocateSpace(Size, Align); + ConstantPool = MCP; + + if (!ConstantPoolBase) return; // Buffer overflow. + +// DEBUG_PRINT(dbgs() << "JIT: Emitted constant pool at [" << ConstantPoolBase +// << "] (size: " << Size << ", alignment: " << Align << ")\n"); + + // Initialize the memory for all of the constant pool entries. + unsigned Offset = 0; + for (unsigned i = 0, e = Constants.size(); i != e; ++i) { + MachineConstantPoolEntry CPE = Constants[i]; + unsigned AlignMask = CPE.getAlignment() - 1; + Offset = (Offset + AlignMask) & ~AlignMask; + + uintptr_t CAddr = (uintptr_t)ConstantPoolBase + Offset; + ConstPoolAddresses.push_back(CAddr); + if (CPE.isMachineConstantPoolEntry()) { + // FIXME: add support to lower machine constant pool values into bytes! + report_fatal_error("Initialize memory with machine specific constant pool" + "entry has not been implemented!"); + } + TheJIT->InitializeMemory(CPE.Val.ConstVal, (void*)CAddr); + // DEBUG_PRINT(dbgs() << "JIT: CP" << i << " at [0x"; + // dbgs().write_hex(CAddr) << "]\n"); + + Type *Ty = CPE.Val.ConstVal->getType(); + Offset += TheJIT->getDataLayout()->getTypeAllocSize(Ty); + } +} + +void JITEmitter::initJumpTableInfo(MachineJumpTableInfo *MJTI) { + if (TheJIT->getJITInfo().hasCustomJumpTables()) + return; + if (MJTI->getEntryKind() == MachineJumpTableInfo::EK_Inline) + return; + + const std::vector &JT = MJTI->getJumpTables(); + if (JT.empty()) return; + + unsigned NumEntries = 0; + for (unsigned i = 0, e = JT.size(); i != e; ++i) + NumEntries += JT[i].MBBs.size(); + + unsigned EntrySize = MJTI->getEntrySize(*TheJIT->getDataLayout()); + + // Just allocate space for all the jump tables now. We will fix up the actual + // MBB entries in the tables after we emit the code for each block, since then + // we will know the final locations of the MBBs in memory. + JumpTable = MJTI; + JumpTableBase = allocateSpace(NumEntries * EntrySize, + MJTI->getEntryAlignment(*TheJIT->getDataLayout())); +} + +void JITEmitter::emitJumpTableInfo(MachineJumpTableInfo *MJTI) { + if (TheJIT->getJITInfo().hasCustomJumpTables()) + return; + + const std::vector &JT = MJTI->getJumpTables(); + if (JT.empty() || !JumpTableBase) return; + + + switch (MJTI->getEntryKind()) { + case MachineJumpTableInfo::EK_Inline: + return; + case MachineJumpTableInfo::EK_BlockAddress: { + // EK_BlockAddress - Each entry is a plain address of block, e.g.: + // .word LBB123 + assert(MJTI->getEntrySize(*TheJIT->getDataLayout()) == sizeof(void*) && + "Cross JIT'ing?"); + + // For each jump table, map each target in the jump table to the address of + // an emitted MachineBasicBlock. + intptr_t *SlotPtr = (intptr_t*)JumpTableBase; + + for (unsigned i = 0, e = JT.size(); i != e; ++i) { + const std::vector &MBBs = JT[i].MBBs; + // Store the address of the basic block for this jump table slot in the + // memory we allocated for the jump table in 'initJumpTableInfo' + for (unsigned mi = 0, me = MBBs.size(); mi != me; ++mi) + *SlotPtr++ = getMachineBasicBlockAddress(MBBs[mi]); + } + break; + } + + case MachineJumpTableInfo::EK_Custom32: + case MachineJumpTableInfo::EK_GPRel32BlockAddress: + case MachineJumpTableInfo::EK_LabelDifference32: { + assert(MJTI->getEntrySize(*TheJIT->getDataLayout()) == 4&&"Cross JIT'ing?"); + // For each jump table, place the offset from the beginning of the table + // to the target address. + int *SlotPtr = (int*)JumpTableBase; + + for (unsigned i = 0, e = JT.size(); i != e; ++i) { + const std::vector &MBBs = JT[i].MBBs; + // Store the offset of the basic block for this jump table slot in the + // memory we allocated for the jump table in 'initJumpTableInfo' + uintptr_t Base = (uintptr_t)SlotPtr; + for (unsigned mi = 0, me = MBBs.size(); mi != me; ++mi) { + uintptr_t MBBAddr = getMachineBasicBlockAddress(MBBs[mi]); + /// FIXME: USe EntryKind instead of magic "getPICJumpTableEntry" hook. + *SlotPtr++ = TheJIT->getJITInfo().getPICJumpTableEntry(MBBAddr, Base); + } + } + break; + } + case MachineJumpTableInfo::EK_GPRel64BlockAddress: + llvm_unreachable( + "JT Info emission not implemented for GPRel64BlockAddress yet."); + } +} + +void JITEmitter::startGVStub(const GlobalValue* GV, + unsigned StubSize, unsigned Alignment) { + SavedBufferBegin = BufferBegin; + SavedBufferEnd = BufferEnd; + SavedCurBufferPtr = CurBufferPtr; + + BufferBegin = CurBufferPtr = MemMgr->allocateStub(GV, StubSize, Alignment); + BufferEnd = BufferBegin+StubSize+1; +} + +void JITEmitter::startGVStub(void *Buffer, unsigned StubSize) { + SavedBufferBegin = BufferBegin; + SavedBufferEnd = BufferEnd; + SavedCurBufferPtr = CurBufferPtr; + + BufferBegin = CurBufferPtr = (uint8_t *)Buffer; + BufferEnd = BufferBegin+StubSize+1; +} + +void JITEmitter::finishGVStub() { + assert(CurBufferPtr != BufferEnd && "Stub overflowed allocated space."); + NumBytes += getCurrentPCOffset(); + BufferBegin = SavedBufferBegin; + BufferEnd = SavedBufferEnd; + CurBufferPtr = SavedCurBufferPtr; +} + +void *JITEmitter::allocIndirectGV(const GlobalValue *GV, + const uint8_t *Buffer, size_t Size, + unsigned Alignment) { + uint8_t *IndGV = MemMgr->allocateStub(GV, Size, Alignment); + memcpy(IndGV, Buffer, Size); + return IndGV; +} + +// getConstantPoolEntryAddress - Return the address of the 'ConstantNum' entry +// in the constant pool that was last emitted with the 'emitConstantPool' +// method. +// +uintptr_t JITEmitter::getConstantPoolEntryAddress(unsigned ConstantNum) const { + assert(ConstantNum < ConstantPool->getConstants().size() && + "Invalid ConstantPoolIndex!"); + return ConstPoolAddresses[ConstantNum]; +} + +// getJumpTableEntryAddress - Return the address of the JumpTable with index +// 'Index' in the jumpp table that was last initialized with 'initJumpTableInfo' +// +uintptr_t JITEmitter::getJumpTableEntryAddress(unsigned Index) const { + const std::vector &JT = JumpTable->getJumpTables(); + assert(Index < JT.size() && "Invalid jump table index!"); + + unsigned EntrySize = JumpTable->getEntrySize(*TheJIT->getDataLayout()); + + unsigned Offset = 0; + for (unsigned i = 0; i < Index; ++i) + Offset += JT[i].MBBs.size(); + + Offset *= EntrySize; + + return (uintptr_t)((char *)JumpTableBase + Offset); +} + +void JITEmitter::EmittedFunctionConfig::onDelete( + JITEmitter *Emitter, const Function *F) { + Emitter->deallocateMemForFunction(F); +} +void JITEmitter::EmittedFunctionConfig::onRAUW( + JITEmitter *, const Function*, const Function*) { + llvm_unreachable("The JIT doesn't know how to handle a" + " RAUW on a value it has emitted."); +} + +//===----------------------------------------------------------------------===// +// Public interface to this file +//===----------------------------------------------------------------------===// + +JITCodeEmitter *JIT::createEmitter(JIT &jit, JITMemoryManager *JMM, + TargetMachine &tm) { + return new JITEmitter(jit, JMM, tm); +} + +// getPointerToFunctionOrStub - If the specified function has been +// code-gen'd, return a pointer to the function. If not, compile it, or use +// a stub to implement lazy compilation if available. +// +void *JIT::getPointerToFunctionOrStub(Function *F) { + // If we have already code generated the function, just return the address. + if (void *Addr = getPointerToGlobalIfAvailable(F)) + return Addr; + + // Get a stub if the target supports it. + JITEmitter *JE = static_cast(getCodeEmitter()); + return JE->getJITResolver().getLazyFunctionStub(F); +} + +void JIT::updateFunctionStubUnlocked(Function *F) { + // Get the empty stub we generated earlier. + JITEmitter *JE = static_cast(getCodeEmitter()); + void *Stub = JE->getJITResolver().getLazyFunctionStub(F); + void *Addr = getPointerToGlobalIfAvailable(F); + assert(Addr != Stub && "Function must have non-stub address to be updated."); + + // Tell the target jit info to rewrite the stub at the specified address, + // rather than creating a new one. + TargetJITInfo::StubLayout layout = getJITInfo().getStubLayout(); + JE->startGVStub(Stub, layout.Size); + getJITInfo().emitFunctionStub(F, Addr, *getCodeEmitter()); + JE->finishGVStub(); +} + +/// freeMachineCodeForFunction - release machine code memory for given Function. +/// +void JIT::freeMachineCodeForFunction(Function *F) { + // Delete translation for this from the ExecutionEngine, so it will get + // retranslated next time it is used. + updateGlobalMapping(F, nullptr); + + // Free the actual memory for the function body and related stuff. + static_cast(JCE)->deallocateMemForFunction(F); +} + + +const std::vector& +llvm::GetJitMachineRelocations(const JIT* p_jit) +{ + const JITEmitter* emitter = static_cast(p_jit->getCodeEmitter()); + return emitter->GetMachineRelocations(); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITExtend.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITExtend.cpp new file mode 100644 index 000000000000..e0626a67ae9f --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITExtend.cpp @@ -0,0 +1,695 @@ +//===-- JIT.cpp - LLVM Just in Time Compiler ------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This tool implements a just-in-time compiler for LLVM, allowing direct +// execution of LLVM bitcode in an efficient manner. +// +//===----------------------------------------------------------------------===// + +#include "JITExtend.h" +#include +#include +#include +// #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace llvm; + +#ifdef __APPLE__ +// Apple gcc defaults to -fuse-cxa-atexit (i.e. calls __cxa_atexit instead +// of atexit). It passes the address of linker generated symbol __dso_handle +// to the function. +// This configuration change happened at version 5330. +# include +# if defined(MAC_OS_X_VERSION_10_4) && \ + ((MAC_OS_X_VERSION_MIN_REQUIRED > MAC_OS_X_VERSION_10_4) || \ + (MAC_OS_X_VERSION_MIN_REQUIRED == MAC_OS_X_VERSION_10_4 && \ + __APPLE_CC__ >= 5330)) +# ifndef HAVE___DSO_HANDLE +# define HAVE___DSO_HANDLE 1 +# endif +# endif +#endif + +#if HAVE___DSO_HANDLE +extern void *__dso_handle __attribute__ ((__visibility__ ("hidden"))); +#endif + +namespace { + +static struct RegisterJIT { + RegisterJIT() { JIT::Register(); } +} JITRegistrator; + +} + +extern "C" void LLVMLinkInJIT() { +} + +/// createJIT - This is the factory method for creating a JIT for the current +/// machine, it does not fall back to the interpreter. This takes ownership +/// of the module. +ExecutionEngine *JIT::createJIT(Module *M, + std::string *ErrorStr, + JITMemoryManager *JMM, + bool GVsWithCode, + TargetMachine *TM) { + // Try to register the program as a source of symbols to resolve against. + // + // FIXME: Don't do this here. + sys::DynamicLibrary::LoadLibraryPermanently(nullptr, nullptr); + + // If the target supports JIT code generation, create the JIT. + if (TargetJITInfo *TJ = TM->getJITInfo()) { + return new JIT(M, *TM, *TJ, JMM, GVsWithCode); + } else { + if (ErrorStr) + *ErrorStr = "target does not support JIT code generation"; + return nullptr; + } +} + +namespace { +/// This class supports the global getPointerToNamedFunction(), which allows +/// bugpoint or gdb users to search for a function by name without any context. +class JitPool { + SmallPtrSet JITs; // Optimize for process containing just 1 JIT. + mutable sys::Mutex Lock; +public: + void Add(JIT *jit) { + MutexGuard guard(Lock); + JITs.insert(jit); + } + void Remove(JIT *jit) { + MutexGuard guard(Lock); + JITs.erase(jit); + } + void *getPointerToNamedFunction(const char *Name) const { + MutexGuard guard(Lock); + assert(JITs.size() != 0 && "No Jit registered"); + //search function in every instance of JIT + for (SmallPtrSet::const_iterator Jit = JITs.begin(), + end = JITs.end(); + Jit != end; ++Jit) { + if (Function *F = (*Jit)->FindFunctionNamed(Name)) + return (*Jit)->getPointerToFunction(F); + } + // The function is not available : fallback on the first created (will + // search in symbol of the current program/library) + return (*JITs.begin())->getPointerToNamedFunction(Name); + } +}; +ManagedStatic AllJits; +} +extern "C" { + // getPointerToNamedFunction - This function is used as a global wrapper to + // JIT::getPointerToNamedFunction for the purpose of resolving symbols when + // bugpoint is debugging the JIT. In that scenario, we are loading an .so and + // need to resolve function(s) that are being mis-codegenerated, so we need to + // resolve their addresses at runtime, and this is the way to do it. + void *getPointerToNamedFunction(const char *Name) { + return AllJits->getPointerToNamedFunction(Name); + } +} + +JIT::JIT(Module *M, TargetMachine &tm, TargetJITInfo &tji, + JITMemoryManager *jmm, bool GVsWithCode) + : ExecutionEngine(M), TM(tm), TJI(tji), + JMM(jmm ? jmm : JITMemoryManager::CreateDefaultMemManager()), + AllocateGVsWithCode(GVsWithCode), isAlreadyCodeGenerating(false) { + setDataLayout(TM.getDataLayout()); + + jitstate = new JITState(M); + + // Initialize JCE + JCE = createEmitter(*this, JMM, TM); + + // Register in global list of all JITs. + AllJits->Add(this); + + // Add target data + MutexGuard locked(lock); + FunctionPassManager &PM = jitstate->getPM(); + M->setDataLayout(TM.getDataLayout()); + PM.add(new DataLayoutPass(M)); + + // Turn the machine code intermediate representation into bytes in memory that + // may be executed. + if (TM.addPassesToEmitMachineCode(PM, *JCE, !getVerifyModules())) { + report_fatal_error("Target does not support machine code emission!"); + } + + // Initialize passes. + PM.doInitialization(); +} + +JIT::~JIT() { + // Cleanup. + AllJits->Remove(this); + delete jitstate; + delete JCE; + // JMM is a ownership of JCE, so we no need delete JMM here. + delete &TM; +} + +/// addModule - Add a new Module to the JIT. If we previously removed the last +/// Module, we need re-initialize jitstate with a valid Module. +void JIT::addModule(Module *M) { + MutexGuard locked(lock); + + if (Modules.empty()) { + assert(!jitstate && "jitstate should be NULL if Modules vector is empty!"); + + jitstate = new JITState(M); + + FunctionPassManager &PM = jitstate->getPM(); + M->setDataLayout(TM.getDataLayout()); + PM.add(new DataLayoutPass(M)); + + // Turn the machine code intermediate representation into bytes in memory + // that may be executed. + if (TM.addPassesToEmitMachineCode(PM, *JCE, !getVerifyModules())) { + report_fatal_error("Target does not support machine code emission!"); + } + + // Initialize passes. + PM.doInitialization(); + } + + ExecutionEngine::addModule(M); +} + +/// removeModule - If we are removing the last Module, invalidate the jitstate +/// since the PassManager it contains references a released Module. +bool JIT::removeModule(Module *M) { + bool result = ExecutionEngine::removeModule(M); + + MutexGuard locked(lock); + + if (jitstate && jitstate->getModule() == M) { + delete jitstate; + jitstate = nullptr; + } + + if (!jitstate && !Modules.empty()) { + jitstate = new JITState(Modules[0]); + + FunctionPassManager &PM = jitstate->getPM(); + M->setDataLayout(TM.getDataLayout()); + PM.add(new DataLayoutPass(M)); + + // Turn the machine code intermediate representation into bytes in memory + // that may be executed. + if (TM.addPassesToEmitMachineCode(PM, *JCE, !getVerifyModules())) { + report_fatal_error("Target does not support machine code emission!"); + } + + // Initialize passes. + PM.doInitialization(); + } + return result; +} + +/// run - Start execution with the specified function and arguments. +/// +GenericValue JIT::runFunction(Function *F, + const std::vector &ArgValues) { + assert(F && "Function *F was null at entry to run()"); + + void *FPtr = getPointerToFunction(F); + assert(FPtr && "Pointer to fn's code was null after getPointerToFunction"); + FunctionType *FTy = F->getFunctionType(); + Type *RetTy = FTy->getReturnType(); + + assert((FTy->getNumParams() == ArgValues.size() || + (FTy->isVarArg() && FTy->getNumParams() <= ArgValues.size())) && + "Wrong number of arguments passed into function!"); + assert(FTy->getNumParams() == ArgValues.size() && + "This doesn't support passing arguments through varargs (yet)!"); + + // Handle some common cases first. These cases correspond to common `main' + // prototypes. + if (RetTy->isIntegerTy(32) || RetTy->isVoidTy()) { + switch (ArgValues.size()) { + case 3: + if (FTy->getParamType(0)->isIntegerTy(32) && + FTy->getParamType(1)->isPointerTy() && + FTy->getParamType(2)->isPointerTy()) { + int (*PF)(int, char **, const char **) = + (int(*)(int, char **, const char **))(intptr_t)FPtr; + + // Call the function. + GenericValue rv; + rv.IntVal = APInt(32, PF(ArgValues[0].IntVal.getZExtValue(), + (char **)GVTOP(ArgValues[1]), + (const char **)GVTOP(ArgValues[2]))); + return rv; + } + break; + case 2: + if (FTy->getParamType(0)->isIntegerTy(32) && + FTy->getParamType(1)->isPointerTy()) { + int (*PF)(int, char **) = (int(*)(int, char **))(intptr_t)FPtr; + + // Call the function. + GenericValue rv; + rv.IntVal = APInt(32, PF(ArgValues[0].IntVal.getZExtValue(), + (char **)GVTOP(ArgValues[1]))); + return rv; + } + break; + case 1: + if (FTy->getParamType(0)->isIntegerTy(32)) { + GenericValue rv; + int (*PF)(int) = (int(*)(int))(intptr_t)FPtr; + rv.IntVal = APInt(32, PF(ArgValues[0].IntVal.getZExtValue())); + return rv; + } + if (FTy->getParamType(0)->isPointerTy()) { + GenericValue rv; + int (*PF)(char *) = (int(*)(char *))(intptr_t)FPtr; + rv.IntVal = APInt(32, PF((char*)GVTOP(ArgValues[0]))); + return rv; + } + break; + } + } + + // Handle cases where no arguments are passed first. + if (ArgValues.empty()) { + GenericValue rv; + switch (RetTy->getTypeID()) { + default: llvm_unreachable("Unknown return type for function call!"); + case Type::IntegerTyID: { + unsigned BitWidth = cast(RetTy)->getBitWidth(); + if (BitWidth == 1) + rv.IntVal = APInt(BitWidth, ((bool(*)())(intptr_t)FPtr)()); + else if (BitWidth <= 8) + rv.IntVal = APInt(BitWidth, ((char(*)())(intptr_t)FPtr)()); + else if (BitWidth <= 16) + rv.IntVal = APInt(BitWidth, ((short(*)())(intptr_t)FPtr)()); + else if (BitWidth <= 32) + rv.IntVal = APInt(BitWidth, ((int(*)())(intptr_t)FPtr)()); + else if (BitWidth <= 64) + rv.IntVal = APInt(BitWidth, ((int64_t(*)())(intptr_t)FPtr)()); + else + llvm_unreachable("Integer types > 64 bits not supported"); + return rv; + } + case Type::VoidTyID: + rv.IntVal = APInt(32, ((int(*)())(intptr_t)FPtr)()); + return rv; + case Type::FloatTyID: + rv.FloatVal = ((float(*)())(intptr_t)FPtr)(); + return rv; + case Type::DoubleTyID: + rv.DoubleVal = ((double(*)())(intptr_t)FPtr)(); + return rv; + case Type::X86_FP80TyID: + case Type::FP128TyID: + case Type::PPC_FP128TyID: + llvm_unreachable("long double not supported yet"); + case Type::PointerTyID: + return PTOGV(((void*(*)())(intptr_t)FPtr)()); + } + } + + // Okay, this is not one of our quick and easy cases. Because we don't have a + // full FFI, we have to codegen a nullary stub function that just calls the + // function we are interested in, passing in constants for all of the + // arguments. Make this function and return. + + // First, create the function. + FunctionType *STy=FunctionType::get(RetTy, false); + Function *Stub = Function::Create(STy, Function::InternalLinkage, "", + F->getParent()); + + // Insert a basic block. + BasicBlock *StubBB = BasicBlock::Create(F->getContext(), "", Stub); + + // Convert all of the GenericValue arguments over to constants. Note that we + // currently don't support varargs. + SmallVector Args; + for (unsigned i = 0, e = ArgValues.size(); i != e; ++i) { + Constant *C = nullptr; + Type *ArgTy = FTy->getParamType(i); + const GenericValue &AV = ArgValues[i]; + switch (ArgTy->getTypeID()) { + default: llvm_unreachable("Unknown argument type for function call!"); + case Type::IntegerTyID: + C = ConstantInt::get(F->getContext(), AV.IntVal); + break; + case Type::FloatTyID: + C = ConstantFP::get(F->getContext(), APFloat(AV.FloatVal)); + break; + case Type::DoubleTyID: + C = ConstantFP::get(F->getContext(), APFloat(AV.DoubleVal)); + break; + case Type::PPC_FP128TyID: + case Type::X86_FP80TyID: + case Type::FP128TyID: + C = ConstantFP::get(F->getContext(), APFloat(ArgTy->getFltSemantics(), + AV.IntVal)); + break; + case Type::PointerTyID: + void *ArgPtr = GVTOP(AV); + if (sizeof(void*) == 4) + C = ConstantInt::get(Type::getInt32Ty(F->getContext()), + (int)(intptr_t)ArgPtr); + else + C = ConstantInt::get(Type::getInt64Ty(F->getContext()), + (intptr_t)ArgPtr); + // Cast the integer to pointer + C = ConstantExpr::getIntToPtr(C, ArgTy); + break; + } + Args.push_back(C); + } + + CallInst *TheCall = CallInst::Create(F, Args, "", StubBB); + TheCall->setCallingConv(F->getCallingConv()); + TheCall->setTailCall(); + if (!TheCall->getType()->isVoidTy()) + // Return result of the call. + ReturnInst::Create(F->getContext(), TheCall, StubBB); + else + ReturnInst::Create(F->getContext(), StubBB); // Just return void. + + // Finally, call our nullary stub function. + GenericValue Result = runFunction(Stub, std::vector()); + // Erase it, since no other function can have a reference to it. + Stub->eraseFromParent(); + // And return the result. + return Result; +} + +void JIT::RegisterJITEventListener(JITEventListener *L) { + if (!L) + return; + MutexGuard locked(lock); + EventListeners.push_back(L); +} +void JIT::UnregisterJITEventListener(JITEventListener *L) { + if (!L) + return; + MutexGuard locked(lock); + std::vector::reverse_iterator I= + std::find(EventListeners.rbegin(), EventListeners.rend(), L); + if (I != EventListeners.rend()) { + std::swap(*I, EventListeners.back()); + EventListeners.pop_back(); + } +} +void JIT::NotifyFunctionEmitted( + const Function &F, + void *Code, size_t Size, + const JITEvent_EmittedFunctionDetails &Details) { + MutexGuard locked(lock); + for (unsigned I = 0, S = EventListeners.size(); I < S; ++I) { + EventListeners[I]->NotifyFunctionEmitted(F, Code, Size, Details); + } +} + +void JIT::NotifyFreeingMachineCode(void *OldPtr) { + MutexGuard locked(lock); + for (unsigned I = 0, S = EventListeners.size(); I < S; ++I) { + EventListeners[I]->NotifyFreeingMachineCode(OldPtr); + } +} + +/// runJITOnFunction - Run the FunctionPassManager full of +/// just-in-time compilation passes on F, hopefully filling in +/// GlobalAddress[F] with the address of F's machine code. +/// +void JIT::runJITOnFunction(Function *F, MachineCodeInfo *MCI) { + MutexGuard locked(lock); + + class MCIListener : public JITEventListener { + MachineCodeInfo *const MCI; + public: + MCIListener(MachineCodeInfo *mci) : MCI(mci) {} + void NotifyFunctionEmitted(const Function &, void *Code, size_t Size, + const EmittedFunctionDetails &) override { + MCI->setAddress(Code); + MCI->setSize(Size); + } + }; + MCIListener MCIL(MCI); + if (MCI) + RegisterJITEventListener(&MCIL); + + runJITOnFunctionUnlocked(F); + + if (MCI) + UnregisterJITEventListener(&MCIL); +} + +void JIT::runJITOnFunctionUnlocked(Function *F) { + assert(!isAlreadyCodeGenerating && "Error: Recursive compilation detected!"); + + jitTheFunctionUnlocked(F); + + // If the function referred to another function that had not yet been + // read from bitcode, and we are jitting non-lazily, emit it now. + while (!jitstate->getPendingFunctions().empty()) { + Function *PF = jitstate->getPendingFunctions().back(); + jitstate->getPendingFunctions().pop_back(); + + assert(!PF->hasAvailableExternallyLinkage() && + "Externally-defined function should not be in pending list."); + + jitTheFunctionUnlocked(PF); + + // Now that the function has been jitted, ask the JITEmitter to rewrite + // the stub with real address of the function. + updateFunctionStubUnlocked(PF); + } +} + +void JIT::jitTheFunctionUnlocked(Function *F) { + isAlreadyCodeGenerating = true; + jitstate->getPM().run(*F); + isAlreadyCodeGenerating = false; + + // clear basic block addresses after this function is done + getBasicBlockAddressMap().clear(); +} + +/// getPointerToFunction - This method is used to get the address of the +/// specified function, compiling it if necessary. +/// +void *JIT::getPointerToFunction(Function *F) { + + if (void *Addr = getPointerToGlobalIfAvailable(F)) + return Addr; // Check if function already code gen'd + + MutexGuard locked(lock); + + // Now that this thread owns the lock, make sure we read in the function if it + // exists in this Module. + std::string ErrorMsg; + if (F->Materialize(&ErrorMsg)) { + report_fatal_error("Error reading function '" + F->getName()+ + "' from bitcode file: " + ErrorMsg); + } + + // ... and check if another thread has already code gen'd the function. + if (void *Addr = getPointerToGlobalIfAvailable(F)) + return Addr; + + if (F->isDeclaration() || F->hasAvailableExternallyLinkage()) { + bool AbortOnFailure = !F->hasExternalWeakLinkage(); + void *Addr = getPointerToNamedFunction(F->getName(), AbortOnFailure); + addGlobalMapping(F, Addr); + return Addr; + } + + runJITOnFunctionUnlocked(F); + + void *Addr = getPointerToGlobalIfAvailable(F); + assert(Addr && "Code generation didn't add function to GlobalAddress table!"); + return Addr; +} + +void JIT::addPointerToBasicBlock(const BasicBlock *BB, void *Addr) { + MutexGuard locked(lock); + + BasicBlockAddressMapTy::iterator I = + getBasicBlockAddressMap().find(BB); + if (I == getBasicBlockAddressMap().end()) { + getBasicBlockAddressMap()[BB] = Addr; + } else { + // ignore repeats: some BBs can be split into few MBBs? + } +} + +void JIT::clearPointerToBasicBlock(const BasicBlock *BB) { + MutexGuard locked(lock); + getBasicBlockAddressMap().erase(BB); +} + +void *JIT::getPointerToBasicBlock(BasicBlock *BB) { + // make sure it's function is compiled by JIT + (void)getPointerToFunction(BB->getParent()); + + // resolve basic block address + MutexGuard locked(lock); + + BasicBlockAddressMapTy::iterator I = + getBasicBlockAddressMap().find(BB); + if (I != getBasicBlockAddressMap().end()) { + return I->second; + } else { + llvm_unreachable("JIT does not have BB address for address-of-label, was" + " it eliminated by optimizer?"); + } +} + +void *JIT::getPointerToNamedFunction(const std::string &Name, + bool AbortOnFailure){ + if (!isSymbolSearchingDisabled()) { + void *ptr = JMM->getPointerToNamedFunction(Name, false); + if (ptr) + return ptr; + } + + /// If a LazyFunctionCreator is installed, use it to get/create the function. + if (LazyFunctionCreator) + if (void *RP = LazyFunctionCreator(Name)) + return RP; + + if (AbortOnFailure) { + report_fatal_error("Program used external function '"+Name+ + "' which could not be resolved!"); + } + return nullptr; +} + + +/// getOrEmitGlobalVariable - Return the address of the specified global +/// variable, possibly emitting it to memory if needed. This is used by the +/// Emitter. +void *JIT::getOrEmitGlobalVariable(const GlobalVariable *GV) { + MutexGuard locked(lock); + + void *Ptr = getPointerToGlobalIfAvailable(GV); + if (Ptr) return Ptr; + + // If the global is external, just remember the address. + if (GV->isDeclaration() || GV->hasAvailableExternallyLinkage()) { +#if HAVE___DSO_HANDLE + if (GV->getName() == "__dso_handle") + return (void*)&__dso_handle; +#endif + Ptr = sys::DynamicLibrary::SearchForAddressOfSymbol(GV->getName()); + if (!Ptr) { + report_fatal_error("Could not resolve external global address: " + +GV->getName()); + } + addGlobalMapping(GV, Ptr); + } else { + // If the global hasn't been emitted to memory yet, allocate space and + // emit it into memory. + Ptr = getMemoryForGV(GV); + addGlobalMapping(GV, Ptr); + EmitGlobalVariable(GV); // Initialize the variable. + } + return Ptr; +} + +/// recompileAndRelinkFunction - This method is used to force a function +/// which has already been compiled, to be compiled again, possibly +/// after it has been modified. Then the entry to the old copy is overwritten +/// with a branch to the new copy. If there was no old copy, this acts +/// just like JIT::getPointerToFunction(). +/// +void *JIT::recompileAndRelinkFunction(Function *F) { + void *OldAddr = getPointerToGlobalIfAvailable(F); + + // If it's not already compiled there is no reason to patch it up. + if (!OldAddr) return getPointerToFunction(F); + + // Delete the old function mapping. + addGlobalMapping(F, nullptr); + + // Recodegen the function + runJITOnFunction(F); + + // Update state, forward the old function to the new function. + void *Addr = getPointerToGlobalIfAvailable(F); + assert(Addr && "Code generation didn't add function to GlobalAddress table!"); + TJI.replaceMachineCodeForFunction(OldAddr, Addr); + return Addr; +} + +/// getMemoryForGV - This method abstracts memory allocation of global +/// variable so that the JIT can allocate thread local variables depending +/// on the target. +/// +char* JIT::getMemoryForGV(const GlobalVariable* GV) { + char *Ptr; + + // GlobalVariable's which are not "constant" will cause trouble in a server + // situation. It's returned in the same block of memory as code which may + // not be writable. + if (isGVCompilationDisabled() && !GV->isConstant()) { + report_fatal_error("Compilation of non-internal GlobalValue is disabled!"); + } + + // Some applications require globals and code to live together, so they may + // be allocated into the same buffer, but in general globals are allocated + // through the memory manager which puts them near the code but not in the + // same buffer. + Type *GlobalType = GV->getType()->getElementType(); + size_t S = getDataLayout()->getTypeAllocSize(GlobalType); + size_t A = getDataLayout()->getPreferredAlignment(GV); + if (GV->isThreadLocal()) { + MutexGuard locked(lock); + Ptr = TJI.allocateThreadLocalMemory(S); + } else if (TJI.allocateSeparateGVMemory()) { + if (A <= 8) { + Ptr = (char*)malloc(S); + } else { + // Allocate S+A bytes of memory, then use an aligned pointer within that + // space. + Ptr = (char*)malloc(S+A); + unsigned MisAligned = ((intptr_t)Ptr & (A-1)); + Ptr = Ptr + (MisAligned ? (A-MisAligned) : 0); + } + } else if (AllocateGVsWithCode) { + Ptr = (char*)JCE->allocateSpace(S, A); + } else { + Ptr = (char*)JCE->allocateGlobal(S, A); + } + return Ptr; +} + +void JIT::addPendingFunction(Function *F) { + MutexGuard locked(lock); + jitstate->getPendingFunctions().push_back(F); +} + + +JITEventListener::~JITEventListener() {} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITExtend.h b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITExtend.h new file mode 100644 index 000000000000..0ce84e0aba95 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/Extend/JITExtend.h @@ -0,0 +1,232 @@ +//===-- JIT.h - Class definition for the JIT --------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the top-level JIT data structure. +// +//===----------------------------------------------------------------------===// + +#ifndef JITEXTEND_H +#define JITEXTEND_H + +#include +#include +#include +#include + +namespace llvm { + +class Function; +struct JITEvent_EmittedFunctionDetails; +class MachineCodeEmitter; +class MachineCodeInfo; +class TargetJITInfo; +class TargetMachine; + +class JITState { +private: + FunctionPassManager PM; // Passes to compile a function + Module *M; // Module used to create the PM + + /// PendingFunctions - Functions which have not been code generated yet, but + /// were called from a function being code generated. + std::vector > PendingFunctions; + +public: + explicit JITState(Module *M) : PM(M), M(M) {} + + FunctionPassManager &getPM() { + return PM; + } + + Module *getModule() const { return M; } + std::vector > &getPendingFunctions() { + return PendingFunctions; + } +}; + + +class JIT : public ExecutionEngine { + /// types + typedef ValueMap + BasicBlockAddressMapTy; + /// data + TargetMachine &TM; // The current target we are compiling to + TargetJITInfo &TJI; // The JITInfo for the target we are compiling to + JITCodeEmitter *JCE; // JCE object + JITMemoryManager *JMM; + std::vector EventListeners; + + /// AllocateGVsWithCode - Some applications require that global variables and + /// code be allocated into the same region of memory, in which case this flag + /// should be set to true. Doing so breaks freeMachineCodeForFunction. + bool AllocateGVsWithCode; + + /// True while the JIT is generating code. Used to assert against recursive + /// entry. + bool isAlreadyCodeGenerating; + + JITState *jitstate; + + /// BasicBlockAddressMap - A mapping between LLVM basic blocks and their + /// actualized version, only filled for basic blocks that have their address + /// taken. + BasicBlockAddressMapTy BasicBlockAddressMap; + + + JIT(Module *M, TargetMachine &tm, TargetJITInfo &tji, + JITMemoryManager *JMM, bool AllocateGVsWithCode); +public: + ~JIT(); + + static void Register() { + JITCtor = createJIT; + } + + /// getJITInfo - Return the target JIT information structure. + /// + TargetJITInfo &getJITInfo() const { return TJI; } + + /// create - Create an return a new JIT compiler if there is one available + /// for the current target. Otherwise, return null. + /// + static ExecutionEngine *create(Module *M, + std::string *Err, + JITMemoryManager *JMM, + CodeGenOpt::Level OptLevel = + CodeGenOpt::Default, + bool GVsWithCode = true, + Reloc::Model RM = Reloc::Default, + CodeModel::Model CMM = CodeModel::JITDefault) { + return ExecutionEngine::createJIT(M, Err, JMM, OptLevel, GVsWithCode, + RM, CMM); + } + + void addModule(Module *M) override; + + /// removeModule - Remove a Module from the list of modules. Returns true if + /// M is found. + bool removeModule(Module *M) override; + + /// runFunction - Start execution with the specified function and arguments. + /// + GenericValue runFunction(Function *F, + const std::vector &ArgValues) override; + + /// getPointerToNamedFunction - This method returns the address of the + /// specified function by using the MemoryManager. As such it is only + /// useful for resolving library symbols, not code generated symbols. + /// + /// If AbortOnFailure is false and no function with the given name is + /// found, this function silently returns a null pointer. Otherwise, + /// it prints a message to stderr and aborts. + /// + void *getPointerToNamedFunction(const std::string &Name, + bool AbortOnFailure = true) override; + + // CompilationCallback - Invoked the first time that a call site is found, + // which causes lazy compilation of the target function. + // + static void CompilationCallback(); + + /// getPointerToFunction - This returns the address of the specified function, + /// compiling it if necessary. + /// + void *getPointerToFunction(Function *F) override; + + /// addPointerToBasicBlock - Adds address of the specific basic block. + void addPointerToBasicBlock(const BasicBlock *BB, void *Addr); + + /// clearPointerToBasicBlock - Removes address of specific basic block. + void clearPointerToBasicBlock(const BasicBlock *BB); + + /// getPointerToBasicBlock - This returns the address of the specified basic + /// block, assuming function is compiled. + void *getPointerToBasicBlock(BasicBlock *BB) override; + + /// getOrEmitGlobalVariable - Return the address of the specified global + /// variable, possibly emitting it to memory if needed. This is used by the + /// Emitter. + void *getOrEmitGlobalVariable(const GlobalVariable *GV) override; + + /// getPointerToFunctionOrStub - If the specified function has been + /// code-gen'd, return a pointer to the function. If not, compile it, or use + /// a stub to implement lazy compilation if available. + /// + void *getPointerToFunctionOrStub(Function *F) override; + + /// recompileAndRelinkFunction - This method is used to force a function + /// which has already been compiled, to be compiled again, possibly + /// after it has been modified. Then the entry to the old copy is overwritten + /// with a branch to the new copy. If there was no old copy, this acts + /// just like JIT::getPointerToFunction(). + /// + void *recompileAndRelinkFunction(Function *F) override; + + /// freeMachineCodeForFunction - deallocate memory used to code-generate this + /// Function. + /// + void freeMachineCodeForFunction(Function *F) override; + + /// addPendingFunction - while jitting non-lazily, a called but non-codegen'd + /// function was encountered. Add it to a pending list to be processed after + /// the current function. + /// + void addPendingFunction(Function *F); + + /// getCodeEmitter - Return the code emitter this JIT is emitting into. + /// + JITCodeEmitter *getCodeEmitter() const { return JCE; } + + static ExecutionEngine *createJIT(Module *M, + std::string *ErrorStr, + JITMemoryManager *JMM, + bool GVsWithCode, + TargetMachine *TM); + + // Run the JIT on F and return information about the generated code + void runJITOnFunction(Function *F, MachineCodeInfo *MCI = nullptr) override; + + void RegisterJITEventListener(JITEventListener *L) override; + void UnregisterJITEventListener(JITEventListener *L) override; + + TargetMachine *getTargetMachine() override { return &TM; } + + /// These functions correspond to the methods on JITEventListener. They + /// iterate over the registered listeners and call the corresponding method on + /// each. + void NotifyFunctionEmitted( + const Function &F, void *Code, size_t Size, + const JITEvent_EmittedFunctionDetails &Details); + void NotifyFreeingMachineCode(void *OldPtr); + + BasicBlockAddressMapTy & + getBasicBlockAddressMap() { + return BasicBlockAddressMap; + } + + +private: + static JITCodeEmitter *createEmitter(JIT &J, JITMemoryManager *JMM, + TargetMachine &tm); + void runJITOnFunctionUnlocked(Function *F); + void updateFunctionStubUnlocked(Function *F); + void jitTheFunctionUnlocked(Function *F); + +protected: + + /// getMemoryforGV - Allocate memory for a global variable. + char* getMemoryForGV(const GlobalVariable* GV) override; + +}; + +const std::vector& GetJitMachineRelocations(const JIT* p_jit); + +} // End llvm namespace + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenUtils.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenUtils.cpp new file mode 100644 index 000000000000..99d6df608cde --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenUtils.cpp @@ -0,0 +1,343 @@ +#include "LlvmCodeGenUtils.h" + +#include "CompilationState.h" +#include "FreeForm2.h" +#include "FreeForm2Assert.h" +#include "FreeForm2Type.h" +#include +#include +#include +#include +#include +#include + +namespace +{ + using namespace FreeForm2; + + void + LLVMRetFailed(const char* p_file, + unsigned int p_line, + const char* p_description) + { + std::ostringstream err; + err << "Received NULL " << p_description << " from LLVM at " + << p_file << ":" << p_line; + throw std::runtime_error(err.str()); + } + + + std::string + NameBlock(const char* p_description, const char* p_block) + { + std::string ret(p_description); + ret += ": "; + ret += p_block; + return ret; + } + + + // Convert to bool by comparing the source value to zero. + llvm::Value* ConvertToBool(llvm::Value* p_value, + const TypeImpl& p_sourceType, + CompilationState& p_state) + { + llvm::Value& zero = p_state.CreateZeroValue(p_sourceType); + if (p_sourceType.IsIntegerType()) + { + return p_state.GetBuilder().CreateICmpNE(p_value, &zero); + } + else if (p_sourceType.IsFloatingPointType()) + { + return p_state.GetBuilder().CreateFCmpONE(p_value, &zero); + } + else + { + FF2_UNREACHABLE(); + } + } + + + // Convert any FreeForm2 primitive to any int type. + llvm::Value* ConvertToInt(llvm::Value* p_value, + const TypeImpl& p_sourceType, + const TypeImpl& p_destType, + CompilationState& p_state) + { + FF2_ASSERT(p_destType.IsIntegerType()); + llvm::Type& destType = p_state.GetType(p_destType); + const llvm::Type& sourceType = p_state.GetType(p_sourceType); + + if (p_sourceType.IsIntegerType()) + { + FF2_ASSERT(destType.getPrimitiveSizeInBits() > 0 + && sourceType.getPrimitiveSizeInBits() > 0); + + // For integer types of the same size, LLVM does not associate + // signed-ness with the type, so this case is an identity + // conversion. + if (destType.getPrimitiveSizeInBits() == sourceType.getPrimitiveSizeInBits()) + { + return p_value; + } + else if (p_destType.IsSigned() && p_sourceType.IsSigned()) + { + return p_state.GetBuilder().CreateSExt(p_value, &destType); + } + else + { + return p_state.GetBuilder().CreateZExt(p_value, &destType); + } + } + else if (p_sourceType.Primitive() == Type::Bool) + { + return p_state.GetBuilder().CreateZExt(p_value, &destType); + } + else if (p_sourceType.IsFloatingPointType()) + { + if (p_destType.IsSigned()) + { + return p_state.GetBuilder().CreateFPToSI(p_value, &destType); + } + else + { + return p_state.GetBuilder().CreateFPToUI(p_value, &destType); + } + } + else + { + FF2_UNREACHABLE(); + } + } + + + // Convert to a value to a floating-point type. Right now, there is only + // Type::Float, but this method is generic enough to support more floating- + // point types in the future. This method does not support identity + // conversion. + llvm::Value* ConvertToFloat(llvm::Value* p_value, + const TypeImpl& p_sourceType, + const TypeImpl& p_destType, + CompilationState& p_state) + { + llvm::Type& destType = p_state.GetType(p_destType); + const llvm::Type& sourceType = p_state.GetType(p_sourceType); + + if (p_sourceType.IsIntegerType() || p_sourceType.Primitive() == Type::Bool) + { + if (p_sourceType.IsSigned()) + { + return p_state.GetBuilder().CreateSIToFP(p_value, &destType); + } + else + { + return p_state.GetBuilder().CreateUIToFP(p_value, &destType); + } + } + else if (p_sourceType.IsFloatingPointType()) + { + // LLVM differentiates floating-point types based on their + // TypeID attributes, not on mantissa/exponent size, bit size, etc. + // For now, assert that they are the same LLVM type ID to prevent + // any issues. In the future, use fptrunc and fpext commands to + // convert between floating-point types. + FF2_ASSERT(destType.getTypeID() == sourceType.getTypeID()); + return p_value; + } + else + { + FF2_UNREACHABLE(); + } + } +} + + +FreeForm2::GenerateConditional::GenerateConditional(CompilationState& p_state, + llvm::Value& p_cond, + const char* p_description) + : m_state(&p_state), + m_description(p_description), + m_then(llvm::BasicBlock::Create(p_state.GetContext(), + llvm::Twine(NameBlock(p_description, "then")), + p_state.GetBuilder().GetInsertBlock()->getParent())), + m_else(llvm::BasicBlock::Create(p_state.GetContext(), + llvm::Twine(NameBlock(p_description, "else")))), + m_finalBlock(llvm::BasicBlock::Create(p_state.GetContext(), + llvm::Twine(NameBlock(p_description, "final")))) +{ + // Other blocks are checked by Case constructor. + CHECK_LLVM_RET(m_finalBlock); + + p_state.GetBuilder().CreateCondBr(&p_cond, m_then.m_block, m_else.m_block); + + m_then.m_block->moveAfter(p_state.GetBuilder().GetInsertBlock()); + + // Set up the builder to generate code into the 'then' block. + p_state.GetBuilder().SetInsertPoint(m_then.m_block); +} + + +void +FreeForm2::GenerateConditional::FinishThen(llvm::Value* p_value) +{ + CHECK_LLVM_RET(p_value); + + if (m_state->GetBuilder().GetInsertBlock()->getTerminator() == NULL) + { + m_then.m_value = p_value; + + // Finish up the 'then' block. + m_state->GetBuilder().CreateBr(m_finalBlock); + } + else + { + m_then.m_value = NULL; + } + + // Update 'then' block in case other codegen has altered it. + m_then.m_block = m_state->GetBuilder().GetInsertBlock(); + + // Set up the builder to generate code into the 'else' block. + llvm::Function* fun = m_state->GetBuilder().GetInsertBlock()->getParent(); + fun->getBasicBlockList().push_back(m_else.m_block); + m_else.m_block->moveAfter(m_then.m_block); + m_state->GetBuilder().SetInsertPoint(m_else.m_block); +} + + +void +FreeForm2::GenerateConditional::FinishElse(llvm::Value* p_value) +{ + CHECK_LLVM_RET(p_value); + + if (m_state->GetBuilder().GetInsertBlock()->getTerminator() == NULL) + { + m_else.m_value = p_value; + + // Finish up the 'else' block. + m_state->GetBuilder().CreateBr(m_finalBlock); + } + else + { + m_else.m_value = NULL; + } + + // Update 'else' block in case other codegen has altered it. + m_else.m_block = m_state->GetBuilder().GetInsertBlock(); +} + + +llvm::Value& +FreeForm2::GenerateConditional::Finish(llvm::Type& p_type) +{ + // Set up the builder to generate code into the final block. + llvm::Function* fun = m_state->GetBuilder().GetInsertBlock()->getParent(); + fun->getBasicBlockList().push_back(m_finalBlock); + m_finalBlock->moveAfter(m_else.m_block); + m_state->GetBuilder().SetInsertPoint(m_finalBlock); + + if (m_then.m_value != NULL && m_else.m_value != NULL) + { + llvm::PHINode* phi = m_state->GetBuilder().CreatePHI(&p_type, 2); + CHECK_LLVM_RET(phi); + phi->addIncoming(m_then.m_value, m_then.m_block); + phi->addIncoming(m_else.m_value, m_else.m_block); + return *phi; + } + else if (m_then.m_value != NULL) + { + return *m_then.m_value; + } + else if (m_else.m_value != NULL) + { + return *m_else.m_value; + } + else + { + return m_state->CreateVoidValue(); + } +} + + +llvm::Value& +FreeForm2::ValueConversion::Do(llvm::Value& p_value, + const TypeImpl& p_sourceType, + const TypeImpl& p_destType, + CompilationState& p_state) +{ + llvm::Value* ret = nullptr; + if (p_sourceType.IsSameAs(p_destType, true)) + { + ret = &p_value; + } + else if (p_destType.IsIntegerType()) + { + ret = ConvertToInt(&p_value, p_sourceType, p_destType, p_state); + } + else if (p_destType.IsFloatingPointType()) + { + ret = ConvertToFloat(&p_value, p_sourceType, p_destType, p_state); + } + else if (p_destType.Primitive() == Type::Bool) + { + ret = ConvertToBool(&p_value, p_sourceType, p_state); + } + CHECK_LLVM_RET(ret); + return *ret; +} + + +FreeForm2::GenerateConditional::Case::Case(llvm::BasicBlock* p_block) + : m_block(p_block), + m_value(NULL) +{ + CHECK_LLVM_RET(p_block); +} + + +void +FreeForm2::CheckLLVMRet(const llvm::Value* p_value, + const char* p_file, + unsigned int p_line) +{ + if (p_value == NULL) + { + LLVMRetFailed(p_file, p_line, "value"); + } +} + + +void +FreeForm2::CheckLLVMRet(const llvm::Type* p_type, + const char* p_file, + unsigned int p_line) +{ + if (p_type == NULL) + { + LLVMRetFailed(p_file, p_line, "type"); + } +} + + +void +FreeForm2::CheckLLVMRet(const llvm::BasicBlock* p_block, + const char* p_file, + unsigned int p_line) +{ + if (p_block == NULL) + { + LLVMRetFailed(p_file, p_line, "basic block"); + } +} + + +void +FreeForm2::CheckLLVMRet(const llvm::Instruction* p_ins, + const char* p_file, + unsigned int p_line) +{ + if (p_ins == NULL) + { + LLVMRetFailed(p_file, p_line, "instruction"); + } +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenUtils.h b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenUtils.h new file mode 100644 index 000000000000..b61cd3d3d405 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenUtils.h @@ -0,0 +1,103 @@ +#pragma once + +#ifndef FREEFORM2_CODEGENUTILS_H +#define FREEFORM2_CODEGENUTILS_H + +#include "FreeForm2Type.h" + +namespace llvm +{ + class Type; + class Value; + class BasicBlock; + class LLVMContext; + class Instruction; +} + +namespace FreeForm2 +{ + class CompilationState; + + class GenerateConditional + { + public: + // Constructor, taking the compilation state, conditional + // value and (optional) description of the conditional. + GenerateConditional(CompilationState& p_state, + llvm::Value& p_cond, + const char* p_description); + + // Note: call functions in the order presented. Pointers are + // checked for NULL, with exceptions thrown in that case. + + // Finish a 'then' block (you'll want to have generated some code + // between construction and FinishThen). + void FinishThen(llvm::Value* p_value); + + // Finish an 'else' block (you'll want to have generated some code + // between FinishThen and FinishElse). + void FinishElse(llvm::Value* p_value); + + // Finish generating the conditional, returning the value of the + // conditional. Call this last, and don't call anything else afterward. + llvm::Value& Finish(llvm::Type& p_type); + + private: + // Compilation state used for generation. + CompilationState* m_state; + + // (Optional) description of the conditional. + const char* m_description; + + // Structure holding information about conditional cases. + struct Case + { + Case(llvm::BasicBlock* p_block); + + llvm::BasicBlock* m_block; + llvm::Value* m_value; + }; + + // 'then' case. + Case m_then; + + // 'else' case. + Case m_else; + + // Basic block that merges all the cases together. + llvm::BasicBlock* m_finalBlock; + }; + + // This class offers a mechanism to convert LLVM value types among the + // various TypeImpl types supported in FF2. + class ValueConversion + { + public: + // Convert a value from a source type to a destination type. Depending + // on the type, zero or more expressions may be added to the + // compilation state. + static llvm::Value& Do(llvm::Value& p_value, + const TypeImpl& p_sourceType, + const TypeImpl& p_destType, + CompilationState& p_state); + }; + + // Macro to help with the constant checking of return values + // required to use LLVM. +#define CHECK_LLVM_RET(val) \ + if (val == NULL) \ + { \ + FreeForm2::CheckLLVMRet(val, __FILE__, __LINE__); \ + } + + // Functions to support CHECK_LLVM_RET above. + void CheckLLVMRet(const llvm::Value* p_value, const char* p_file, unsigned int p_line); + void CheckLLVMRet(const llvm::Type* p_type, const char* p_file, unsigned int p_line); + void CheckLLVMRet(const llvm::BasicBlock* p_block, const char* p_file, unsigned int p_line); + void CheckLLVMRet(const llvm::Instruction* p_ins, const char* p_file, unsigned int p_line); +}; + +#endif + + + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenerator.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenerator.cpp new file mode 100644 index 000000000000..c8cb82c25cb6 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenerator.cpp @@ -0,0 +1,2556 @@ +#include "LlvmCodeGenerator.h" + +#include "ArrayCodeGen.h" +#include "Allocation.h" +#include "ArrayLength.h" +#include "ArrayLiteralExpression.h" +#include "ArrayDereferenceExpression.h" +#include "BinaryOperator.h" +#include "BlockExpression.h" +#include "LlvmCodeGenUtils.h" +#include "ConvertExpression.h" +#include "Conditional.h" +#include "Declaration.h" +#include "FeatureSpec.h" +#include "FreeForm2.h" +#include "FreeForm2Assert.h" +#include "LetExpression.h" +#include "LiteralExpression.h" +#include +#include +#include +#include +#include +#include +#include "OperatorExpression.h" +#include "RandExpression.h" +#include "RangeReduceExpression.h" +#include "RefExpression.h" +#include "SelectNth.h" +#include "UnaryOperator.h" +#include +#include + +using namespace FreeForm2; + +namespace +{ +LlvmCodeGenerator::CompiledValue& +CompileFloatEquality(CompilationState& m_state, + LlvmCodeGenerator::CompiledValue& p_left, + LlvmCodeGenerator::CompiledValue& p_right) +{ + llvm::IRBuilder<>& builder = m_state.GetBuilder(); + + llvm::Type* type = p_left.getType(); + CHECK_LLVM_RET(type); + FF2_ASSERT(p_right.getType() && type->getTypeID() == p_right.getType()->getTypeID()); + + LlvmCodeGenerator::CompiledValue* small = llvm::ConstantFP::get(type, 10E-9); + CHECK_LLVM_RET(small); + LlvmCodeGenerator::CompiledValue* negSmall = llvm::ConstantFP::get(type, -10E-9); + CHECK_LLVM_RET(negSmall); + + // Determine whether the right expression is small. + LlvmCodeGenerator::CompiledValue* rightSmallCmp + = builder.CreateFCmp(llvm::CmpInst::FCMP_OLT, &p_right, small); + CHECK_LLVM_RET(rightSmallCmp); + LlvmCodeGenerator::CompiledValue* rightNegSmallCmp + = builder.CreateFCmp(llvm::CmpInst::FCMP_OGT, &p_right, negSmall); + CHECK_LLVM_RET(rightNegSmallCmp); + LlvmCodeGenerator::CompiledValue* rightIsSmall + = builder.CreateAnd(rightSmallCmp, rightNegSmallCmp); + CHECK_LLVM_RET(rightIsSmall); + + GenerateConditional cond(m_state, + *rightIsSmall, + "Approximate fp cmp: right small?"); + + // Determine whether the left expression is small. + LlvmCodeGenerator::CompiledValue* leftSmallCmp + = builder.CreateFCmp(llvm::CmpInst::FCMP_OLT, &p_left, small); + CHECK_LLVM_RET(leftSmallCmp); + LlvmCodeGenerator::CompiledValue* leftNegSmallCmp + = builder.CreateFCmp(llvm::CmpInst::FCMP_OGT, &p_left, negSmall); + CHECK_LLVM_RET(leftNegSmallCmp); + LlvmCodeGenerator::CompiledValue* leftIsSmall + = builder.CreateAnd(leftSmallCmp, leftNegSmallCmp); + CHECK_LLVM_RET(leftIsSmall); + cond.FinishThen(leftIsSmall); + + // Determine whether the difference between left and right is small. + LlvmCodeGenerator::CompiledValue* diff = builder.CreateFSub(&p_left, &p_right); + CHECK_LLVM_RET(diff); + LlvmCodeGenerator::CompiledValue* normDiff = builder.CreateFDiv(diff, &p_right); + CHECK_LLVM_RET(normDiff); + LlvmCodeGenerator::CompiledValue* diffSmallCmp + = builder.CreateFCmp(llvm::CmpInst::FCMP_OLT, normDiff, small); + CHECK_LLVM_RET(diffSmallCmp); + LlvmCodeGenerator::CompiledValue* diffNegSmallCmp + = builder.CreateFCmp(llvm::CmpInst::FCMP_OGT, normDiff, negSmall); + CHECK_LLVM_RET(diffNegSmallCmp); + LlvmCodeGenerator::CompiledValue* diffIsSmall + = builder.CreateAnd(diffSmallCmp, diffNegSmallCmp); + CHECK_LLVM_RET(diffIsSmall); + cond.FinishElse(diffIsSmall); + + return cond.Finish(m_state.GetBoolType()); +} + + +// Create a value which will evaluate to a random number in the range [0, 1.0]. +llvm::Value& CreateRandomFloat(CompilationState& p_state, + llvm::Type* p_floatType) +{ + llvm::IRBuilder<>& builder = p_state.GetBuilder(); + + llvm::Function* rand = p_state.GetRuntimeLibrary().FindFunction(CStackSizedString("rand")); + CHECK_LLVM_RET(rand); + + llvm::Value* ret = builder.CreateCall(rand); + CHECK_LLVM_RET(ret); + + if (p_floatType + && p_floatType->getPrimitiveSizeInBits() < rand->getReturnType()->getPrimitiveSizeInBits()) + { + ret = builder.CreateFPTrunc(ret, p_floatType); + CHECK_LLVM_RET(ret); + } + else if (p_floatType + && p_floatType->getPrimitiveSizeInBits() > rand->getReturnType()->getPrimitiveSizeInBits()) + { + ret = builder.CreateFPExt(ret, p_floatType); + CHECK_LLVM_RET(ret); + } + + return *ret; +} + +} + + +FreeForm2::LlvmCodeGenerator::LlvmCodeGenerator( + CompilationState& p_state, + const AllocationVector& p_allocations, + CompilerFactory::DestinationFunctionType p_destinationFunctionType) + : m_state(p_state), + m_allocations(p_allocations), + m_destinationFunctionType(p_destinationFunctionType), + m_returnType(nullptr), + m_returnValue(nullptr), + m_function(nullptr) +{ +} + + +LlvmCodeGenerator::CompiledValue* +FreeForm2::LlvmCodeGenerator::GetResult() +{ + return m_stack.top(); +} + +llvm::Function* +FreeForm2::LlvmCodeGenerator::GetFuction() const +{ + return m_function; +} + +void +FreeForm2::LlvmCodeGenerator::Visit(const SelectNthExpression& p_expr) +{ + // Ensure that the number of selection elements can be represented in an + // Int type. + const llvm::APInt maxInt = llvm::APInt::getSignedMaxValue(m_state.GetIntBits()); + const size_t sizetBitSize = sizeof(size_t) * 8; + llvm::APInt numChildren(sizetBitSize, p_expr.GetNumChildren() - 1); + if (m_state.GetIntBits() < sizetBitSize) + { + llvm::APInt trunc = numChildren.trunc(m_state.GetIntBits()); + FF2_ASSERT(trunc.getActiveBits() == numChildren.getActiveBits()); + numChildren = std::move(trunc); + } + + // Check that the unsigned Int version of numChildren is no larger than the + // max signed Int. + FF2_ASSERT(maxInt.uge(numChildren)); + + CompiledValue* index = m_stack.top(); + m_stack.pop(); + + FF2_ASSERT(p_expr.GetIndex().GetType().IsIntegerType()); + + // Promote index to an Int type. + index = &ValueConversion::Do(*index, + p_expr.GetIndex().GetType(), + TypeImpl::GetIntInstance(true), + m_state); + + CompiledValue* high = m_stack.top(); + m_stack.pop(); + + CompiledValue* select = high; + + // Loop backward through all children that aren't the + // last, creating a linear chained select. Note that this + // handles the single-child case by simply returning the + // value of that child. + for (llvm::APInt i(m_state.GetIntBits(), 1); i.slt(numChildren); ++i) + { + CompiledValue* current = m_stack.top(); + m_stack.pop(); + + const llvm::APInt child = numChildren - i - 1; + CompiledValue* childValue = llvm::ConstantInt::get(&m_state.GetIntType(), child); + CHECK_LLVM_RET(childValue); + CompiledValue* cond = m_state.GetBuilder().CreateICmpSLE(index, childValue); + CHECK_LLVM_RET(cond); + select = m_state.GetBuilder().CreateSelect(cond, current, select); + CHECK_LLVM_RET(select); + } + + // Check whether index is out-of-bounds. We use two's-complement-based + // cleverness to do this, by doing an unsigned comparison to the length of + // the array, and relying on two's complement representation to make + // negative numbers larger (in unsigned values) than the maximum index. + CompiledValue* bounds = llvm::ConstantInt::get(&m_state.GetIntType(), numChildren); + CHECK_LLVM_RET(bounds); + CompiledValue* inBounds = m_state.GetBuilder().CreateICmpULT(index, bounds); + CHECK_LLVM_RET(inBounds); + + // Select between value created before (in-bounds) and zero value + // (out-of-bounds). + CompiledValue* boundsSelect + = m_state.GetBuilder().CreateSelect(inBounds, + select, + &m_state.CreateZeroValue(p_expr.GetType())); + CHECK_LLVM_RET(boundsSelect); + + m_stack.push(boundsSelect); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const SelectRangeExpression& p_expr) +{ + llvm::IRBuilder<>& builder = m_state.GetBuilder(); + CompiledValue* array = m_stack.top(); + m_stack.pop(); + + CompiledValue* count = m_stack.top(); + m_stack.pop(); + + CompiledValue* start = m_stack.top(); + m_stack.pop(); + + FF2_ASSERT(p_expr.GetType().Primitive() == Type::Array); + const ArrayType& destType = static_cast(p_expr.GetType()); + + CompiledValue* bounds = builder.CreateExtractValue(array, ArrayCodeGen::boundsPosition); + CHECK_LLVM_RET(bounds); + + CompiledValue& dimensionBound = ArrayCodeGen::MaskBounds(m_state, *bounds); + + // Guard against a start value that is past the end of the array or less + // than zero. The less-than-zero condition is checked implicitly by doing + // unsigned comparison (start would become larger than dimensionBound). + CompiledValue* validateStart = builder.CreateICmpULT(start, &dimensionBound); + CHECK_LLVM_RET(validateStart); + + CompiledValue* countZero = llvm::ConstantInt::get(count->getType(), 0); + CHECK_LLVM_RET(countZero); + + // Guard against a count value that is less than or equal to zero. If the + // count falls into this range, the returned array should be empty. + CompiledValue* validateCount = builder.CreateICmpSGT(count, countZero); + CHECK_LLVM_RET(validateCount); + + CompiledValue* guardCondition = builder.CreateAnd(validateStart, validateCount); + CHECK_LLVM_RET(guardCondition); + + GenerateConditional startGuard(m_state, *guardCondition, "SelectRange start guard"); + + CompiledValue* srcCount = builder.CreateExtractValue(array, ArrayCodeGen::countPosition); + CHECK_LLVM_RET(srcCount); + + CompiledValue* subArrayCount = builder.CreateUDiv(srcCount, &dimensionBound); + CHECK_LLVM_RET(subArrayCount); + + // Compute the new pointer to the first element. + CompiledValue* oldPtr = builder.CreateExtractValue(array, ArrayCodeGen::pointerPosition); + CHECK_LLVM_RET(oldPtr); + + CompiledValue* startElemIndex = builder.CreateMul(subArrayCount, start); + CHECK_LLVM_RET(startElemIndex); + + CompiledValue* newPtr = builder.CreateInBoundsGEP(oldPtr, startElemIndex); + CHECK_LLVM_RET(newPtr); + + // Correct the count if necessary. + CompiledValue* maxCount = builder.CreateSub(&dimensionBound, start); + CHECK_LLVM_RET(maxCount); + + CompiledValue* checkCount = builder.CreateICmpSGT(count, maxCount); + CHECK_LLVM_RET(checkCount); + + CompiledValue* correctCount = builder.CreateSelect(checkCount, maxCount, count); + CHECK_LLVM_RET(correctCount); + + // Get the new bounds bit vector. + CompiledValue& subArrayBounds = ArrayCodeGen::ShiftBounds(m_state, *bounds, 1); + + CompiledValue& newBounds = ArrayCodeGen::UnshiftBound(m_state, subArrayBounds, *correctCount); + + // Get the new array count. + CompiledValue* newCount = builder.CreateMul(subArrayCount, correctCount); + CHECK_LLVM_RET(newCount); + + CompiledValue& newArray = ArrayCodeGen::CreateArray(m_state, destType, newBounds, *newCount, *newPtr); + + startGuard.FinishThen(&newArray); + + startGuard.FinishElse(&ArrayCodeGen::CreateEmptyArray(m_state, destType)); + + CompiledValue& returnVal = startGuard.Finish(*newArray.getType()); + m_stack.push(&returnVal); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ArrayLengthExpression& p_expr) +{ + CompiledValue* array = m_stack.top(); + m_stack.pop(); + + CompiledValue* bounds + = m_state.GetBuilder().CreateExtractValue(array, ArrayCodeGen::boundsPosition); + CHECK_LLVM_RET(bounds); + + // Length of the array we're dereferencing is in the least significant bits. + CompiledValue& length = ArrayCodeGen::MaskBounds(m_state, *bounds); + FF2_ASSERT(length.getType() && length.getType()->isIntegerTy()); + + llvm::Type& retType = m_state.GetType(p_expr.GetType()); + FF2_ASSERT(retType.isIntegerTy()); + + // Truncate the 64-bit return value to the return size, which should be 32 + // bits. + FF2_ASSERT(retType.getPrimitiveSizeInBits() < length.getType()->getPrimitiveSizeInBits()); + CompiledValue* ret = m_state.GetBuilder().CreateTrunc(&length, &retType); + CHECK_LLVM_RET(ret); + + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ArrayDereferenceExpression& p_expr) +{ + LlvmCodeGenerator::CompiledValue* index = m_stack.top(); + m_stack.pop(); + + LlvmCodeGenerator::CompiledValue* array = m_stack.top(); + m_stack.pop(); + + LlvmCodeGenerator::CompiledValue* bounds + = m_state.GetBuilder().CreateExtractValue(array, ArrayCodeGen::boundsPosition); + CHECK_LLVM_RET(bounds); + + // Length of the array we're dereferencing is in the least significant bits. + LlvmCodeGenerator::CompiledValue& arrayLength = ArrayCodeGen::MaskBounds(m_state, *bounds); + + // Check whether index is out-of-bounds. We use two's-complement-based + // cleverness to do this, by doing an unsigned comparison to the length of + // the array, and relying on two's complement representation to make + // negative numbers larger (in unsigned values) than the maximum index. + + // Compare index to length. + LlvmCodeGenerator::CompiledValue* inBounds + = m_state.GetBuilder().CreateICmpULT(index, &arrayLength); + CHECK_LLVM_RET(inBounds); + + GenerateConditional inBoundsCase(m_state, *inBounds, "array dereference guard"); + + LlvmCodeGenerator::CompiledValue* val = NULL; + LlvmCodeGenerator::CompiledValue* pointer = m_state.GetBuilder().CreateExtractValue(array, ArrayCodeGen::pointerPosition); + CHECK_LLVM_RET(pointer); + if (p_expr.GetType().Primitive() == Type::Array) + { + const ArrayType& arrayType = static_cast(p_expr.GetType()); + + LlvmCodeGenerator::CompiledValue& resultBounds = ArrayCodeGen::ShiftBounds(m_state, *bounds, 1); + + // The count of the sub-array being dereferenced is the count of the + // current array divided by the bounds of the current array. This is + // used to compute the offset of the base of the sub-array. + CompiledValue* arrayCount + = m_state.GetBuilder().CreateExtractValue(array, ArrayCodeGen::countPosition); + CHECK_LLVM_RET(arrayCount); + + CompiledValue* subArrayCount = m_state.GetBuilder().CreateUDiv(arrayCount, &arrayLength); + CHECK_LLVM_RET(subArrayCount); + + CompiledValue* indexSelect = m_state.GetBuilder().CreateMul(subArrayCount, index); + CHECK_LLVM_RET(indexSelect); + + LlvmCodeGenerator::CompiledValue* elementPtr + = m_state.GetBuilder().CreateInBoundsGEP(pointer, indexSelect); + CHECK_LLVM_RET(elementPtr); + + val = &ArrayCodeGen::CreateArray(m_state, arrayType, resultBounds, *subArrayCount, *elementPtr); + } + else + { + // Lookup index into pointer. + LlvmCodeGenerator::CompiledValue* elementPtr + = m_state.GetBuilder().CreateInBoundsGEP(pointer, index); + CHECK_LLVM_RET(elementPtr); + + val = m_state.GetBuilder().CreateLoad(elementPtr); + } + CHECK_LLVM_RET(val); + + inBoundsCase.FinishThen(val); + + // If the index wasn't in bounds, give back zero. + inBoundsCase.FinishElse(&m_state.CreateZeroValue(p_expr.GetType())); + llvm::Value* ret = &inBoundsCase.Finish(m_state.GetType(p_expr.GetType())); + + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitReference(const ArrayDereferenceExpression& p_expr) +{ + // References are not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Allocate(const Allocation& p_allocation) +{ + switch (p_allocation.GetAllocationType()) + { + case Allocation::ArrayLiteral: + { + FF2_ASSERT(p_allocation.GetType().Primitive() == Type::Array); + const ArrayType& arrayInfo = static_cast(p_allocation.GetType()); + + PreAllocatedArray preAllocatedArray; + + // Calculate array bounds. + std::pair encoded + = ArrayCodeGen::EncodeDimensions(arrayInfo); + preAllocatedArray.m_bounds + = llvm::ConstantInt::get(m_state.GetContext(), + llvm::APInt(sizeof(encoded.first) * 8, encoded.first)); + CHECK_LLVM_RET(preAllocatedArray.m_bounds); + + // Add calculated element count. + FF2_ASSERT(encoded.second == arrayInfo.GetMaxElements()); + FF2_ASSERT(encoded.second == p_allocation.GetNumChildren()); + preAllocatedArray.m_count + = llvm::ConstantInt::get(m_state.GetContext(), + llvm::APInt(sizeof(encoded.second) * 8, encoded.second)); + CHECK_LLVM_RET(preAllocatedArray.m_count); + + preAllocatedArray.m_array = NULL; + llvm::Type& elementType = m_state.GetType(arrayInfo.GetChildType()); + if (arrayInfo.GetMaxElements() > 0) + { + // Add array pointer to structure. + LlvmCodeGenerator::CompiledValue* arraySize + = llvm::ConstantInt::get(m_state.GetContext(), + llvm::APInt(sizeof(encoded.second) * 8, + arrayInfo.GetMaxElements())); + CHECK_LLVM_RET(arraySize); + preAllocatedArray.m_array = m_state.GetBuilder().CreateAlloca(&elementType, arraySize); + } + else + { + // Handle zero-length case by assigning NULL, as alloca with zero bytes + // invokes undefined LLVM behaviour. + llvm::Type* nullType = llvm::PointerType::get(&elementType, 0); + CHECK_LLVM_RET(nullType); + preAllocatedArray.m_array = llvm::Constant::getNullValue(nullType); + } + CHECK_LLVM_RET(preAllocatedArray.m_array); + + m_allocatedArrays.insert(std::make_pair(p_allocation.GetAllocationId(), preAllocatedArray)); + + break; + } + + // Declarations not implemented in FF2. + case Allocation::Declaration: __attribute__((__fallthrough__)); + + // Literal streams not implemented in FF2 + case Allocation::LiteralStream: __attribute__((__fallthrough__)); + + default: + { + FF2_UNREACHABLE(); + } + } +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ArrayLiteralExpression& p_expr) +{ + FF2_ASSERT(m_allocatedArrays.find(p_expr.GetId()) != m_allocatedArrays.end()); + FF2_ASSERT(p_expr.GetType().Primitive() == Type::Array); + + PreAllocatedArray preAllocatedArray = m_allocatedArrays[p_expr.GetId()]; + + // Populate structure with values. + for (unsigned int i = 0; i < p_expr.GetNumChildren(); i++) + { + LlvmCodeGenerator::CompiledValue* element = m_state.GetBuilder().CreateConstInBoundsGEP1_32(preAllocatedArray.m_array, i); + CHECK_LLVM_RET(element); + m_state.GetBuilder().CreateStore(m_stack.top(), element); + m_stack.pop(); + } + + m_stack.push(&ArrayCodeGen::CreateArray(m_state, + static_cast(p_expr.GetType()), + *preAllocatedArray.m_bounds, + *preAllocatedArray.m_count, + *preAllocatedArray.m_array)); +} + + +bool +FreeForm2::LlvmCodeGenerator::AlternativeVisit(const ConditionalExpression& p_expr) +{ + p_expr.GetCondition().Accept(*this); + GenerateConditional cond(m_state, *m_stack.top(), "if"); + m_stack.pop(); + + p_expr.GetThen().Accept(*this); + cond.FinishThen(m_stack.top()); + m_stack.pop(); + + p_expr.GetElse().Accept(*this); + cond.FinishElse(m_stack.top()); + m_stack.pop(); + + m_stack.push(&cond.Finish(m_state.GetType(p_expr.GetType()))); + + return true; +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ConditionalExpression& p_expr) +{ + // This should have been managed by AlternativeVisit. + Unreachable(__FILE__, __LINE__); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ConvertToFloatExpression& p_expr) +{ + CompiledValue* child = m_stack.top(); + m_stack.pop(); + + m_stack.push(&ValueConversion::Do(*child, p_expr.GetChildType(), p_expr.GetType(), m_state)); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ConvertToIntExpression& p_expr) +{ + CompiledValue* child = m_stack.top(); + m_stack.pop(); + + m_stack.push(&ValueConversion::Do(*child, p_expr.GetChildType(), p_expr.GetType(), m_state)); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ConvertToUInt64Expression& p_expr) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ConvertToInt32Expression& p_expr) +{ + CompiledValue* child = m_stack.top(); + m_stack.pop(); + + m_stack.push(&ValueConversion::Do(*child, p_expr.GetChildType(), p_expr.GetType(), m_state)); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ConvertToUInt32Expression& p_expr) +{ + CompiledValue* child = m_stack.top(); + m_stack.pop(); + + m_stack.push(&ValueConversion::Do(*child, p_expr.GetChildType(), p_expr.GetType(), m_state)); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ConvertToBoolExpression& p_expr) +{ + LlvmCodeGenerator::CompiledValue* child = m_stack.top(); + m_stack.pop(); + + m_stack.push(&ValueConversion::Do(*child, p_expr.GetChildType(), p_expr.GetType(), m_state)); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ConvertToImperativeExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const DeclarationExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const DirectPublishExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ExternExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LiteralIntExpression& p_expr) +{ + const llvm::Type& returnType = m_state.GetType(p_expr.GetType()); + FF2_ASSERT(returnType.isIntegerTy()); + const llvm::APInt value(returnType.getPrimitiveSizeInBits(), + static_cast(p_expr.GetConstantValue().m_int), + true); + CompiledValue* ret = llvm::ConstantInt::get(m_state.GetContext(), value); + CHECK_LLVM_RET(ret); + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LiteralUInt64Expression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LiteralInt32Expression& p_expr) +{ + const llvm::Type& returnType = m_state.GetType(p_expr.GetType()); + FF2_ASSERT(returnType.isIntegerTy()); + + // Sign-extend the value to 64-bits; direct cast to UInt64 would lose data. + const Int64 intValue = p_expr.GetConstantValue().m_int32; + const llvm::APInt value(returnType.getPrimitiveSizeInBits(), static_cast(intValue), true); + CompiledValue* ret = llvm::ConstantInt::get(m_state.GetContext(), value); + CHECK_LLVM_RET(ret); + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LiteralUInt32Expression& p_expr) +{ + const llvm::Type& returnType = m_state.GetType(p_expr.GetType()); + FF2_ASSERT(returnType.isIntegerTy()); + const llvm::APInt value(returnType.getPrimitiveSizeInBits(), + p_expr.GetConstantValue().m_int, + false); + CompiledValue* ret = llvm::ConstantInt::get(m_state.GetContext(), value); + CHECK_LLVM_RET(ret); + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LiteralFloatExpression& p_expr) +{ + CompiledValue* ret + = llvm::ConstantFP::get(m_state.GetContext(), + llvm::APFloat(p_expr.GetConstantValue().m_float)); + CHECK_LLVM_RET(ret); + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LiteralBoolExpression& p_expr) +{ + CompiledValue* ret + = llvm::ConstantInt::get(m_state.GetContext(), + llvm::APInt(1, p_expr.GetConstantValue().m_bool ? 1 : 0)); + CHECK_LLVM_RET(ret); + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LiteralVoidExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LiteralStreamExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LiteralWordExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LiteralInstanceHeaderExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +bool +FreeForm2::LlvmCodeGenerator::AlternativeVisit(const LetExpression& p_expr) +{ + for (size_t i = 0; i + 1 < p_expr.GetNumChildren(); i++) + { + p_expr.GetBound()[i].second->Accept(*this); + + LlvmCodeGenerator::CompiledValue* value = m_stack.top(); + m_stack.pop(); + + // Store value for later use. + m_state.SetVariableValue(p_expr.GetBound()[i].first, *value); + } + + p_expr.GetValue().Accept(*this); + + return true; +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const LetExpression&) +{ + // Handled by AlternativeVisit. + Unreachable(__FILE__, __LINE__); +} + + +bool +FreeForm2::LlvmCodeGenerator::AlternativeVisit(const BlockExpression& p_expr) +{ + const auto children = static_cast(p_expr.GetNumChildren()); + FF2_ASSERT(children == p_expr.GetNumChildren()); + + for (unsigned int i = 0; i < children; i++) + { + p_expr.GetChild(i).Accept(*this); + } + + // Block expressions visit children from top-to-bottom. This means that + // the last child is on top of the stack. + LlvmCodeGenerator::CompiledValue& ret = *m_stack.top(); + + // Remove all other values from the stack. + for (unsigned int i = 0; i < children; i++) + { + m_stack.pop(); + } + + // Return result value to the stack. + m_stack.push(&ret); + + return true; +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const BlockExpression&) +{ + // This should have been managed by AlternativeVisit. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const MutationExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const MatchExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +// void +// FreeForm2::LlvmCodeGenerator::Visit(const MatchWordExpression&) +// { +// // Not supported in FF2. +// FF2_UNREACHABLE(); +// } + + +// void +// FreeForm2::LlvmCodeGenerator::Visit(const MatchLiteralExpression&) +// { +// // Not supported in FF2. +// FF2_UNREACHABLE(); +// } + + +// void +// FreeForm2::LlvmCodeGenerator::Visit(const MatchCurrentWordExpression&) +// { +// // Not supported in FF2. +// FF2_UNREACHABLE(); +// } + + +void +FreeForm2::LlvmCodeGenerator::Visit(const MatchOperatorExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const MatchGuardExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const MatchBindExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const MemberAccessExpression&) +{ + // Not yet supported. + Unreachable(__FILE__, __LINE__); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitReference(const MemberAccessExpression&) +{ + // Not yet supported. + Unreachable(__FILE__, __LINE__); +} + + +//void +//FreeForm2::LlvmCodeGenerator::Visit(const ObjectMethodExpression& p_expr) +//{ +// // Not supported in FF2. +// FF2_UNREACHABLE(); +//} + + +//void +//FreeForm2::LlvmCodeGenerator::Visit(const NeuralInputResultExpression& p_expr) +//{ +// // Only outputting to a float array is allowed at this time. +// FF2_ASSERT(p_expr.m_child.GetType().Primitive() == Type::Float); +// +// // Treat output as an array of float +// auto cast +// = m_state.GetBuilder().CreatePointerCast(&m_state.GetArrayReturnSpace(), +// &m_state.GetFloatPtrType()); +// CHECK_LLVM_RET(cast); +// +// // Get pointer to the target element in the output array. +// auto element +// = m_state.GetBuilder().CreateConstInBoundsGEP1_32(cast, +// p_expr.m_index); +// CHECK_LLVM_RET(element); +// +// // Store value in the target. +// m_state.GetBuilder().CreateStore(m_stack.top(), element); +//} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const PhiNodeExpression& p_expr) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const PublishExpression& p_expr) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const FeatureRefExpression& p_expr) +{ + FF2_ASSERT(m_state.GetFeatureArgument() != NULL); + + // Call below encodes size of m_index, so we check it's what we expect. + llvm::Value* address + = m_state.GetBuilder().CreateConstInBoundsGEP1_32(m_state.GetFeatureArgument(), + p_expr.m_index, + llvm::Twine("feature array access")); + CHECK_LLVM_RET(address); + llvm::Value* val = m_state.GetBuilder().CreateLoad(address); + CHECK_LLVM_RET(val); + + // We ensure that there's enough space in the integer type to take the + // features, plus at least a sign bit. Zero-extend the feature value into a + // full integer. + BOOST_STATIC_ASSERT(sizeof(Result::IntType) > sizeof(Expression::FeatureType)); + FF2_ASSERT(p_expr.GetType().Primitive() == Type::Int); + val = m_state.GetBuilder().CreateZExt(val, &m_state.GetIntType()); + CHECK_LLVM_RET(val); + m_stack.push(val); +} + + +//void +//FreeForm2::LlvmCodeGenerator::Visit(const FSMExpression&) +//{ +// // Not supported in FF2. +// Unreachable(__FILE__, __LINE__); +//} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const FunctionExpression&) +{ + // Not supported in FF2. + Unreachable(__FILE__, __LINE__); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const FunctionCallExpression&) +{ + // Not supported in FF2. + Unreachable(__FILE__, __LINE__); +} + + +bool +FreeForm2::LlvmCodeGenerator::AlternativeVisit(const RangeReduceExpression& p_expr) +{ + llvm::IRBuilder<>& builder = m_state.GetBuilder(); + FF2_ASSERT(p_expr.GetReduceId() != VariableID::c_invalidID); + + p_expr.GetLow().Accept(*this); + CompiledValue& index = *m_stack.top(); + m_stack.pop(); + + p_expr.GetHigh().Accept(*this); + CompiledValue& limit = *m_stack.top(); + m_stack.pop(); + + CompiledValue* step = llvm::ConstantInt::get(m_state.GetContext(), + llvm::APInt(m_state.GetIntBits(), 1)); + CHECK_LLVM_RET(step); + + p_expr.GetInitial().Accept(*this); + CompiledValue& initial = *m_stack.top(); + m_stack.pop(); + + // Save current state. + llvm::BasicBlock* originalBlock = builder.GetInsertBlock(); + llvm::Function* currentFun = originalBlock->getParent(); + + // Create a block for the accumulator and loop-var phi nodes, as well as + // the condition. + llvm::BasicBlock* condBlock + = llvm::BasicBlock::Create(m_state.GetContext(), llvm::Twine("range-reduce-condition"), currentFun); + CHECK_LLVM_RET(condBlock); + + // Create and insert a basic block that encapsulates the loop, and a block + // for the code after the loop. + llvm::BasicBlock* loopBlock + = llvm::BasicBlock::Create(m_state.GetContext(), llvm::Twine("range-reduce-loop"), currentFun); + CHECK_LLVM_RET(loopBlock); + llvm::BasicBlock* afterLoopBlock + = llvm::BasicBlock::Create(m_state.GetContext(), llvm::Twine("range-reduce-after"), currentFun); + CHECK_LLVM_RET(afterLoopBlock); + + const CompiledValue* startBranch = builder.CreateBr(condBlock); + CHECK_LLVM_RET(startBranch); + + builder.SetInsertPoint(condBlock); + + // Create a PHI node to unify the values of the loop index, and one to unify + // values of the accumulator variable. + llvm::PHINode* loopVar = builder.CreatePHI(index.getType(), 2, llvm::Twine("range-reduce-loop-var")); + CHECK_LLVM_RET(loopVar); + loopVar->addIncoming(&index, originalBlock); + m_state.SetVariableValue(p_expr.GetStepId(), *loopVar); + + FF2_ASSERT(p_expr.GetType().Primitive() != Type::Void); + llvm::PHINode* accVar = builder.CreatePHI(initial.getType(), 2, llvm::Twine("range-reduce-acc-var")); + CHECK_LLVM_RET(accVar); + accVar->addIncoming(&initial, originalBlock); + m_state.SetVariableValue(p_expr.GetReduceId(), *accVar); + + // Visit the initial condition (that the low limit is less than the high limit). + // This guards against iterating through the loop once under these conditions. + CompiledValue* initialCond = builder.CreateICmpSLT(loopVar, &limit); + CHECK_LLVM_RET(initialCond); + + builder.CreateCondBr(initialCond, loopBlock, afterLoopBlock); + builder.SetInsertPoint(loopBlock); + + p_expr.GetReduceExpression().Accept(*this); + + LlvmCodeGenerator::CompiledValue& loopValue = *m_stack.top(); + m_stack.pop(); + + // Add step to the loop variable. + llvm::Type* types[] = { &m_state.GetIntType() }; + llvm::Function* add = llvm::Intrinsic::getDeclaration(&m_state.GetModule(), + llvm::Intrinsic::sadd_with_overflow, + types); + CHECK_LLVM_RET(add); + CompiledValue* addCall = builder.CreateCall2(add, loopVar, step); + CHECK_LLVM_RET(addCall); + CompiledValue* inc = builder.CreateExtractValue(addCall, 0); + CHECK_LLVM_RET(inc); + + CompiledValue* jumpCond = builder.CreateExtractValue(addCall, 1); + CHECK_LLVM_RET(jumpCond); + + // Create a conditional branch to the condition block; overflow causes + // a loop to break. + const CompiledValue* jump + = builder.CreateCondBr(jumpCond, afterLoopBlock, condBlock); + CHECK_LLVM_RET(jump); + + // Add new values to PHI nodes. + llvm::BasicBlock* loopEndBlock = builder.GetInsertBlock(); + loopVar->addIncoming(inc, loopEndBlock); + + builder.SetInsertPoint(afterLoopBlock); + + accVar->addIncoming(&loopValue, loopEndBlock); + + // Create a PHI node in case the loop was skipped entirely + llvm::PHINode* end = builder.CreatePHI(initial.getType(), 2, llvm::Twine("range-reduce-skip-acc")); + end->addIncoming(accVar, condBlock); + end->addIncoming(&loopValue, loopEndBlock); + + m_stack.push(end); + return true; +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ForEachLoopExpression&) +{ + // TODO: Implement (TFS #461742) + Unreachable(__FILE__, __LINE__); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ComplexRangeLoopExpression&) +{ + // TODO: Implement (TFS #461742) + Unreachable(__FILE__, __LINE__); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const RangeReduceExpression& p_expr) +{ + // Handled by AlternativeVisit. + Unreachable(__FILE__, __LINE__); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const UnaryOperatorExpression& p_expr) +{ + switch (p_expr.m_op) + { + case UnaryOperator::minus: + { + VisitUnaryMinus(p_expr); + break; + } + + case UnaryOperator::log: + { + VisitUnaryLog(p_expr, false); + break; + } + + case UnaryOperator::log1: + { + VisitUnaryLog(p_expr, true); + break; + } + + case UnaryOperator::abs: + { + VisitUnaryAbs(p_expr); + break; + } + + case UnaryOperator::_not: + case UnaryOperator::bitnot: + { + VisitUnaryNot(p_expr); + break; + } + + case UnaryOperator::round: + { + VisitUnaryRound(p_expr); + break; + } + + case UnaryOperator::trunc: + { + VisitUnaryTrunc(p_expr); + break; + } + + default: + { + Unreachable(__FILE__, __LINE__); + break; + } + }; +} + + +void +FreeForm2::LlvmCodeGenerator::VisitUnaryMinus(const UnaryOperatorExpression& p_expr) +{ + LlvmCodeGenerator::CompiledValue* child = m_stack.top(); + m_stack.pop(); + + LlvmCodeGenerator::CompiledValue* neg; + + if (p_expr.GetType().IsIntegerType()) + { + neg = m_state.GetBuilder().CreateNeg(child); + } + else if (p_expr.GetType().IsFloatingPointType()) + { + neg = m_state.GetBuilder().CreateFNeg(child); + } + else + { + FF2_UNREACHABLE(); + } + CHECK_LLVM_RET(neg); + m_stack.push(neg); +} + + +LlvmCodeGenerator::CompiledValue& +CompileLogCall(CompilationState& p_state, + llvm::ArrayRef p_type, + LlvmCodeGenerator::CompiledValue& p_child) +{ + llvm::Function* fun = llvm::Intrinsic::getDeclaration(&p_state.GetModule(), + llvm::Intrinsic::log, + p_type); + CHECK_LLVM_RET(fun); + LlvmCodeGenerator::CompiledValue* log = p_state.GetBuilder().CreateCall(fun, &p_child); + CHECK_LLVM_RET(log); + return *log; +} + + +void +FreeForm2::LlvmCodeGenerator::VisitUnaryLog(const UnaryOperatorExpression& p_expr, + bool p_addOne) +{ + FF2_ASSERT(p_expr.GetType().IsFloatingPointType()); + FF2_ASSERT(p_expr.m_child.GetType().IsFloatingPointType() + || p_expr.m_child.GetType().IsIntegerType()); + + LlvmCodeGenerator::CompiledValue* arg = m_stack.top(); + m_stack.pop(); + + // Ensure the argument is a float. + arg = &ValueConversion::Do(*arg, p_expr.m_child.GetType(), p_expr.GetType(), m_state); + + llvm::Type& floatType = m_state.GetType(p_expr.GetType()); + if (p_addOne) + { + LlvmCodeGenerator::CompiledValue* one + = llvm::ConstantFP::get(&floatType, 1); + CHECK_LLVM_RET(one); + arg = m_state.GetBuilder().CreateFAdd(arg, one); + CHECK_LLVM_RET(arg); + } + + // Guard against values zero or less, for which we always return + // negative infinity. + LlvmCodeGenerator::CompiledValue* zero + = llvm::ConstantFP::get(&floatType, 0); + CHECK_LLVM_RET(zero); + LlvmCodeGenerator::CompiledValue* cmp + = m_state.GetBuilder().CreateFCmpOGT(arg, zero); + CHECK_LLVM_RET(cmp); + + LlvmCodeGenerator::CompiledValue* log + = &CompileLogCall(m_state, &floatType, *arg); + + // Select the value to return. + LlvmCodeGenerator::CompiledValue* negInfinity + = llvm::ConstantFP::getInfinity(&floatType, true); + CHECK_LLVM_RET(negInfinity); + LlvmCodeGenerator::CompiledValue* select + = m_state.GetBuilder().CreateSelect(cmp, log, negInfinity); + CHECK_LLVM_RET(select); + + m_stack.push(select); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitUnaryNot(const UnaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType().Primitive() == Type::Bool + || p_expr.GetType().IsIntegerType()); + LlvmCodeGenerator::CompiledValue* child = m_stack.top(); + m_stack.pop(); + + LlvmCodeGenerator::CompiledValue* ret = m_state.GetBuilder().CreateNot(child); + CHECK_LLVM_RET(ret); + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitUnaryAbs(const UnaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType().IsFloatingPointType() || p_expr.GetType().IsIntegerType()); + + CompiledValue* child = m_stack.top(); + m_stack.pop(); + + CompiledValue* zero = &m_state.CreateZeroValue(p_expr.GetType()); + CompiledValue* cond = nullptr; + CompiledValue* neg = nullptr; + + if (p_expr.GetType().IsIntegerType()) + { + // Generate the condition. + cond = m_state.GetBuilder().CreateICmpSGE(child, zero, llvm::Twine("abs test")); + + // Create negation. + neg = m_state.GetBuilder().CreateSub(zero, child); + } + else + { + // Generate the condition. + cond = m_state.GetBuilder().CreateFCmpOGE(child, zero, llvm::Twine("abs test")); + + // Create negation. + neg = m_state.GetBuilder().CreateFSub(zero, child); + } + CHECK_LLVM_RET(cond); + CHECK_LLVM_RET(neg); + + // Select between them. + LlvmCodeGenerator::CompiledValue* select + = m_state.GetBuilder().CreateSelect(cond, child, neg); + CHECK_LLVM_RET(select); + + m_stack.push(select); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitUnaryRound(const UnaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType().IsIntegerType()); + FF2_ASSERT(p_expr.m_child.GetType().IsFloatingPointType() + || p_expr.m_child.GetType().IsIntegerType()); + if (p_expr.m_child.GetType().IsIntegerType()) + { + FF2_ASSERT(p_expr.GetType().IsSameAs(p_expr.m_child.GetType(), true)); + + // No work is necessary to round an int. + return; + } + + CompiledValue* child = m_stack.top(); + m_stack.pop(); + + CompiledValue* ret = NULL; + llvm::IRBuilder<>& builder = m_state.GetBuilder(); + + llvm::Type& floatType = m_state.GetType(p_expr.m_child.GetType()); + llvm::Type& intType = m_state.GetType(p_expr.GetType()); + + CompiledValue* zero = llvm::ConstantFP::get(&floatType, 0.0); + CHECK_LLVM_RET(zero); + CompiledValue* cmp = builder.CreateFCmpOGE(child, zero, llvm::Twine("round comparison")); + CHECK_LLVM_RET(cmp); + + CompiledValue* half = llvm::ConstantFP::get(&floatType, 0.5); + CHECK_LLVM_RET(half); + + GenerateConditional cond(m_state, *cmp, "round test"); + + // Generate condition where value is over zero, and we add 0.5 to push + // toward infinity. + CompiledValue* plus = builder.CreateFAdd(child, half); + CHECK_LLVM_RET(plus); + CompiledValue* plusret = builder.CreateFPToSI(plus, &intType); + CHECK_LLVM_RET(plusret); + cond.FinishThen(plusret); + + // Generate condition where value is under zero, and we subtract 0.5 + // to push toward infinity. + CompiledValue* minus = builder.CreateFSub(child, half); + CHECK_LLVM_RET(minus); + CompiledValue* minusret = builder.CreateFPToSI(minus, &intType); + CHECK_LLVM_RET(minusret); + + cond.FinishElse(minusret); + + ret = &cond.Finish(intType); + + m_stack.push(ret); +} + + +void FreeForm2::LlvmCodeGenerator::VisitUnaryTrunc(const UnaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType().IsIntegerType()); + if (p_expr.m_child.GetType().IsIntegerType()) + { + // This operator covers only floating-point-to-int truncation. Int-to- + // int truncation is covert by the ConvertTo*Expressions. + FF2_ASSERT(p_expr.m_child.GetType().IsSameAs(p_expr.GetType(), true)); + return; + } + else + { + FF2_ASSERT(p_expr.m_child.GetType().IsFloatingPointType()); + } + + LlvmCodeGenerator::CompiledValue* child = m_stack.top(); + m_stack.pop(); + + CompiledValue* ret + = &ValueConversion::Do(*child, p_expr.m_child.GetType(), p_expr.GetType(), m_state); + + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const BinaryOperatorExpression& p_expr) +{ + // All the operands in the binary expression are pushed into the stack + // left-to-right. The binary operators expect the operands to be + // in reverse order, so a queue is used to invert that part of the stack. + std::queue tmpQueue; + + for (size_t i = 0; i < p_expr.GetNumChildren(); i++) + { + tmpQueue.push(m_stack.top()); + m_stack.pop(); + } + + for (size_t i = 0; i < p_expr.GetNumChildren(); i++) + { + m_stack.push(tmpQueue.front()); + tmpQueue.pop(); + } + + for (size_t i = 1; i < p_expr.GetNumChildren(); i++) + { + switch (p_expr.GetOperator()) + { + case BinaryOperator::plus: + { + VisitPlus(p_expr); + break; + } + + case BinaryOperator::minus: + { + VisitMinus(p_expr); + break; + } + + case BinaryOperator::multiply: + { + VisitMultiply(p_expr); + break; + } + + case BinaryOperator::divides: + { + VisitDivides(p_expr); + break; + } + + case BinaryOperator::mod: + { + VisitMod(p_expr); + break; + } + + case BinaryOperator::_and: + case BinaryOperator::_bitand: + { + VisitAnd(p_expr); + break; + } + + case BinaryOperator::_or: + case BinaryOperator::_bitor: + { + VisitOr(p_expr); + break; + } + + case BinaryOperator::log: + { + VisitLog(p_expr); + break; + } + + case BinaryOperator::pow: + { + VisitPow(p_expr); + break; + } + + case BinaryOperator::max: + { + VisitMaxMin(p_expr, false); + break; + } + + case BinaryOperator::min: + { + VisitMaxMin(p_expr, true); + break; + } + + case BinaryOperator::eq: + { + VisitEqual(p_expr, false); + break; + } + + case BinaryOperator::neq: + { + VisitEqual(p_expr, true); + break; + } + + case BinaryOperator::lt: + { + VisitCompare(p_expr, true, false); + break; + } + + case BinaryOperator::lte: + { + VisitCompare(p_expr, true, true); + break; + } + + case BinaryOperator::gt: + { + VisitCompare(p_expr, false, false); + break; + } + + case BinaryOperator::gte: + { + VisitCompare(p_expr, false, true); + break; + } + + default: + { + Unreachable(__FILE__, __LINE__); + break; + } + }; + } +} + + + + +void +FreeForm2::LlvmCodeGenerator::VisitPlus(const BinaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType() == p_expr.GetChildType()); + + CompiledValue* left = m_stack.top(); + m_stack.pop(); + + CompiledValue* right = m_stack.top(); + m_stack.pop(); + + CompiledValue* ret = NULL; + if (p_expr.GetType().IsIntegerType()) + { + ret = m_state.GetBuilder().CreateAdd(left, right); + } + else if (p_expr.GetType().IsFloatingPointType()) + { + ret = m_state.GetBuilder().CreateFAdd(left, right); + } + else + { + // Shouldn't get here. + FF2_UNREACHABLE(); + } + CHECK_LLVM_RET(ret); + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitMinus(const BinaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType() == p_expr.GetChildType()); + + CompiledValue* left = m_stack.top(); + m_stack.pop(); + + CompiledValue* right = m_stack.top(); + m_stack.pop(); + + CompiledValue* ret = NULL; + if (p_expr.GetType().IsIntegerType()) + { + ret = m_state.GetBuilder().CreateSub(left, right); + } + else if (p_expr.GetType().IsFloatingPointType()) + { + ret = m_state.GetBuilder().CreateFSub(left, right); + } + else + { + // Shouldn't get here. + FF2_UNREACHABLE(); + } + CHECK_LLVM_RET(ret); + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitMultiply(const BinaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType() == p_expr.GetChildType()); + + LlvmCodeGenerator::CompiledValue* left = m_stack.top(); + m_stack.pop(); + + LlvmCodeGenerator::CompiledValue* right = m_stack.top(); + m_stack.pop(); + + LlvmCodeGenerator::CompiledValue* ret = NULL; + if (p_expr.GetType().IsIntegerType()) + { + ret = m_state.GetBuilder().CreateMul(left, right); + } + else if (p_expr.GetType().IsFloatingPointType()) + { + ret = m_state.GetBuilder().CreateFMul(left, right); + } + else + { + // Shouldn't get here. + FF2_UNREACHABLE(); + } + CHECK_LLVM_RET(ret); + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitAnd(const BinaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType() == p_expr.GetChildType()); + + LlvmCodeGenerator::CompiledValue* left = m_stack.top(); + m_stack.pop(); + + LlvmCodeGenerator::CompiledValue* right = m_stack.top(); + m_stack.pop(); + + LlvmCodeGenerator::CompiledValue* ret + = m_state.GetBuilder().CreateAnd(left, right); + CHECK_LLVM_RET(ret); + + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitOr(const BinaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType() == p_expr.GetChildType()); + + LlvmCodeGenerator::CompiledValue* left = m_stack.top(); + m_stack.pop(); + + LlvmCodeGenerator::CompiledValue* right = m_stack.top(); + m_stack.pop(); + + LlvmCodeGenerator::CompiledValue* ret + = m_state.GetBuilder().CreateOr(left, right); + CHECK_LLVM_RET(ret); + + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitLog(const BinaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType().IsFloatingPointType()); + FF2_ASSERT(p_expr.GetChildType().IsFloatingPointType() + || p_expr.GetChildType().IsIntegerType()); + + CompiledValue* left = m_stack.top(); + m_stack.pop(); + + CompiledValue* right = m_stack.top(); + m_stack.pop(); + + if (!p_expr.GetChildType().IsFloatingPointType()) + { + left = &ValueConversion::Do(*left, p_expr.GetChildType(), p_expr.GetType(), m_state); + right = &ValueConversion::Do(*right, p_expr.GetChildType(), p_expr.GetType(), m_state); + } + + // Guard against values zero or less, for which we always return + // negative infinity. + CompiledValue* zero = &m_state.CreateZeroValue(p_expr.GetType()); + CompiledValue* cmpLeft = m_state.GetBuilder().CreateFCmpOGT(left, zero); + CHECK_LLVM_RET(cmpLeft); + CompiledValue* cmpRight = m_state.GetBuilder().CreateFCmpOGT(right, zero); + CHECK_LLVM_RET(cmpRight); + CompiledValue* cmp = m_state.GetBuilder().CreateAnd(cmpLeft, cmpRight, llvm::Twine("log guard")); + + // Create the log. + llvm::Type& floatType = m_state.GetType(p_expr.GetType()); + CompiledValue* logLeft = &CompileLogCall(m_state, &floatType, *left); + CompiledValue* logRight = &CompileLogCall(m_state, &floatType, *right); + CompiledValue* log = m_state.GetBuilder().CreateFDiv(logLeft, logRight); + + // Select the value to return. + CompiledValue* negInfinity = llvm::ConstantFP::getInfinity(&floatType, true); + CHECK_LLVM_RET(negInfinity); + CompiledValue* select = m_state.GetBuilder().CreateSelect(cmp, log, negInfinity); + CHECK_LLVM_RET(select); + m_stack.push(select); +} + + +// Typedef for a function that combines two values into a returned value. +typedef LlvmCodeGenerator::CompiledValue& (*OperatorFun)(CompilationState& m_state, + LlvmCodeGenerator::CompiledValue& p_left, + LlvmCodeGenerator::CompiledValue& p_right); + + +static LlvmCodeGenerator::CompiledValue& +GenerateSDivInstruction(CompilationState& m_state, + LlvmCodeGenerator::CompiledValue& p_left, + LlvmCodeGenerator::CompiledValue& p_right) +{ + LlvmCodeGenerator::CompiledValue* div + = m_state.GetBuilder().CreateSDiv(&p_left, &p_right); + CHECK_LLVM_RET(div); + return *div; +} + + +static LlvmCodeGenerator::CompiledValue& +GenerateUDivInstruction(CompilationState& m_state, + LlvmCodeGenerator::CompiledValue& p_left, + LlvmCodeGenerator::CompiledValue& p_right) +{ + LlvmCodeGenerator::CompiledValue* div + = m_state.GetBuilder().CreateUDiv(&p_left, &p_right); + CHECK_LLVM_RET(div); + return *div; +} + + +static LlvmCodeGenerator::CompiledValue& +GenerateSModInstruction(CompilationState& m_state, + LlvmCodeGenerator::CompiledValue& p_left, + LlvmCodeGenerator::CompiledValue& p_right) +{ + LlvmCodeGenerator::CompiledValue* mod + = m_state.GetBuilder().CreateSRem(&p_left, &p_right); + CHECK_LLVM_RET(mod); + return *mod; +} + + +static LlvmCodeGenerator::CompiledValue& +GenerateUModInstruction(CompilationState& m_state, + LlvmCodeGenerator::CompiledValue& p_left, + LlvmCodeGenerator::CompiledValue& p_right) +{ + LlvmCodeGenerator::CompiledValue* mod + = m_state.GetBuilder().CreateURem(&p_left, &p_right); + CHECK_LLVM_RET(mod); + return *mod; +} + + +LlvmCodeGenerator::CompiledValue& +CompileGuardedDivMod(CompilationState& m_state, + LlvmCodeGenerator::CompiledValue& p_left, + LlvmCodeGenerator::CompiledValue& p_right, + OperatorFun p_operator, + LlvmCodeGenerator::CompiledValue& underflowResult) +{ + // Assert that both operands are integer types of the same size. + FF2_ASSERT(p_left.getType() && p_right.getType() + && p_left.getType()->isIntegerTy() + && p_right.getType()->isIntegerTy(p_left.getType()->getPrimitiveSizeInBits())); + llvm::IntegerType& intType = llvm::cast(*p_left.getType()); + + // Division/Modulus is complicated in the freeforms language by the + // fact that we guard for division by zero using conditionals. + + // Generate the test for divide-by-zero. + LlvmCodeGenerator::CompiledValue* zero = llvm::ConstantInt::get(&intType, 0); + CHECK_LLVM_RET(zero); + LlvmCodeGenerator::CompiledValue* zeroCond + = m_state.GetBuilder().CreateICmpNE(zero, + &p_right, + llvm::Twine("div-by-zero guard")); + CHECK_LLVM_RET(zeroCond); + + GenerateConditional divzero(m_state, *zeroCond, "div-by-zero guard"); + + // Generate some constants. + LlvmCodeGenerator::CompiledValue* negOne + = llvm::ConstantInt::getSigned(&intType, -1); + CHECK_LLVM_RET(negOne); + LlvmCodeGenerator::CompiledValue* minInt + = llvm::ConstantInt::get(p_left.getType(), + llvm::APInt::getSignedMinValue(intType.getBitWidth())); + CHECK_LLVM_RET(minInt); + + // Check whether we're dividing MIN_INT by -1 (which causes a + // strange underflow condition). + LlvmCodeGenerator::CompiledValue* minIntCond + = m_state.GetBuilder().CreateICmpNE(&p_left, minInt); + CHECK_LLVM_RET(minIntCond); + LlvmCodeGenerator::CompiledValue* negOneCond + = m_state.GetBuilder().CreateICmpNE(&p_right, negOne); + CHECK_LLVM_RET(negOneCond); + LlvmCodeGenerator::CompiledValue* underFlowCond + = m_state.GetBuilder().CreateOr(minIntCond, negOneCond); + CHECK_LLVM_RET(underFlowCond); + + // Select between underflow result and regular result. Note that we + // need to generate the instruction here, via a function pointer, to + // avoid evaluating that instruction when we are dividing by zero. + GenerateConditional underflow(m_state, *underFlowCond, "underflow guard"); + underflow.FinishThen(&p_operator(m_state, p_left, p_right)); + underflow.FinishElse(&underflowResult); + + // Select between result of underflow comparison and div-by-zero result + // (defined to be zero). + divzero.FinishThen(&underflow.Finish(intType)); + divzero.FinishElse(zero); + return divzero.Finish(intType); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitDivides(const BinaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType() == p_expr.GetChildType()); + + CompiledValue* left = m_stack.top(); + m_stack.pop(); + + CompiledValue* right = m_stack.top(); + m_stack.pop(); + + CompiledValue* ret = nullptr; + + if (p_expr.GetType().IsIntegerType()) + { + llvm::Type& type = m_state.GetType(p_expr.GetType()); + FF2_ASSERT(type.isIntegerTy()); + + llvm::APInt max; + OperatorFun op; + if (p_expr.GetType().IsSigned()) + { + max = llvm::APInt::getSignedMaxValue(type.getPrimitiveSizeInBits()); + op = GenerateSDivInstruction; + } + else + { + max = llvm::APInt::getMaxValue(type.getPrimitiveSizeInBits()); + op = GenerateUDivInstruction; + } + CompiledValue* maxInt = llvm::ConstantInt::get(&type, max); + CHECK_LLVM_RET(maxInt); + + ret = &CompileGuardedDivMod(m_state, *left, *right, op, *maxInt); + } + else if (p_expr.GetType().IsFloatingPointType()) + { + ret = m_state.GetBuilder().CreateFDiv(left, right); + } + else + { + FF2_UNREACHABLE(); + } + CHECK_LLVM_RET(ret); + + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitMod(const BinaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType() == p_expr.GetChildType()); + + CompiledValue* left = m_stack.top(); + m_stack.pop(); + + CompiledValue* right = m_stack.top(); + m_stack.pop(); + + CompiledValue* ret = nullptr; + + if (p_expr.GetType().IsIntegerType()) + { + CompiledValue& zero = m_state.CreateZeroValue(p_expr.GetType()); + if (p_expr.GetType().IsSigned()) + { + ret = &CompileGuardedDivMod(m_state, *left, *right, GenerateSModInstruction, zero); + } + else + { + ret = &CompileGuardedDivMod(m_state, *left, *right, GenerateUModInstruction, zero); + } + } + else if (p_expr.GetType().IsFloatingPointType()) + { + ret = m_state.GetBuilder().CreateFRem(left, right); + CHECK_LLVM_RET(ret); + } + else + { + FF2_UNREACHABLE(); + } + + m_stack.push(ret); +} + +void +FreeForm2::LlvmCodeGenerator::VisitPow(const BinaryOperatorExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType() == p_expr.GetChildType()); + FF2_ASSERT(p_expr.GetType().IsIntegerType() + || p_expr.GetType().IsFloatingPointType()); + + CompiledValue* left = m_stack.top(); + m_stack.pop(); + + CompiledValue* right = m_stack.top(); + m_stack.pop(); + + CompiledValue* ret = nullptr; + + if (p_expr.GetChildType().IsIntegerType()) + { + const TypeImpl& floatType = TypeImpl::GetFloatInstance(true); + left = &ValueConversion::Do(*left, p_expr.GetChildType(), floatType, m_state); + right = &ValueConversion::Do(*right, p_expr.GetChildType(), floatType, m_state); + } + else + { + FF2_ASSERT(left->getType() + && right->getType() + && left->getType()->getTypeID() == right->getType()->getTypeID()); + } + + // Create a reference to the LLVM intrinsic 'pow', which + // does floating point power raising (note that we supply + // the floating point type we're interested in as a + // parameter, to specify exactly which 'pow' we want). + llvm::Type* floatType = left->getType(); + llvm::Function* fun + = llvm::Intrinsic::getDeclaration(&m_state.GetModule(), + llvm::Intrinsic::pow, + floatType); + CHECK_LLVM_RET(fun); + ret = m_state.GetBuilder().CreateCall2(fun, left, right); + CHECK_LLVM_RET(ret); + + if (p_expr.GetType().IsIntegerType()) + { + ret = &ValueConversion::Do(*ret, TypeImpl::GetFloatInstance(true), p_expr.GetType(), m_state); + } + + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitMaxMin(const BinaryOperatorExpression& p_expr, + bool p_minimum) +{ + FF2_ASSERT(p_expr.GetType() == p_expr.GetChildType()); + FF2_ASSERT(p_expr.GetChildType().IsIntegerType() + || p_expr.GetChildType().IsFloatingPointType()); + + // Compare the two values. + const char* descr = p_minimum ? "min compare" : "max compare"; + + CompiledValue* left = m_stack.top(); + m_stack.pop(); + + CompiledValue* right = m_stack.top(); + m_stack.pop(); + + CompiledValue* cmp = nullptr; + CompiledValue* select = nullptr; + + if (p_expr.GetType().IsIntegerType()) + { + cmp = m_state.GetBuilder().CreateICmpSGT(left, right, llvm::Twine(descr)); + } + else + { + cmp = m_state.GetBuilder().CreateFCmpOGT(left, right, llvm::Twine(descr)); + } + CHECK_LLVM_RET(cmp); + select = m_state.GetBuilder().CreateSelect(cmp, + p_minimum ? right : left, + p_minimum ? left : right); + CHECK_LLVM_RET(select); + m_stack.push(select); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitEqual(const BinaryOperatorExpression& p_expr, + bool p_inequality) +{ + CompiledValue* left = m_stack.top(); + m_stack.pop(); + + CompiledValue* right = m_stack.top(); + m_stack.pop(); + + CompiledValue* ret = nullptr; + if (p_expr.GetChildType().IsIntegerType() + || p_expr.GetChildType().Primitive() == Type::Bool) + { + ret = m_state.GetBuilder().CreateICmp(llvm::CmpInst::ICMP_EQ, left, right); + } + else if (p_expr.GetChildType().IsFloatingPointType()) + { + ret = &CompileFloatEquality(m_state, *left, *right); + } + else + { + // Shouldn't get here. + FF2_UNREACHABLE(); + } + CHECK_LLVM_RET(ret); + + if (p_inequality) + { + ret = m_state.GetBuilder().CreateNot(ret); + CHECK_LLVM_RET(ret); + } + m_stack.push(ret); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitCompare(const BinaryOperatorExpression& p_expr, + bool p_less, + bool p_equal) +{ + CompiledValue* left = m_stack.top(); + m_stack.pop(); + + CompiledValue* right = m_stack.top(); + m_stack.pop(); + + CompiledValue* ret = nullptr; + llvm::CmpInst::Predicate predicate = llvm::CmpInst::BAD_ICMP_PREDICATE; + + if (p_expr.GetChildType().IsIntegerType() + || p_expr.GetChildType().Primitive() == Type::Bool) + { + if (p_less) + { + predicate = p_equal ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_SLT; + } + else + { + predicate = p_equal ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_SGT; + } + + ret = m_state.GetBuilder().CreateICmp(predicate, left, right); + } + else if (p_expr.GetChildType().IsFloatingPointType()) + { + if (p_less) + { + predicate = p_equal ? llvm::CmpInst::FCMP_OLE : llvm::CmpInst::FCMP_OLT; + } + else + { + predicate = p_equal ? llvm::CmpInst::FCMP_OGE : llvm::CmpInst::FCMP_OGT; + } + + ret = m_state.GetBuilder().CreateFCmp(predicate, left, right); + if (p_equal) + { + CHECK_LLVM_RET(ret); + + LlvmCodeGenerator::CompiledValue& approx + = CompileFloatEquality(m_state, *left, *right); + ret = m_state.GetBuilder().CreateOr(ret, &approx); + } + } + else + { + // Shouldn't get here. + FF2_UNREACHABLE(); + } + CHECK_LLVM_RET(ret); + m_stack.push(ret); +} + + +llvm::Function* +FreeForm2::LlvmCodeGenerator::CreateFeatureFunction(const TypeImpl& p_returnType) +{ + // Create top-level function. + llvm::Type* returnType = nullptr; + std::vector arguments; + + if (m_destinationFunctionType == CompilerFactory::SingleDocumentEvaluation) + { + // The streamFeatureInput argument will not be used, so use any pointer + // for its type. + llvm::PointerType* streamFeatureInput + = llvm::PointerType::get(&m_state.GetIntType(), 0); + CHECK_LLVM_RET(streamFeatureInput); + arguments.push_back(streamFeatureInput); + + // Create feature array input type. + llvm::PointerType* featurePointerType + = llvm::PointerType::get(&m_state.GetFeatureType(), 0); + CHECK_LLVM_RET(featurePointerType); + arguments.push_back(featurePointerType); + } + else if (m_destinationFunctionType == CompilerFactory::DocumentSetEvaluation) + { + // Create feature array input type. + llvm::PointerType* featurePointerType + = llvm::PointerType::get(llvm::PointerType::get(&m_state.GetFeatureType(), 0), 0); + CHECK_LLVM_RET(featurePointerType); + arguments.push_back(featurePointerType); + + // Create document index input type. + llvm::IntegerType* docIndexType + =&m_state.GetInt32Type(); + CHECK_LLVM_RET(docIndexType); + arguments.push_back(docIndexType); + + // Create document count input type. + llvm::IntegerType* docCountType + = &m_state.GetInt32Type(); + CHECK_LLVM_RET(docCountType); + arguments.push_back(docCountType); + + // Create cache pointer type + llvm::PointerType* cachePointerType + = llvm::PointerType::get(&m_state.GetIntType(), 0); + CHECK_LLVM_RET(cachePointerType); + arguments.push_back(cachePointerType); + } + + if (p_returnType.Primitive() == Type::Array) + { + // If return type is an array, pass in an argument to copy resulting, + // flattened array into. Return value indicates dynamic bounds. + const ArrayType& info = static_cast(p_returnType); + llvm::PointerType* space + = llvm::PointerType::get(&m_state.GetType(info.GetChildType()), 0); + CHECK_LLVM_RET(space); + if (m_destinationFunctionType == CompilerFactory::SingleDocumentEvaluation) + { + arguments.push_back(space); + } + returnType = &m_state.GetArrayBoundsType(); + } + else + { + // Otherwise, pass in dummy arg (pointer to int type). + llvm::PointerType* space = llvm::PointerType::get(&m_state.GetIntType(), 0); + CHECK_LLVM_RET(space); + if (m_destinationFunctionType == CompilerFactory::SingleDocumentEvaluation) + { + arguments.push_back(space); + } + + returnType = &m_state.GetType(p_returnType); + } + + llvm::FunctionType* funType + = llvm::FunctionType::get(returnType, arguments, false); + CHECK_LLVM_RET(funType); + + llvm::Function* fun = llvm::Function::Create(funType, + llvm::Function::ExternalLinkage, + llvm::Twine(""), + &m_state.GetModule()); + CHECK_LLVM_RET(fun); + + // Create a basic block within the function, and codegen into it. + llvm::BasicBlock* block + = llvm::BasicBlock::Create(m_state.GetContext(), llvm::Twine("entry"), fun); + m_state.GetBuilder().SetInsertPoint(block); + + if (m_destinationFunctionType == CompilerFactory::SingleDocumentEvaluation) + { + // Push function arguments onto the value stack. + llvm::Function::arg_iterator iter = fun->arg_begin(); + FF2_ASSERT(iter != fun->arg_end()); + iter->setName(llvm::Twine("arg1")); + ++iter; + + FF2_ASSERT(iter != fun->arg_end()); + m_state.SetFeatureArgument(*iter); + iter->setName(llvm::Twine("p_features")); + ++iter; + + // Keep array return space reference. + FF2_ASSERT(iter != fun->arg_end()); + llvm::Value& arraySpace = *iter; + m_state.SetArrayReturnSpace(arraySpace); + iter->setName(llvm::Twine("p_output")); + ++iter; + FF2_ASSERT(iter == fun->arg_end()); + } + else if (m_destinationFunctionType == CompilerFactory::DocumentSetEvaluation) + { + // Push function arguments onto the value stack. + llvm::Function::arg_iterator iter = fun->arg_begin(); + FF2_ASSERT(iter != fun->arg_end()); + m_state.SetFeatureArrayPointer(*iter); + iter->setName(llvm::Twine("p_featureArray")); + ++iter; + + FF2_ASSERT(iter != fun->arg_end()); + m_state.SetAggregatedDocumentIndex(*iter); + iter->setName(llvm::Twine("p_index")); + ++iter; + + FF2_ASSERT(iter != fun->arg_end()); + m_state.SetAggregatedDocumentCount(*iter); + iter->setName(llvm::Twine("p_count")); + ++iter; + + FF2_ASSERT(iter != fun->arg_end()); + m_state.SetAggregatedCache(*iter); + iter->setName(llvm::Twine("p_cache")); + ++iter; + FF2_ASSERT(iter == fun->arg_end()); + + llvm::Value* featureArrayPointer + = m_state.GetBuilder().CreateInBoundsGEP(&m_state.GetFeatureArrayPointer(), + &m_state.GetAggregatedDocumentIndex(), + llvm::Twine("init feature array")); + CHECK_LLVM_RET(featureArrayPointer); + + llvm::Value* featureArray = m_state.GetBuilder().CreateLoad(featureArrayPointer); + CHECK_LLVM_RET(featureArray); + + m_state.SetFeatureArgument(*featureArray); + + m_documentContextStack.push(featureArray); + } + + return fun; +} + + +void +FreeForm2::LlvmCodeGenerator::CreateAllocations() +{ + typedef std::vector> AllocationVector; + for (size_t i = 0; i < m_allocations.size(); i++) + { + Allocate(*m_allocations[i]); + } +} + + +FreeForm2::LlvmCodeGenerator::CompiledValue& +FreeForm2::LlvmCodeGenerator::CreateReturn(const TypeImpl& p_type) +{ + CompiledValue& value = *m_returnValue; + + llvm::Value* ret = NULL; + if (p_type.Primitive() == Type::Array) + { + llvm::Value& arraySpace = m_state.GetArrayReturnSpace(); + const ArrayType& arrayType = static_cast(p_type); + + // Need to copy array result into array result argument, return bounds. + ret = m_state.GetBuilder().CreateRet(&ArrayCodeGen::IssueReturn(m_state, + value, + arrayType, + arraySpace)); + } + else + { + ret = m_state.GetBuilder().CreateRet(&value); + } + CHECK_LLVM_RET(ret); + return *ret; +} + + +bool +FreeForm2::LlvmCodeGenerator::AlternativeVisit(const FeatureSpecExpression& p_expr) +{ + // Generate the feature body. + p_expr.GetBody().Accept(*this); + + return true; +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const FeatureSpecExpression& p_expr) +{ + // This is handled in AlternativeVisit + Unreachable(__FILE__, __LINE__); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const FeatureGroupSpecExpression&) +{ + // Not supported in FF2. + Unreachable(__FILE__, __LINE__); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ReturnExpression&) +{ + // Not supported by FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const StreamDataExpression&) +{ + // Not supported by FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const UpdateStreamDataExpression&) +{ + // Not supported by FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const VariableRefExpression& p_expr) +{ + llvm::Value* value = m_state.GetVariableValue(p_expr.GetId()); + FF2_ASSERT(value != NULL); + m_stack.push(value); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitReference(const VariableRefExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitReference(const ThisExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::VisitReference(const UnresolvedAccessExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +//bool +//IsFeatureSpecMissing(const FreeForm2::Expression& p_expr) +//{ +// const FreeForm2::BlockExpression* block +// = dynamic_cast(&p_expr); +// if (!block || block->GetNumChildren() <= 0) +// { +// return false; +// } +// +// const FreeForm2::NeuralInputResultExpression* input +// = dynamic_cast(&block->GetChild(0)); +// if (!input || input->GetNumChildren() <= 0) +// { +// return false; +// } +// +// +// const FreeForm2::FeatureSpecExpression* feature +// = dynamic_cast(&input->m_child); +// +// return feature == nullptr; +//} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ImportFeatureExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const StateExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const StateMachineExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ExecuteMachineExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ExecuteStreamRewritingStateMachineGroupExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ExecuteMachineGroupExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const YieldExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const RandFloatExpression& p_expr) +{ + CompiledValue& randValue = CreateRandomFloat(m_state, &m_state.GetType(p_expr.GetType())); + m_stack.push(&randValue); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const RandIntExpression& p_expr) +{ + llvm::IRBuilder<>& builder = m_state.GetBuilder(); + CompiledValue* upperBound = m_stack.top(); + m_stack.pop(); + + CompiledValue* lowerBound = m_stack.top(); + m_stack.pop(); + + CompiledValue& randValue = CreateRandomFloat(m_state, nullptr); + + CompiledValue* range = builder.CreateSub(upperBound, lowerBound); + CHECK_LLVM_RET(range); + + CompiledValue* floatRange = builder.CreateSIToFP(range, randValue.getType()); + CHECK_LLVM_RET(floatRange); + + CompiledValue* valInRange = builder.CreateFMul(floatRange, &randValue); + CHECK_LLVM_RET(valInRange); + + CompiledValue* intValInRange = builder.CreateFPToSI(valInRange, &m_state.GetType(p_expr.GetType())); + CHECK_LLVM_RET(intValInRange); + + CompiledValue* finalValue = builder.CreateAdd(intValInRange, lowerBound); + CHECK_LLVM_RET(finalValue); + + m_stack.push(finalValue); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const ThisExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const UnresolvedAccessExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const TypeInitializerExpression&) +{ + // Not supported in FF2. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const AggregateContextExpression&) +{ + // TODO: Implement me. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::LlvmCodeGenerator::Visit(const DebugExpression&) +{ + // Not supported in FreeForm2. + FF2_UNREACHABLE(); +} + + +llvm::Function& +FreeForm2::LlvmCodeGenerator::Compile( + const Expression& p_expr, + CompilationState& p_state, + const AllocationVector& p_allocations, + CompilerFactory::DestinationFunctionType p_destinationFunctionType) +{ + LlvmCodeGenerator visitor(p_state, p_allocations, p_destinationFunctionType); + visitor.m_returnType = &p_expr.GetType(); + visitor.m_function = visitor.CreateFeatureFunction(*visitor.m_returnType); + visitor.CreateAllocations(); + + p_expr.Accept(visitor); + + if (!visitor.m_returnValue) + { + visitor.m_returnValue = visitor.m_stack.top(); + } + + visitor.CreateReturn(*visitor.m_returnType); + FF2_ASSERT(llvm::Function::classof(visitor.m_function)); + return *visitor.m_function; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenerator.h b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenerator.h new file mode 100644 index 000000000000..91275acb775d --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCodeGenerator.h @@ -0,0 +1,203 @@ +#pragma once + +#ifndef FREEFORM2_LLVMCODEGENERATOR_H +#define FREEFORM2_LLVMCODEGENERATOR_H + +#include +#include +#include +#include "CompilationState.h" +#include "Expression.h" +#include +#include +#include +#include "Visitor.h" + +namespace llvm +{ + class Value; + class BasicBlock; + class PHINode; +} + +namespace FreeForm2 +{ + class Allocation; + + // LlvmCodeGenerator, an interface to implement the visitor pattern. + class LlvmCodeGenerator : public Visitor + { + public: + typedef llvm::Value CompiledValue; + typedef std::vector> AllocationVector; + + // Compile an expression tree down to LLVM-IR. + static llvm::Function& Compile(const Expression& p_expr, + CompilationState& p_state, + const AllocationVector& p_allocations, + CompilerFactory::DestinationFunctionType p_destinationFunctionType); + + LlvmCodeGenerator(CompilationState& p_state, + const AllocationVector& p_allocations, + CompilerFactory::DestinationFunctionType p_destinationFunctionType); + + // Provide the implementation of the Visitor pattern to generate + // LLVM code for each expression and operator. + void Visit(const LiteralFloatExpression& p_expr) override; + void Visit(const SelectNthExpression& p_expr) override; + void Visit(const SelectRangeExpression& p_expr) override; + void Visit(const ArrayLengthExpression& p_expr) override; + void Visit(const ArrayDereferenceExpression& p_expr) override; + void Visit(const ArrayLiteralExpression& p_expr) override; + bool AlternativeVisit(const ConditionalExpression& p_expr) override; + void Visit(const ConditionalExpression& p_expr) override; + void Visit(const ConvertToFloatExpression& p_expr) override; + void Visit(const ConvertToIntExpression& p_expr) override; + void Visit(const ConvertToUInt64Expression& p_expr) override; + void Visit(const ConvertToInt32Expression& p_expr) override; + void Visit(const ConvertToUInt32Expression& p_expr) override; + void Visit(const ConvertToBoolExpression& p_expr) override; + void Visit(const ConvertToImperativeExpression& p_expr) override; + void Visit(const DeclarationExpression& p_expr) override; + void Visit(const DirectPublishExpression& p_expr) override; + void Visit(const ExternExpression& p_expr) override; + void Visit(const LiteralIntExpression& p_expr) override; + void Visit(const LiteralUInt64Expression& p_expr) override; + void Visit(const LiteralInt32Expression& p_expr) override; + void Visit(const LiteralUInt32Expression& p_expr) override; + void Visit(const LiteralBoolExpression& p_expr) override; + void Visit(const LiteralVoidExpression& p_expr) override; + void Visit(const LiteralStreamExpression& p_expr) override; + void Visit(const LiteralWordExpression& p_expr) override; + void Visit(const LiteralInstanceHeaderExpression& p_expr) override; + bool AlternativeVisit(const LetExpression& p_expr) override; + void Visit(const LetExpression& p_expr) override; + void Visit(const MutationExpression& p_expr) override; + void Visit(const MatchExpression& p_expr) override; + void Visit(const MatchOperatorExpression&) override; + void Visit(const MatchGuardExpression&) override; + void Visit(const MatchBindExpression&) override; + void Visit(const MemberAccessExpression&) override; + void Visit(const PhiNodeExpression&) override; + void Visit(const PublishExpression&) override; + bool AlternativeVisit(const BlockExpression& p_expr) override; + void Visit(const BlockExpression& p_expr) override; + void Visit(const FeatureRefExpression& p_expr) override; + void Visit(const FunctionExpression& p_expr) override; + void Visit(const FunctionCallExpression& p_expr) override; + bool AlternativeVisit(const RangeReduceExpression& p_expr) override; + void Visit(const RangeReduceExpression& p_expr) override; + void Visit(const ForEachLoopExpression& p_expr) override; + void Visit(const ComplexRangeLoopExpression& p_expr) override; + void Visit(const UnaryOperatorExpression& p_expr) override; + void Visit(const BinaryOperatorExpression& p_expr) override; + bool AlternativeVisit(const FeatureSpecExpression& p_expr) override; + void Visit(const FeatureSpecExpression& p_expr) override; + void Visit(const FeatureGroupSpecExpression& p_expr) override; + void Visit(const ReturnExpression& p_expr) override; + void Visit(const StreamDataExpression& p_expr) override; + void Visit(const UpdateStreamDataExpression& p_expr) override; + void Visit(const VariableRefExpression& p_expr) override; + void Visit(const ImportFeatureExpression& p_expr) override; + void Visit(const StateExpression& p_expr) override; + void Visit(const StateMachineExpression& p_expr) override; + void Visit(const ExecuteMachineExpression& p_expr) override; + void Visit(const ExecuteStreamRewritingStateMachineGroupExpression& p_expr) override; + void Visit(const ExecuteMachineGroupExpression& p_expr) override; + void Visit(const YieldExpression& p_expr) override; + void Visit(const RandFloatExpression& p_expr) override; + void Visit(const RandIntExpression& p_expr) override; + void Visit(const ThisExpression& p_expr) override; + void Visit(const UnresolvedAccessExpression& p_expr) override; + void Visit(const TypeInitializerExpression& p_expr) override; + void Visit(const AggregateContextExpression& p_expr) override; + void Visit(const DebugExpression& p_expr) override; + + void VisitReference(const ArrayDereferenceExpression& p_expr) override; + void VisitReference(const VariableRefExpression& p_expr) override; + void VisitReference(const MemberAccessExpression& p_expr) override; + void VisitReference(const ThisExpression& p_expr) override; + void VisitReference(const UnresolvedAccessExpression& p_expr) override; + + // Gets the compiled version of the whole object tree. + LlvmCodeGenerator::CompiledValue* GetResult(); + + llvm::Function* GetFuction() const; + + private: + // Helper functions for unary operator expressions. + void VisitUnaryMinus(const UnaryOperatorExpression& p_expr); + void VisitUnaryLog(const UnaryOperatorExpression& p_expr, bool p_addOne); + void VisitUnaryNot(const UnaryOperatorExpression& p_expr); + void VisitUnaryAbs(const UnaryOperatorExpression& p_expr); + void VisitUnaryRound(const UnaryOperatorExpression& p_expr); + void VisitUnaryTrunc(const UnaryOperatorExpression& p_expr); + + // Helper functions for binary operator expressions. + void VisitPlus(const BinaryOperatorExpression& p_expr); + void VisitMinus(const BinaryOperatorExpression& p_expr); + void VisitMultiply(const BinaryOperatorExpression& p_expr); + void VisitDivides(const BinaryOperatorExpression& p_expr); + void VisitMod(const BinaryOperatorExpression& p_expr); + void VisitAnd(const BinaryOperatorExpression& p_expr); + void VisitOr(const BinaryOperatorExpression& p_expr); + void VisitLog(const BinaryOperatorExpression& p_expr); + void VisitPow(const BinaryOperatorExpression& p_expr); + void VisitMaxMin(const BinaryOperatorExpression& p_expr, bool p_minimum); + void VisitEqual(const BinaryOperatorExpression& p_expr, bool p_inequality); + void VisitCompare(const BinaryOperatorExpression& p_expr, bool p_less, bool p_equal); + + // Helper function for array dereference expressions. + void VisitArrayDereference(const ArrayDereferenceExpression& p_expr, + bool p_isRef); + + // Allocate memory for the given allocation. + void Allocate(const Allocation& p_allocation); + + // Create an LLVM function to wrap a Derived/FeatureSpecSpec. + llvm::Function* CreateFeatureFunction(const TypeImpl& p_returnType); + + // Create a return value of the specified type. Array space is provided + // for returning an array. + CompiledValue& CreateReturn(const TypeImpl& p_type); + + // Generate all allocation code stored in the program's allocation vector. + void CreateAllocations(); + + // A stack of intermediate expressions that have already been compiled. + std::stack m_stack; + + // A stack that keeps track potision of the document context. + // This is for aggregated freeform, do not use it in ffv2. + std::stack m_documentContextStack; + + // Holds the underlying LLVM state objects, and the symbol table. + CompilationState& m_state; + + // A reference to the program being compiled. + const AllocationVector& m_allocations; + + // Holds the state needed for decoupling array allocation and initialization. + struct PreAllocatedArray + { + LlvmCodeGenerator::CompiledValue* m_bounds; + LlvmCodeGenerator::CompiledValue* m_count; + LlvmCodeGenerator::CompiledValue* m_array; + }; + + // A mapping to get the PreallocatedArray structure for each ArrayLiteralExpression. + std::map m_allocatedArrays; + + // A mapping to get the pre-allocated CompiledValue for each expression. + std::map m_allocatedValues; + + // Destination Function Type. + CompilerFactory::DestinationFunctionType m_destinationFunctionType; + + const TypeImpl* m_returnType; + CompiledValue* m_returnValue; + llvm::Function* m_function; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCompiler.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCompiler.cpp new file mode 100644 index 000000000000..dc8cb7fb6c0a --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCompiler.cpp @@ -0,0 +1,783 @@ +#include "LlvmCompiler.h" + +#include "ArrayCodeGen.h" +#include "ArrayType.h" +#include +#include +#include +#include "CompilationState.h" +#include "Executable.h" +#include "FreeForm2Assert.h" +#include "FreeForm2Type.h" +#include "FreeForm2Utils.h" +#include "LlvmCodeGenUtils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "Program.h" +#include +#include "ValueResult.h" +#include +#include "FreeForm2Support.h" + +using namespace FreeForm2; + +#define __stdcall + +// Chkstk method signature needed to identify __chkstk call. +// This appears within low-level IR if the Freeform expression incurs +// a large amount of stack usage. +// extern "C" void __stdcall __chkstk(size_t); + + +// Copied from LLVM code. +namespace X86 +{ + /// RelocationType - An enum for the x86 relocation codes. Note that + /// the terminology here doesn't follow x86 convention - word means + /// 32-bit and dword means 64-bit. The relocations will be treated + /// by JIT or ObjectCode emitters, this is transparent to the x86 code + /// emitter but JIT and ObjectCode will treat them differently + enum RelocationType { + /// reloc_pcrel_word - PC relative relocation, add the relocated value to + /// the value already in memory, after we adjust it for where the PC is. + reloc_pcrel_word = 0, + + /// reloc_picrel_word - PIC base relative relocation, add the relocated + /// value to the value already in memory, after we adjust it for where the + /// PIC base is. + reloc_picrel_word = 1, + + /// reloc_absolute_word - absolute relocation, just add the relocated + /// value to the value already in memory. + reloc_absolute_word = 2, + + /// reloc_absolute_word_sext - absolute relocation, just add the relocated + /// value to the value already in memory. In object files, it represents a + /// value which must be sign-extended when resolving the relocation. + reloc_absolute_word_sext = 3, + + /// reloc_absolute_dword - absolute relocation, just add the relocated + /// value to the value already in memory. + reloc_absolute_dword = 4 + }; +} + +namespace +{ + using namespace FreeForm2; + + // Table of external functions to be called from LLVM-generated code. + static const void* s_externalFunctions[] = + { + pow, //std::powf, + log, //std::logf, + floor, //std::floorf, + ceil, //std::ceilf, + fmod, //std::fmodf, + tanh, //std::tanhf, + std::rand, + // __chkstk, + FreeForm2GetRandomValue + }; + + // Table of names for external functions (must match size and order of previous table) + // The callers use these s_externalNames to match s_externalFunctions. Any rename may result in failure in external function matching. + static const char* s_externalNames[] = + { + "powf", + "logf", + "floorf", + "ceilf", + "fmodf", + "tanhf", + "rand", + // "__chkstk", + "FreeForm2GetRandomValue" + }; + + void ConvertRelocations(std::vector& p_dest, + const std::vector& p_src, + const LlvmExecutableImpl::FunctionInfo& p_func, + const std::vector& p_externals) + { + for (const auto& mr : p_src) + { + LlvmExecutableImpl::RelocationInfo relocation = { 0, 0, 0 }; + relocation.m_offset = static_cast(mr.getMachineCodeOffset()); + FF2_ASSERT(relocation.m_offset < p_func.m_length); + if (mr.getRelocationType() == X86::reloc_pcrel_word) + { + // Jump relative to PC, need no relocation + continue; + } + FF2_ASSERT(mr.getRelocationType() == X86::reloc_absolute_dword); + uint64_t relocated = *(reinterpret_cast(p_func.m_start + relocation.m_offset)); + const uint8_t* result = reinterpret_cast(mr.getResultPointer()); + + if (p_func.m_start <= result && result < p_func.m_start + p_func.m_length) + { + relocation.m_type = LlvmExecutableImpl::RelocationInfo::Internal; + relocation.m_delta = static_cast(reinterpret_cast(relocated) - p_func.m_start); + } + else + { + const void* pFreeForm2GetRandomValue = reinterpret_cast(&FreeForm2GetRandomValue); + relocation.m_type = LlvmExecutableImpl::RelocationInfo::External; + const void* relocatedFun = reinterpret_cast(relocated); + const auto found = std::find(p_externals.cbegin(), p_externals.cend(), relocatedFun); + if (found != p_externals.cend()) + { + relocation.m_delta = static_cast(std::distance(p_externals.cbegin(), found)); + } + else + { + FF2_ASSERT(relocatedFun == pFreeForm2GetRandomValue); + relocation.m_delta = static_cast(p_externals.size() - 1); + } + } + + p_dest.push_back(relocation); + } + } +} + +namespace FreeForm2 +{ + // Wrapper class around LLVM default memory, that allows us to persist the + // memory manager beyond the lifetime of the owning module. + class PersistentJITMemoryManager : public llvm::JITMemoryManager + { + public: + PersistentJITMemoryManager() + : m_manager(llvm::JITMemoryManager::CreateDefaultMemManager()) + { + m_fun.m_start = nullptr; + m_fun.m_length = 0; + } + + + // Return the delegated manager. + boost::shared_ptr + GetDelegate() + { + return m_manager; + } + + // Implement all virtual functions declared by JITMemoryManager by + // passing them to the delegate. + + + virtual void + setMemoryWritable() override + { + return m_manager->setMemoryWritable(); + } + + + virtual void + setMemoryExecutable() override + { + return m_manager->setMemoryExecutable(); + } + + + virtual void + setPoisonMemory(bool p_poison) override + { + return m_manager->setPoisonMemory(p_poison); + } + + + virtual void + AllocateGOT() override + { + return m_manager->AllocateGOT(); + } + + + virtual uint8_t* + getGOTBase() const override + { + return m_manager->getGOTBase(); + } + + + virtual uint8_t* + startFunctionBody(const llvm::Function* p_f, uintptr_t& p_actualSize) override + { + return m_manager->startFunctionBody(p_f, p_actualSize); + } + + + virtual uint8_t* + allocateStub(const llvm::GlobalValue* p_f, unsigned p_stubSize, unsigned p_alignment) override + { + return m_manager->allocateStub(p_f, p_stubSize, p_alignment); + } + + + virtual void + endFunctionBody(const llvm::Function* p_f, uint8_t* p_functionStart, uint8_t* p_functionEnd) override + { + m_fun.m_start = p_functionStart; + m_fun.m_length = p_functionEnd - p_functionStart; + return m_manager->endFunctionBody(p_f, p_functionStart, p_functionEnd); + } + + + virtual uint8_t* + allocateSpace(intptr_t p_size, unsigned p_alignment) override + { + return m_manager->allocateSpace(p_size, p_alignment); + } + + + virtual uint8_t* + allocateGlobal(uintptr_t p_size, unsigned p_alignment) override + { + return m_manager->allocateGlobal(p_size, p_alignment); + } + + + virtual void + deallocateFunctionBody(void* p_body) override + { + return m_manager->deallocateFunctionBody(p_body); + } + + virtual uint8_t* + allocateCodeSection(uintptr_t p_size, unsigned p_alignment, unsigned p_sectionID, llvm::StringRef p_sectionName) override + { + return m_manager->allocateCodeSection(p_size, p_alignment, p_sectionID, p_sectionName); + } + + virtual uint8_t* + allocateDataSection(uintptr_t p_size, unsigned p_alignment, unsigned p_sectionID, llvm::StringRef p_sectionName, bool p_isReadOnly) override + { + return m_manager->allocateDataSection(p_size, p_alignment, p_sectionID, p_sectionName, p_isReadOnly); + } + + virtual bool + finalizeMemory(std::string* p_errMsg) override + { + return m_manager->finalizeMemory(p_errMsg); + } + + // Function length and size. + LlvmExecutableImpl::FunctionInfo m_fun; + + private: + + // Delegate manager. + boost::shared_ptr m_manager; + }; + + // Critical section used to single-thread entry to LLVM code generation. + // CRITSEC s_llvmCriticalSection; + std::mutex s_llvmCriticalSection; + bool s_use_llvmCriticalSection = true; +} + +class ConditionalAutoCriticalSection //: public INoHeapInstance +{ +private: + // CRITSEC* m_pCritSec; + std::mutex* m_pCritSec; + bool m_yes = true; + + ConditionalAutoCriticalSection(const ConditionalAutoCriticalSection&) = delete; + ConditionalAutoCriticalSection& operator=(const ConditionalAutoCriticalSection&) = delete; + +public: + // ConditionalAutoCriticalSection(CRITSEC* pCritSec, bool pYes = true) + ConditionalAutoCriticalSection(std::mutex* pCritSec, bool pYes = true) + : m_pCritSec(pCritSec) + , m_yes(pYes) + { + if (m_yes) + { + this->m_pCritSec->lock(); + } + } + + ~ConditionalAutoCriticalSection() + { + if (m_yes) + { + this->m_pCritSec->unlock(); + } + } +}; + +FreeForm2::LlvmCompilerImpl::LlvmCompilerImpl(unsigned int p_optimizationLevel, + CompilerFactory::DestinationFunctionType p_destinationFunctionType) + : m_optimizationLevel(p_optimizationLevel), + m_destinationFunctionType(p_destinationFunctionType), + m_persistentMemoryManager(nullptr) +{ + ConditionalAutoCriticalSection cs(&s_llvmCriticalSection, s_use_llvmCriticalSection); + llvm::InitializeNativeTarget(); + + std::unique_ptr persistent(new PersistentJITMemoryManager()); + m_persistentMemoryManager = persistent.get(); + + m_state.reset(new CompilationState(*persistent)); + + // Arrange our memory managers, so that the persistent memory + // manager is owned by the module, and we have a shared pointer to + // the delegated manager. + m_memoryManager = persistent->GetDelegate(); + persistent.release(); + + // Initialize and run a function pass manager, for optimization. + m_functionPassManager.reset(new llvm::FunctionPassManager(&m_state->GetModule())); + + llvm::PassManagerBuilder builder; + builder.OptLevel = p_optimizationLevel; + builder.populateFunctionPassManager(*m_functionPassManager); + + m_functionPassManager->doInitialization(); +} + + +FreeForm2::LlvmCompilerImpl::~LlvmCompilerImpl() +{ + ConditionalAutoCriticalSection cs(&s_llvmCriticalSection, s_use_llvmCriticalSection); + + // Reset the state variables inside the critical section. + m_memoryManager.reset(); + //m_persistentMemoryManager.release(); + m_functionPassManager.reset(); + m_state.reset(); +} + + +CompilationState& +LlvmCompilerImpl::GetState() +{ + return *m_state; +} + + +boost::shared_ptr +LlvmCompilerImpl::GetMemoryManager() +{ + return m_memoryManager; +} + + +llvm::FunctionPassManager& +LlvmCompilerImpl::GetFunctionPassManager() +{ + return *m_functionPassManager; +} + + +unsigned int +LlvmCompilerImpl::GetOptimizationLevel() const +{ + return m_optimizationLevel; +} + + +const PersistentJITMemoryManager& +LlvmCompilerImpl::GetPersistentMemoryManager() const +{ + FF2_ASSERT(m_persistentMemoryManager != nullptr); + return *m_persistentMemoryManager; +} + + +std::unique_ptr +FreeForm2::LlvmCompilerImpl::Compile(const ProgramImpl& p_program, bool p_debugOutput) +{ + std::auto_ptr execImpl(new LlvmExecutableImpl(*this, + p_program, + p_debugOutput, + m_destinationFunctionType)); + auto exec = boost::make_shared(execImpl); + + return std::unique_ptr(new ExecutableCompilerResults(exec)); +} + + +FreeForm2::LlvmExecutableImpl::LlvmExecutableImpl(FreeForm2::LlvmCompilerImpl& p_compiler, + const FreeForm2::ProgramImpl& p_program, + bool p_dumpExecutable, + CompilerFactory::DestinationFunctionType p_destinationFunctionType) + : m_type(p_program.GetType().GetImplementation()), + m_map(p_program.GetFeatureMap()), + m_destinationFunctionType(p_destinationFunctionType) +{ + // Single-thread access to LLVM, pending work to allow + // safe-multithreaded access. + ConditionalAutoCriticalSection cs(&s_llvmCriticalSection, s_use_llvmCriticalSection); + + CompilationState& state = p_compiler.GetState(); + + llvm::Function& fun = LlvmCodeGenerator::Compile(p_program.GetExpression(), + state, + p_program.GetAllocations(), + m_destinationFunctionType); + + if (p_dumpExecutable) + { + state.GetModule().dump(); + } + + // Verify the function. Note that, despite the name, verifyFunction + // returns true if the function is corrupt. + std::string verifierMessage; + llvm::raw_string_ostream messageStream(verifierMessage); + if (llvm::verifyFunction(fun, &messageStream)) + { + // Print the LLVM IR for the module. + std::string module; + llvm::raw_string_ostream moduleStream(module); + state.GetModule().print(moduleStream, NULL); + std::ostringstream err; + err << "Error verifying LLVM function ('" << verifierMessage + << "'): " << std::endl << moduleStream.str() << std::endl; + std::cout << err.str() << std::endl; + throw std::runtime_error(err.str()); + } + + p_compiler.GetFunctionPassManager().run(fun); + + if (p_compiler.GetOptimizationLevel() > 0 && p_dumpExecutable) + { + state.GetModule().dump(); + } + + // Get pointer to function for later execution. + m_fun = state.GetExecutionEngine().getPointerToFunction(&fun); + + m_memoryManager = p_compiler.GetMemoryManager(); + m_funcInfo = p_compiler.GetPersistentMemoryManager().m_fun; + + std::vector externals; + for (const auto& name : s_externalNames) + { + externals.push_back(state.GetExecutionEngine().getPointerToNamedFunction(name, false)); + } + + ConvertRelocations(m_relocations, + llvm::GetMachineRelocations(&state.GetExecutionEngine()), + GetFuncInfo(), + externals); +} + + +FreeForm2::LlvmExecutableImpl::~LlvmExecutableImpl() +{ + ConditionalAutoCriticalSection cs(&s_llvmCriticalSection, s_use_llvmCriticalSection); + + // Reset the memory manager inside the critical section. + m_memoryManager.reset(); +} + + +boost::shared_ptr +FreeForm2::LlvmExecutableImpl::Evaluate(StreamFeatureInput* p_input, + const Executable::FeatureType p_features[]) const +{ + FF2_ASSERT(m_destinationFunctionType == FreeForm2::CompilerFactory::SingleDocumentEvaluation); + switch (m_type.Primitive()) + { + // Return Result::IntType for any integer result for Phoenix compatibility. + case Type::UInt32: __attribute__((__fallthrough__)); + case Type::Int32: __attribute__((__fallthrough__)); + case Type::Int: + { + typedef Result::IntType (*TypedFun)(StreamFeatureInput* p_input, + const Executable::FeatureType*, + Result::IntType*); + TypedFun fun = reinterpret_cast(m_fun); + return boost::shared_ptr( + new ValueResult(EvaluateInternal( + fun, p_input, p_features, NULL, __FILE__, __LINE__))); + } + + case Type::Float: + { + typedef Result::FloatType (*TypedFun)(StreamFeatureInput* p_input, + const Executable::FeatureType*, + Result::IntType*); + TypedFun fun = reinterpret_cast(m_fun); + return boost::shared_ptr( + new ValueResult(EvaluateInternal( + fun, p_input, p_features, NULL, __FILE__, __LINE__))); + } + + case Type::Bool: + { + typedef bool (*TypedFun)(StreamFeatureInput* p_input, + const Executable::FeatureType*, + Result::IntType*); + TypedFun fun = reinterpret_cast(m_fun); + return boost::shared_ptr( + new ValueResult(EvaluateInternal( + fun, p_input, p_features, NULL, __FILE__, __LINE__))); + } + + case Type::Array: + { + const ArrayType& arrayType = static_cast(m_type.GetImplementation()); + switch (arrayType.GetChildType().Primitive()) + { + case Type::Int: + { + std::pair> + ret = EvaluateArray(p_input, + p_features, + __FILE__, + __LINE__); + return ArrayCodeGen::CreateArrayResult(arrayType, + ret.first, + ret.second); + } + + case Type::Float: + { + std::pair> + ret = EvaluateArray(p_input, + p_features, + __FILE__, + __LINE__); + return ArrayCodeGen::CreateArrayResult(arrayType, + ret.first, + ret.second); + } + + case Type::Bool: + { + std::pair> + ret = EvaluateArray(p_input, p_features, __FILE__, __LINE__); + return ArrayCodeGen::CreateArrayResult(arrayType, + ret.first, + ret.second); + } + + default: + { + Unreachable(__FILE__, __LINE__); + } + } + break; + } + + default: + { + Unreachable(__FILE__, __LINE__); + } + } +} + + +boost::shared_ptr +FreeForm2::LlvmExecutableImpl::Evaluate(const Executable::FeatureType* const* p_features, + UInt32 p_currentDocument, + UInt32 p_documentCount, + Int64* p_cache) const +{ + FF2_ASSERT(m_destinationFunctionType == FreeForm2::CompilerFactory::DocumentSetEvaluation); + switch (m_type.Primitive()) + { + case Type::Int: + case Type::Int32: + case Type::UInt32: + { + typedef Result::IntType (*TypedFun) (const Executable::FeatureType* const*, + UInt32, + UInt32, + Int64*); + TypedFun fun = reinterpret_cast(m_fun); + return boost::shared_ptr( + new ValueResult(EvaluateInternal( + fun, p_features, p_currentDocument, p_documentCount, p_cache, __FILE__, __LINE__))); + } + + case Type::Float: + { + typedef Result::FloatType (*TypedFun) (const Executable::FeatureType* const*, + UInt32, + UInt32, + Int64*); + TypedFun fun = reinterpret_cast(m_fun); + return boost::shared_ptr( + new ValueResult(EvaluateInternal( + fun, p_features, p_currentDocument, p_documentCount, p_cache, __FILE__, __LINE__))); + } + + case Type::Bool: + { + typedef bool (*TypedFun) (const Executable::FeatureType* const*, + UInt32, + UInt32, + Int64*); + TypedFun fun = reinterpret_cast(m_fun); + return boost::shared_ptr( + new ValueResult(EvaluateInternal( + fun, p_features, p_currentDocument, p_documentCount, p_cache, __FILE__, __LINE__))); + } + + default: + { + Unreachable(__FILE__, __LINE__); + } + } +} + + +Executable::DirectEvalFun +FreeForm2::LlvmExecutableImpl::EvaluationFunction() const +{ + // DirectEvalFun assumes float return for freeform2 float type. + BOOST_STATIC_ASSERT(sizeof(Result::FloatType) == sizeof(float)); + if (m_type.Primitive() == Type::Float) + { + return reinterpret_cast(m_fun); + } + else + { + return NULL; + } +} + + +Executable::AggregatedEvalFun +FreeForm2::LlvmExecutableImpl::AggregatedEvaluationFunction() const +{ + FF2_ASSERT(m_destinationFunctionType == FreeForm2::CompilerFactory::DocumentSetEvaluation); + // Aggregated EvalFun assumes float return for freeform2 float type. + BOOST_STATIC_ASSERT(sizeof(Result::FloatType) == sizeof(float)); + if (m_type.Primitive() == Type::Float) + { + return reinterpret_cast(m_fun); + } + else + { + return nullptr; + } +} + + +const Type& +FreeForm2::LlvmExecutableImpl::GetType() const +{ + return m_type; +} + + +// Get the size of external memory. +size_t +FreeForm2::LlvmExecutableImpl::GetExternalSize() const +{ + if (m_memoryManager.get()) + { + // This will give us an approximation of the memory allocated by LLVM, albeit one that is almost certainly correct for our purposes. + const size_t sizeOfMemoryAllocatedByLlvm = (m_memoryManager->GetNumCodeSlabs() * m_memoryManager->GetDefaultCodeSlabSize()) + + (m_memoryManager->GetNumDataSlabs() * m_memoryManager->GetDefaultDataSlabSize()); + + // JITMemoryManager is a shared resource by all executables (an executable belongs to one neural input). + // Each executable will report an equal part of the shared memory. + return (sizeOfMemoryAllocatedByLlvm + sizeof(llvm::JITMemoryManager)) / m_memoryManager.use_count(); + } + else + { + return sizeof(llvm::JITMemoryManager); + } +} + + +template +ReturnType +FreeForm2::LlvmExecutableImpl::EvaluateInternal(ReturnType(*p_fun)(const Executable::FeatureType* const* p_features, + UInt32 p_currentDocument, + UInt32 p_documentCount, + Int64* p_cache), + const Executable::FeatureType* const* p_features, + UInt32 p_currentDocument, + UInt32 p_documentCount, + Int64* p_cache, + const char* p_sourceFile, + unsigned int p_sourceLine) const +{ + return p_fun(p_features, p_currentDocument, p_documentCount, p_cache); +} + + +// Function to take flattened array results of templated type, and reify +// them into a multi-dimensional array result structure. +template +std::pair> +FreeForm2::LlvmExecutableImpl::EvaluateArray(StreamFeatureInput* p_input, + const Executable::FeatureType p_features[], + const char* p_sourceFile, + unsigned int p_sourceLine) const +{ + typedef ArrayCodeGen::ArrayBoundsType (*TypedFun)(StreamFeatureInput*, + const Executable::FeatureType*, + T*); + TypedFun fun = reinterpret_cast(m_fun); + const ArrayType& arrayType = static_cast(m_type.GetImplementation()); + boost::shared_array space(new T[arrayType.GetMaxElements()]); + return std::make_pair(EvaluateInternal(fun, + p_input, + p_features, + space.get(), + p_sourceFile, + p_sourceLine), + space); +} + + +template +ReturnType +FreeForm2::LlvmExecutableImpl::EvaluateInternal(ReturnType (*p_fun)(StreamFeatureInput* p_input, + const Executable::FeatureType p_features[], + ArrayArgType* p_arraySpace), + StreamFeatureInput* p_input, + const Executable::FeatureType p_features[], + ArrayArgType* p_arraySpace, + const char* p_sourceFile, + unsigned int p_sourceLine) const +{ + try + { + return p_fun(p_input, p_features, p_arraySpace); + } + catch(std::exception e){ + + } + Unreachable(__FILE__, __LINE__); +} + + +const LlvmExecutableImpl::FunctionInfo& +LlvmExecutableImpl::GetFuncInfo() const +{ + return m_funcInfo; +} + + +bool +LlvmExecutableImpl::operator==(const FreeForm2::LlvmExecutableImpl& p_other) const +{ + // TODO: implement correct equality check for executable code with respect to relocations. + return true; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCompiler.h b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCompiler.h new file mode 100644 index 000000000000..b46bbd4ff4e9 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmCompiler.h @@ -0,0 +1,220 @@ +#pragma once + +#include "ArrayCodeGen.h" +#include +#include "Compiler.h" +#include "Executable.h" +#include "FreeForm2.h" +#include +#include "TypeImpl.h" + +namespace llvm +{ + namespace legacy { + class FunctionPassManager; + } + + using legacy::FunctionPassManager; + class JITMemoryManager; +} + +namespace FreeForm2 +{ + // class CodeHeap; + class CompilationState; + class LlvmExecutableImpl; + class PersistentJITMemoryManager; + + // This is to control whether we're usign llvm code gen single-threaded or not. + // Default value is true, which restricts llvm code gen to be single-threaded globally, + // due to LLVM's restrictions that may or may not exists in current version. + // TODO: we're setting this flag to false experimentally in restricted scenario (e.g. URPService TC2) + // to allow multi-threaded llvm code gen, so that we can alleviate some long-standing + // pain caused by extremely time-consuming compilation process of very large freeform files. + // After double confirming that the version of LLVM integrated in current code base is fully + // multi-thread safe, we shall remove this critical section all together. + extern bool s_use_llvmCriticalSection; + + class LlvmCompilerImpl : public CompilerImpl + { + public: + LlvmCompilerImpl(unsigned int p_optimizationLevel, + CompilerFactory::DestinationFunctionType p_destinationFunctionType); + virtual ~LlvmCompilerImpl(); + + // Compile the given program into an executable. + virtual + std::unique_ptr + Compile(const ProgramImpl& p_program, + bool p_debugOutput) override; + + CompilationState& GetState(); + boost::shared_ptr GetMemoryManager(); + llvm::FunctionPassManager& GetFunctionPassManager(); + unsigned int GetOptimizationLevel() const; + const PersistentJITMemoryManager& GetPersistentMemoryManager() const; + + private: + std::unique_ptr m_state; + + boost::shared_ptr m_memoryManager; + + std::unique_ptr m_functionPassManager; + + PersistentJITMemoryManager* m_persistentMemoryManager; + + unsigned int m_optimizationLevel; + + // Destination function type. + CompilerFactory::DestinationFunctionType m_destinationFunctionType; + }; + + // An ExectuableImpl is the implementation class for executables, which + // currently compiles and runs via LLVM. Supports serialization and + // deserialization of executable code. + // + // There are two ways for LlvmExecutableImpl to be created: + // * As a result of FreeForm compilation. In this case binary code is allocated + // by m_memoryManager and m_relocations contain relocation table which should be + // serialized with the code. + // * As a result of deserialization. In this case binary code is allocaed by m_codeHeap + // and relocation table precedes it in the memory. Re-serilaization just writes + // an extent of memory specified by m_funcInfo. + class LlvmExecutableImpl : public ExecutableImpl + { + public: + static const int c_serializedVersion = 3; + + // Descriptor for LLVM-generated code chunk. + struct FunctionInfo + { + uint8_t* m_start; + ptrdiff_t m_length; + }; + + // Code relocation descriptor. + struct RelocationInfo + { + // Relocation type. + enum Type + { + // m_delta is an offset inside LLVM-generated code. + Internal, + // m_dela is an index in the table of external functions. + External + }; + + // Relocation type. + uint32_t m_type; + + // Offset of relocated value from the m_start of FunctionInfo. + uint32_t m_offset; + + // See comment for enum Type. + uint32_t m_delta; + }; + + // Compile a program to executable code. + LlvmExecutableImpl(LlvmCompilerImpl& p_compiler, + const ProgramImpl& p_program, + bool p_dumpExecutable, + CompilerFactory::DestinationFunctionType p_destinationFunctionType); + + // Take and wrap deserialized executable code. + // p_binary and p_binarySize defines a serialized representation of function preceded by relocation table. + // LlvmExecutableImpl(unsigned char* p_binary, + // size_t p_binarySize, + // const TypeImpl& p_type, + // boost::shared_ptr p_codeHeap, + // CompilerFactory::DestinationFunctionType p_destinationFunctionType); + + virtual ~LlvmExecutableImpl(); + + virtual boost::shared_ptr + Evaluate(StreamFeatureInput* p_input, + const Executable::FeatureType p_features[]) const override; + + virtual boost::shared_ptr + Evaluate(const Executable::FeatureType* const* p_features, + UInt32 p_currentDocument, + UInt32 p_documentCount, + Int64* p_cache) const override; + + virtual Executable::DirectEvalFun + EvaluationFunction() const override; + + virtual Executable::AggregatedEvalFun + AggregatedEvaluationFunction() const override; + + virtual const Type& GetType() const override; + + // Get the size of external memory. + virtual size_t GetExternalSize() const override; + + const FunctionInfo& GetFuncInfo() const; + + // unsigned char* SerializeBinary(unsigned char* p_buffer) const; + + // size_t GetSerializedSize() const; + + // Compares this ExecutableImpl against another one. + bool operator==(const LlvmExecutableImpl& p_other) const; + + // static LlvmExecutableImpl* DeserializeBinary(unsigned char* p_binary, const boost::shared_ptr& p_codeHeap, size_t p_codeSize); + + private: + + // Function to support aggregated freefom. + template + ReturnType EvaluateInternal(ReturnType (*p_fun)(const Executable::FeatureType* const* p_features, + UInt32 p_currentDocument, + UInt32 p_documentCount, + Int64* p_cache), + const Executable::FeatureType* const* p_features, + UInt32 p_currentDocument, + UInt32 p_documentCount, + Int64* p_cache, + const char* p_sourceFile, + unsigned int p_sourceLine) const; + + // Function to take flattened array results of templated type, and reify + // them into a multi-dimensional array result structure. + template + std::pair> + EvaluateArray(StreamFeatureInput* p_input, + const Executable::FeatureType p_features[], + const char* p_sourceFile, + unsigned int p_sourceLine) const; + + template + ReturnType EvaluateInternal(ReturnType (*p_fun)(StreamFeatureInput* p_input, + const Executable::FeatureType p_features[], + ArrayArgType* p_arraySpace), + StreamFeatureInput* p_input, + const Executable::FeatureType p_features[], + ArrayArgType* p_arraySpace, + const char* p_sourceFile, + unsigned int p_sourceLine) const; + + // Top-level type of the program. + Type m_type; + + // Generic pointer to generated function. + void* m_fun; + + // JIT memory manager that owns the memory pointed to by m_fun. + boost::shared_ptr m_memoryManager; + + // Feature map used to compile program. + DynamicRank::IFeatureMap& m_map; + + // Destination function type. + CompilerFactory::DestinationFunctionType m_destinationFunctionType; + + // Program compilation result descriptor. Used by 1st ctor only. + FunctionInfo m_funcInfo; + + // Executable relocation data. Used by 1st ctor only. + std::vector m_relocations; + }; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmRuntimeLibrary.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmRuntimeLibrary.cpp new file mode 100644 index 000000000000..3be031125efd --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmRuntimeLibrary.cpp @@ -0,0 +1,252 @@ +#include "LlvmRuntimeLibrary.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// #include +#include +#include + +using namespace FreeForm2; + +namespace FreeForm2 +{ + // static CRITSEC s_randLock; + static std::mutex s_randLock; + static std::minstd_rand0 s_randGenerator; + static bool s_isRandSeeded = false; + static std::uniform_real_distribution<> s_randDistribution; +} + + // Get a random number in the range [0, 1.0). This function uses a simple + // LCPRNG. +extern "C" double FreeForm2GetRandomValue() +{ + // ::AutoCriticalSection lock(&FreeForm2::s_randLock); + s_randLock.lock(); + if (!FreeForm2::s_isRandSeeded) + { + FreeForm2::s_randGenerator.seed(GetTickCount()); + FreeForm2::s_isRandSeeded = true; + } + double value = FreeForm2::s_randDistribution(FreeForm2::s_randGenerator); + s_randLock.unlock(); + return value; + // return FreeForm2::s_randDistribution(FreeForm2::s_randGenerator); +} + +namespace +{ + + struct GlobalEntry + { + // Create a GlobalEntry for a name, GlobalValue, and mapping tuple. + GlobalEntry(const char* p_name, llvm::GlobalValue* p_value, void* p_mapping) + : m_name(p_name), + m_value(p_value), + m_mapping(p_mapping) + { + } + + // The name of the global. + std::string m_name; + + // The actual llvm value type. + llvm::GlobalValue* m_value; + + // The pointer to which the value will be mapped. + void* m_mapping; + }; + + // Create a ValueAndMapping function for std::rand. + GlobalEntry CreateRand(llvm::LLVMContext& p_context) + { + llvm::Type* randRet = llvm::Type::getDoubleTy(p_context); + FF2_ASSERT(randRet->getPrimitiveSizeInBits() == sizeof(decltype(FreeForm2GetRandomValue())) * 8); + llvm::FunctionType* randSig = llvm::FunctionType::get(randRet, false); + llvm::Function* randFunc = llvm::Function::Create(randSig, + llvm::GlobalValue::ExternalLinkage, + llvm::Twine("rand")); + return GlobalEntry("rand", randFunc, &FreeForm2GetRandomValue); + } + + // Compare SIZED_STRING values for a less-than ordering. + struct SizedStrLess + { + bool operator()(SIZED_STRING p_left, SIZED_STRING p_right) const + { + int val = std::char_traits::compare(p_left.pcData, + p_right.pcData, + std::min(p_left.cbData, p_right.cbData)); + if (!val) + { + return p_left.cbData < p_right.cbData; + } + else + { + return val < 0; + } + } + }; +} + +class LlvmRuntimeLibrary::Impl final +{ +public: + Impl(llvm::LLVMContext& p_context) + : m_context(p_context) + { + Initialize(); + } + + // Delete copy constructor and assignment operator. + Impl(const Impl&) = delete; + Impl& operator=(const Impl&) = delete; + + // Add all runtime symbols to the specified module. This method implements + // LlvmRuntimeLibrary::AddLibraryToModule. + void AddLibraryToModule(llvm::Module& p_module) const + { + for (const auto& entry : m_globals) + { + llvm::GlobalValue* value = entry.second.first; + if (llvm::Function* func = llvm::dyn_cast(value)) + { + p_module.getFunctionList().push_back(func); + } + else if (llvm::GlobalVariable* var = llvm::dyn_cast(value)) + { + p_module.getGlobalList().push_back(var); + } + else + { + FF2_UNREACHABLE(); + } + } + } + + // Add global value mappings to an exeuction engine. This method implements + // LlvmRuntimeLibrary::AddExecutionMappings. + void AddExecutionMappings(llvm::ExecutionEngine& p_engine) const + { + for (const auto& entry : m_globals) + { + const llvm::GlobalValue* const value = entry.second.first; + void* const mapping = entry.second.second; + p_engine.updateGlobalMapping(value, mapping); + } + } + + // Look up a GlobalValue by name. This implements + // LlvmRuntimeLibrary::FundValue. + llvm::GlobalValue* FindValue(SIZED_STRING p_name) const + { + const auto find = m_globals.find(p_name); + if (find == m_globals.end()) + { + return nullptr; + } + else + { + const GlobalAndMapping& value = find->second; + return value.first; + } + } + + // Find a runtime function with the specified name. This implements + // LlvmRuntimeLibrary::FindFunction. + llvm::Function* FindFunction(SIZED_STRING p_name) const + { + return llvm::dyn_cast_or_null(FindValue(p_name)); + } + +private: + // Initialize the+ runtime library. + void Initialize(); + + // Add a GlobalEntry to the mapping structures. + void AddEntry(const GlobalEntry& p_entry); + + // A pair containing the LLVM GlobalValue and the mapping pointer. + typedef std::pair GlobalAndMapping; + + // A mapping of global name to GlobalEntry struct. + std::map m_globals; + + // The storage for names of the globals. + std::forward_list m_globalNames; + + // The LLVMContext passed to the constructor. + llvm::LLVMContext& m_context; +}; + + +void LlvmRuntimeLibrary::Impl::Initialize() +{ + AddEntry(CreateRand(m_context)); +} + + +void LlvmRuntimeLibrary::Impl::AddEntry(const GlobalEntry& p_entry) +{ + m_globalNames.push_front(p_entry.m_name); + const std::string& name = m_globalNames.front(); + const SIZED_STRING sizedName = CStackSizedString(name.c_str(), name.size()); + + auto pairIterBool + = m_globals.emplace(sizedName, GlobalAndMapping(p_entry.m_value, p_entry.m_mapping)); + FF2_ASSERT(pairIterBool.second && "Global already exists"); +} + + +LlvmRuntimeLibrary::LlvmRuntimeLibrary(llvm::LLVMContext& p_context) + : m_impl(new Impl(p_context)) +{ +} + + +// This empty destructor is required so that the compiled can find Impl::~Impl. +// If it is not explicitly defined in this file, the compiler will not call +// the Impl destructor and will issue a warning. +LlvmRuntimeLibrary::~LlvmRuntimeLibrary() +{ +} + + +void +LlvmRuntimeLibrary::AddLibraryToModule(llvm::Module& p_module) const +{ + m_impl->AddLibraryToModule(p_module); +} + + +void +LlvmRuntimeLibrary::AddExecutionMappings(llvm::ExecutionEngine& p_engine) const +{ + m_impl->AddExecutionMappings(p_engine); +} + + +llvm::GlobalValue* +LlvmRuntimeLibrary::FindValue(SIZED_STRING p_name) const +{ + return m_impl->FindValue(p_name); +} + + +llvm::Function* +LlvmRuntimeLibrary::FindFunction(SIZED_STRING p_name) const +{ + return m_impl->FindFunction(p_name); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmRuntimeLibrary.h b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmRuntimeLibrary.h new file mode 100644 index 000000000000..a3c83d5196b8 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Backend/llvm/LlvmRuntimeLibrary.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include +#include + +namespace llvm +{ + class LLVMContext; + class ExecutionEngine; + class Module; + class GlobalValue; + class Function; +} + +namespace FreeForm2 +{ + class LlvmRuntimeLibrary : boost::noncopyable + { + public: + // Create a LLVM Runtime Library using the specified LLVMContext. + LlvmRuntimeLibrary(llvm::LLVMContext& p_context); + + // Destroy the implementation. + ~LlvmRuntimeLibrary(); + + // Add all runtime symbols to the specified module. This is similar to + // adding forward declarations to a .cpp file. + void AddLibraryToModule(llvm::Module& p_module) const; + + // Add global value mappings to an exeuction engine, which is necessary + // when linking a module which uses the runtime library. + void AddExecutionMappings(llvm::ExecutionEngine& p_engine) const; + + // Look up a GlobalValue by name. GlobalValues generally include + // external variables and functions; see the LLVM documentation for + // more information. If the value is not found, this method returns + // null. + llvm::GlobalValue* FindValue(SIZED_STRING p_name) const; + + // Find a runtime function with the specified name. This is a + // specialization of FindValue for Functions. If the function is not + // found, or if the GlobalValue is not a function, this method returns + // null. + llvm::Function* FindFunction(SIZED_STRING p_name) const; + + private: + // The implementation of this class is hidden. + class Impl; + + // Pointer to the implementation. + std::unique_ptr m_impl; + }; +} + +extern "C" double FreeForm2GetRandomValue(); + +inline unsigned long GetTickCount() +{ + struct timespec ts; + + clock_gettime(CLOCK_MONOTONIC, &ts); + + return (ts.tv_sec * 1000 + ts.tv_nsec / 1000000); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Allocation.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Allocation.cpp new file mode 100644 index 000000000000..090adbe9efe0 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Allocation.cpp @@ -0,0 +1,67 @@ +#include "Allocation.h" + +#include "ArrayType.h" +#include "FreeForm2Assert.h" + +using namespace FreeForm2; + +FreeForm2::Allocation::Allocation(AllocationType p_allocType, + VariableID p_id, + const TypeImpl& p_type) + : m_allocType(p_allocType), + m_id(p_id), + m_type(p_type), + m_children(0) +{ + if (p_allocType == ArrayLiteral) + { + FF2_ASSERT(m_type.Primitive() == Type::Array); + m_children = 1; + + const ArrayType& arrayType = static_cast(p_type); + for (UInt32 i = 0; i < arrayType.GetDimensionCount(); i++) + { + m_children *= arrayType.GetDimensions()[i]; + } + } +} + + +FreeForm2::Allocation::Allocation(AllocationType p_allocType, + VariableID p_id, + const TypeImpl& p_type, + size_t p_children) + : m_allocType(p_allocType), + m_id(p_id), + m_type(p_type), + m_children(p_children) +{ +} + + +FreeForm2::Allocation::AllocationType +FreeForm2::Allocation::GetAllocationType() const +{ + return m_allocType; +} + + +const TypeImpl& +FreeForm2::Allocation::GetType() const +{ + return m_type; +} + + +VariableID +FreeForm2::Allocation::GetAllocationId() const +{ + return m_id; +} + + +size_t +FreeForm2::Allocation::GetNumChildren() const +{ + return m_children; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Allocation.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Allocation.h new file mode 100644 index 000000000000..20cb3c536b17 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Allocation.h @@ -0,0 +1,64 @@ +#pragma once +#ifndef FREEFORM2_ARRAYALLOCATIONWRAPPER_H +#define FREEFORM2_ARRAYALLOCATIONWRAPPER_H + +#include "Expression.h" +#include "Visitor.h" +#include + +namespace FreeForm2 +{ + class AllocationExpression; + + // Wraps the original expression and decouples array allocation from + // initialization. + class Allocation : public boost::noncopyable + { + public: + enum AllocationType + { + ArrayLiteral, + FeatureArray, + ExternArray, + LiteralStream, + LiteralWord, + Declaration + }; + + Allocation(AllocationType p_allocType, + VariableID p_id, + const TypeImpl& p_type); + + Allocation(AllocationType p_allocType, + VariableID p_id, + const TypeImpl& p_type, + size_t p_children); + + // Gets the type of the element to be allocated. + const TypeImpl& GetType() const; + + // Gets the number of children of the allocation. + size_t GetNumChildren() const; + + // Gets the type of the allocation. + AllocationType GetAllocationType() const; + + // Gets the identifier of the array. + VariableID GetAllocationId() const; + + private: + // The type of the allocation. + const AllocationType m_allocType; + + // The number of children. + size_t m_children; + + // The array identificator. + const VariableID m_id; + + // The type of the element to be allocated. + const TypeImpl& m_type; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayDereferenceExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayDereferenceExpression.cpp new file mode 100644 index 000000000000..3ee6d4b2ba42 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayDereferenceExpression.cpp @@ -0,0 +1,147 @@ +#include "ArrayDereferenceExpression.h" + +#include "ArrayType.h" +#include "FreeForm2Assert.h" +#include "RefExpression.h" +#include "SimpleExpressionOwner.h" +#include +#include "TypeManager.h" +#include "TypeUtil.h" +#include "Visitor.h" + +namespace +{ + const FreeForm2::TypeImpl& + DerefType(const FreeForm2::TypeImpl& p_arrayType, const FreeForm2::SourceLocation& p_sourceLocation) + { + if (p_arrayType.Primitive() != FreeForm2::Type::Array) + { + std::ostringstream err; + err << "The array operand in an array dereference expression " + << "is not an array (instead, it is a " + << p_arrayType << ")"; + throw FreeForm2::ParseError(err.str(), p_sourceLocation); + } + + const FreeForm2::ArrayType& arrayType = static_cast(p_arrayType); + return arrayType.GetDerefType(); + } +} + + +FreeForm2::ArrayDereferenceExpression::ArrayDereferenceExpression(const Annotations& p_annotations, + const Expression& p_array, + const Expression& p_index, + size_t p_version) + : Expression(p_annotations), + m_type(DerefType(p_array.GetType(), p_annotations.m_sourceLocation)), + m_array(p_array), + m_index(p_index), + m_version(p_version) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::ArrayDereferenceExpression::GetType() const +{ + if (!m_index.GetType().IsIntegerType()) + { + std::ostringstream err; + err << "The index operand in an array dereference expression " + << "is not an integer type (instead, it is a " + << m_index.GetType() << ")"; + throw ParseError(err.str(), GetSourceLocation()); + } + + return m_type; +} + + +size_t +FreeForm2::ArrayDereferenceExpression::GetNumChildren() const +{ + return 2; +} + + +void +FreeForm2::ArrayDereferenceExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_array.Accept(p_visitor); + m_index.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +void +FreeForm2::ArrayDereferenceExpression::AcceptReference(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisitReference(*this)) + { + m_array.AcceptReference(p_visitor); + m_index.Accept(p_visitor); + + p_visitor.VisitReference(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::Expression& +FreeForm2::ArrayDereferenceExpression::GetArray() const +{ + return m_array; +} + + +const FreeForm2::Expression& +FreeForm2::ArrayDereferenceExpression::GetIndex() const +{ + return m_index; +} + + +size_t +FreeForm2::ArrayDereferenceExpression::GetVersion() const +{ + return m_version; +} + + +FreeForm2::VariableID +FreeForm2::ArrayDereferenceExpression::GetBaseArrayId() const +{ + const ArrayDereferenceExpression* deref = this; + const Expression* array = &GetArray(); + + while (deref != nullptr) + { + array = &deref->GetArray(); + deref = dynamic_cast(array); + } + + const VariableRefExpression* base + = dynamic_cast(array); + + if (base != nullptr) + { + return base->GetId(); + } + else + { + // The base array is a literal. + return VariableID::c_invalidID; + } +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayDereferenceExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayDereferenceExpression.h new file mode 100644 index 000000000000..f137bbf95c4b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayDereferenceExpression.h @@ -0,0 +1,53 @@ +#pragma once + +#ifndef FREEFORM2_ARRAY_DEREFERENCE_EXPRESSION_H +#define FREEFORM2_ARRAY_DEREFERENCE_EXPRESSION_H + +#include "Expression.h" + +namespace FreeForm2 +{ + class ArrayType; + class ProgramParseState; + + // An array-dereference expression removes a dimension from an array. + class ArrayDereferenceExpression : public Expression + { + public: + ArrayDereferenceExpression(const Annotations& p_annotations, + const Expression& p_array, + const Expression& p_index, + size_t p_version); + + // Methods inherited from Expression. + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual void AcceptReference(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + + const Expression& GetArray() const; + const Expression& GetIndex() const; + size_t GetVersion() const; + + // Gets the VariableID of the array object, regardless of the number + // of dereferences. + VariableID GetBaseArrayId() const; + + private: + // Dereferenced type of this expression. + const TypeImpl& m_type; + + // Array expression being dereferenced. + const Expression& m_array; + + // Index supplied. + const Expression& m_index; + + // A unique version number associated with a particular + // value for this variable. + const size_t m_version; + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLength.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLength.cpp new file mode 100644 index 000000000000..f4f952428255 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLength.cpp @@ -0,0 +1,54 @@ +#include "ArrayLength.h" + +#include "FreeForm2Assert.h" +#include "Expression.h" +#include "SimpleExpressionOwner.h" +#include "Visitor.h" +#include + +FreeForm2::ArrayLengthExpression::ArrayLengthExpression(const Annotations& p_annotations, + const Expression& p_array) + : Expression(p_annotations), + m_array(p_array) +{ +} + + +void +FreeForm2::ArrayLengthExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_array.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::ArrayLengthExpression::GetType() const +{ + if (m_array.GetType().Primitive() != Type::Array) + { + std::ostringstream err; + err << "Argument to array-length expression must be " + << "an array (got type '" + << m_array.GetType() << "')"; + throw ParseError(err.str(), GetSourceLocation()); + } + + return TypeImpl::GetUInt32Instance(true); +} + + +size_t +FreeForm2::ArrayLengthExpression::GetNumChildren() const +{ + return 1; +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLength.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLength.h new file mode 100644 index 000000000000..531bb14770bd --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLength.h @@ -0,0 +1,38 @@ +#pragma once + +#ifndef FREEFORM2_ARRAY_LENGTH_H +#define FREEFORM2_ARRAY_LENGTH_H + +#include "Expression.h" + +namespace FreeForm2 +{ + class ArrayLengthExpression : public Expression + { + public: + // Construct an array length expression given the array in question. + ArrayLengthExpression(const Annotations& p_annotations, + const Expression& p_array); + + virtual void Accept(Visitor& p_visitor) const override; + + // Return the type of an array-length expression (int). + virtual const TypeImpl& GetType() const override; + + // Return the number of child nodes for this expression. + virtual size_t GetNumChildren() const override; + + // Get the array + const Expression& GetArray() const + { + return m_array; + } + + private: + // Array that we're calculating the length of. + const Expression& m_array; + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLiteralExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLiteralExpression.cpp new file mode 100644 index 000000000000..e848ad8e955c --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLiteralExpression.cpp @@ -0,0 +1,450 @@ +#include "ArrayLiteralExpression.h" + +#include "Allocation.h" +#include "ArrayType.h" +#include +#include +#include "ConvertExpression.h" +#include "FreeForm2Assert.h" +#include "SimpleExpressionOwner.h" +#include "Visitor.h" +#include "TypeManager.h" +#include "TypeUtil.h" +#include + + +FreeForm2::ArrayLiteralExpression::ArrayLiteralExpression( + const Annotations& p_annotations, + const TypeImpl& p_annotatedType, + const std::vector& p_children, + VariableID p_id, + TypeManager& p_typeManager) + : Expression(p_annotations), + m_isFlat(false), + m_type(NULL), + m_numChildren(static_cast(p_children.size())), + m_id(p_id) +{ + for (unsigned int i = 0; i < p_children.size(); i++) + { + m_children[i] = p_children[i]; + } + + m_type = &UnifyTypes(p_annotatedType, p_typeManager); +} + + +FreeForm2::ArrayLiteralExpression::ArrayLiteralExpression( + const Annotations& p_annotations, + const ArrayType& p_type, + const std::vector& p_elements, + VariableID p_id) + : Expression(p_annotations), + m_type(&p_type), + m_isFlat(true), + m_numChildren(static_cast(p_elements.size())), + m_id(p_id) +{ + FF2_ASSERT(p_type.Primitive() == Type::Array); + FF2_ASSERT(m_type->GetChildType().Primitive() != Type::Array); + + for (unsigned int i = 0; i < p_elements.size(); i++) + { + m_children[i] = p_elements[i]; + } +} + + +const FreeForm2::ArrayType& +FreeForm2::ArrayLiteralExpression::UnifyTypes(const TypeImpl& p_annotatedType, + TypeManager& p_typeManager) +{ + if (m_numChildren == 0) + { + const unsigned int dimensions[] = { 0 }; + return p_typeManager.GetArrayType(p_annotatedType, false, 1, dimensions, 0); + } + + const TypeImpl* childType = &m_children[0]->GetType(); + + for (unsigned int i = 0; i < m_numChildren; i++) + { + const TypeImpl& nextType = m_children[i]->GetType(); + if (!childType->IsSameAs(nextType, true)) + { + const TypeImpl& unifiedType + = TypeUtil::Unify(*childType, nextType, p_typeManager, true, false); + if (!unifiedType.IsValid()) + { + std::ostringstream err; + err << "Can't unify " << childType << " and " << nextType; + throw ParseError(err.str(), GetSourceLocation()); + } + childType = &unifiedType; + } + else + { + childType = childType->IsConst() ? childType : &nextType; + } + } + + const TypeImpl* unifiedType = NULL; + const TypeImpl* inferredType = NULL; + std::vector dimensions; + unsigned int maxElements = 0; + + if (childType->Primitive() == Type::Array) + { + FF2_ASSERT(!IsFlat()); + const ArrayType* childArray = static_cast(childType); + + dimensions.push_back(m_numChildren); + dimensions.insert(dimensions.end(), + childArray->GetDimensions(), + childArray->GetDimensions() + childArray->GetDimensionCount()); + + inferredType = &childArray->GetChildType(); + unifiedType = &TypeUtil::Unify(p_annotatedType, + childArray->GetChildType(), + p_typeManager, + false, + false); + for (unsigned int i = 0; i < m_numChildren; i++) + { + FF2_ASSERT(m_children[i]->GetType().Primitive() == Type::Array); + maxElements += static_cast(m_children[i]->GetType()).GetMaxElements(); + } + } + else + { + unifiedType = &TypeUtil::Unify(*childType, p_annotatedType, p_typeManager, false, false); + inferredType = childType; + maxElements = m_numChildren; + dimensions.push_back(m_numChildren); + } + + if (!unifiedType->IsValid()) + { + std::ostringstream err; + err << "Annotated array type (" + << p_annotatedType + << ") did not match inferred array type (" + << *inferredType + << ")"; + throw ParseError(err.str(), GetSourceLocation()); + } + + // Force array literals to be non-const. This is required for intialization + // of array variables. This would theoretically allow an assignment to an + // array literal; however, this is disallowed by Visage grammar. + return p_typeManager.GetArrayType(*unifiedType, + false, + static_cast(dimensions.size()), + &dimensions[0], + maxElements); +} + + +void +FreeForm2::ArrayLiteralExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (size_t i = 0; i < GetNumChildren(); i++) + { + size_t idx = GetNumChildren() - i - 1; + + m_children[idx]->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +void +FreeForm2::ArrayLiteralExpression::AcceptReference(Visitor& p_visitor) const +{ + std::ostringstream err; + err << "Array literals cannot be used as l-values."; + throw ParseError(err.str(), GetSourceLocation()); +} + + +boost::shared_ptr +FreeForm2::ArrayLiteralExpression::Alloc(const Annotations& p_annotations, + const TypeImpl& p_childType, + const std::vector& p_children, + VariableID p_id, + TypeManager& p_typeManager) +{ + size_t bytes = sizeof(ArrayLiteralExpression) + + (p_children.size() - 1) * sizeof(Expression*); + + if (p_children.size() > ArrayType::c_maxElementsPerDimension) + { + std::ostringstream err; + err << "Array literals cannot have more than " << ArrayType::c_maxElementsPerDimension << " elements per dimension."; + throw ParseError(err.str(), p_annotations.m_sourceLocation); + } + + // Allocate a shared_ptr that deletes an ArrayLiteralExpression + // allocated in a char[]. + boost::shared_ptr exp(new (new char[bytes]) + ArrayLiteralExpression(p_annotations, p_childType, p_children, p_id, p_typeManager), DeleteAlloc); + return exp; +} + + +boost::shared_ptr +FreeForm2::ArrayLiteralExpression::Alloc(const Annotations& p_annotations, + const ArrayType& p_type, + const std::vector& p_children, + VariableID p_id) +{ + size_t bytes = sizeof(ArrayLiteralExpression) + + (p_children.size() - 1) * sizeof(Expression*); + + for (unsigned int i = 0; i < p_type.GetDimensionCount(); i++) + { + if (p_type.GetDimensions()[i] > ArrayType::c_maxElementsPerDimension) + { + std::ostringstream err; + err << "Array literals cannot have more than " << ArrayType::c_maxElementsPerDimension << " elements per dimension."; + throw ParseError(err.str(), p_annotations.m_sourceLocation); + } + } + + // Allocate a shared_ptr that deletes an ArrayLiteralExpression + // allocated in a char[]. + boost::shared_ptr exp(new (new char[bytes]) + ArrayLiteralExpression(p_annotations, p_type, p_children, p_id), DeleteAlloc); + return exp; +} + + +const FreeForm2::Expression* const* +FreeForm2::ArrayLiteralExpression::Begin() const +{ + return &m_children[0]; +} + + +const FreeForm2::Expression* const* +FreeForm2::ArrayLiteralExpression::End() const +{ + return &m_children[GetNumChildren()]; +} + + +void +FreeForm2::ArrayLiteralExpression::DeleteAlloc(ArrayLiteralExpression* p_allocated) +{ + // Manually call dtor for array literal expression. + p_allocated->~ArrayLiteralExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + + +const FreeForm2::ArrayLiteralExpression& +FreeForm2::ArrayLiteralExpression::Flatten(SimpleExpressionOwner& p_owner) const +{ + return Flatten(p_owner, NULL, NULL); +} + + +const FreeForm2::ArrayLiteralExpression& +FreeForm2::ArrayLiteralExpression::Flatten(SimpleExpressionOwner& p_owner, + const TypeImpl* p_annotatedType, + TypeManager* p_typeManager) const +{ + // We've chosen a pointer+length-based representation of an array + // (pascal-style, sort-of). That means we can't use 'compositional' + // arrays, where multi-dimensional arrays are simply single-dimension + // arrays that hold other single-dimension arrays. As such, we need to + // ensure that every array is 'square', in that each element in each + // dimension holds the same number of elements. Thus, we can infer the + // size of a multi-dimensional array by knowing the number of + // sub-sub-elements per sub-element, and the number of sub-elements, + // making our pointer+length representation practical. + + FF2_ASSERT(!IsFlat()); + + // Check for empty arrays with no annotations. + if (p_annotatedType && p_annotatedType->Primitive() == Type::Unknown) + { + p_annotatedType = NULL; + } + + if (m_numChildren == 0 && !p_annotatedType) + { + throw ParseError("Can't infer array element type from empty array.", GetSourceLocation()); + } + + // Stack of unflattened array elements, along with the number of + // dimensions of elements they contain. + std::vector> + stack(1, std::make_pair(m_type->GetDimensionCount(), this)); + + // Sizes of each dimension. + std::vector dimensions(m_type->GetDimensionCount(), 0); + + // Indication, for each dimension, of whether we currently know how big + // the dimension is. + std::vector dimensionSizeKnown(m_type->GetDimensionCount(), false); + + std::vector elementCount(m_type->GetDimensionCount(), 0); + + // Flattened vector of elements. + std::vector elements; + elements.reserve(m_type->GetMaxElements()); + + while (!stack.empty()) + { + FF2_ASSERT(stack.back().second != NULL); + const ArrayLiteralExpression& current = *stack.back().second; + unsigned int currentDimension = stack.back().first; + unsigned int numChildren = static_cast(current.GetNumChildren()); + stack.pop_back(); + + // Check dimensions. + if (dimensionSizeKnown[currentDimension - 1]) + { + if (numChildren != dimensions[currentDimension - 1]) + { + std::ostringstream err; + err << "Array element "; + for (unsigned int i = 0; i < currentDimension; i++) + { + err << (i != 0 ? ", " : "") << elementCount[i]; + } + err << " was expected to be a literal array (as all " + << "non-leaf elements of a literal array must be), " + << "but was not"; + throw ParseError(err.str(), GetSourceLocation()); + } + } + else + { + dimensionSizeKnown[currentDimension - 1] = true; + dimensions[currentDimension - 1] = numChildren; + } + + if (currentDimension > 1) + { + // Check that it's an array literal. Note that we iterate + // through elements in reverse order, so that they come off the + // stack in the correct order. + const Expression* const* iter = current.End(); + while (iter != current.Begin()) + { + --iter; + + // We need to ensure that the child is a literal array, so + // that we can check it's square (and hence giving us + // assurance that our simple representation is valid). This + // might be better done with a type-indication returning + // member function, but does require dynamic type information, + // short of making ArrayLiteralExpression aware of the + // types of children it has (not a good idea). + const ArrayLiteralExpression* child + = dynamic_cast(*iter); + if (child != NULL) + { + stack.push_back(std::make_pair(currentDimension - 1, child)); + } + else + { + std::ostringstream err; + err << "Array element "; + for (unsigned int i = 0; i < currentDimension; i++) + { + err << (i != 0 ? ", " : "") << elementCount[i]; + } + err << " was expected to be a literal array (as all " + << "non-leaf elements of a literal array must be), " + << "but was not"; + } + } + } + else + { + // We're down to elements of the array, save them. + FF2_ASSERT(currentDimension == 1); + for (const Expression* const* iter = current.Begin(); + iter != current.End(); + ++iter) + { + elements.push_back(*iter); + } + } + + // Keep track of the number of elements processed, for decent error + // messages. + elementCount[currentDimension - 1]++; + } + + // Ensure all dimensions are set. + for (unsigned int i = 0; i < dimensions.size(); i++) + { + FF2_ASSERT(dimensionSizeKnown[i]); + } + + FF2_ASSERT(elements.size() == m_type->GetMaxElements()); + + const ArrayType* newType = m_type; + if (p_annotatedType != NULL) + { + FF2_ASSERT(p_typeManager != NULL); + const TypeImpl& type + = TypeUtil::Unify(*p_annotatedType, m_type->GetChildType(), *p_typeManager, false, false); + if (type != m_type->GetChildType()) + { + FF2_ASSERT(m_type->IsFixedSize()); + newType = &p_typeManager->GetArrayType(type, + m_type->IsConst(), + m_type->GetDimensionCount(), + m_type->GetDimensions(), + m_type->GetMaxElements()); + } + } + + boost::shared_ptr flat = Alloc(GetAnnotations(), *newType, elements, m_id); + p_owner.AddExpression(flat); + return *flat; +} + + +bool +FreeForm2::ArrayLiteralExpression::IsFlat() const +{ + return m_isFlat; +} + + +size_t +FreeForm2::ArrayLiteralExpression::GetNumChildren() const +{ + return m_numChildren; +} + + +const FreeForm2::TypeImpl& +FreeForm2::ArrayLiteralExpression::GetType() const +{ + return *m_type; +} + + +FreeForm2::VariableID +FreeForm2::ArrayLiteralExpression::GetId() const +{ + return m_id; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLiteralExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLiteralExpression.h new file mode 100644 index 000000000000..19686194cbde --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ArrayLiteralExpression.h @@ -0,0 +1,105 @@ +#pragma once + +#ifndef FREEFORM2_ARRAY_LITERAL_EXPRESSION_H +#define FREEFORM2_ARRAY_LITERAL_EXPRESSION_H + +#include "ArrayType.h" +#include "Expression.h" +#include + +// (array-literal [ ... ]) or (array-literal [ ...] type) + +namespace FreeForm2 +{ + class SimpleExpressionOwner; + class ProgramParseState; + class TypeManager; + class Visitor; + + // An array-literal expression generates an array literal. + class ArrayLiteralExpression : public Expression + { + public: + // Methods inherited from Expression. + size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual void AcceptReference(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + + // Custom allocator for this expression type. + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const TypeImpl& p_annotatedType, + const std::vector& p_children, + VariableID p_id, + TypeManager& p_typeManager); + + // Custom allocator for a flat array. + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const ArrayType& p_type, + const std::vector& p_children, + VariableID p_id); + + // Iterate over children. + const Expression* const* Begin() const; + const Expression* const* End() const; + + // Flatten array, throwing an exception for non-square arrays, and for + // arrays that contain non-literal arrays. If an annotated type is + // provided to the latter function, the annotated type will be unified + // with the child type of the array. In this case, both p_annotatedType + // and p_typeManager must be non-NULL. + const ArrayLiteralExpression& Flatten(SimpleExpressionOwner& p_owner) const; + const ArrayLiteralExpression& Flatten(SimpleExpressionOwner& p_owner, + const TypeImpl* p_annotatedType, + TypeManager* p_typeManager) const; + + // Return a flag indicating whether or not this array has been + // flattened. If this is true, all children of this expression are of + // non-array type. + bool IsFlat() const; + + // Gets the integer identificator for this array literal. + VariableID GetId() const; + + private: + // Construct an array literal with an annotated type. + ArrayLiteralExpression(const Annotations& p_annotations, + const TypeImpl& p_annotatedType, + const std::vector& p_children, + VariableID p_id, + TypeManager& p_typeManager); + + // Construct a flat array. + ArrayLiteralExpression(const Annotations& p_annotations, + const ArrayType& p_type, + const std::vector& p_elements, + VariableID p_id); + + // Unify child types into the final array type. + const ArrayType& UnifyTypes(const TypeImpl& p_annotatedType, TypeManager& p_typeManager); + + // The type of this array literal. + const ArrayType* m_type; + + // Custom deallocator for this expression type, suitable for + // use as a shared_ptr destructor. + static void DeleteAlloc(ArrayLiteralExpression* p_allocated); + + // Whether this array has been flattened. + bool m_isFlat; + + // Number of children of this array. + unsigned int m_numChildren; + + // An unique integer id for the current expression. + const VariableID m_id; + + // Array of children of this node, allocated using struct hack. + const Expression* m_children[1]; + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BinaryOperator.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BinaryOperator.cpp new file mode 100644 index 000000000000..9934382927e0 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BinaryOperator.cpp @@ -0,0 +1,130 @@ +#include "BinaryOperator.h" + +#include "FreeForm2Assert.h" +#include "FreeForm2Type.h" +#include "TypeImpl.h" +#include "TypeUtil.h" +#include + +static +std::vector +GetSupportedOperandTypes(FreeForm2::BinaryOperator::Operation p_op) +{ + using FreeForm2::BinaryOperator; + std::vector types; + types.reserve(3); + + switch (p_op) + { + // The following operators support only bool: + case BinaryOperator::_and: + case BinaryOperator::_or: + types.push_back(FreeForm2::Type::Bool); + break; + + // The following operators support bool, int and float: + case BinaryOperator::eq: + case BinaryOperator::neq: + types.push_back(FreeForm2::Type::Bool); + __attribute__((__fallthrough__)); + + // The following operators support int and float: + case BinaryOperator::lt: __attribute__((__fallthrough__)); + case BinaryOperator::lte: __attribute__((__fallthrough__)); + case BinaryOperator::gt: __attribute__((__fallthrough__)); + case BinaryOperator::gte: __attribute__((__fallthrough__)); + case BinaryOperator::max: __attribute__((__fallthrough__)); + case BinaryOperator::min: __attribute__((__fallthrough__)); + case BinaryOperator::plus: __attribute__((__fallthrough__)); + case BinaryOperator::minus: __attribute__((__fallthrough__)); + case BinaryOperator::multiply: __attribute__((__fallthrough__)); + case BinaryOperator::divides: __attribute__((__fallthrough__)); + case BinaryOperator::mod: __attribute__((__fallthrough__)); + case BinaryOperator::pow: __attribute__((__fallthrough__)); + case BinaryOperator::log: + types.push_back(FreeForm2::Type::Int); + types.push_back(FreeForm2::Type::UInt64); + types.push_back(FreeForm2::Type::Int32); + types.push_back(FreeForm2::Type::UInt32); + types.push_back(FreeForm2::Type::Float); + break; + + // Bit operations only support ints. + case BinaryOperator::_bitand: __attribute__((__fallthrough__)); + case BinaryOperator::_bitor: __attribute__((__fallthrough__)); + case BinaryOperator::bitshiftleft: __attribute__((__fallthrough__)); + case BinaryOperator::bitshiftright: + types.push_back(FreeForm2::Type::Int); + types.push_back(FreeForm2::Type::UInt64); + types.push_back(FreeForm2::Type::Int32); + types.push_back(FreeForm2::Type::UInt32); + break; + + default: + FreeForm2::Unreachable(__FILE__, __LINE__); + } + + return types; +} + + +const FreeForm2::TypeImpl& +FreeForm2::BinaryOperator::GetBestOperandType(Operation p_operator, + const TypeImpl& p_operandType) +{ + if (p_operandType.Primitive() == Type::Unknown) + { + return p_operandType; + } + + const std::vector types = GetSupportedOperandTypes(p_operator); + + if (std::find(types.begin(), types.end(), p_operandType.Primitive()) != types.end()) + { + return p_operandType; + } + else + { + return TypeImpl::GetInvalidType(); + } +} + + +const FreeForm2::TypeImpl& +FreeForm2::BinaryOperator::GetResultType(Operation p_operator, + const TypeImpl& p_operandType) +{ + switch (p_operator) + { + case BinaryOperator::max: __attribute__((__fallthrough__)); + case BinaryOperator::min: __attribute__((__fallthrough__)); + case BinaryOperator::_and: __attribute__((__fallthrough__)); + case BinaryOperator::_or: __attribute__((__fallthrough__)); + case BinaryOperator::plus: __attribute__((__fallthrough__)); + case BinaryOperator::minus: __attribute__((__fallthrough__)); + case BinaryOperator::multiply: __attribute__((__fallthrough__)); + case BinaryOperator::mod: __attribute__((__fallthrough__)); + case BinaryOperator::pow: __attribute__((__fallthrough__)); + case BinaryOperator::_bitand: __attribute__((__fallthrough__)); + case BinaryOperator::_bitor: __attribute__((__fallthrough__)); + case BinaryOperator::bitshiftleft: __attribute__((__fallthrough__)); + case BinaryOperator::bitshiftright: __attribute__((__fallthrough__)); + case BinaryOperator::divides: + return p_operandType; + + case BinaryOperator::log: + return TypeImpl::GetFloatInstance(true); + + case BinaryOperator::eq: __attribute__((__fallthrough__)); + case BinaryOperator::neq: __attribute__((__fallthrough__)); + case BinaryOperator::lt: __attribute__((__fallthrough__)); + case BinaryOperator::lte: __attribute__((__fallthrough__)); + case BinaryOperator::gt: __attribute__((__fallthrough__)); + case BinaryOperator::gte: + return TypeImpl::GetBoolInstance(true); + + default: + FreeForm2::Unreachable(__FILE__, __LINE__); + } +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BinaryOperator.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BinaryOperator.h new file mode 100644 index 000000000000..fd5965c4eaf2 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BinaryOperator.h @@ -0,0 +1,59 @@ +#pragma once + +#ifndef FREEFORM2_BINARY_OPERATOR_H +#define FREEFORM2_BINARY_OPERATOR_H + +#include "Expression.h" +#include "FreeForm2Type.h" + +namespace FreeForm2 +{ + class BinaryOperator + { + public: + enum Operation + { + plus, + minus, + multiply, + divides, + mod, + max, + min, + pow, + log, + + eq, + neq, + lt, + lte, + gt, + gte, + + _and, + _or, + + _bitand, + _bitor, + bitshiftleft, + bitshiftright, + + invalid + }; + + // Select the best operand type for an operator. Best is defined in + // terms of TypeUtil::SelectBestType. If no valid type is found, an + // invalid TypeImpl is returned. + static const TypeImpl& GetBestOperandType(Operation p_operator, + const TypeImpl& p_operandType); + + // Return the type of a binary operator result given an operator and + // an operand type. If the operand type is not a valid operand type for + // the operator, the return type is undefined. + static const TypeImpl& GetResultType(Operation p_operator, + const TypeImpl& p_operandType); + }; +} + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BlockExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BlockExpression.cpp new file mode 100644 index 000000000000..039f0ba9320a --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BlockExpression.cpp @@ -0,0 +1,104 @@ +#include "BlockExpression.h" + +#include "FreeForm2Assert.h" +#include "Visitor.h" + +boost::shared_ptr +FreeForm2::BlockExpression::Alloc(const Annotations& p_annotations, + const Expression** p_children, + unsigned int p_numChildren, + unsigned int p_numBound) +{ + FF2_ASSERT(p_numChildren > 0); + + size_t bytes = sizeof(BlockExpression) + (p_numChildren - 1) * sizeof(Expression*); + + // Allocate a shared_ptr that deletes an BlockExpression + // allocated in a char[]. + boost::shared_ptr exp; + exp.reset(new (new char[bytes]) BlockExpression(p_annotations, p_children, p_numChildren, p_numBound), + DeleteAlloc); + return exp; +} + + +const FreeForm2::TypeImpl& +FreeForm2::BlockExpression::GetType() const +{ + return *m_returnType; +} + + +size_t +FreeForm2::BlockExpression::GetNumChildren() const +{ + return m_numChildren; +} + + +unsigned int +FreeForm2::BlockExpression::GetNumBound() const +{ + return m_numBound; +} + + +void +FreeForm2::BlockExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (unsigned int i = 0; i < m_numChildren; i++) + { + m_children[i]->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::Expression& +FreeForm2::BlockExpression::GetChild(unsigned int p_index) const +{ + return *m_children[p_index]; +} + + +FreeForm2::BlockExpression::BlockExpression(const Annotations& p_annotations, + const Expression** p_children, + unsigned int p_numChildren, + unsigned int p_numBound) + : Expression(p_annotations), + m_numChildren(p_numChildren), + m_numBound(p_numBound), + m_returnType(NULL) +{ + FF2_ASSERT(m_numChildren > 0); + m_returnType = &p_children[p_numChildren - 1]->GetType().AsConstType(); + + // We rely on the custom allocator Alloc to provide enough space + // for all of the children. + for (unsigned int i = 0; i < m_numChildren; i++) + { + m_children[i] = p_children[i]; + } +} + + +void +FreeForm2::BlockExpression::DeleteAlloc(BlockExpression* p_allocated) +{ + // Manually call dtor for block expression. + p_allocated->~BlockExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BlockExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BlockExpression.h new file mode 100644 index 000000000000..63654621e2cd --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/BlockExpression.h @@ -0,0 +1,61 @@ +#pragma once + +#ifndef FREEFORM2_BLOCK_EXPRESSION_H +#define FREEFORM2_BLOCK_EXPRESSION_H + +#include +#include "Expression.h" + +namespace FreeForm2 +{ + // A block expression is a series of expressions. + class BlockExpression : public Expression + { + public: + // Allocate a block expression for the given array of child expressions, + // with p_numBound indicating the number of symbols bound by (and not + // scoped within) the immediate children of the block expression. + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const Expression** p_children, + unsigned int p_numChildren, + unsigned int p_numBound); + + // Return the number of symbols bound by immediate children of this + // block expression, and left open. + unsigned int GetNumBound() const; + + // Methods inherited from Expression. + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + + // Gets the p_index-th child of this expression. + const Expression& GetChild(unsigned int p_index) const; + + private: + // Private ctor: use Alloc to create block expressions. + BlockExpression(const Annotations& p_annotations, + const Expression** p_children, + unsigned int p_numChildren, + unsigned int p_numBound); + + // Custom deallocation method. + static void DeleteAlloc(BlockExpression* p_allocated); + + // The return type of the block. + const TypeImpl* m_returnType; + + // Number of children of this block. + unsigned int m_numChildren; + + // Number of symbols left bound by children of the block expression. + unsigned int m_numBound; + + // Children, allocated via struct hack. + const Expression* m_children[1]; + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/CMakeLists.txt b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/CMakeLists.txt new file mode 100644 index 000000000000..4422aa80c171 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/CMakeLists.txt @@ -0,0 +1,53 @@ +cmake_minimum_required(VERSION 3.15) + +set(PROJECT_NAME DRFreeFormExpressionLibrary) + +project(${PROJECT_NAME}) + +set(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS} -fpermissive") + +add_library(${PROJECT_NAME} STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/Allocation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ArrayDereferenceExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ArrayLength.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ArrayLiteralExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/BinaryOperator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/BlockExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Conditional.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ConvertExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/DebugExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Declaration.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Expression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Extern.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/FeatureSpec.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Function.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/LetExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/LiteralExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Match.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/MemberAccessExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Mutation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/OperatorExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/PhiNode.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Publish.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/RandExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/RangeReduceExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/RefExpression.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/SelectNth.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/StateMachine.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/StreamData.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/SymbolTable.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/TypeUtil.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/UnaryOperator.cpp +) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../../inc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../NeuralTree.Library/inc + ${CMAKE_CURRENT_SOURCE_DIR}/../Shared + ) + +install(TARGETS ${PROJECT_NAME} + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + ) \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Conditional.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Conditional.cpp new file mode 100644 index 000000000000..877abb89cf2f --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Conditional.cpp @@ -0,0 +1,103 @@ +#include "Conditional.h" + +#include "Expression.h" +#include "FreeForm2.h" +#include "FreeForm2Assert.h" +#include "SimpleExpressionOwner.h" +#include "Visitor.h" +#include + + +FreeForm2::ConditionalExpression::ConditionalExpression(const Annotations& p_annotations, + const Expression& p_condition, + const Expression& p_then, + const Expression& p_else) + : Expression(p_annotations), + m_condition(p_condition), + m_then(p_then), + m_else(p_else) +{ +} + + +const FreeForm2::Expression& +FreeForm2::ConditionalExpression::GetCondition() const +{ + return m_condition; +} + + +const FreeForm2::Expression& +FreeForm2::ConditionalExpression::GetThen() const +{ + return m_then; +} + + +const FreeForm2::Expression& +FreeForm2::ConditionalExpression::GetElse() const +{ + return m_else; +} + + +void +FreeForm2::ConditionalExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_else.Accept(p_visitor); + m_then.Accept(p_visitor); + m_condition.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::ConditionalExpression::GetType() const +{ + if (m_condition.GetType().Primitive() != Type::Bool) + { + std::ostringstream err; + err << "Condition of type '" + << m_condition.GetType() + << "' supplied to if expression as condition " + << "(expected boolean)"; + throw ParseError(err.str(), GetSourceLocation()); + } + + if (!m_then.GetType().IsSameAs(m_else.GetType(), true)) + { + if (m_then.GetType().IsIntegerType() + && m_else.GetType().IsIntegerType()) + { + return TypeImpl::GetIntInstance(true); + } + else + { + std::ostringstream err; + err << "'then' (supplied '" + << m_then.GetType() + << "' and 'else' (supplied '" + << m_else.GetType() + << "' clauses of condition must have matching types."; + throw ParseError(err.str(), GetSourceLocation()); + } + } + + return m_then.GetType().AsConstType(); +} + + +size_t +FreeForm2::ConditionalExpression::GetNumChildren() const +{ + return 3; +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Conditional.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Conditional.h new file mode 100644 index 000000000000..14f0eda8acc3 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Conditional.h @@ -0,0 +1,43 @@ +#pragma once + +#ifndef FREEFORM2_CONDITIONAL_H +#define FREEFORM2_CONDITIONAL_H + +#include "Expression.h" + +namespace FreeForm2 +{ + class ConditionalExpression : public Expression + { + public: + ConditionalExpression(const Annotations& p_annotations, + const Expression& p_condition, + const Expression& p_then, + const Expression& p_else); + + virtual void Accept(Visitor& p_visitor) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + // Gets the condition expression for this conditional. + const Expression& GetCondition() const; + + // Gets the then expression for this conditional. + const Expression& GetThen() const; + + // Gets the else expression for this conditional. + const Expression& GetElse() const; + + private: + // Condition used to choose between then/else. + const Expression& m_condition; + + // Value if condition is true. + const Expression& m_then; + + // Value if condition is false. + const Expression& m_else; + }; +}; + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ConvertExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ConvertExpression.cpp new file mode 100644 index 000000000000..14f7c3d2e903 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ConvertExpression.cpp @@ -0,0 +1,239 @@ +#include "ConvertExpression.h" + +#include "FreeForm2Assert.h" +#include "SimpleExpressionOwner.h" +#include "Visitor.h" +#include "TypeUtil.h" +#include +#include + +static +void +CheckType(const FreeForm2::TypeImpl& p_type, + const FreeForm2::TypeImpl& p_child, + const FreeForm2::SourceLocation& p_sourceLocation) +{ + if (!FreeForm2::TypeUtil::IsConvertible(p_child, p_type) + && p_child.Primitive() != FreeForm2::Type::Unknown) + { + std::ostringstream err; + err << "Expression type " << p_child + << " cannot be converted to type " << p_type; + throw FreeForm2::ParseError(err.str(), p_sourceLocation); + } +} + + +FreeForm2::ConversionExpression::ConversionExpression(const Annotations& p_annotations, + const Expression& p_child) + : Expression(p_annotations), + m_child(p_child) +{ +} + + +FreeForm2::ConversionExpression::~ConversionExpression() +{ +} + + +size_t +FreeForm2::ConversionExpression::GetNumChildren() const +{ + return 1; +} + + +const FreeForm2::TypeImpl& +FreeForm2::ConversionExpression::GetChildType() const +{ + return m_child.GetType(); +} + + +const FreeForm2::Expression& +FreeForm2::ConversionExpression::GetChild() const +{ + return m_child; +} + + +template +void FreeForm2::ConversionExpression::AcceptDerived(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(static_cast(*this))) + { + m_child.Accept(p_visitor); + + p_visitor.Visit(static_cast(*this)); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +FreeForm2::ConvertToFloatExpression::ConvertToFloatExpression(const Annotations& p_annotations, + const Expression& p_child) + : ConversionExpression(p_annotations, p_child) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::ConvertToFloatExpression::GetType() const +{ + const TypeImpl& type = TypeImpl::GetFloatInstance(true); + CheckType(type, GetChildType(), GetSourceLocation()); + + return type; +} + + +void +FreeForm2::ConvertToFloatExpression::Accept(Visitor& p_visitor) const +{ + AcceptDerived(p_visitor); +} + + +FreeForm2::ConvertToIntExpression::ConvertToIntExpression(const Annotations& p_annotations, + const Expression& p_child) + : ConversionExpression(p_annotations, p_child) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::ConvertToIntExpression::GetType() const +{ + const TypeImpl& type = TypeImpl::GetIntInstance(true); + CheckType(type, GetChildType(), GetSourceLocation()); + + return type; +} + + +void +FreeForm2::ConvertToIntExpression::Accept(Visitor& p_visitor) const +{ + AcceptDerived(p_visitor); +} + + +FreeForm2::ConvertToUInt64Expression::ConvertToUInt64Expression(const Annotations& p_annotations, + const Expression& p_child) + : ConversionExpression(p_annotations, p_child) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::ConvertToUInt64Expression::GetType() const +{ + const TypeImpl& type = TypeImpl::GetUInt64Instance(true); + CheckType(type, GetChildType(), GetSourceLocation()); + + return type; +} + + +void +FreeForm2::ConvertToUInt64Expression::Accept(Visitor& p_visitor) const +{ + AcceptDerived(p_visitor); +} + + +FreeForm2::ConvertToInt32Expression::ConvertToInt32Expression(const Annotations& p_annotations, + const Expression& p_child) + : ConversionExpression(p_annotations, p_child) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::ConvertToInt32Expression::GetType() const +{ + const TypeImpl& type = TypeImpl::GetInt32Instance(true); + CheckType(type, GetChildType(), GetSourceLocation()); + + return type; +} + + +void +FreeForm2::ConvertToInt32Expression::Accept(Visitor& p_visitor) const +{ + AcceptDerived(p_visitor); +} + + +FreeForm2::ConvertToUInt32Expression::ConvertToUInt32Expression(const Annotations& p_annotations, + const Expression& p_child) + : ConversionExpression(p_annotations, p_child) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::ConvertToUInt32Expression::GetType() const +{ + const TypeImpl& type = TypeImpl::GetUInt32Instance(true); + CheckType(type, GetChildType(), GetSourceLocation()); + + return type; +} + + +void +FreeForm2::ConvertToUInt32Expression::Accept(Visitor& p_visitor) const +{ + AcceptDerived(p_visitor); +} + + +FreeForm2::ConvertToBoolExpression::ConvertToBoolExpression(const Annotations& p_annotations, + const Expression& p_child) + : ConversionExpression(p_annotations, p_child) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::ConvertToBoolExpression::GetType() const +{ + const TypeImpl& type = TypeImpl::GetBoolInstance(true); + CheckType(type, GetChildType(), GetSourceLocation()); + + return type; +} + + +void +FreeForm2::ConvertToBoolExpression::Accept(Visitor& p_visitor) const +{ + AcceptDerived(p_visitor); +} + + +FreeForm2::ConvertToImperativeExpression::ConvertToImperativeExpression(const Annotations& p_annotations, + const Expression& p_child) + : ConversionExpression(p_annotations, p_child) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::ConvertToImperativeExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +void +FreeForm2::ConvertToImperativeExpression::Accept(Visitor& p_visitor) const +{ + AcceptDerived(p_visitor); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ConvertExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ConvertExpression.h new file mode 100644 index 000000000000..4d46ccdc508a --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/ConvertExpression.h @@ -0,0 +1,137 @@ +#pragma once + +#ifndef FREEFORM2_CONVERTEXPRESSION_H +#define FREEFORM2_CONVERTEXPRESSION_H + +#include "Expression.h" + +namespace FreeForm2 +{ + // Base class to represent a type conversion. + class ConversionExpression : public Expression + { + public: + // Create a conversion expression, converting the expression given + // expression. The derived class must specify the conversion type with + // the GetType method. + ConversionExpression(const Annotations& p_annotations, + const Expression& p_child); + virtual ~ConversionExpression(); + + virtual size_t GetNumChildren() const override; + + const TypeImpl& GetChildType() const; + + const Expression& GetChild() const; + + protected: + // Accept a visitor for a derived class. + template + void AcceptDerived(Visitor& p_visitor) const; + + private: + // Child expression to convert. + const Expression& m_child; + }; + + // Class to convert to a float. + class ConvertToFloatExpression : public ConversionExpression + { + public: + // Create a conversion expression, taking the expression to convert + // to float. + ConvertToFloatExpression(const Annotations& p_annotations, + const Expression& p_child); + + virtual const TypeImpl& GetType() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + // Class to convert to an integer. + class ConvertToIntExpression : public ConversionExpression + { + public: + // Create a conversion expression, taking the expression to convert + // to int. Any floating point data is truncated. + ConvertToIntExpression(const Annotations& p_annotations, + const Expression& p_child); + + virtual const TypeImpl& GetType() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + // Class to convert to an uint64. + class ConvertToUInt64Expression : public ConversionExpression + { + public: + // Create a conversion expression, taking the expression to convert + // to uint64. Any floating point data is truncated. + ConvertToUInt64Expression(const Annotations& p_annotations, + const Expression& p_child); + + virtual const TypeImpl& GetType() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + // Class to convert to an int32. + class ConvertToInt32Expression : public ConversionExpression + { + public: + // Create a conversion expression, taking the expression to convert + // to int32. Any data is truncated. + ConvertToInt32Expression(const Annotations& p_annotations, + const Expression& p_child); + + virtual const TypeImpl& GetType() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + // Class to convert to an uint32. + class ConvertToUInt32Expression : public ConversionExpression + { + public: + // Create a conversion expression, taking the expression to convert + // to int32. Any data is truncated. + ConvertToUInt32Expression(const Annotations& p_annotations, + const Expression& p_child); + + virtual const TypeImpl& GetType() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + // Class to convert to a boolean. + class ConvertToBoolExpression : public ConversionExpression + { + public: + // Create a conversion expression, taking the expression to convert + // to bool. + ConvertToBoolExpression(const Annotations& p_annotations, + const Expression& p_child); + + virtual const TypeImpl& GetType() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + // Class to issue code for a given expression, then discard the resulting + // value and return void. + class ConvertToImperativeExpression : public ConversionExpression + { + public: + // Create a conversion expression, taking the expression to convert + // to imperative. + ConvertToImperativeExpression(const Annotations& p_annotations, + const Expression& p_child); + + virtual const TypeImpl& GetType() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/DebugExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/DebugExpression.cpp new file mode 100644 index 000000000000..b40fa0e44ab5 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/DebugExpression.cpp @@ -0,0 +1,78 @@ +#include "DebugExpression.h" + +#include "ArrayType.h" +#include "FreeForm2Assert.h" +#include +#include "TypeImpl.h" +#include "Visitor.h" + +using namespace FreeForm2; + +DebugExpression::DebugExpression(const Annotations& p_annotations, + const Expression& p_child, + const std::string& p_childText) + : Expression(p_annotations), + m_child(p_child), + m_childText(p_childText) +{ +} + + +void +DebugExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_child.Accept(p_visitor); + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const TypeImpl& +DebugExpression::GetType() const +{ + const TypeImpl& childType = m_child.GetType(); + const TypeImpl* checkType = &childType; + if (childType.Primitive() == Type::Array) + { + const ArrayType& type = static_cast(childType); + checkType = &type.GetChildType(); + } + + if (!checkType->IsLeafType()) + { + std::ostringstream err; + err << "Cannot debug the expression " << m_childText + << " of type " << *checkType + << ". Only arrays and primitive types are supported."; + throw std::runtime_error(err.str()); + } + + return TypeImpl::GetVoidInstance(); +} + + +size_t +DebugExpression::GetNumChildren() const +{ + return 1; +} + + +const Expression& +DebugExpression::GetChild() const +{ + return m_child; +} + + +const std::string& +DebugExpression::GetChildText() const +{ + return m_childText; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/DebugExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/DebugExpression.h new file mode 100644 index 000000000000..6a5bee520bae --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/DebugExpression.h @@ -0,0 +1,47 @@ +#pragma once + +#ifndef FREEFORM2_DEBUG_EXPRESSION_H +#define FREEFORM2_DEBUG_EXPRESSION_H + +#include "Expression.h" +#include + +namespace FreeForm2 +{ + // A DebugExpression assists developers by allowing expressions to be + // debugged. The exact method of debugging depends on other compiler + // settings, but generally debug instrumentation should provide the + // original text of the expression being debugged, and the value of that + // expression. + class DebugExpression : public Expression + { + public: + // Construct a DebugExpression which will provide debugging + // capabilities for an expression. The child text refers to the + // original text of the expression to debug. + DebugExpression(const Annotations& p_annotations, + const Expression& p_child, + const std::string& p_childText); + + // Methods inherited from Expression. + virtual void Accept(Visitor&) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Get the child expression. + const Expression& GetChild() const; + + // Get the text of the child expression to be printed for debugging + // purposes. + const std::string& GetChildText() const; + + private: + // The child expression. + const Expression& m_child; + + // The original text of the child expression. + std::string m_childText; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Declaration.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Declaration.cpp new file mode 100644 index 000000000000..0201da3b300b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Declaration.cpp @@ -0,0 +1,104 @@ +#include "Declaration.h" + +#include "ArrayType.h" +#include "FreeForm2Assert.h" +#include +//#include +#include "TypeUtil.h" +#include "Visitor.h" + +FreeForm2::DeclarationExpression::DeclarationExpression(const Annotations& p_annotations, + const TypeImpl& p_type, + const Expression& p_init, + bool p_voidValue, + VariableID p_id, + size_t p_version) + : Expression(p_annotations), + m_declType(p_type), + m_init(p_init), + m_voidValue(p_voidValue), + m_id(p_id), + m_version(p_version) +{ +} + + +void +FreeForm2::DeclarationExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_init.Accept(p_visitor); + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::DeclarationExpression::GetType() const +{ + if (m_declType.Primitive() != Type::Unknown && !TypeUtil::IsAssignable(m_declType, m_init.GetType())) + { + std::ostringstream err; + err << "Declaration initializer (of type " + << m_init.GetType() + << ") does not match declared type of variable (" << m_declType << ")"; + throw ParseError(err.str(), GetSourceLocation()); + } + + if (m_voidValue) + { + return TypeImpl::GetVoidInstance(); + } + else + { + FF2_ASSERT(m_init.GetType().Primitive() != Type::Unknown); + FF2_ASSERT(m_init.GetType().Primitive() != Type::Void); + return m_init.GetType().AsConstType(); + } +} + + +size_t +FreeForm2::DeclarationExpression::GetNumChildren() const +{ + return 1; +} + + +const FreeForm2::Expression& +FreeForm2::DeclarationExpression::GetInit() const +{ + return m_init; +} + + +bool FreeForm2::DeclarationExpression::HasVoidValue() const +{ + return m_voidValue; +} + + +const FreeForm2::TypeImpl& +FreeForm2::DeclarationExpression::GetDeclaredType() const +{ + return m_declType; +} + + +FreeForm2::VariableID +FreeForm2::DeclarationExpression::GetId() const +{ + return m_id; +} + + +size_t +FreeForm2::DeclarationExpression::GetVersion() const +{ + return m_version; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Declaration.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Declaration.h new file mode 100644 index 000000000000..5b1a7abb1a56 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Declaration.h @@ -0,0 +1,64 @@ +#pragma once + +#ifndef FREEFORM2_DECLARATION_H +#define FREEFORM2_DECLARATION_H + +#include "Expression.h" + +namespace FreeForm2 +{ + class DeclarationExpression : public Expression + { + public: + // Create a declaration expression from a type (which may be + // Type::Unknown) and an initialiser. p_voidValue controls whether + // the DeclarationExpression evaluates to a void value (imperatively), + // or to the p_init expression (functionally). + DeclarationExpression(const Annotations& p_annotations, + const TypeImpl& p_type, + const Expression& p_init, + bool p_voidValue, + VariableID p_id, + size_t p_version); + + // Virtual methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + // Whether this expression evaluates to void, or the value m_init; + bool HasVoidValue() const; + + // Return the type of the variable declared (not the type of this + // declaration expression, which is significantly different). + const TypeImpl& GetDeclaredType() const; + + // Return the initialization expression. + const Expression& GetInit() const; + + // Gets this declaration's unique identifier and value version. + VariableID GetId() const; + size_t GetVersion() const; + + private: + + // Type of the variable declared by this expression. (Note: this is + // *not* the type of the declaration expression); + const TypeImpl& m_declType; + + // Initialisation expression. + const Expression& m_init; + + // Whether this expression evaluates to void, or the value m_init; + bool m_voidValue; + + // A unique identificator to allow separation of allocation and usage. + const VariableID m_id; + + // A unique version number associated with a particular + // value for this variable. + const size_t m_version; + }; +}; + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Expression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Expression.cpp new file mode 100644 index 000000000000..01629cde4c28 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Expression.cpp @@ -0,0 +1,233 @@ +#include "Expression.h" + +#include +#include "FreeForm2Assert.h" +#include +#include + +using namespace FreeForm2; + +const FreeForm2::VariableID FreeForm2::VariableID::c_invalidID = { MAX_UINT32 }; + +bool +FreeForm2::VariableID::operator==(VariableID p_other) const +{ + return m_value == p_other.m_value; +} + + +bool +FreeForm2::VariableID::operator!=(VariableID p_other) const +{ + return m_value != p_other.m_value; +} + + +bool +FreeForm2::VariableID::operator<(VariableID p_other) const +{ + return m_value < p_other.m_value; +} + + +FreeForm2::Result::IntType +FreeForm2::ConstantValue::GetInt(const TypeImpl& p_type) const +{ + switch (p_type.Primitive()) + { + case Type::Int: return m_int; + case Type::UInt64: return m_uint64; + case Type::Int32: return m_int32; + case Type::UInt32: return m_uint32; + default: Unreachable(__FILE__, __LINE__); + } +} + + +FreeForm2::Expression::Expression(const Annotations& p_annotations) + : m_annotations(p_annotations) +{ +} + + +FreeForm2::Expression::~Expression() +{ +} + + +void +FreeForm2::Expression::AcceptReference(Visitor&) const +{ + throw ParseError("Invalid l-value", GetSourceLocation()); +} + + +bool +FreeForm2::Expression::IsConstant() const +{ + return false; +} + + +FreeForm2::ConstantValue +FreeForm2::Expression::GetConstantValue() const +{ + // The following assertion should fail, unless a child class implemented + // IsConstant without overriding GetConstantValue, in which case the + // second assertion will trip. + FF2_ASSERT(IsConstant()); + FF2_ASSERT(false && "Expression-type class must override both IsConstant and GetConstantValue"); + Unreachable(__FILE__, __LINE__); +} + + +const FreeForm2::ValueBounds& +FreeForm2::Expression::GetValueBounds() const +{ + return m_annotations.m_valueBounds; +} + + +const FreeForm2::SourceLocation& +FreeForm2::Expression::GetSourceLocation() const +{ + return m_annotations.m_sourceLocation; +} + + +const FreeForm2::Annotations& +FreeForm2::Expression::GetAnnotations() const +{ + return m_annotations; +} + + +FreeForm2::ExpressionOwner::~ExpressionOwner() +{ +} + + +FreeForm2::ValueBounds::ValueBounds() + : m_lower(std::numeric_limits::min()), + m_upper(std::numeric_limits::max()) +{ +} + + +FreeForm2::ValueBounds::ValueBounds(const TypeImpl& p_type) +{ + switch (p_type.Primitive()) + { + case Type::Int32: + { + m_lower = std::numeric_limits::min(); + m_upper = std::numeric_limits::max(); + break; + } + + case Type::UInt32: + { + m_lower = std::numeric_limits::min(); + m_upper = std::numeric_limits::max(); + break; + } + + default: + { + m_lower = std::numeric_limits::min(); + m_upper = std::numeric_limits::max(); + break; + } + } +} + + +FreeForm2::ValueBounds::ValueBounds(const TypeImpl& p_type, ConstantValue p_value) +{ + switch (p_type.Primitive()) + { + case Type::Int32: + { + m_lower = m_upper = p_value.m_int32; + break; + } + + case Type::UInt32: + { + m_lower = m_upper = p_value.m_uint32; + break; + } + + case Type::Int: + { + m_lower = m_upper = p_value.m_int; + break; + } + + case Type::UInt64: + { + if (p_value.m_uint64 + <= static_cast(std::numeric_limits::max())) + { + m_lower = m_upper = static_cast(p_value.m_uint64); + } + else + { + m_lower = std::numeric_limits::min(); + m_upper = std::numeric_limits::max(); + } + break; + } + + default: + { + m_lower = std::numeric_limits::min(); + m_upper = std::numeric_limits::max(); + break; + } + } +} + + +FreeForm2::ValueBounds::ValueBounds(FreeForm2::Result::IntType p_lower, FreeForm2::Result::IntType p_upper) + : m_lower(p_lower), + m_upper(p_upper) +{ +} + + +bool +FreeForm2::ValueBounds::operator==(const FreeForm2::ValueBounds& p_other) const +{ + return m_lower == p_other.m_lower && m_upper == p_other.m_upper; +} + + +bool +FreeForm2::ValueBounds::operator!=(const FreeForm2::ValueBounds& p_other) const +{ + return m_lower != p_other.m_lower || m_upper != p_other.m_upper; +} + + +const FreeForm2::ValueBounds FreeForm2::ValueBounds::c_empty( + std::numeric_limits::max(), + std::numeric_limits::min()); + + +FreeForm2::Annotations::Annotations() +{ +} + + +FreeForm2::Annotations::Annotations(FreeForm2::SourceLocation p_sourceLocation) + : m_sourceLocation(p_sourceLocation) +{ +} + + +FreeForm2::Annotations::Annotations(FreeForm2::SourceLocation p_sourceLocation, FreeForm2::ValueBounds p_valueBounds) + : m_sourceLocation(p_sourceLocation), + m_valueBounds(p_valueBounds) +{ +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Expression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Expression.h new file mode 100644 index 000000000000..916ada6f1bcf --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Expression.h @@ -0,0 +1,143 @@ +#pragma once + +#ifndef FREEFORM2_EXPRESSION_H +#define FREEFORM2_EXPRESSION_H + +#include +#include +#include "FreeForm2.h" +#include "FreeForm2Executable.h" +#include "FreeForm2Result.h" +#include +#include "TypeImpl.h" + +namespace DynamicRank +{ + class IFeatureMap; +} + +namespace FreeForm2 +{ + class Visitor; + + // This structure represents a strongly-typed identifier for variable + // bindings. + struct VariableID + { + unsigned int m_value; + + // Comparison operators for IDs. + bool operator==(VariableID p_other) const; + bool operator!=(VariableID p_other) const; + bool operator<(VariableID p_other) const; + + static const VariableID c_invalidID; + }; + + + // Stream insertion operator for VariableIDs. + template + std::basic_ostream& + operator<<(std::basic_ostream& p_out, VariableID p_id) + { + return p_out << p_id.m_value; + } + + + // Class that owns expressions generated by an Expression::Parse. + class ExpressionOwner : boost::noncopyable + { + public: + virtual ~ExpressionOwner() = 0; + }; + + + // A struct that holds the static range of valid values for an expression. + struct ValueBounds + { + ValueBounds(); + ValueBounds(const TypeImpl& p_type); + ValueBounds(const TypeImpl& p_type, ConstantValue p_value); + ValueBounds(Result::IntType p_lower, Result::IntType p_upper); + + bool operator==(const ValueBounds& p_other) const; + bool operator!=(const ValueBounds& p_other) const; + + Result::IntType m_lower; + Result::IntType m_upper; + + // An empty range, defined in such a way that the Union and Intersection operators + // are well defined. + static const ValueBounds c_empty; + }; + + + // A struct that holds all the annotations for an expression node. + struct Annotations + { + Annotations(); + Annotations(SourceLocation p_sourceLocation); + Annotations(SourceLocation p_sourceLocation, ValueBounds p_valueBounds); + + ValueBounds m_valueBounds; + SourceLocation m_sourceLocation; + }; + + + // Expression class, that can be used to evaluate free form2 expressions. + class Expression : boost::noncopyable + { + public: + typedef Executable::FeatureType FeatureType; + + typedef Executable::InputType InputType; + + explicit Expression(const Annotations& p_annotations); + + virtual ~Expression(); + + // Invoke the visitor pattern over the expression tree. + virtual void Accept(Visitor&) const = 0; + + // Invoke the visitor pattern over the expression tree, with the added + // caveat that the visitor is expected to produce a reference instead of + // a value. + virtual void AcceptReference(Visitor&) const; + + // Returns the type of this expression. + virtual const TypeImpl& GetType() const = 0; + + // Returns the number of children of this expression. + virtual size_t GetNumChildren() const = 0; + + // Returns a flag indicating whether or not this expression is a + // compile-time constant. + virtual bool IsConstant() const; + + // Retrieve the value of this expression as a compile-time constant. + // The member of Literal used is determined by the type of the + // Expression. If this expression is not a compile-time constant, this + // method throws an exception. + virtual ConstantValue GetConstantValue() const; + + // Gets the annotations for the current node. + const Annotations& GetAnnotations() const; + + // Retrieve the compile-time inferred interval of possible values for this + // expression. This information is calculated from the type of the expression, + // constants, binary operations and foreach loops. Since this is a compile-time + // assertion, many expressions won't be able to correctly determine the bounds + // and instead will return a default value of + // (std::numeric_limits::min(), std::numeric_limits::max()). + virtual const ValueBounds& GetValueBounds() const; + + // Retrieve the source code location for this particular expression. + virtual const SourceLocation& GetSourceLocation() const; + + protected: + // The annotations (assertions and source location) for this expression node. + const Annotations m_annotations; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Extern.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Extern.cpp new file mode 100644 index 000000000000..01cd30b8f85f --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Extern.cpp @@ -0,0 +1,213 @@ +#include "Extern.h" + +#include "ArrayType.h" +#include +#include "FreeForm2Assert.h" +#include "FreeForm2ExternalData.h" +#include "FreeForm2Result.h" +#include "ObjectType.h" +#include +#include "TypeManager.h" +#include "TypeImpl.h" +#include "Visitor.h" + +namespace +{ + // This class acts as an ExternalData implementor for external objects. + // This implementation relies on the backend to know how to resolve the + // objects, and not to call the user-defined resolver for objects. + class ExternalObject : public FreeForm2::ExternalData + { + public: + ExternalObject(const char* p_name) + : ExternalData(p_name, + *FreeForm2::TypeManager::GetGlobalTypeManager().GetTypeInfo(p_name)) + { + } + + virtual ~ExternalObject() + { + } + }; +} + +// This struct holds information for resolving built-in objects. +struct FreeForm2::ExternExpression::BuiltInObject +{ + BuiltInObject(const char* p_name) : m_object(p_name) + { + }; + + ExternalObject m_object; +}; + +namespace +{ + typedef FreeForm2::ExternExpression::BuiltInObject BuiltInObject; + static const BuiltInObject c_numberOfTuplesCommonImpl("NumberOfTuplesCommon"); + static const BuiltInObject c_numberOfTuplesCommonNoDuplicateImpl("NumberOfTuplesCommonNoDuplicate"); + static const BuiltInObject c_numberOfTuplesInTriplesCommonImpl("NumberOfTuplesInTriplesCommon"); + static const BuiltInObject c_alterationAndTermWeightImpl("AlterationAndTermWeightingCalculator"); + static const BuiltInObject c_alterationWeightImpl("AlterationWeightingCalculator"); + static const BuiltInObject c_trueNearDoubleQueueImpl("TrueNearDoubleQueue"); + static const BuiltInObject c_boundedQueueImpl("BoundedQueue"); +} + +const FreeForm2::ExternExpression::BuiltInObject& +FreeForm2::ExternExpression::c_numberOfTuplesCommonObject = c_numberOfTuplesCommonImpl; +const FreeForm2::ExternExpression::BuiltInObject& +FreeForm2::ExternExpression::c_numberOfTuplesCommonNoDuplicateObject = c_numberOfTuplesCommonNoDuplicateImpl; +const FreeForm2::ExternExpression::BuiltInObject& +FreeForm2::ExternExpression::c_numberOfTuplesInTriplesCommonObject = c_numberOfTuplesInTriplesCommonImpl; +const FreeForm2::ExternExpression::BuiltInObject& +FreeForm2::ExternExpression::c_alterationAndTermWeightObject = c_alterationAndTermWeightImpl; +const FreeForm2::ExternExpression::BuiltInObject& +FreeForm2::ExternExpression::c_alterationWeightObject = c_alterationWeightImpl; +const FreeForm2::ExternExpression::BuiltInObject& +FreeForm2::ExternExpression::c_trueNearDoubleQueueObject = c_trueNearDoubleQueueImpl; +const FreeForm2::ExternExpression::BuiltInObject& +FreeForm2::ExternExpression::c_boundedQueueObject = c_boundedQueueImpl; + +const FreeForm2::ExternExpression::BuiltInObject* +FreeForm2::ExternExpression::GetObjectByName(const std::string& p_name) +{ + static const FreeForm2::ExternExpression::BuiltInObject* const c_objects[] + = { &c_numberOfTuplesCommonObject, + &c_numberOfTuplesCommonNoDuplicateObject, + &c_numberOfTuplesInTriplesCommonObject, + &c_alterationAndTermWeightObject, + &c_alterationWeightObject, + &c_trueNearDoubleQueueObject, + &c_boundedQueueObject }; + + BOOST_FOREACH(const FreeForm2::ExternExpression::BuiltInObject* const obj, c_objects) + { + + if (p_name == obj->m_object.GetName()) + { + return obj; + } + } + return nullptr; +} + + +const FreeForm2::ExternalData& +FreeForm2::ExternExpression::GetObjectData(const BuiltInObject& p_object) +{ + return p_object.m_object; +} + + +FreeForm2::ExternExpression::ExternExpression( + const Annotations& p_annotations, + const ExternalData& p_data, + const TypeImpl& p_declaredType, + VariableID p_id, + TypeManager& p_typeManager) + : Expression(Annotations(p_annotations.m_sourceLocation, + p_data.IsCompileTimeConstant() + ? ValueBounds(p_data.GetType(), p_data.GetCompileTimeValue()) + : ValueBounds(p_data.GetType()))), + m_data(p_data), + m_id(p_id) +{ + if (m_data.GetType() != p_declaredType) + { + std::ostringstream err; + err << "Incorrect type for external data member " + << m_data.GetName() << ". Expected type " + << m_data.GetType() << "; found type " + << p_declaredType; + throw ParseError(err.str(), GetSourceLocation()); + } +} + + +FreeForm2::ExternExpression::ExternExpression( + const Annotations& p_annotations, + const BuiltInObject& p_object, + VariableID p_id, + TypeManager& p_typeManager) + : Expression(Annotations(p_annotations.m_sourceLocation, + p_object.m_object.IsCompileTimeConstant() + ? ValueBounds(p_object.m_object.GetType(), p_object.m_object.GetCompileTimeValue()) + : ValueBounds(p_object.m_object.GetType()))), + m_data(p_object.m_object), + m_id(p_id) +{ +} + + +FreeForm2::ExternExpression::ExternExpression(const Annotations& p_annotations, + const ExternalData& p_data, + const TypeImpl& p_type) + : Expression(Annotations(p_annotations.m_sourceLocation, + p_data.IsCompileTimeConstant() + ? ValueBounds(p_data.GetType(), p_data.GetCompileTimeValue()) + : ValueBounds(p_data.GetType()))), + m_data(p_data), + m_id(VariableID::c_invalidID) +{ + if (m_data.GetType() != p_type) + { + std::ostringstream err; + err << "Incorrect type for external data member " + << m_data.GetName() << ". Expected type " + << m_data.GetType() << "; found type " + << p_type; + throw ParseError(err.str(), GetSourceLocation()); + } +} + + +void +FreeForm2::ExternExpression::Accept(Visitor& p_visitor) const +{ + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } +} + + +size_t +FreeForm2::ExternExpression::GetNumChildren() const +{ + return 0; +} + + +const FreeForm2::TypeImpl& +FreeForm2::ExternExpression::GetType() const +{ + return m_data.GetType(); +} + + +bool +FreeForm2::ExternExpression::IsConstant() const +{ + return m_data.IsCompileTimeConstant(); +} + + +FreeForm2::ConstantValue +FreeForm2::ExternExpression::GetConstantValue() const +{ + return m_data.GetCompileTimeValue(); +} + + +const FreeForm2::ExternalData& +FreeForm2::ExternExpression::GetData() const +{ + return m_data; +} + + +FreeForm2::VariableID +FreeForm2::ExternExpression::GetId() const +{ + return m_id; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Extern.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Extern.h new file mode 100644 index 000000000000..3bef7923eccb --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Extern.h @@ -0,0 +1,80 @@ +#pragma once + +#ifndef FREEFORM2_EXTERN_H +#define FREEFORM2_EXTERN_H + +#include "Expression.h" +#include + +namespace FreeForm2 +{ + class TypeManager; + class ExternalData; + + // Extern expressions represent an external data member declared in the + // program. External data is associated with a member of the + // ExternalData::DataType enum. + class ExternExpression : public Expression + { + public: + // Opaque structure for resolving external data objects. These objects + // are not currently supported by the API and must be built into the + // compiler. + struct BuiltInObject; + static const BuiltInObject& c_numberOfTuplesCommonObject; + static const BuiltInObject& c_numberOfTuplesCommonNoDuplicateObject; + static const BuiltInObject& c_numberOfTuplesInTriplesCommonObject; + static const BuiltInObject& c_alterationAndTermWeightObject; + static const BuiltInObject& c_alterationWeightObject; + static const BuiltInObject& c_trueNearDoubleQueueObject; + static const BuiltInObject& c_boundedQueueObject; + + // Find a BuiltInObject by name. This function returns nullptr if not + // object exists for the name. + static const BuiltInObject* GetObjectByName(const std::string& p_name); + static const ExternalData& GetObjectData(const BuiltInObject& p_object); + + // Create an external data reference for the specified piece of data + // of a declared type. + ExternExpression(const Annotations& p_annotations, + const ExternalData& p_data, + const TypeImpl& p_declaredType, + VariableID p_id, + TypeManager& p_typeManager); + + // Create an external data reference for a basic type external data + // member. + ExternExpression(const Annotations& p_annotations, + const ExternalData& p_data, + const TypeImpl& p_declaredType); + + // Create an extern expression for an object. + ExternExpression(const Annotations& p_annotations, + const BuiltInObject& p_object, + VariableID p_id, + TypeManager& p_typeManager); + + // Methods inherited from Expression. + virtual void Accept(Visitor&) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + virtual bool IsConstant() const override; + virtual ConstantValue GetConstantValue() const override; + + // Return the data enum entry associated with this extern. + const ExternalData& GetData() const; + + // Get the array allocation ID associated with this extern. If this + // extern is not an array type, the return value is not defined. + VariableID GetId() const; + + private: + // Name of the extern variable. + const ExternalData& m_data; + + // Allocation ID for the array for this extern, if applicable. + const VariableID m_id; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/FeatureSpec.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/FeatureSpec.cpp new file mode 100644 index 000000000000..d6588e7a07e1 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/FeatureSpec.cpp @@ -0,0 +1,592 @@ +#include "FeatureSpec.h" + +#include "ArrayType.h" +#include +#include +#include +#include +#include +#include +#include "FreeForm2Assert.h" +#include +#include +#include +#include +#include "TypeManager.h" +#include "TypeUtil.h" +#include "Visitor.h" +#include "MigratedApi.h" + +namespace +{ + void EscapeString(std::string& p_str) + { + static const char c_escapeSequences[][3] = { + "\\\"", "\\n", "\\t", "\\\\" + }; + static const char c_replaceSequences[][2] = { + "\"", "\n", "\t", "\\" + }; + static_assert(countof(c_escapeSequences) == countof(c_replaceSequences), + "Replacement sequence arrays must be parallel"); + + typedef boost::iterator_range ConstCharPtrRange; + for (size_t i = 0; i < countof(c_escapeSequences); i++) + { + // Construct string objects to prevent the compiler from whining + // about unsafe parameters. + const std::string escapeSequence(c_escapeSequences[i]); + const std::string replaceSequence(c_replaceSequences[i]); + boost::replace_all(p_str, escapeSequence, replaceSequence); + } + } +} + + +FreeForm2::FeatureSpecExpression::FeatureName::FeatureName() + : m_isParameterized(false) +{ +} + + +FreeForm2::FeatureSpecExpression::FeatureName::FeatureName(const std::string& p_name) + : m_name(p_name), + m_isParameterized(false) +{ +} + + +FreeForm2::FeatureSpecExpression::FeatureName::FeatureName(const std::string& p_name, + const ParameterMap& p_parameters) + : m_name(p_name), + m_params(p_parameters), + m_isParameterized(true) +{ +} + + +FreeForm2::FeatureSpecExpression::FeatureName +FreeForm2::FeatureSpecExpression::FeatureName::Parse(const std::string& p_name, + bool p_isParameterized, + const std::string& p_parameterization, + const SourceLocation& p_location) +{ + if (!p_isParameterized) + { + return FeatureName(p_name); + } + else + { + FeatureName name(p_name); + name.m_isParameterized = true; + if (p_parameterization.size() > 0) + { + try + { + std::string params = p_parameterization; + EscapeString(params); + + FF2_ASSERT(params.front() == '"' && params.back() == '"'); + auto paramRange = boost::make_iterator_range(params.cbegin() + 1, params.cend() - 1); + + std::vector values; + for (auto iter = boost::make_split_iterator(paramRange, boost::first_finder(",")); + !iter.eof(); + ++iter) + { + typedef boost::iterator_range StringRange; + const StringRange range = *iter; + values.clear(); + + // Note that boost::equals does approximately the same thing + // as the bind expression, but for some reason it causes a + // compiler warning about unsafe parameters. + boost::split(values, range, boost::bind(&boost::is_equal::operator(), + boost::is_equal(), + _1, + '=')); + FF2_ASSERT(values.size() == 2); + + const std::string& paramName = values[0]; + const std::string& paramValue = values[1]; + name.m_params.insert(std::make_pair(paramName, paramValue)); + } + return name; + } + catch (...) + { + std::ostringstream err; + err << "Unable to parse parameterization " << p_parameterization + << "; expected format is \"=,...\""; + throw FreeForm2::ParseError(err.str(), p_location); + } + } + else + { + return name; + } + } +} + + +const std::string& +FreeForm2::FeatureSpecExpression::FeatureName::GetName() const +{ + return m_name; +} + + +const FreeForm2::FeatureSpecExpression::FeatureName::ParameterMap& +FreeForm2::FeatureSpecExpression::FeatureName::GetParameters() const +{ + return m_params; +} + + +bool +FreeForm2::FeatureSpecExpression::FeatureName::IsParameterized() const +{ + return m_isParameterized; +} + + +FreeForm2::SymbolTable::Symbol +FreeForm2::FeatureSpecExpression::FeatureName::GetSymbol() const +{ + if (m_isParameterized) + { + if (m_paramStr.empty()) + { + std::ostringstream out; + out << "\""; + bool first = true; + BOOST_FOREACH(const Parameter& param, m_params) + { + if (!first) + { + out << ','; + } + out << param.first << '=' << param.second; + first = false; + } + out << "\""; + m_paramStr = out.str(); + } + return SymbolTable::Symbol(CStackSizedString(m_name.c_str()), + CStackSizedString(m_paramStr.c_str())); + } + else + { + return SymbolTable::Symbol(CStackSizedString(m_name.c_str())); + } +} + + +bool +FreeForm2::FeatureSpecExpression::FeatureName::operator==(const FeatureName& p_other) const +{ + return GetName() == p_other.GetName() + && IsParameterized() == p_other.IsParameterized() + && GetParameters() == p_other.GetParameters(); +} + + +bool +FreeForm2::FeatureSpecExpression::FeatureName::operator<(const FeatureName& p_other) const +{ + const int nameCompare = GetName().compare(p_other.GetName()); + if (nameCompare == 0) + { + if (IsParameterized() == p_other.IsParameterized()) + { + return GetParameters() < p_other.GetParameters(); + } + else + { + return !IsParameterized(); + } + } + else + { + return nameCompare < 0; + } +} + + +std::ostream& +operator<<(std::ostream& p_out, const FreeForm2::FeatureSpecExpression::FeatureName& p_name) +{ + return p_out << p_name.GetSymbol().ToString(); +} + + +bool +FreeForm2::FeatureSpecExpression::IgnoreParameterLess::operator()( + const FeatureName& p_left, + const FeatureName& p_right) const +{ + return p_left.GetName() < p_right.GetName(); +} + + +FreeForm2::FeatureSpecExpression::FeatureSpecExpression( + const Annotations& p_annotations, + const boost::shared_ptr p_publishFeatureMap, + const Expression& p_body, + FeatureSpecType p_featureSpecType, + bool p_returnsValue) + : Expression(p_annotations), + m_publishFeatureMap(p_publishFeatureMap), + m_body(p_body), + m_featureSpecType(p_featureSpecType), + m_returnsValue(p_returnsValue) +{ + FF2_ASSERT(p_publishFeatureMap != NULL && p_publishFeatureMap->size() > 0); + + // Check to make sure that the published types are valid and non-void. + for (const auto& featureNameToType : *p_publishFeatureMap) + { + if (featureNameToType.second.Primitive() != Type::Unknown + && featureNameToType.second.IsValid() + && featureNameToType.second.Primitive() == Type::Void) + { + std::ostringstream err; + err << "FeatureSpecExpression cannot have feature of type: '" + << featureNameToType.second << "'"; + throw ParseError(err.str(), GetSourceLocation()); + } + } +} + + +void +FreeForm2::FeatureSpecExpression::Accept(Visitor& p_visitor) const +{ + if (!p_visitor.AlternativeVisit(*this)) + { + m_body.Accept(p_visitor); + + p_visitor.Visit(*this); + } +} + + +const FreeForm2::TypeImpl& +FreeForm2::FeatureSpecExpression::GetType() const +{ + if (m_returnsValue) + { + const TypeImpl& returnType = m_publishFeatureMap->begin()->second; + + if (!TypeUtil::IsAssignable(returnType, m_body.GetType())) + { + std::ostringstream err; + err << "Expected feature to return type " << returnType + << ", but found return type " << m_body.GetType(); + throw ParseError(err.str(), GetSourceLocation()); + } + return returnType; + } + else + { + // If features are published, this FeatureSpec should be of type void. + if (m_body.GetType().Primitive() != Type::Void) + { + std::ostringstream err; + err << "Last statement of feature spec should be of type void"; + throw ParseError(err.str(), m_body.GetSourceLocation()); + } + return m_body.GetType(); + } +} + + +size_t +FreeForm2::FeatureSpecExpression::GetNumChildren() const +{ + return 1; +} + + +boost::shared_ptr +FreeForm2::FeatureSpecExpression::GetPublishFeatureMap() const +{ + return m_publishFeatureMap; +} + + +const FreeForm2::Expression& +FreeForm2::FeatureSpecExpression::GetBody() const +{ + return m_body; +} + + +bool +FreeForm2::FeatureSpecExpression::IsDerived() const +{ + return m_featureSpecType == DerivedFeature || m_featureSpecType == AggregatedDerivedFeature; +} + + +FreeForm2::FeatureSpecExpression::FeatureSpecType +FreeForm2::FeatureSpecExpression::GetFeatureSpecType() const +{ + return m_featureSpecType; +} + + +FreeForm2::ImportFeatureExpression::ImportFeatureExpression( + const Annotations& p_annotations, + const FreeForm2::FeatureSpecExpression::FeatureName& p_featureName, + const std::vector& p_dimensions, + VariableID p_id, + TypeManager& p_typeManager) + : Expression(p_annotations), + m_featureName(p_featureName), + m_type(p_typeManager.GetArrayType( + TypeImpl::GetUInt32Instance(true), + true, + static_cast(p_dimensions.size()), + &p_dimensions[0], + std::accumulate(p_dimensions.begin(), p_dimensions.end(), 1u, std::multiplies()))), + m_id(p_id) +{ +} + + +FreeForm2::ImportFeatureExpression::ImportFeatureExpression( + const Annotations& p_annotations, + const FreeForm2::FeatureSpecExpression::FeatureName& p_featureName, + VariableID p_id) + : Expression(Annotations(p_annotations.m_sourceLocation, ValueBounds(TypeImpl::GetUInt32Instance(true)))), + m_featureName(p_featureName), + m_type(TypeImpl::GetUInt32Instance(true)), + m_id(p_id) +{ +} + + +void +FreeForm2::ImportFeatureExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::ImportFeatureExpression::GetType() const +{ + return m_type; +} + + +size_t +FreeForm2::ImportFeatureExpression::GetNumChildren() const +{ + return 0; +} + + +const FreeForm2::FeatureSpecExpression::FeatureName& +FreeForm2::ImportFeatureExpression::GetFeatureName() const +{ + return m_featureName; +} + + +FreeForm2::VariableID +FreeForm2::ImportFeatureExpression::GetId() const +{ + return m_id; +} + + +FreeForm2::FeatureGroupSpecExpression::FeatureGroupSpecExpression(const Annotations& p_annotations, + const std::string& p_name, + const std::vector& p_featureSpecs, + bool p_isExtendedExperimental, + bool p_isSmallExperimental, + bool p_isBlockLevelFeature, + bool p_isBodyBlockFeature, + bool p_isForwardIndexFeature, + const std::string& p_metaStreamName) + : Expression(p_annotations), + m_name(p_name), + m_featureSpecs(p_featureSpecs), + m_isExtendedExperimental(p_isExtendedExperimental), + m_isSmallExperimental(p_isSmallExperimental), + m_isBlockLevelFeature(p_isBlockLevelFeature), + m_isBodyBlockFeature(p_isBodyBlockFeature), + m_isForwardIndexFeature(p_isForwardIndexFeature), + m_metaStreamName(p_metaStreamName) +{ + FF2_ASSERT(p_featureSpecs.size() > 0); + + m_featureSpecType = p_featureSpecs[0]->GetFeatureSpecType(); + + BOOST_FOREACH(const FeatureSpecExpression* featureSpec, p_featureSpecs) + { + if (m_featureSpecType != featureSpec->GetFeatureSpecType()) + { + std::ostringstream err; + err << "All feature specifications within feature group '" << p_name << "' must be of the same type."; + throw ParseError(err.str(), GetSourceLocation()); + } + } +} + + +void +FreeForm2::FeatureGroupSpecExpression::Accept(Visitor& p_visitor) const +{ + if (!p_visitor.AlternativeVisit(*this)) + { + for(std::vector::const_iterator specExpression = m_featureSpecs.begin(); + specExpression != m_featureSpecs.end(); + ++specExpression) + { + (*specExpression)->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } +} + + +const FreeForm2::TypeImpl& +FreeForm2::FeatureGroupSpecExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::FeatureGroupSpecExpression::GetNumChildren() const +{ + return m_featureSpecs.size(); +} + + +FreeForm2::FeatureSpecExpression::FeatureSpecType +FreeForm2::FeatureGroupSpecExpression::GetFeatureSpecType() const +{ + return m_featureSpecType; +} + + +const std::string& +FreeForm2::FeatureGroupSpecExpression::GetName() const +{ + return m_name; +} + + +const std::vector& +FreeForm2::FeatureGroupSpecExpression::GetFeatureSpecs() const +{ + return m_featureSpecs; +} + + +bool +FreeForm2::FeatureGroupSpecExpression::IsExtendedExperimental() const +{ + return m_isExtendedExperimental; +} + + +bool +FreeForm2::FeatureGroupSpecExpression::IsSmallExperimental() const +{ + return m_isSmallExperimental; +} + + +bool +FreeForm2::FeatureGroupSpecExpression::IsBlockLevelFeature() const +{ + return m_isBlockLevelFeature; +} + + +bool +FreeForm2::FeatureGroupSpecExpression::IsBodyBlockFeature() const +{ + return m_isBodyBlockFeature; +} + + +bool +FreeForm2::FeatureGroupSpecExpression::IsForwardIndexFeature() const +{ + return m_isForwardIndexFeature; +} + + +const std::string& +FreeForm2::FeatureGroupSpecExpression::GetMetaStreamName() const +{ + return m_metaStreamName; +} + + +bool +FreeForm2::FeatureGroupSpecExpression::IsPerStream() const +{ + return !(m_featureSpecType == FeatureSpecExpression::AggregatedDerivedFeature + || m_featureSpecType == FeatureSpecExpression::AbInitioFeature + || !m_metaStreamName.empty() + || m_isBodyBlockFeature); +} + + +FreeForm2::AggregateContextExpression::AggregateContextExpression(const Annotations& p_annotations, + const Expression& p_body) + : Expression(p_annotations), m_body(p_body) +{ +} + + +void +FreeForm2::AggregateContextExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_body.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::AggregateContextExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::AggregateContextExpression::GetNumChildren() const +{ + return 1; +} + + +const FreeForm2::Expression& +FreeForm2::AggregateContextExpression::GetBody() const +{ + return m_body; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/FeatureSpec.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/FeatureSpec.h new file mode 100644 index 000000000000..556bd273da0a --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/FeatureSpec.h @@ -0,0 +1,268 @@ +#pragma once + +#ifndef FREEFORM2_FEATURE_SPEC_H +#define FREEFORM2_FEATURE_SPEC_H + +#include "Expression.h" +#include "FreeForm2Features.h" +#include +#include +#include "SymbolTable.h" +#include +#include +#include + +namespace FreeForm2 +{ + class TypeManager; + + // This class encapsulates a function-like expression (either a feature + // specification or an actual function). If features are published within + // this spec, the Type should be Void, otherwise, the type is the same as + // the returned value. + class FeatureSpecExpression : public Expression + { + public: + // A class containing information about a feature name, including + // parameterization. + class FeatureName : boost::partially_ordered + { + public: + typedef std::map ParameterMap; + typedef ParameterMap::value_type Parameter; + + // Construct an empty feature name. + FeatureName(); + + // Construct a feature name without a parameterization. + explicit FeatureName(const std::string& p_name); + + // Construct a FeatureName was a base name and parameterization. + FeatureName(const std::string& p_name, + const ParameterMap& p_parameters); + + // Parse a feature name from a base name and a parameterization + // string. The parameterization string may be empty, even if the + // parameterization flag is true - this yields a parameterized + // feature with no parameters, which implies a single feature + // value exists for the parameter. + static FeatureName Parse(const std::string& p_name, + bool p_isParameterized, + const std::string& p_parameterization, + const SourceLocation& p_location); + + // Accessor methods. + const std::string& GetName() const; + bool IsParameterized() const; + const ParameterMap& GetParameters() const; + SymbolTable::Symbol GetSymbol() const; + + // Operators + bool operator==(const FeatureName& p_other) const; + bool operator<(const FeatureName& p_other) const; + + private: + // The name of the feature being imported. + std::string m_name; + + // The (optional) parameter of the feature being imported. + ParameterMap m_params; + + // The parameter string for creating a symbol. + mutable std::string m_paramStr; + + // Whether this feature is parameterized. + bool m_isParameterized; + }; + + // This structure is a simple functor used to compare feature names, + // ignoring parameterization. It is useful for publishing features, + // when the names themselves must be unique. + struct IgnoreParameterLess + { + bool operator()(const FeatureName& p_left, const FeatureName& p_right) const; + }; + + // Mapping of the names of the features being published to their types. + typedef std::map PublishFeatureMap; + + // The type of the feature specification. These names are against the + // coding guidelines only to temporarily limit amount of code touched + // by this CL. To be fixed in TFS ID 472156. + typedef FeatureInformation::FeatureType FeatureSpecType; + static const FeatureSpecType MetaStreamFeature = FeatureInformation::MetaStreamFeature; + static const FeatureSpecType DerivedFeature = FeatureInformation::DerivedFeature; + static const FeatureSpecType AggregatedDerivedFeature = FeatureInformation::AggregatedDerivedFeature; + static const FeatureSpecType AbInitioFeature = FeatureInformation::AbInitioFeature; + + // Construct a feature specification. + FeatureSpecExpression(const Annotations& p_annotations, + boost::shared_ptr p_publishFeatureMap, + const Expression& p_body, + FeatureSpecType p_featureSpecType, + bool p_returnsValue); + + // Methods inherited from Expression + virtual void Accept(Visitor&) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Accessor methods + const Expression& GetBody() const; + bool IsDerived() const; + FeatureSpecType GetFeatureSpecType() const; + boost::shared_ptr GetPublishFeatureMap() const; + + private: + // A mapping of feature name to type of the features being published. + boost::shared_ptr m_publishFeatureMap; + + // Body of the feature. + const Expression& m_body; + + // The type of feature specification. + FeatureSpecType m_featureSpecType; + + // Whether this FeatureSpec returns a value (versus publishing feature + // names). This is a temporary parameter that will be removed when a + // Function type is added to this class. This will be addressed with + // TFS 321891. + bool m_returnsValue; + }; + + // This class imports a feature values as a declaration. This class is + // distinct from FeatureRefExpression in that it imports features which + // are dependent on the current metastream; it is more efficient to + // determine these values at runtime than to have a FeatureRef for each + // stream and array dimensions. + class ImportFeatureExpression : public Expression + { + public: + // Import an array of parameterized feature values. + ImportFeatureExpression(const Annotations& p_annotations, + const FeatureSpecExpression::FeatureName& p_featureName, + const std::vector& p_dimensions, + VariableID p_id, + TypeManager& p_typeManager); + + // Import a single per-stream feature value. + ImportFeatureExpression(const Annotations& p_annotations, + const FeatureSpecExpression::FeatureName& p_featureName, + VariableID p_id); + + // Methods inherited from Expression. + virtual void Accept(Visitor&) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Accessor methods. + VariableID GetId() const; + const FeatureSpecExpression::FeatureName& GetFeatureName() const; + + private: + // A struct containing feature name information. + const FeatureSpecExpression::FeatureName m_featureName; + + // The type of the feature. Most importantly, this holds the array + // dimensions for the feature. + const TypeImpl& m_type; + + // Allocation ID for the value. + VariableID m_id; + }; + + // This class wraps several feature specifications into a feature group. + // This is needed to be able to emit code that is compatible with the current + // implementation of features in the IFM. + class FeatureGroupSpecExpression : public Expression + { + public: + FeatureGroupSpecExpression(const Annotations& p_annotations, + const std::string& p_name, + const std::vector& p_featureSpecs, + bool p_isExtendedExperimental, + bool p_isSmallExperimental, + bool p_isBlockLevelFeature, + bool p_isBodyBlockFeature, + bool p_isForwardIndexFeature, + const std::string& p_metaStreamName); + + // Methods inherited from Expression. + virtual void Accept(Visitor&) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Gets the type of all the feature specifications. + FeatureSpecExpression::FeatureSpecType GetFeatureSpecType() const; + + // Accessor methods. + const std::string& GetName() const; + const std::vector& GetFeatureSpecs() const; + bool IsExtendedExperimental() const; + bool IsSmallExperimental() const; + bool IsBlockLevelFeature() const; + bool IsBodyBlockFeature() const; + bool IsForwardIndexFeature() const; + const std::string& GetMetaStreamName() const; + + // Returns true if the features in this FeatureGroup are per-stream features. + bool IsPerStream() const; + + private: + // The name of the feature group. + const std::string m_name; + + // The child feature specs. + const std::vector m_featureSpecs; + + // Whether this feature group is extended experimental. + bool m_isExtendedExperimental; + + // Whether this feature group is small experimental. + bool m_isSmallExperimental; + + // Whether this feature group is block level. + bool m_isBlockLevelFeature; + + // Whether this feature group is body block. + bool m_isBodyBlockFeature; + + // Whether this feature group uses the forward index. + bool m_isForwardIndexFeature; + + // The stream over which this metastream feature is supposed to operate over. + const std::string m_metaStreamName; + + // The type of all the feature specifications. All the feature spec types must + // be the same. + FeatureSpecExpression::FeatureSpecType m_featureSpecType; + }; + + // This expression denotes the beggining of an aggregate block. This + // expression type is only allowed inside aggregate features and will be + // run exactly once per stream on which this feature is being evaluated. + class AggregateContextExpression : public Expression + { + public: + // Create an aggregate context containing a body expression. + AggregateContextExpression(const Annotations& p_annotations, + const Expression& p_body); + + // Methods inherited from Expression. + virtual void Accept(Visitor&) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Get the body of the aggregate context. + const Expression& GetBody() const; + + private: + // The body of the aggregator loop. + const Expression& m_body; + }; +} + +std::ostream& operator<<(std::ostream& p_out, + const FreeForm2::FeatureSpecExpression::FeatureName& p_name); + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Function.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Function.cpp new file mode 100644 index 000000000000..f7263f4cf22e --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Function.cpp @@ -0,0 +1,255 @@ +#include "Function.h" + +#include "FreeForm2Assert.h" +#include "FunctionType.h" +#include "RefExpression.h" +#include "Visitor.h" + + +FreeForm2::FunctionExpression::FunctionExpression(const Annotations& p_annotations, + const FunctionType& p_type, + const std::string& p_name, + const std::vector& p_parameters, + const Expression& p_body) + : Expression(p_annotations), + m_type(p_type), + m_name(p_name), + m_parameters(p_parameters), + m_body(p_body) +{ +} + + +const FreeForm2::FunctionType& +FreeForm2::FunctionExpression::GetFunctionType() const +{ + return m_type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::FunctionExpression::GetType() const +{ + return m_type; +} + + +const std::string& +FreeForm2::FunctionExpression::GetName() const +{ + return m_name; +} + + +size_t +FreeForm2::FunctionExpression::GetNumChildren() const +{ + return GetNumParameters() + 1; +} + + +size_t +FreeForm2::FunctionExpression::GetNumParameters() const +{ + return m_parameters.size(); +} + + +const FreeForm2::Expression& +FreeForm2::FunctionExpression::GetBody() const +{ + return m_body; +} + + +const std::vector& +FreeForm2::FunctionExpression::GetParameters() const +{ + return m_parameters; +} + + +void +FreeForm2::FunctionExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (size_t i = 0; i < m_parameters.size(); i++) + { + m_parameters[i].m_parameter->Accept(p_visitor); + } + + m_body.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +FreeForm2::FunctionCallExpression::FunctionCallExpression(const Annotations& p_annotations, + const Expression& p_function, + const std::vector& p_parameters) + : Expression(p_annotations), + m_function(p_function), + m_numParameters(p_parameters.size()) +{ + m_type = static_cast(&p_function.GetType()); + + // Note that we rely on this ctor not throwing exceptions during + // allocation below. + + // We rely on our allocator to size this object to be big enough to + // hold all children, and enforce this forcing construction via Alloc. + for (size_t i = 0; i < m_numParameters; i++) + { + m_parameters[i] = p_parameters[i]; + } +} + + +void +FreeForm2::FunctionCallExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_function.Accept(p_visitor); + + for (size_t i = 0; i < m_numParameters; i++) + { + if (GetFunctionType().BeginParameters()[i]->IsConst()) + { + m_parameters[i]->Accept(p_visitor); + } + else + { + m_parameters[i]->AcceptReference(p_visitor); + } + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::FunctionType& +FreeForm2::FunctionCallExpression::GetFunctionType() const +{ + return *m_type; +} + + +const FreeForm2::Expression& +FreeForm2::FunctionCallExpression::GetFunction() const +{ + return m_function; +} + + +const FreeForm2::TypeImpl& +FreeForm2::FunctionCallExpression::GetType() const +{ + return m_type->GetReturnType(); +} + + +size_t +FreeForm2::FunctionCallExpression::GetNumChildren() const +{ + return GetNumParameters() + 1; +} + + +size_t +FreeForm2::FunctionCallExpression::GetNumParameters() const +{ + return m_numParameters; +} + + +const FreeForm2::Expression* const* +FreeForm2::FunctionCallExpression::GetParameters() const +{ + return &m_parameters[0]; +} + + +boost::shared_ptr +FreeForm2::FunctionCallExpression::Alloc(const Annotations& p_annotations, + const Expression& p_function, + const std::vector& p_parameters) +{ + FF2_ASSERT(p_function.GetType().Primitive() == Type::Function); + + size_t bytes = sizeof(FunctionCallExpression) + + (std::max((size_t) 1ULL, p_parameters.size()) - 1) * sizeof(Expression*); + + // Allocate a shared_ptr that deletes an FunctionCallExpression + // allocated in a char[]. + boost::shared_ptr exp(new (new char[bytes]) + FunctionCallExpression(p_annotations, p_function, p_parameters), DeleteAlloc); + return exp; +} + + +void +FreeForm2::FunctionCallExpression::DeleteAlloc(FunctionCallExpression* p_allocated) +{ + // Manually call dtor for operator expression. + p_allocated->~FunctionCallExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + +FreeForm2::ReturnExpression::ReturnExpression(const Annotations& p_annotations, + const Expression& p_value) + : Expression(p_annotations), + m_value(p_value) +{ +} + + +void +FreeForm2::ReturnExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_value.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +size_t +FreeForm2::ReturnExpression::GetNumChildren() const +{ + return 1; +} + + +const FreeForm2::TypeImpl& +FreeForm2::ReturnExpression::GetType() const +{ + return m_value.GetType().AsConstType(); +} + + +const FreeForm2::Expression& +FreeForm2::ReturnExpression::GetValue() const +{ + return m_value; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Function.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Function.h new file mode 100644 index 000000000000..cd3c52fe0045 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Function.h @@ -0,0 +1,130 @@ +#pragma once + +#include "Expression.h" +#include "FeatureSpec.h" + +namespace FreeForm2 +{ + class FunctionType; + class VariableRefExpression; + + // Function expression represents the declaration of a user-defined function. + // Formal parameters, together with implicit feature parameters, are assigned + // during parse time for both function declarations and function calls. + class FunctionExpression : public Expression + { + public: + // A struct that holds the actual function parameter ref and the feature name + // metadata, to be able to bind these parameters by name. + struct Parameter + { + const VariableRefExpression* m_parameter; + bool m_isFeatureParameter; + FeatureSpecExpression::FeatureName m_featureName; + }; + + // Create an function expression. + FunctionExpression(const Annotations& p_annotations, + const FunctionType& p_type, + const std::string& p_name, + const std::vector& p_parameters, + const Expression& p_body); + + // Accessor methods. + const FunctionType& GetFunctionType() const; + const std::string& GetName() const; + size_t GetNumParameters() const; + const Expression& GetBody() const; + const std::vector& GetParameters() const; + + // Methods inherited from Expression. + virtual void Accept(Visitor&) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + private: + + // The function type. + const FunctionType& m_type; + + // The function name. + const std::string m_name; + + // The expression that, when evaluated, computes the return value of the + // function. + const Expression& m_body; + + // A list of parameters. + std::vector m_parameters; + }; + + // Function call expressions represent the calling of a function, which can be + // either external or user-defined. + class FunctionCallExpression : public Expression + { + public: + // Create an function call with an ExternExpression. + // p_function must have a Function type. + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const Expression& p_function, + const std::vector& p_parameters); + + // Get the function type. + const FunctionType& GetFunctionType() const; + + // Get the function. + const Expression& GetFunction() const; + + size_t GetNumParameters() const; + const Expression* const* GetParameters() const; + + // Methods inherited from Expression. + virtual void Accept(Visitor&) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + private: + // Constructors are private, call Alloc instead. + FunctionCallExpression(const Annotations& p_annotations, + const Expression& p_function, + const std::vector& p_parameters); + + // The function type. + const FunctionType* m_type; + + // The expression that, when evaluated, becomes a callable expression of type + // Function. + const Expression& m_function; + + // Number of parameters of this node. + size_t m_numParameters; + + // Array of parameters of this node, allocated using struct hack. + static void DeleteAlloc(FunctionCallExpression* p_allocated); + const Expression* m_parameters[1]; + }; + + // The Return expression causes execution to exit out of a function, + // passing a value back to the caller. + class ReturnExpression : public Expression + { + public: + // Create a ReturnExpression that returns the value of an expression. + ReturnExpression(const Annotations& p_annotations, + const Expression& p_value); + + // Methods inherited from Expression + virtual void Accept(Visitor&) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + // Return the value of the expression to be returned by this + // expression. + const Expression& GetValue() const; + private: + + // The expression to be returned. + const Expression& m_value; + }; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LetExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LetExpression.cpp new file mode 100644 index 000000000000..a5a361751983 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LetExpression.cpp @@ -0,0 +1,107 @@ +#include "LetExpression.h" + +#include "FreeForm2Assert.h" +#include "RefExpression.h" +#include "SimpleExpressionOwner.h" +#include "Visitor.h" +#include + +FreeForm2::LetExpression::LetExpression(const Annotations& p_annotations, + const std::vector& p_children, + const Expression* p_value) + : Expression(p_annotations), + m_numBound(static_cast(p_children.size())), + m_value(p_value) +{ + for (unsigned int i = 0; i < p_children.size(); i++) + { + m_bound[i] = p_children[i]; + } +} + + +const FreeForm2::TypeImpl& +FreeForm2::LetExpression::GetType() const +{ + return m_value->GetType(); +} + + +size_t +FreeForm2::LetExpression::GetNumChildren() const +{ + return m_numBound + 1; +} + + +const FreeForm2::Expression& +FreeForm2::LetExpression::GetValue() const +{ + return *m_value; +} + + +const FreeForm2::LetExpression::IdExpressionPair* +FreeForm2::LetExpression::GetBound() const +{ + return m_bound; +} + + +void +FreeForm2::LetExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (unsigned int i = 0; i < m_numBound; i++) + { + m_bound[i].second->Accept(p_visitor); + } + + m_value->Accept(p_visitor); + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +boost::shared_ptr +FreeForm2::LetExpression::Alloc(const Annotations& p_annotations, + const std::vector& p_children, + const Expression* p_value) +{ + size_t bytes = sizeof(LetExpression) + (p_children.size() - 1) * sizeof(IdExpressionPair); + + // Constructor assertions must appear in the allocation method: the + // constructor may not throw; otherwise, the raw memory allocation will + // leak. + FF2_ASSERT(!p_children.empty()); + for (size_t i = 0; i < p_children.size(); i++) + { + for (size_t j = 0; j < i; j++) + { + FF2_ASSERT(p_children[j].first != p_children[i].first); + } + } + + // Allocate a shared_ptr that deletes an LetExpression + // allocated in a char[]. + boost::shared_ptr exp(new (new char[bytes]) LetExpression(p_annotations, p_children, p_value), + DeleteAlloc); + return exp; +} + + +void +FreeForm2::LetExpression::DeleteAlloc(LetExpression* p_allocated) +{ + // Manually call dtor for let expression. + p_allocated->~LetExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LetExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LetExpression.h new file mode 100644 index 000000000000..41d43d3fdbc9 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LetExpression.h @@ -0,0 +1,53 @@ +#pragma once + +#ifndef FREEFORM2_LETEXPRESSION_H +#define FREEFORM2_LETEXPRESSION_H + +#include "Expression.h" +#include +#include + +namespace FreeForm2 +{ + class LetExpression : public Expression + { + public: + typedef std::pair IdExpressionPair; + + // Methods inherited from Expression. + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + + // Methods needed for the alternative visitation method. + const Expression& GetValue() const; + const IdExpressionPair* GetBound() const; + + // Custom allocation method. + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const std::vector& p_children, + const Expression* p_value); + + private: + // Private constructor for struct hack allocation. + LetExpression(const Annotations& p_annotations, + const std::vector& p_children, + const Expression* p_value); + + // Custom deallocation method. + static void DeleteAlloc(LetExpression* p_allocated); + + // Sub-expression that dictates the value of the let expression. + const Expression* m_value; + + // Number of quantities (variables) bound by this let. + unsigned int m_numBound; + + // Array of quantites, allocated via struct hack, bound by this let. + IdExpressionPair m_bound[1]; + }; +} + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LiteralExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LiteralExpression.cpp new file mode 100644 index 000000000000..ade83b561b5e --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LiteralExpression.cpp @@ -0,0 +1,545 @@ +#include "LiteralExpression.h" + +#include "Allocation.h" +#include "FreeForm2Assert.h" +#include "Visitor.h" + + +FreeForm2::LeafTypeLiteralExpression::LeafTypeLiteralExpression(const Annotations& p_annotations, + Result::IntType p_value) + : Expression(Annotations(p_annotations.m_sourceLocation, ValueBounds(p_value, p_value))) +{ + m_value.m_int = p_value; +} + + +FreeForm2::LeafTypeLiteralExpression::LeafTypeLiteralExpression(const Annotations& p_annotations, + Result::UInt64Type p_value) + : Expression(Annotations(p_annotations.m_sourceLocation, ValueBounds(p_value, p_value))) +{ + m_value.m_uint64 = p_value; +} + + +FreeForm2::LeafTypeLiteralExpression::LeafTypeLiteralExpression(const Annotations& p_annotations, + Result::Int32Type p_value) + : Expression(Annotations(p_annotations.m_sourceLocation, ValueBounds(p_value, p_value))) +{ + m_value.m_int32 = p_value; +} + + +FreeForm2::LeafTypeLiteralExpression::LeafTypeLiteralExpression(const Annotations& p_annotations, + Result::UInt32Type p_value) + : Expression(Annotations(p_annotations.m_sourceLocation, ValueBounds(p_value, p_value))) +{ + m_value.m_uint32 = p_value; +} + + +FreeForm2::LeafTypeLiteralExpression::LeafTypeLiteralExpression(const Annotations& p_annotations, + Result::FloatType p_value) + : Expression(p_annotations) +{ + m_value.m_float = p_value; +} + + +FreeForm2::LeafTypeLiteralExpression::LeafTypeLiteralExpression(const Annotations& p_annotations, + Result::BoolType p_value) + : Expression(Annotations(p_annotations.m_sourceLocation, ValueBounds(p_value, p_value))) +{ + m_value.m_bool = p_value; +} + + +bool +FreeForm2::LeafTypeLiteralExpression::IsConstant() const +{ + return true; +} + + +FreeForm2::ConstantValue +FreeForm2::LeafTypeLiteralExpression::GetConstantValue() const +{ + return m_value; +} + + +FreeForm2::LiteralIntExpression::LiteralIntExpression(const Annotations& p_annotations, + Result::IntType p_val) + : LeafTypeLiteralExpression(p_annotations, p_val) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::LiteralIntExpression::GetType() const +{ + return TypeImpl::GetIntInstance(true); +} + + +size_t +FreeForm2::LiteralIntExpression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::LiteralIntExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +FreeForm2::LiteralUInt64Expression::LiteralUInt64Expression(const Annotations& p_annotations, + Result::UInt64Type p_val) + : LeafTypeLiteralExpression(p_annotations, p_val) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::LiteralUInt64Expression::GetType() const +{ + return TypeImpl::GetUInt64Instance(true); +} + + +size_t +FreeForm2::LiteralUInt64Expression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::LiteralUInt64Expression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +FreeForm2::LiteralInt32Expression::LiteralInt32Expression(const Annotations& p_annotations, + int p_val) + : LeafTypeLiteralExpression(p_annotations, p_val) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::LiteralInt32Expression::GetType() const +{ + return TypeImpl::GetInt32Instance(true); +} + + +size_t +FreeForm2::LiteralInt32Expression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::LiteralInt32Expression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +FreeForm2::LiteralUInt32Expression::LiteralUInt32Expression(const Annotations& p_annotations, + unsigned int p_val) + : LeafTypeLiteralExpression(p_annotations, p_val) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::LiteralUInt32Expression::GetType() const +{ + return TypeImpl::GetUInt32Instance(true); +} + + +size_t +FreeForm2::LiteralUInt32Expression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::LiteralUInt32Expression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +FreeForm2::LiteralFloatExpression::LiteralFloatExpression(const Annotations& p_annotations, + Result::FloatType p_val) + : LeafTypeLiteralExpression(p_annotations, p_val) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::LiteralFloatExpression::GetType() const +{ + return TypeImpl::GetFloatInstance(true); +} + + +size_t +FreeForm2::LiteralFloatExpression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::LiteralFloatExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +FreeForm2::LiteralBoolExpression::LiteralBoolExpression(const Annotations& p_annotations, + bool p_val) + : LeafTypeLiteralExpression(p_annotations, p_val) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::LiteralBoolExpression::GetType() const +{ + return TypeImpl::GetBoolInstance(true); +} + + +size_t +FreeForm2::LiteralBoolExpression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::LiteralBoolExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::LiteralVoidExpression& +FreeForm2::LiteralVoidExpression::GetInstance() +{ + static const Annotations s_annotations; + static const LiteralVoidExpression s_expr(s_annotations); + return s_expr; +} + + +const FreeForm2::TypeImpl& +FreeForm2::LiteralVoidExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::LiteralVoidExpression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::LiteralVoidExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +FreeForm2::LiteralVoidExpression::LiteralVoidExpression(const Annotations& p_annotations) + : Expression(p_annotations) +{ +} + + +FreeForm2::LiteralWordExpression::LiteralWordExpression( + const Annotations& p_annotations, + const Expression& p_word, + const Expression& p_offset, + const Expression* p_attribute, + const Expression* p_length, + const Expression* p_candidate, + VariableID p_variableID) + : Expression(p_annotations), + m_isHeader(false), + m_word(p_word), + m_offset(p_offset), + m_attribute(p_attribute), + m_length(p_length), + m_candidate(p_candidate), + m_variableID(p_variableID) +{ +} + + +FreeForm2::LiteralWordExpression::LiteralWordExpression( + const Annotations& p_annotations, + const Expression& p_length, + const Expression& p_count, + VariableID p_variableID) + : Expression(p_annotations), + m_isHeader(true), + m_word(p_length), + m_offset(p_count), + m_attribute(NULL), + m_length(NULL), + m_candidate(NULL), + m_variableID(p_variableID) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::LiteralWordExpression::GetType() const +{ + return TypeImpl::GetWordInstance(true); +} + + +size_t +FreeForm2::LiteralWordExpression::GetNumChildren() const +{ + return 2 + + (m_attribute != NULL ? 1 : 0) + + (m_length != NULL ? 1 : 0) + + (m_candidate != NULL ? 1 : 0); +} + + +FreeForm2::VariableID +FreeForm2::LiteralWordExpression::GetId() const +{ + return m_variableID; +} + + +void +FreeForm2::LiteralWordExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_word.Accept(p_visitor); + m_offset.Accept(p_visitor); + if (m_attribute != NULL) + { + m_attribute->Accept(p_visitor); + } + if (m_length != NULL) + { + m_length->Accept(p_visitor); + } + if (m_candidate != NULL) + { + m_candidate->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::LiteralStreamExpression::GetType() const +{ + return TypeImpl::GetStreamInstance(true); +} + + +size_t +FreeForm2::LiteralStreamExpression::GetNumChildren() const +{ + return m_numChildren; +} + + +void +FreeForm2::LiteralStreamExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (unsigned int i = 0; i < m_numChildren; i++) + { + m_children[i]->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::Expression* const* +FreeForm2::LiteralStreamExpression::GetChildren() const +{ + return m_children; +} + + +FreeForm2::VariableID +FreeForm2::LiteralStreamExpression::GetId() const +{ + return m_id; +} + + +FreeForm2::LiteralStreamExpression::LiteralStreamExpression(const Annotations& p_annotations, + const Expression** p_children, + size_t p_numChildren, + VariableID p_id) + : Expression(p_annotations), + m_numChildren(p_numChildren), + m_id(p_id) +{ + for (unsigned int i = 0; i < m_numChildren; i++) + { + m_children[i] = p_children[i]; + } +} + + +boost::shared_ptr +FreeForm2::LiteralStreamExpression::Alloc(const Annotations& p_annotations, + const Expression** p_children, + size_t p_numChildren, + VariableID p_id) +{ + FF2_ASSERT(p_numChildren > 0); + size_t bytes = sizeof(LiteralStreamExpression) + + (p_numChildren - 1) * sizeof(Expression*); + + // Allocate a shared_ptr that deletes an LiteralStreamExpression + // allocated in a char[]. + boost::shared_ptr exp(new (new char[bytes]) + LiteralStreamExpression(p_annotations, p_children, p_numChildren, p_id), DeleteAlloc); + return exp; +} + + +void +FreeForm2::LiteralStreamExpression::DeleteAlloc(LiteralStreamExpression* p_allocated) +{ + // Manually call dtor for stream expression. + p_allocated->~LiteralStreamExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + + +FreeForm2::LiteralInstanceHeaderExpression::LiteralInstanceHeaderExpression(const Annotations& p_annotations, + const Expression& p_instanceCount, + const Expression& p_rank, + const Expression& p_instanceLength) + : Expression(p_annotations), + m_instanceCount(p_instanceCount), + m_rank(p_rank), + m_instanceLength(p_instanceLength) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::LiteralInstanceHeaderExpression::GetType() const +{ + return TypeImpl::GetInstanceHeaderInstance(true); +} + + +size_t +FreeForm2::LiteralInstanceHeaderExpression::GetNumChildren() const +{ + return 3; +} + + +void +FreeForm2::LiteralInstanceHeaderExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_instanceCount.Accept(p_visitor); + m_rank.Accept(p_visitor); + m_instanceLength.Accept(p_visitor); + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LiteralExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LiteralExpression.h new file mode 100644 index 000000000000..652c8a5c1a14 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/LiteralExpression.h @@ -0,0 +1,221 @@ +#pragma once + +#ifndef FREEFORM2_LITERALEXPRESSION_H +#define FREEFORM2_LITERALEXPRESSION_H + +#include "Expression.h" +#include "FreeForm2Result.h" + +namespace FreeForm2 +{ + // This class represents a literal expression of any leaf type. + class LeafTypeLiteralExpression : public Expression + { + public: + explicit LeafTypeLiteralExpression(const Annotations& p_annotations, Result::IntType p_value); + explicit LeafTypeLiteralExpression(const Annotations& p_annotations, Result::UInt64Type p_value); + explicit LeafTypeLiteralExpression(const Annotations& p_annotations, Result::Int32Type p_value); + explicit LeafTypeLiteralExpression(const Annotations& p_annotations, Result::UInt32Type p_value); + explicit LeafTypeLiteralExpression(const Annotations& p_annotations, Result::FloatType p_value); + explicit LeafTypeLiteralExpression(const Annotations& p_annotations, Result::BoolType p_value); + + virtual bool IsConstant() const override; + virtual ConstantValue GetConstantValue() const override; + + private: + ConstantValue m_value; + }; + + class LiteralIntExpression : public LeafTypeLiteralExpression + { + public: + // Construct a literal expression from an int. + explicit LiteralIntExpression(const Annotations& p_annotations, Result::IntType p_val); + + virtual const TypeImpl& GetType() const override; + + virtual size_t GetNumChildren() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + class LiteralUInt64Expression : public LeafTypeLiteralExpression + { + public: + // Construct a literal expression from an uint64. + explicit LiteralUInt64Expression(const Annotations& p_annotations, Result::UInt64Type p_val); + + virtual const TypeImpl& GetType() const override; + + virtual size_t GetNumChildren() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + class LiteralInt32Expression : public LeafTypeLiteralExpression + { + public: + // Construct a literal expression from an int32. + explicit LiteralInt32Expression(const Annotations& p_annotations, Result::Int32Type p_val); + + virtual const TypeImpl& GetType() const override; + + virtual size_t GetNumChildren() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + class LiteralUInt32Expression : public LeafTypeLiteralExpression + { + public: + // Construct a literal expression from an uint32. + explicit LiteralUInt32Expression(const Annotations& p_annotations, Result::UInt32Type p_val); + + virtual const TypeImpl& GetType() const override; + + virtual size_t GetNumChildren() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + class LiteralFloatExpression : public LeafTypeLiteralExpression + { + public: + // Construct a literal expression from a float. + explicit LiteralFloatExpression(const Annotations& p_annotations, Result::FloatType p_val); + + virtual const TypeImpl& GetType() const override; + + virtual size_t GetNumChildren() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + class LiteralBoolExpression : public LeafTypeLiteralExpression + { + public: + // Construct a literal expression from an int. + explicit LiteralBoolExpression(const Annotations& p_annotations, Result::BoolType p_val); + + virtual const TypeImpl& GetType() const override; + + virtual size_t GetNumChildren() const override; + + virtual void Accept(Visitor& p_visitor) const override; + }; + + class LiteralVoidExpression : public Expression + { + public: + // Get the single instance of LiteralVoidExpression. There's no point + // creating many instances of this expression, since they're all the + // same. + static const LiteralVoidExpression& GetInstance(); + + // Virtual methods inherited from Expression. + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + + private: + LiteralVoidExpression(const Annotations& p_annotations); + }; + + class LiteralWordExpression : public Expression + { + public: + // Construct a literal expression from a set of integers. + LiteralWordExpression(const Annotations& p_annotations, + const Expression& p_word, + const Expression& p_offset, + const Expression* p_attribute, + const Expression* p_length, + const Expression* p_candidate, + VariableID p_variableId); + + // Construct a literal instance header from integers. + LiteralWordExpression(const Annotations& p_annotations, + const Expression& p_instanceHeaderLength, + const Expression& p_instanceHeaderOffset, + VariableID p_variableId); + + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + + // Gets the integer identificator for this literal. + VariableID GetId() const; + + // Whether this literal represents a stream instance header or an + // ordinary word occurrence. + bool m_isHeader; + + // Members in a WordOccurrence. TODO: ensure that a WordOccurrence + // struct is of the same size. Note that we've overloaded word to carry + // instance header lengths, and offset to carry instance header counts. + const Expression& m_word; + const Expression& m_offset; + const Expression* m_attribute; + const Expression* m_length; + const Expression* m_candidate; + + private: + // The integer identificator for this literal. + VariableID m_variableID; + }; + + class LiteralStreamExpression : public Expression + { + public: + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + + static boost::shared_ptr + Alloc(const Annotations& p_annotations, const Expression** p_children, size_t p_numChildren, VariableID p_id); + + const Expression* const* GetChildren() const; + + VariableID GetId() const; + + private: + LiteralStreamExpression(const Annotations& p_annotations, + const Expression** p_children, + size_t p_numChildren, + VariableID p_id); + + static void DeleteAlloc(LiteralStreamExpression* p_allocated); + + // Number of children of this node. + size_t m_numChildren; + + // A unique identificator to allow separation of allocation and usage. + const VariableID m_id; + + // Array of children of this node, allocated using struct hack. + const Expression* m_children[1]; + }; + + // Represents an instance header in a stream. + class LiteralInstanceHeaderExpression : public Expression + { + public: + // Constructor. + LiteralInstanceHeaderExpression(const Annotations& p_annotations, + const Expression& p_instanceCount, + const Expression& p_rank, + const Expression& p_instanceLength); + + // Methods inherited from Expression class. + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + + // Properties associated with instance headers. + const Expression& m_instanceCount; + const Expression& m_rank; + const Expression& m_instanceLength; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Match.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Match.cpp new file mode 100644 index 000000000000..e77866bedc6f --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Match.cpp @@ -0,0 +1,356 @@ +#include "Match.h" + +#include "FreeForm2Assert.h" +#include +#include "Visitor.h" + +FreeForm2::MatchExpression::MatchExpression(const Annotations& p_annotations, + const Expression& p_value, + const MatchSubExpression& p_pattern, + const Expression& p_action, + bool p_isOverlapping) + : Expression(p_annotations), + m_value(p_value), + m_pattern(p_pattern), + m_action(p_action), + m_isOverlapping(p_isOverlapping) +{ +} + + +void +FreeForm2::MatchExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_value.Accept(p_visitor); + m_pattern.Accept(p_visitor); + m_action.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::MatchExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::MatchExpression::GetNumChildren() const +{ + return 3; +} + + +const FreeForm2::Expression& +FreeForm2::MatchExpression::GetValue() const +{ + return m_value; +} + + +const FreeForm2::MatchSubExpression& +FreeForm2::MatchExpression::GetPattern() const +{ + return m_pattern; +} + + +const FreeForm2::Expression& +FreeForm2::MatchExpression::GetAction() const +{ + return m_action; +} + + +bool +FreeForm2::MatchExpression::IsOverlapping() const +{ + return m_isOverlapping; +} + + +void +FreeForm2::MatchOperatorExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (size_t i = 0; i < m_numChildren; i++) + { + m_children[i]->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::MatchOperatorExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::MatchOperatorExpression::GetNumChildren() const +{ + return m_numChildren; +} + + +FreeForm2::MatchSubExpression::Info +FreeForm2::MatchOperatorExpression::GetInfo() const +{ + switch (GetOperator()) + { + case kleene: + { + FF2_ASSERT(m_numChildren == 1); + return Info(0, Info::c_indeterminate); + } + + case atLeastOne: + { + FF2_ASSERT(m_numChildren == 1); + return Info(m_children[0]->GetInfo().m_minLength, + Info::c_indeterminate); + } + + case alternation: + { + unsigned int min = UINT_MAX; + unsigned int max = 0; + + for (size_t i = 0; i < GetNumChildren(); i++) + { + Info info = m_children[i]->GetInfo(); + min = std::min(min, info.m_minLength); + max = std::max(max, info.m_maxLength); + } + + return Info(min, max); + } + + case concatenation: + { + Info combined(0, 0); + + for (size_t i = 0; i < GetNumChildren(); i++) + { + Info info = m_children[i]->GetInfo(); + combined.m_minLength += info.m_minLength; + + if (combined.m_maxLength != Info::c_indeterminate + && info.m_maxLength != Info::c_indeterminate) + { + combined.m_maxLength += info.m_maxLength; + } + else + { + combined.m_maxLength = Info::c_indeterminate; + } + } + + return combined; + } + + default: + { + Unreachable(__FILE__, __LINE__); + break; + } + } +} + + +boost::shared_ptr +FreeForm2::MatchOperatorExpression::Alloc(const Annotations& p_annotations, + const MatchSubExpression** p_children, + size_t p_numChildren, + Operator p_op) +{ + size_t bytes = sizeof(MatchOperatorExpression) + + sizeof(Expression*) * (p_numChildren - 1); + + // Allocate a shared_ptr that deletes an MatchOperatorExpression + // allocated in a char[]. + boost::shared_ptr exp(new (new char[bytes]) + MatchOperatorExpression(p_annotations, p_children, p_numChildren, p_op), DeleteAlloc); + return exp; +} + + +boost::shared_ptr +FreeForm2::MatchOperatorExpression::Alloc(const Annotations& p_annotations, + const MatchSubExpression& p_left, + const MatchSubExpression& p_right, + Operator p_op) +{ + const MatchSubExpression* array[2]; + array[0] = &p_left; + array[1] = &p_right; + return Alloc(p_annotations, array, sizeof(array) / sizeof(*array), p_op); +} + + +boost::shared_ptr +FreeForm2::MatchOperatorExpression::Alloc(const Annotations& p_annotations, + const MatchSubExpression& p_expr, + Operator p_op) +{ + const MatchSubExpression* array[1]; + array[0] = &p_expr; + return Alloc(p_annotations, array, sizeof(array) / sizeof(*array), p_op); +} + + +FreeForm2::MatchOperatorExpression::Operator +FreeForm2::MatchOperatorExpression::GetOperator() const +{ + return m_op; +} + + +const FreeForm2::MatchSubExpression* const* +FreeForm2::MatchOperatorExpression::GetChildren() const +{ + return m_children; +} + + +FreeForm2::MatchOperatorExpression::MatchOperatorExpression(const Annotations& p_annotations, + const MatchSubExpression** p_children, + size_t p_numChildren, + Operator p_op) + : MatchSubExpression(p_annotations), + m_numChildren(p_numChildren), + m_op(p_op) +{ + for (size_t i = 0; i < p_numChildren; i++) + { + m_children[i] = p_children[i]; + } +} + + +void FreeForm2::MatchOperatorExpression::DeleteAlloc(MatchOperatorExpression* p_allocated) +{ + // Manually call dtor for operator expression. + p_allocated->~MatchOperatorExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + + +FreeForm2::MatchGuardExpression::MatchGuardExpression(const Annotations& p_annotations, + const Expression& p_guard) + : MatchSubExpression(p_annotations), + m_guard(p_guard) +{ +} + + +void +FreeForm2::MatchGuardExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_guard.Accept(p_visitor); + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::MatchGuardExpression::GetType() const +{ + if (m_guard.GetType().Primitive() != Type::Bool) + { + std::ostringstream err; + err << "Guard expression must evaluate to a boolean (evalutes to " + << m_guard.GetType() << ")"; + throw ParseError(err.str(), GetSourceLocation()); + } + + return TypeImpl::GetBoolInstance(true); +} + + +size_t +FreeForm2::MatchGuardExpression::GetNumChildren() const +{ + return 1; +} + + +FreeForm2::MatchSubExpression::Info +FreeForm2::MatchGuardExpression::GetInfo() const +{ + return Info(0, 0); +} + + +FreeForm2::MatchBindExpression::MatchBindExpression(const Annotations& p_annotations, + const MatchSubExpression& p_value, + VariableID p_id) + : MatchSubExpression(p_annotations), + m_value(p_value), + m_id(p_id) +{ +} + + +void +FreeForm2::MatchBindExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_value.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::MatchBindExpression::GetType() const +{ + return m_value.GetType(); +} + + +size_t +FreeForm2::MatchBindExpression::GetNumChildren() const +{ + return 1; +} + + +FreeForm2::MatchSubExpression::Info +FreeForm2::MatchBindExpression::GetInfo() const +{ + return m_value.GetInfo(); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Match.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Match.h new file mode 100644 index 000000000000..f57422b8ed92 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Match.h @@ -0,0 +1,164 @@ +#pragma once + +#ifndef FREEFORM2_MATCH_H +#define FREEFORM2_MATCH_H + +#include "Expression.h" +#include "MatchSub.h" + +namespace FreeForm2 +{ + // A match expression represents a used-entered match statement that + // evaluates a pattern against a given value, and takes the given action if + // the pattern matches. + class MatchExpression : public Expression + { + public: + // Create a match expression from the value to match, the + // pattern, and the corresponding action. + MatchExpression(const Annotations& p_annotations, + const Expression& p_value, + const MatchSubExpression& p_pattern, + const Expression& p_action, + bool p_isOverlapping); + + // Virtual methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + const Expression& GetValue() const; + const MatchSubExpression& GetPattern() const; + const Expression& GetAction() const; + bool IsOverlapping() const; + + private: + // Value to be matched against. + const Expression& m_value; + + // Pattern to match. + const MatchSubExpression& m_pattern; + + // Action to take. + const Expression& m_action; + + // Flag to indicate whether or not the match should overlap. + bool m_isOverlapping; + }; + + // A MatchOperatorExpression combines matching constraints with a variety of + // operators. We translate these to finite-state matchines before issuing + // code, so the operations aren't very distinct in the syntax tree. + class MatchOperatorExpression : public MatchSubExpression + { + public: + // Enumeration of different match operations. + enum Operator + { + // Kleene star, unbounded repetition. + kleene, + + // '+' operation, which matches at least one repetition. + atLeastOne, + + // Alternation, allowing any of a given set of matching constraints. + alternation, + + // Concatenation, matching constraints in sequence. + concatenation, + + invalid + }; + + // Methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Methods inherited from MatchSubExpression. + virtual Info GetInfo() const override; + + // Methods to allocate MatchOperatorExpression objects. + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const MatchSubExpression** p_children, + size_t p_numChildren, + Operator p_op); + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const MatchSubExpression& p_left, + const MatchSubExpression& p_right, + Operator p_op); + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const MatchSubExpression& p_expr, + Operator p_op); + + // Get the operator represented by this expression. + Operator GetOperator() const; + + // Get the children of this expression. + const MatchSubExpression* const* GetChildren() const; + + private: + MatchOperatorExpression(const Annotations& p_annotations, + const MatchSubExpression** p_children, + size_t p_numChildren, + Operator p_op); + + static void DeleteAlloc(MatchOperatorExpression* p_allocated); + + // Matching operator used. + const Operator m_op; + + // Number of children of this node. + size_t m_numChildren; + + // Array of children of this node, allocated using struct hack. + const MatchSubExpression* m_children[1]; + }; + + // A MatchGuardExpression represents guarding of a pattern by an arbitrary + // statement evaluating to a bool + class MatchGuardExpression : public MatchSubExpression + { + public: + MatchGuardExpression(const Annotations& p_annotations, + const Expression& p_guard); + + // Methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Methods inherited from MatchSubExpression. + virtual Info GetInfo() const override; + + const Expression& m_guard; + }; + + // Expression to allow binding of variables during matching. Note that we + // need to handle this separately from normal binding, due to extensive + // state machine transformations. + class MatchBindExpression : public MatchSubExpression + { + public: + MatchBindExpression(const Annotations& p_annotations, + const MatchSubExpression& p_value, + VariableID p_id); + + // Virtual methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Methods inherited from MatchSubExpression. + virtual Info GetInfo() const override; + + const MatchSubExpression& m_value; + + const VariableID m_id; + }; +}; + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/MatchSub.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/MatchSub.h new file mode 100644 index 000000000000..4e38a57796ea --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/MatchSub.h @@ -0,0 +1,54 @@ +#pragma once + +#ifndef FREEFORM2_MATCH_SUB_H +#define FREEFORM2_MATCH_SUB_H + +#include "Expression.h" +#include "FreeForm2Assert.h" + +namespace FreeForm2 +{ + // A match sub- expression is a base class representing match + // operators (that is, repetition, concatenation, etc). + class MatchSubExpression : public Expression + { + public: + struct Info + { + Info(unsigned int p_minLength, unsigned int p_maxLength) + : m_minLength(p_minLength), m_maxLength(p_maxLength) + { + FF2_ASSERT(m_minLength <= m_maxLength); + FF2_ASSERT(m_minLength != c_indeterminate); + } + + // Calculated minimum limit on the length of matches from this FSM. + // Must be less than or equal to m_maxLength, and not c_indeterminate. + unsigned int m_minLength; + + // Calculated maximum limit on the length of matches from this FSM. + // Will be c_indeterminate if there's no easily calculable limit on the + // length of match using this FSM. + unsigned int m_maxLength; + + // Constant indicating that a pattern matches arbitrarily long input. + static const unsigned int c_indeterminate = UINT_MAX; + }; + + + MatchSubExpression(const Annotations& p_annotations) + : Expression(p_annotations) + { + } + + virtual ~MatchSubExpression() + { + } + + // Calculate information for this sub-expression. + virtual Info GetInfo() const = 0; + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/MemberAccessExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/MemberAccessExpression.cpp new file mode 100644 index 000000000000..37de13e49978 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/MemberAccessExpression.cpp @@ -0,0 +1,204 @@ +#include "MemberAccessExpression.h" + +#include "FreeForm2Assert.h" +#include "SimpleExpressionOwner.h" +#include "StateMachine.h" +#include "StateMachineType.h" +#include "Mutation.h" +#include "Visitor.h" +#include "TypeUtil.h" +#include + +namespace +{ + const FreeForm2::Expression& GetInitializer(const FreeForm2::StateMachineExpression& p_machine, + const FreeForm2::CompoundType::Member& p_memberInfo) + { + const FreeForm2::TypeInitializerExpression& expr = p_machine.GetInitializer(); + for (const FreeForm2::TypeInitializerExpression::Initializer* iter = expr.BeginInitializers(); + iter != expr.EndInitializers(); + ++iter) + { + if (&p_memberInfo == iter->m_member) + { + return *iter->m_initializer; + } + } + FreeForm2::Unreachable(__FILE__, __LINE__); + } +} + + +FreeForm2::MemberAccessExpression::MemberAccessExpression(const Annotations& p_annotations, + const Expression& p_struct, + const CompoundType::Member& p_memberInfo, + size_t p_version) + : Expression(Annotations(p_annotations.m_sourceLocation, ValueBounds(*p_memberInfo.m_type))), + m_struct(p_struct), + m_memberInfo(p_memberInfo), + m_version(p_version) +{ + FF2_ASSERT(CompoundType::IsCompoundType(m_struct.GetType())); + const CompoundType& compoundType = static_cast(m_struct.GetType()); + FF2_ASSERT(compoundType.FindMember(p_memberInfo.m_name) != NULL); +} + + +const FreeForm2::TypeImpl& +FreeForm2::MemberAccessExpression::GetType() const +{ + return *m_memberInfo.m_type; +} + + +size_t +FreeForm2::MemberAccessExpression::GetNumChildren() const +{ + return 1; +} + + +void +FreeForm2::MemberAccessExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_struct.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +void +FreeForm2::MemberAccessExpression::AcceptReference(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + m_struct.AcceptReference(p_visitor); + + p_visitor.VisitReference(*this); + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +bool +FreeForm2::MemberAccessExpression::IsConstant() const +{ + if (m_struct.GetType().Primitive() == Type::StateMachine) + { + const StateMachineType& machine = static_cast(m_struct.GetType()); + boost::shared_ptr ptr = machine.GetDefinition(); + return ptr != nullptr && m_memberInfo.m_type->IsConst() + && GetInitializer(*ptr, m_memberInfo).IsConstant(); + } + else + { + return false; + } +} + + +FreeForm2::ConstantValue +FreeForm2::MemberAccessExpression::GetConstantValue() const +{ + FF2_ASSERT(m_struct.GetType().Primitive() == Type::StateMachine); + const StateMachineType& machine = static_cast(m_struct.GetType()); + boost::shared_ptr ptr = machine.GetDefinition(); + return GetInitializer(*ptr, m_memberInfo).GetConstantValue(); +} + + +const FreeForm2::Expression& +FreeForm2::MemberAccessExpression::GetStruct() const +{ + return m_struct; +} + + +const FreeForm2::CompoundType::Member& +FreeForm2::MemberAccessExpression::GetMemberInfo() const +{ + return m_memberInfo; +} + + +size_t +FreeForm2::MemberAccessExpression::GetVersion() const +{ + return m_version; +} + + +FreeForm2::UnresolvedAccessExpression::UnresolvedAccessExpression(const Annotations& p_annotations, + const Expression& p_object, + const std::string& p_memberName, + const TypeImpl& p_expectedType) + : Expression(p_annotations), + m_object(p_object), + m_memberName(p_memberName), + m_expectedType(p_expectedType) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::UnresolvedAccessExpression::GetType() const +{ + return m_expectedType; +} + + +size_t +FreeForm2::UnresolvedAccessExpression::GetNumChildren() const +{ + return 1; +} + + +void +FreeForm2::UnresolvedAccessExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_object.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +void +FreeForm2::UnresolvedAccessExpression::AcceptReference(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + m_object.AcceptReference(p_visitor); + + p_visitor.VisitReference(*this); + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::Expression& +FreeForm2::UnresolvedAccessExpression::GetObject() const +{ + return m_object; +} + + +const std::string& +FreeForm2::UnresolvedAccessExpression::GetMemberName() const +{ + return m_memberName; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/MemberAccessExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/MemberAccessExpression.h new file mode 100644 index 000000000000..47fd1729bdda --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/MemberAccessExpression.h @@ -0,0 +1,75 @@ +#pragma once + +#ifndef FREEFORM2_MEMBER_ACCESS_EXPRESSION_H +#define FREEFORM2_MEMBER_ACCESS_EXPRESSION_H + +#include "Expression.h" +#include "CompoundType.h" + +namespace FreeForm2 +{ + // An array-dereference expression removes a dimension from an array. + class MemberAccessExpression : public Expression + { + public: + MemberAccessExpression(const Annotations& p_annotations, + const Expression& p_struct, + const CompoundType::Member& p_memberInfo, + size_t p_version); + + // Methods inherited from Expression. + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual void AcceptReference(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual bool IsConstant() const override; + virtual ConstantValue GetConstantValue() const override; + + const Expression& GetStruct() const; + const CompoundType::Member& GetMemberInfo() const; + size_t GetVersion() const; + + private: + // Struct expression whose member will be accessed. + const Expression& m_struct; + + // Member information. + const CompoundType::Member& m_memberInfo; + + // A unique version number associated with a particular + // value for this variable. + const size_t m_version; + }; + + // An unresolved member to be accessed. + class UnresolvedAccessExpression : public Expression + { + public: + UnresolvedAccessExpression(const Annotations& p_annotations, + const Expression& p_object, + const std::string& p_memberName, + const TypeImpl& p_expectedType); + + // Methods inherited from Expression. + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual void AcceptReference(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + + const Expression& GetObject() const; + const std::string& GetMemberName() const; + + private: + // Struct expression whose member will be accessed. + const Expression& m_object; + + // Member information. + std::string m_memberName; + + // The expected type of the member access expression. + const TypeImpl& m_expectedType; + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Mutation.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Mutation.cpp new file mode 100644 index 000000000000..7eda0e46f9a9 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Mutation.cpp @@ -0,0 +1,248 @@ +#include "Mutation.h" + +#include +#include "FreeForm2Assert.h" +#include +#include +#include "StateMachineType.h" +#include "StructType.h" +#include "TypeUtil.h" +#include "Visitor.h" + +namespace +{ + void DeleteAlloc(FreeForm2::TypeInitializerExpression* p_delete) + { + // Explicitly invoke the destructor. + p_delete->~TypeInitializerExpression(); + + // Cast the char*, as the memory was allocated as char[]. + char* mem = reinterpret_cast(p_delete); + delete[] mem; + } +} + +FreeForm2::MutationExpression::MutationExpression(const Annotations& p_annotations, + const Expression& p_lvalue, + const Expression& p_rvalue) + : Expression(p_annotations), + m_lvalue(p_lvalue), + m_rvalue(p_rvalue) +{ +} + + +void +FreeForm2::MutationExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + try + { + m_lvalue.AcceptReference(p_visitor); + } + catch (const std::exception&) + { + std::ostringstream err; + err << "Invalid l-value in mutation expression"; + throw ParseError(err.str(), GetSourceLocation()); + } + + m_rvalue.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::MutationExpression::GetType() const +{ + const TypeImpl& left = m_lvalue.GetType(); + const TypeImpl& right = m_rvalue.GetType(); + + if (!TypeUtil::IsAssignable(left, right)) + { + std::ostringstream err; + err << "Mismatched types in assignment (" << left << " and " << right << ")"; + throw ParseError(err.str(), GetSourceLocation()); + } + + if (left.Primitive() == Type::Array) + { + std::ostringstream err; + err << "Can't assign types that are not of fixed size (such as arrays)"; + throw ParseError(err.str(), GetSourceLocation()); + } + + if (left.IsConst()) + { + std::ostringstream err; + err << "Can't assign to constant types"; + throw ParseError(err.str(), GetSourceLocation()); + } + + return TypeImpl::GetVoidInstance(); +} + + +size_t FreeForm2::MutationExpression::GetNumChildren() const +{ + return 2; +} + + +const FreeForm2::Expression& +FreeForm2::MutationExpression::GetLeftValue() const +{ + return m_lvalue; +} + + +const FreeForm2::Expression& +FreeForm2::MutationExpression::GetRightValue() const +{ + return m_rvalue; +} + + +FreeForm2::TypeInitializerExpression::TypeInitializerExpression( + const Annotations& p_annotations, + const CompoundType& p_type, + const Initializer* p_initializers, + size_t p_numInitializers) + : Expression(p_annotations), + m_type(p_type), + m_numInitializers(p_numInitializers) +{ + memcpy(m_initializers, p_initializers, sizeof(Initializer) * m_numInitializers); + ValidateMembers(); +} + + +boost::shared_ptr +FreeForm2::TypeInitializerExpression::Alloc(const Annotations& p_annotations, + const CompoundType& p_type, + const Initializer* p_initializers, + size_t p_numInitializers) +{ + const size_t memSize = sizeof(TypeInitializerExpression) + + sizeof(Initializer) * (std::max(p_numInitializers, (size_t) 1ULL) - 1); + char* mem = NULL; + try + { + mem = new char[memSize]; + return boost::shared_ptr( + new (mem) TypeInitializerExpression(p_annotations, p_type, p_initializers, p_numInitializers)); + } + catch (...) + { + delete[] mem; + throw; + } +} + + +void +FreeForm2::TypeInitializerExpression::ValidateMembers() const +{ + // Collect all member names in the type into a set. + std::set names; + if (m_type.Primitive() == Type::Struct) + { + const StructType& type = static_cast(m_type); + BOOST_FOREACH (const StructType::MemberInfo& member, type.GetMembers()) + { + names.insert(member.m_name); + } + } + else + { + FF2_ASSERT(m_type.Primitive() == Type::StateMachine); + const StateMachineType& type = static_cast(m_type); + for (const StructType::Member* iter = type.BeginMembers(); iter != type.EndMembers(); ++iter) + { + names.insert(iter->m_name); + } + } + + // Search for names not being initialized. + for (const Initializer* iter = BeginInitializers(); iter != EndInitializers(); ++iter) + { + const TypeImpl& memberType = *iter->m_member->m_type; + const TypeImpl& initType = iter->m_initializer->GetType(); + if (!TypeUtil::IsAssignable(memberType, initType)) + { + std::ostringstream err; + err << "Mismatched types in initializer (" << memberType << " and " << initType << ")"; + throw ParseError(err.str(), GetSourceLocation()); + } + + FF2_ASSERT(iter != NULL && iter->m_member != NULL); + std::set::iterator find = names.find(iter->m_member->m_name); + FF2_ASSERT(find != names.end()); + names.erase(find); + } + + if (!names.empty()) + { + std::ostringstream err; + err << "all members must be initialized; missing: "; + BOOST_FOREACH (const std::string& name, names) + { + err << name << " "; + } + throw ParseError(err.str(), GetSourceLocation()); + } +} + + +void +FreeForm2::TypeInitializerExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (const Initializer* iter = BeginInitializers(); iter != EndInitializers(); ++iter) + { + iter->m_initializer->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeInitializerExpression::GetType() const +{ + return m_type; +} + + +size_t +FreeForm2::TypeInitializerExpression::GetNumChildren() const +{ + return m_numInitializers; +} + + +const FreeForm2::TypeInitializerExpression::Initializer* +FreeForm2::TypeInitializerExpression::BeginInitializers() const +{ + return m_initializers; +} + + +const FreeForm2::TypeInitializerExpression::Initializer* +FreeForm2::TypeInitializerExpression::EndInitializers() const +{ + return m_initializers + m_numInitializers; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Mutation.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Mutation.h new file mode 100644 index 000000000000..d0e9367eb2da --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Mutation.h @@ -0,0 +1,87 @@ +#pragma once + +#ifndef FREEFORM2_MUTATION_H +#define FREEFORM2_MUTATION_H + +#include "CompoundType.h" +#include "Expression.h" + +namespace FreeForm2 +{ + class MutationExpression : public Expression + { + public: + // Create a mutation expression from a type and an initialiser. + MutationExpression(const Annotations& p_annotations, + const Expression& p_lvalue, + const Expression& p_rvalue); + + // Virtual methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Get the l-value and r-value expressions. + const Expression& GetLeftValue() const; + const Expression& GetRightValue() const; + + private: + // Expression to be mutated. + const Expression& m_lvalue; + + // Value assigned to l-value after mutation. + const Expression& m_rvalue; + }; + + class TypeInitializerExpression : public Expression + { + public: + // This struct represents a member-initializer pair. Note that + // std::pair is not used because it is a non-POD type, and doesn't + // work with the struct hack. + struct Initializer + { + const CompoundType::Member* m_member; + const Expression* m_initializer; + size_t m_version; + }; + + // Allocate a new TypeInitializerExpression. + static boost::shared_ptr Alloc(const Annotations& p_annotations, + const CompoundType& p_type, + const Initializer* p_initializers, + size_t p_numInitializers); + + // Virtual methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Access initializers. + const Initializer* BeginInitializers() const; + const Initializer* EndInitializers() const; + + private: + // Create a type initializer with member-expression pairs to initialize + // each member. All members in the CompoundType must be specified. + TypeInitializerExpression(const Annotations& p_annotations, + const CompoundType& p_type, + const Initializer* p_initializers, + size_t p_initializerCount); + + // Validate that all members in the CompoundType are initialized. + void ValidateMembers() const; + + // The type being initialized. + const CompoundType& m_type; + + // The number of initializers. + size_t m_numInitializers; + + // Initializer list, allocated using the struct hack. + Initializer m_initializers[1]; + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/OperatorExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/OperatorExpression.cpp new file mode 100644 index 000000000000..3ae8e128a927 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/OperatorExpression.cpp @@ -0,0 +1,265 @@ +#include "OperatorExpression.h" + +#include "FreeForm2Assert.h" +#include "BinaryOperator.h" +#include "UnaryOperator.h" +#include "Visitor.h" +#include "TypeUtil.h" +#include + + +FreeForm2::UnaryOperatorExpression::UnaryOperatorExpression(const Annotations& p_annotations, + const Expression& p_child, + UnaryOperator::Operation p_op) + : Expression(p_annotations), + m_child(p_child), + m_op(p_op), + m_type(InferType()) +{ +} + + +void +FreeForm2::UnaryOperatorExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_child.Accept(p_visitor); + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::UnaryOperatorExpression::GetType() const +{ + return m_type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::UnaryOperatorExpression::InferType() const +{ + const TypeImpl& operandType = UnaryOperator::GetBestOperandType(m_op, m_child.GetType()); + if (!operandType.IsValid()) + { + std::ostringstream err; + err << "Invalid operand type " << m_child.GetType() + << " supplied to operator"; + throw ParseError(err.str(), GetSourceLocation()); + } + + const TypeImpl& type = UnaryOperator::GetReturnType(m_op, operandType); + FF2_ASSERT(type.IsValid()); + return type.AsConstType(); +} + + +size_t +FreeForm2::UnaryOperatorExpression::GetNumChildren() const +{ + return 1; +} + + +FreeForm2::BinaryOperatorExpression::BinaryOperatorExpression(const Annotations& p_annotations, + const std::vector& p_children, + const BinaryOperator::Operation p_binaryOp, + TypeManager& p_typeManager) + : Expression(p_annotations), + m_resultType(NULL), + m_childType(NULL), + m_numChildren(p_children.size()), + m_binaryOp(p_binaryOp) +{ + // Note that we rely on this ctor not throwing exceptions during + // allocation below. + + // We rely on our allocator to size this object to be big enough to + // hold all children, and enforce this forcing construction via Alloc. + FF2_ASSERT(m_numChildren >= 2); + for (size_t i = 0; i < p_children.size(); i++) + { + m_children[i] = p_children[i]; + } + + // Infer the child and return types. + m_childType = &InferChildType(p_typeManager); + m_resultType = &BinaryOperator::GetResultType(m_binaryOp, *m_childType); + + m_valueBounds = ValueBounds(*m_resultType); +} + + +FreeForm2::BinaryOperatorExpression::BinaryOperatorExpression(const Annotations& p_annotations, + const Expression& p_leftChild, + const Expression& p_rightChild, + const BinaryOperator::Operation p_binaryOp, + TypeManager& p_typeManager) + : Expression(p_annotations), + m_resultType(NULL), + m_childType(NULL), + m_numChildren(2), + m_binaryOp(p_binaryOp) +{ + // We rely on our allocator to size this object to be big enough to + // hold all children, and enforce this forcing construction via Alloc. + m_children[0] = &p_leftChild; + m_children[1] = &p_rightChild; + + // Infer the child and return types. + m_childType = &InferChildType(p_typeManager); + m_resultType = &BinaryOperator::GetResultType(m_binaryOp, *m_childType); +} + + +void +FreeForm2::BinaryOperatorExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (size_t i = 0; i < m_numChildren; i++) + { + m_children[i]->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::BinaryOperatorExpression::InferChildType(TypeManager& p_typeManager) const +{ + const TypeImpl* unifiedType(&TypeImpl::GetUnknownType()); + + for (size_t i = 0; i < m_numChildren; i++) + { + const TypeImpl& type + = TypeUtil::Unify(m_children[i]->GetType(), *unifiedType, p_typeManager, false, true); + + if (!type.IsValid()) + { + std::ostringstream err; + err << "Arguments to binary operators (except index) " + << "are expected to be of a unifiable type. The first " + << i << " elements are of type '" << *unifiedType + << "', element " << i + 1 << " is of type '" + << m_children[i]->GetType() << "'"; + throw ParseError(err.str(), GetSourceLocation()); + } + unifiedType = &type; + } + + if (!unifiedType->IsLeafType() && unifiedType->Primitive() != Type::Unknown) + { + std::ostringstream err; + err << "Expected fixed-size type; got type: " + << unifiedType; + throw ParseError(err.str(), GetSourceLocation()); + } + + const TypeImpl& childType = BinaryOperator::GetBestOperandType(m_binaryOp, *unifiedType); + + if (!childType.IsValid()) + { + std::ostringstream err; + err << "Invalid operand type " << *unifiedType + << " supplied to operator"; + throw ParseError(err.str(), GetSourceLocation()); + } + + return childType; +} + + +const FreeForm2::TypeImpl& +FreeForm2::BinaryOperatorExpression::GetType() const +{ + return *m_resultType; +} + + +const FreeForm2::TypeImpl& +FreeForm2::BinaryOperatorExpression::GetChildType() const +{ + return *m_childType; +} + + +size_t +FreeForm2::BinaryOperatorExpression::GetNumChildren() const +{ + return m_numChildren; +} + + +const FreeForm2::Expression* const* +FreeForm2::BinaryOperatorExpression::GetChildren() const +{ + return m_children; +} + + +FreeForm2::BinaryOperator::Operation +FreeForm2::BinaryOperatorExpression::GetOperator() const +{ + return m_binaryOp; +} + + +boost::shared_ptr +FreeForm2::BinaryOperatorExpression::Alloc(const Annotations& p_annotations, + const std::vector& p_children, + const BinaryOperator::Operation p_binaryOp, + TypeManager& p_typeManager) +{ + size_t bytes = sizeof(BinaryOperatorExpression) + + (p_children.size() - 1) * sizeof(Expression*); + + // Allocate a shared_ptr that deletes an BinaryOperatorExpression + // allocated in a char[]. + boost::shared_ptr exp(new (new char[bytes]) + BinaryOperatorExpression(p_annotations, p_children, p_binaryOp, p_typeManager), DeleteAlloc); + return exp; +} + + +boost::shared_ptr +FreeForm2::BinaryOperatorExpression::Alloc(const Annotations& p_annotations, + const Expression& p_leftChild, + const Expression& p_rightChild, + const BinaryOperator::Operation p_binaryOp, + TypeManager& p_typeManager) +{ + size_t bytes = sizeof(BinaryOperatorExpression) + sizeof(Expression*); + + // Allocate a shared_ptr that deletes an BinaryOperatorExpression + // allocated in a char[]. + boost::shared_ptr exp(new (new char[bytes]) + BinaryOperatorExpression(p_annotations, p_leftChild, p_rightChild, p_binaryOp, p_typeManager), + DeleteAlloc); + return exp; +} + + +void +FreeForm2::BinaryOperatorExpression::DeleteAlloc(BinaryOperatorExpression* p_allocated) +{ + // Manually call dtor for operator expression. + p_allocated->~BinaryOperatorExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/OperatorExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/OperatorExpression.h new file mode 100644 index 000000000000..561a227d168c --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/OperatorExpression.h @@ -0,0 +1,120 @@ +#pragma once + +#ifndef FREEFORM2_OPERATOR_EXPRESSION_H +#define FREEFORM2_OPERATOR_EXPRESSION_H + +#include "BinaryOperator.h" +#include "Expression.h" +#include "FreeForm2Type.h" +#include "UnaryOperator.h" +#include + +namespace FreeForm2 +{ + class BinaryOperator; + class TypeManager; + + class UnaryOperatorExpression : public Expression + { + public: + UnaryOperatorExpression(const Annotations& p_annotations, + const Expression& p_child, + UnaryOperator::Operation p_op, + ValueBounds p_valueBounds); + UnaryOperatorExpression(const Annotations& p_annotations, + const Expression& p_child, + UnaryOperator::Operation p_op); + + virtual void Accept(Visitor& p_visitor) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + // Unary operator used in this expression. + const UnaryOperator::Operation m_op; + + // Child of this node. + const Expression& m_child; + + private: + // Infer the resulting type of the operator expression. + const TypeImpl& InferType() const; + + // The stored result type of the operator expression. + const TypeImpl& m_type; + }; + + class BinaryOperatorExpression : public Expression + { + public: + void Accept(Visitor& p_visitor) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + const TypeImpl& GetChildType() const; + const Expression* const* GetChildren() const; + + BinaryOperator::Operation GetOperator() const; + + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const std::vector& p_children, + const BinaryOperator::Operation p_binaryOp, + TypeManager& p_typeManager); + + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const std::vector& p_children, + const BinaryOperator::Operation p_binaryOp, + TypeManager& p_typeManager, + ValueBounds p_valueBounds); + + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const Expression& p_leftChild, + const Expression& p_rightChild, + const BinaryOperator::Operation p_binaryOp, + TypeManager& p_typeManager); + + private: + // Constructors are private, call Alloc instead. + BinaryOperatorExpression(const Annotations& p_annotations, + const std::vector& p_children, + const BinaryOperator::Operation p_binaryOp, + TypeManager& p_typeManager); + BinaryOperatorExpression(const Annotations& p_annotations, + const std::vector& p_children, + const BinaryOperator::Operation p_binaryOp, + TypeManager& p_typeManager, + ValueBounds p_valueBounds); + BinaryOperatorExpression(const Annotations& p_annotations, + const Expression& p_leftChild, + const Expression& p_rightChild, + const BinaryOperator::Operation p_binaryOp, + TypeManager& p_typeManager); + + // Infer the resulting type of this operator expression. + const TypeImpl* m_resultType; + const TypeImpl* m_childType; + + // Infer the type that children of this expression need to be. + const TypeImpl& InferChildType(TypeManager& p_typeManager) const; + + // Binary operator used to compile arithmetic. + const BinaryOperator::Operation m_binaryOp; + + // Number of children of this node. + size_t m_numChildren; + + // The statically calculated bounds of the values this expression + // can take. + ValueBounds m_valueBounds; + + // Array of children of this node, allocated using struct hack. + const Expression* m_children[1]; + + static void DeleteAlloc(BinaryOperatorExpression* p_allocated); + }; +} + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/PhiNode.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/PhiNode.cpp new file mode 100644 index 000000000000..7dc664632662 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/PhiNode.cpp @@ -0,0 +1,103 @@ +#include "PhiNode.h" + +#include "FreeForm2Assert.h" +#include "Visitor.h" + +boost::shared_ptr +FreeForm2::PhiNodeExpression::Alloc(const Annotations& p_annotations, + size_t p_version, + size_t p_incomingVersionsCount, + const size_t* p_incomingVersions) +{ + FF2_ASSERT(p_incomingVersionsCount > 0); + + size_t bytes = sizeof(PhiNodeExpression) + (p_incomingVersionsCount - 1) * sizeof(unsigned long long); + + // Allocate a shared_ptr that deletes an BlockExpression + // allocated in a char[]. + boost::shared_ptr + exp(new (new char[bytes]) PhiNodeExpression(p_annotations, + p_version, + p_incomingVersionsCount, + p_incomingVersions), + DeleteAlloc); + return exp; +} + + +const FreeForm2::TypeImpl& +FreeForm2::PhiNodeExpression::GetType() const +{ + return FreeForm2::TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::PhiNodeExpression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::PhiNodeExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +size_t +FreeForm2::PhiNodeExpression::GetVersion() const +{ + return m_version; +} + + +size_t +FreeForm2::PhiNodeExpression::GetIncomingVersionsCount() const +{ + return m_incomingVersionsCount; +} + + +const size_t* +FreeForm2::PhiNodeExpression::GetIncomingVersions() const +{ + return m_incomingVersions; +} + + +FreeForm2::PhiNodeExpression::PhiNodeExpression(const Annotations& p_annotations, + size_t p_version, + size_t p_incomingVersionsCount, + const size_t* p_incomingVersions) + : Expression(p_annotations), + m_version(p_version), + m_incomingVersionsCount(p_incomingVersionsCount) +{ + // We rely on the custom allocator Alloc to provide enough space + // for all of the incomings. + for (unsigned int i = 0; i < m_incomingVersionsCount; i++) + { + m_incomingVersions[i] = p_incomingVersions[i]; + } +} + + +void +FreeForm2::PhiNodeExpression::DeleteAlloc(PhiNodeExpression* p_allocated) +{ + // Manually call dtor for phi node expression. + p_allocated->~PhiNodeExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/PhiNode.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/PhiNode.h new file mode 100644 index 000000000000..7281377d6c5c --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/PhiNode.h @@ -0,0 +1,48 @@ +#pragma once + +#include "Expression.h" + +namespace FreeForm2 +{ + // The Phi node marks places in the code where the value of a variable can + // change in different branches of the code. This node should be ignored + // by the backends since it doesn't affect the compiled output. + // + // The incoming array refers to the list of variable versions that can reach + // this point of the code. + class PhiNodeExpression : public Expression + { + public: + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + size_t p_version, + size_t p_incomingVersionsCount, + const size_t* p_incomingVersions); + + // Methods inherited from Expression + virtual void Accept(Visitor&) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + // Getter methods. + size_t GetVersion() const; + size_t GetIncomingVersionsCount() const; + const size_t* GetIncomingVersions() const; + + private: + // Create a PhiNode expression. + PhiNodeExpression(const Annotations& p_annotations, + size_t p_version, + size_t p_incomingVersionsCount, + const size_t* p_incomingVersions); + + static void DeleteAlloc(PhiNodeExpression* p_allocated); + + size_t m_version; + + size_t m_incomingVersionsCount; + + // Allocated using the struct hack. + size_t m_incomingVersions[1]; + }; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Publish.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Publish.cpp new file mode 100644 index 000000000000..4807ea8f6843 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Publish.cpp @@ -0,0 +1,171 @@ +#include "Publish.h" +#include "Visitor.h" +#include "FreeForm2Assert.h" +#include +#include "TypeUtil.h" + +FreeForm2::PublishExpression::PublishExpression(const Annotations& p_annotations, + const std::string& p_featureName, + const Expression& p_value) + : Expression(p_annotations), + m_featureName(p_featureName), + m_value(p_value) +{ +} + + +void +FreeForm2::PublishExpression::Accept(Visitor& p_visitor) const +{ + if (!p_visitor.AlternativeVisit(*this)) + { + m_value.Accept(p_visitor); + + p_visitor.Visit(*this); + } +} + + +size_t +FreeForm2::PublishExpression::GetNumChildren() const +{ + return 1; +} + + +const FreeForm2::TypeImpl& +FreeForm2::PublishExpression::GetType() const +{ + return FreeForm2::TypeImpl::GetVoidInstance(); +} + + +const FreeForm2::Expression& +FreeForm2::PublishExpression::GetValue() const +{ + return m_value; +} + + +const std::string& +FreeForm2::PublishExpression::GetFeatureName() const +{ + return m_featureName; +} + + +FreeForm2::DirectPublishExpression::DirectPublishExpression(const Annotations& p_annotations, + const std::string& p_featureName, + const Expression** p_indices, + const unsigned int p_numIndices, + const Expression& p_value) + : Expression(p_annotations), + m_featureName(p_featureName), + m_numIndices(p_numIndices), + m_value(p_value) +{ + for (size_t i = 0; i < m_numIndices; i++) + { + m_indices[i] = p_indices[i]; + } +} + + +boost::shared_ptr +FreeForm2::DirectPublishExpression::Alloc(const Annotations& p_annotations, + const std::string& p_featureName, + const Expression** p_indices, + const unsigned int p_numIndices, + const Expression& p_value) +{ + size_t bytes = sizeof(DirectPublishExpression) + + (p_numIndices - 1) * sizeof(Expression*); + + // Allocate a shared_ptr that deletes an DirectPublishExpression + // allocated in a char[]. + boost::shared_ptr exp(new (new char[bytes]) + DirectPublishExpression(p_annotations, p_featureName, p_indices, p_numIndices, p_value), DeleteAlloc); + return exp; +} + + +void +FreeForm2::DirectPublishExpression::DeleteAlloc(DirectPublishExpression* p_allocated) +{ + // Manually call dtor for expression. + p_allocated->~DirectPublishExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + + +void +FreeForm2::DirectPublishExpression::Accept(Visitor& p_visitor) const +{ + if (!p_visitor.AlternativeVisit(*this)) + { + for (unsigned int i = 0; i < m_numIndices; ++i) + { + m_indices[i]->Accept(p_visitor); + } + + m_value.Accept(p_visitor); + + p_visitor.Visit(*this); + } +} + + +size_t +FreeForm2::DirectPublishExpression::GetNumChildren() const +{ + return 1 + m_numIndices; +} + + +const FreeForm2::TypeImpl& +FreeForm2::DirectPublishExpression::GetType() const +{ + for (unsigned int i = 0; i < m_numIndices; ++i) + { + if (!m_indices[i]->GetType().IsIntegerType()) + { + std::ostringstream err; + err << "Index type in array publication is not an integer type " + << "instead, it is a " << m_indices[i]->GetType() << ")."; + throw ParseError(err.str(), m_indices[i]->GetSourceLocation()); + } + } + + return FreeForm2::TypeImpl::GetVoidInstance(); +} + + +const FreeForm2::Expression& +FreeForm2::DirectPublishExpression::GetValue() const +{ + return m_value; +} + + +const std::string& +FreeForm2::DirectPublishExpression::GetFeatureName() const +{ + return m_featureName; +} + + +unsigned int +FreeForm2::DirectPublishExpression::GetNumIndices() const +{ + return m_numIndices; +} + + +const FreeForm2::Expression* const * +FreeForm2::DirectPublishExpression::GetIndices() const +{ + return m_indices; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Publish.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Publish.h new file mode 100644 index 000000000000..12bec536a749 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Publish.h @@ -0,0 +1,93 @@ +#pragma once + +#ifndef FREEFORM2_PUBLISH_H +#define FREEFORM2_PUBLISH_H + +#include "Expression.h" + +namespace FreeForm2 +{ + // The Publish expression will declare the value of a feature. + class PublishExpression : public Expression + { + public: + // Create a Publish expression that declares the value of a feature. + PublishExpression(const Annotations& p_annotations, + const std::string& p_featureName, + const Expression& p_value); + + // Methods inherited from Expression + virtual void Accept(Visitor&) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + // Return the value of the expression. + const Expression& GetValue() const; + + // Return the feature name of the expression. + const std::string& GetFeatureName() const; + + private: + // The value of the feature being published. + const Expression& m_value; + + // The name of the feature being published. + const std::string m_featureName; + }; + + // The DirectPublish expression will declare the value of an element in a feature array. + class DirectPublishExpression : public Expression + { + public: + // Allocate and construct a new DirectPublishExpression. + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const std::string& p_featureName, + const Expression** p_indices, + const unsigned int p_numIndices, + const Expression& p_value); + + // Methods inherited from Expression + virtual void Accept(Visitor&) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + // Return the value of the expression. + const Expression& GetValue() const; + + // Return the feature name of the expression. + const std::string& GetFeatureName() const; + + // Return the number of indices for this array. + unsigned int GetNumIndices() const; + + // Return a pointer to the list of indices. + const Expression* const * GetIndices() const; + + private: + // Create a Publish expression that declares the value of a feature. + DirectPublishExpression(const Annotations& p_annotations, + const std::string& p_featureName, + const Expression** p_indices, + const unsigned int p_numIndices, + const Expression& p_value); + + // The value of the feature being published. + const Expression& m_value; + + // The name of the feature being published. + const std::string m_featureName; + + // The number of indices. + const unsigned int m_numIndices; + + // The destructor for the struct hack. + static void DeleteAlloc(DirectPublishExpression* p_allocated); + + // The indices of the array element to publish. + // Allocated using the struct hack. + const Expression* m_indices[1]; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RandExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RandExpression.cpp new file mode 100644 index 000000000000..d80ee3091047 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RandExpression.cpp @@ -0,0 +1,86 @@ +#include "RandExpression.h" + +#include "FreeForm2Assert.h" +#include "Visitor.h" + +FreeForm2::RandFloatExpression::RandFloatExpression(const Annotations& p_annotations) + : Expression(p_annotations) +{ +} + + +void +FreeForm2::RandFloatExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::RandFloatExpression::GetType() const +{ + return TypeImpl::GetFloatInstance(true); +} + + +size_t +FreeForm2::RandFloatExpression::GetNumChildren() const +{ + return 0; +} + + +const FreeForm2::RandFloatExpression& +FreeForm2::RandFloatExpression::GetInstance() +{ + static const Annotations s_annotations; + static const RandFloatExpression s_instance(s_annotations); + return s_instance; +} + + +FreeForm2::RandIntExpression::RandIntExpression(const Annotations& p_annotations, + const Expression& p_lowerBoundExpression, + const Expression& p_upperBoundExpression) + : Expression(p_annotations), + m_lowerBoundExpression(p_lowerBoundExpression), + m_upperBoundExpression(p_upperBoundExpression) +{ +} + + +void +FreeForm2::RandIntExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_lowerBoundExpression.Accept(p_visitor); + m_upperBoundExpression.Accept(p_visitor); + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::RandIntExpression::GetType() const +{ + return TypeImpl::GetIntInstance(true); +} + + +size_t +FreeForm2::RandIntExpression::GetNumChildren() const +{ + return 2; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RandExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RandExpression.h new file mode 100644 index 000000000000..6c39dd5818dd --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RandExpression.h @@ -0,0 +1,54 @@ +#pragma once + +#ifndef FREEFORM2_RAND_EXPRESSION_H +#define FREEFORM2_RAND_EXPRESSION_H + +#include "Expression.h" + +namespace FreeForm2 +{ + // The RandFloatExpression generates a random float + // in the range of 0 to 1 inclusive. + class RandFloatExpression : public Expression + { + public: + // Methods inherited from Expression. + void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Get a reference to the singleton for RandFloatExpression. + static const RandFloatExpression& GetInstance(); + + private: + // Private constructor for singleton class. + RandFloatExpression(const Annotations& p_annotations); + }; + + // The RandIntExpression generates a random integer + // in the range specified inclusive of the lower bound + // and upper bound exclusive. + class RandIntExpression: public Expression + { + public: + // Constructor for RandIntExpression. + RandIntExpression(const Annotations& p_annotations, + const Expression& p_lowerBoundExpression, + const Expression& p_upperBoundExpression); + + // Methods inherited from Expression. + void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + private: + + // Lower bound expression to generate a random number. + const Expression& m_lowerBoundExpression; + + // Upper bound expression to generate a random number. + const Expression& m_upperBoundExpression; + }; +}; + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RangeReduceExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RangeReduceExpression.cpp new file mode 100644 index 000000000000..30d76b940078 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RangeReduceExpression.cpp @@ -0,0 +1,615 @@ +#include "RangeReduceExpression.h" + +#include "BinaryOperator.h" +#include +#include +#include "Conditional.h" +#include "LiteralExpression.h" +#include "FreeForm2Assert.h" +#include "OperatorExpression.h" +#include "RefExpression.h" +#include "SimpleExpressionOwner.h" +#include "Visitor.h" +#include "TypeUtil.h" +#include "UnaryOperator.h" +#include + +namespace +{ + // This method creates the following precondition expression: + // (max(low, high) - abs(step) < max(low, high)) + // && (min(low, high) + abs(step) > min(low, high)) + // && ((step > 0) == (high > low)) + // && step != 0) + // and the following loop condition: + // (step > 0 ? loopVar <= high - step + // : loopVar >= high - step). + + boost::tuple + CreateGenericLoopConditions( + const std::pair& p_range, + const FreeForm2::Expression& p_step, + const FreeForm2::Expression& p_loopVar, + FreeForm2::SimpleExpressionOwner& p_owner, + FreeForm2::TypeManager& p_typeManager) + { + using namespace FreeForm2; + + // Common expressions. + auto zero = boost::make_shared(p_loopVar.GetAnnotations(), 0); + p_owner.AddExpression(zero); + auto stepSign + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), p_step, *zero, BinaryOperator::gt, p_typeManager); + p_owner.AddExpression(stepSign); + + // Create the precondition + const Expression* precondition = nullptr; + { + auto high + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *p_range.first, *p_range.second, BinaryOperator::max, p_typeManager); + p_owner.AddExpression(high); + auto low + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *p_range.first, *p_range.second, BinaryOperator::min, p_typeManager); + p_owner.AddExpression(low); + auto step + = boost::make_shared(p_loopVar.GetAnnotations(), p_step, UnaryOperator::abs); + p_owner.AddExpression(step); + auto highMinusStep + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *high, *step, BinaryOperator::minus, p_typeManager); + p_owner.AddExpression(highMinusStep); + auto lowPlusStep + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *low, *step, BinaryOperator::plus, p_typeManager); + p_owner.AddExpression(lowPlusStep); + auto underflowCheck + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *highMinusStep, *high, BinaryOperator::lt, p_typeManager); + p_owner.AddExpression(underflowCheck); + auto overflowCheck + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *lowPlusStep, *low, BinaryOperator::gt, p_typeManager); + p_owner.AddExpression(overflowCheck); + auto stepMoving + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), p_step, *zero, BinaryOperator::neq, p_typeManager); + p_owner.AddExpression(stepMoving); + auto rangeSign + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *p_range.second, *p_range.first, BinaryOperator::gt, p_typeManager); + p_owner.AddExpression(rangeSign); + auto rangeCheck + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *stepSign, *rangeSign, BinaryOperator::eq, p_typeManager); + p_owner.AddExpression(rangeCheck); + auto and1 + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *underflowCheck, *overflowCheck, BinaryOperator::_and, p_typeManager); + p_owner.AddExpression(and1); + auto and2 + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *and1, *rangeCheck, BinaryOperator::_and, p_typeManager); + p_owner.AddExpression(and2); + auto and3 + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *and2, *stepMoving, BinaryOperator::_and, p_typeManager); + p_owner.AddExpression(and3); + precondition = and3.get(); + } + + // Create the loop condition. + const Expression* condition = nullptr; + { + auto endMinusStep + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *p_range.second, p_step, BinaryOperator::minus, p_typeManager); + p_owner.AddExpression(endMinusStep); + auto incRange + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), p_loopVar, *endMinusStep, BinaryOperator::lte, p_typeManager); + p_owner.AddExpression(incRange); + auto decRange + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), p_loopVar, *endMinusStep, BinaryOperator::gte, p_typeManager); + p_owner.AddExpression(decRange); + auto cond = boost::make_shared(p_loopVar.GetAnnotations(), *stepSign, *incRange, *decRange); + p_owner.AddExpression(cond); + condition = cond.get(); + } + + return boost::make_tuple(precondition, condition); + } + + + boost::tuple + CreateConditionsForKnownStep( + const std::pair& p_range, + const FreeForm2::Expression& p_step, + const FreeForm2::Expression& p_loopVar, + FreeForm2::SimpleExpressionOwner& p_owner, + FreeForm2::TypeManager& p_typeManager) + { + using namespace FreeForm2; + FF2_ASSERT(p_step.IsConstant() && p_step.GetConstantValue().GetInt(p_step.GetType()) != 0); + const Result::IntType stepVal = p_step.GetConstantValue().GetInt(p_step.GetType()); + const bool isIncreasing = stepVal > 0; + + // For increasing ranges, create the expression: + // ((high - step < high) && (low + step > low) && high > low) + // For decreasing ranges, create the expression: + // ((high - step > high) && (low + step < low) && high < low) + const Expression* precondition = nullptr; + { + const Expression& low = *p_range.first; + const Expression& high = *p_range.second; + auto highMinusStep = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), high, p_step, BinaryOperator::minus, p_typeManager); + p_owner.AddExpression(highMinusStep); + auto lowPlusStep = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), low, p_step, BinaryOperator::plus, p_typeManager); + p_owner.AddExpression(lowPlusStep); + const BinaryOperator::Operation check1Op = isIncreasing ? BinaryOperator::lt : BinaryOperator::gt; + auto check1 = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *highMinusStep, high, check1Op, p_typeManager); + p_owner.AddExpression(check1); + const BinaryOperator::Operation check2Op = isIncreasing ? BinaryOperator::gt : BinaryOperator::lt; + auto check2 = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *lowPlusStep, low, check2Op, p_typeManager); + p_owner.AddExpression(check2); + const BinaryOperator::Operation rangeOp = isIncreasing ? BinaryOperator::gt : BinaryOperator::lt; + auto rangeCheck = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), high, low, rangeOp, p_typeManager); + p_owner.AddExpression(rangeCheck); + auto and1 = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *check1, *check2, BinaryOperator::_and, p_typeManager); + p_owner.AddExpression(and1); + auto and2 = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), *and1, *rangeCheck, BinaryOperator::_and, p_typeManager); + p_owner.AddExpression(and2); + precondition = and2.get(); + } + + // For increasing ranges, create the expression: loopVar <= high - step. + // For decreasing ranges, create the expression: loopVar >= high - step. + const Expression* condition = nullptr; + { + const Expression& high = *p_range.second; + auto highMinusStep + = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), high, p_step, BinaryOperator::minus, p_typeManager); + p_owner.AddExpression(highMinusStep); + const BinaryOperator::Operation compareOp = isIncreasing ? BinaryOperator::lte : BinaryOperator::gte; + auto compare = BinaryOperatorExpression::Alloc(p_loopVar.GetAnnotations(), p_loopVar, *highMinusStep, compareOp, p_typeManager); + p_owner.AddExpression(compare); + condition = compare.get(); + } + + return boost::make_tuple(precondition, condition); + } +} + + +FreeForm2::RangeReduceExpression::RangeReduceExpression( + const Annotations& p_annotations, + const Expression& p_low, + const Expression& p_high, + const Expression& p_initial, + const Expression& p_reduce, + VariableID p_stepId, + VariableID p_reduceId) + : Expression(p_annotations), + m_low(p_low), + m_high(p_high), + m_initial(p_initial), + m_reduce(p_reduce), + m_stepId(p_stepId), + m_reduceId(p_reduceId), + m_type(InferType()) +{ +} + +const FreeForm2::TypeImpl& +FreeForm2::RangeReduceExpression::InferType() const +{ + if (!m_low.GetType().IsIntegerType() || !m_high.GetType().IsIntegerType()) + { + std::ostringstream err; + err << "Expected low range and high range arguments to be of compatible integer types;" + << " got " << m_low.GetType() << ", " << m_high.GetType() + << " respectively."; + throw ParseError(err.str(), GetSourceLocation()); + } + + if (!(m_initial.GetType().IsSameAs(m_reduce.GetType(), true))) + { + std::ostringstream err; + err << "Expected initial reduction argument to range-reduce to " + "be of the same type as the reduction expression. Got " + << m_initial.GetType() << " and " + << m_reduce.GetType() << " respectively."; + throw ParseError(err.str(), GetSourceLocation()); + } + + if (m_reduce.GetType().Primitive() == Type::Array) + { + std::ostringstream err; + err << "An array cannot be the result of a looping expression, " + "such as range-reduce, as our array representation " + "relies on reusing array space (and thus uses constant " + "space). If arrays were the result of loops using this " + "representation, dangling pointers would result."; + throw ParseError(err.str(), GetSourceLocation()); + } + return m_reduce.GetType().AsConstType(); +} + + +size_t +FreeForm2::RangeReduceExpression::GetNumChildren() const +{ + return 4; +} + + +void +FreeForm2::RangeReduceExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_initial.Accept(p_visitor); + m_high.Accept(p_visitor); + m_low.Accept(p_visitor); + m_reduce.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::RangeReduceExpression::GetType() const +{ + return m_type; +} + + +const FreeForm2::Expression& +FreeForm2::RangeReduceExpression::GetLow() const +{ + return m_low; +} + + +const FreeForm2::Expression& +FreeForm2::RangeReduceExpression::GetHigh() const +{ + return m_high; +} + + +const FreeForm2::Expression& +FreeForm2::RangeReduceExpression::GetInitial() const +{ + return m_initial; +} + + +const FreeForm2::Expression& +FreeForm2::RangeReduceExpression::GetReduceExpression() const +{ + return m_reduce; +} + + +FreeForm2::VariableID +FreeForm2::RangeReduceExpression::GetReduceId() const +{ + return m_reduceId; +} + + +FreeForm2::VariableID +FreeForm2::RangeReduceExpression::GetStepId() const +{ + return m_stepId; +} + + +FreeForm2::ForEachLoopExpression::ForEachLoopExpression( + const Annotations& p_annotations, + const std::pair& p_bounds, + const Expression& p_next, + const Expression& p_body, + VariableID p_iteratorId, + size_t p_version, + LoopHint p_hint, + TypeManager& p_typeManager) + : Expression(p_annotations), + m_begin(*p_bounds.first), + m_end(*p_bounds.second), + m_next(p_next), + m_body(p_body), + m_iteratorType(nullptr), + m_iteratorId(p_iteratorId), + m_version(p_version), + m_hint(p_hint) +{ + FF2_ASSERT(p_bounds.first && p_bounds.second); + m_iteratorType = &TypeUtil::Unify(m_begin.GetType(), m_end.GetType(), p_typeManager, false, true); + m_iteratorType = &TypeUtil::Unify(m_next.GetType(), *m_iteratorType, p_typeManager, false, true); + + if (!m_iteratorType->IsValid()) + { + std::ostringstream err; + err << "For-each bounds must have unifiable types. Got " + << m_begin.GetType() << ", " << m_end.GetType() + << ", and " << m_next.GetType() + << " for beginning, ending, and step values respectively."; + throw ParseError(err.str(), GetSourceLocation()); + } +} + +size_t +FreeForm2::ForEachLoopExpression::GetNumChildren() const +{ + return 4; +} + + +void +FreeForm2::ForEachLoopExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_begin.Accept(p_visitor); + m_end.Accept(p_visitor); + m_next.Accept(p_visitor); + m_body.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::ForEachLoopExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +const FreeForm2::Expression& +FreeForm2::ForEachLoopExpression::GetBegin() const +{ + return m_begin; +} + + +const FreeForm2::Expression& +FreeForm2::ForEachLoopExpression::GetEnd() const +{ + return m_end; +} + + +const FreeForm2::Expression& +FreeForm2::ForEachLoopExpression::GetNext() const +{ + return m_next; +} + + +const FreeForm2::Expression& +FreeForm2::ForEachLoopExpression::GetBody() const +{ + return m_body; +} + + +const FreeForm2::TypeImpl& +FreeForm2::ForEachLoopExpression::GetIteratorType() const +{ + return *m_iteratorType; +} + + +FreeForm2::VariableID +FreeForm2::ForEachLoopExpression::GetIteratorId() const +{ + return m_iteratorId; +} + + +size_t +FreeForm2::ForEachLoopExpression::GetVersion() const +{ + return m_version; +} + + +FreeForm2::ForEachLoopExpression::LoopHint +FreeForm2::ForEachLoopExpression::GetHint() const +{ + return m_hint; +} + + +FreeForm2::ComplexRangeLoopExpression::ComplexRangeLoopExpression( + const Annotations& p_annotations, + const std::pair& p_range, + const Expression& p_step, + const Expression& p_body, + const Expression& p_precondition, + const Expression& p_loopCondition, + const TypeImpl& p_stepType, + VariableID p_stepId, + size_t p_version) + : Expression(p_annotations), + m_low(*p_range.first), + m_high(*p_range.second), + m_step(p_step), + m_body(p_body), + m_precondition(p_precondition), + m_loopCondition(p_loopCondition), + m_stepType(p_stepType), + m_stepId(p_stepId), + m_version(p_version) +{ + FF2_ASSERT(p_range.first && p_range.second); + FF2_ASSERT(p_stepId != VariableID::c_invalidID); + if (!p_range.first->GetType().IsIntegerType() || !p_range.second->GetType().IsIntegerType() + || !p_step.GetType().IsIntegerType()) + { + std::ostringstream err; + err << "Range bounds and step value must all be integral types. Got " + << p_range.first->GetType() << ", " << p_range.second->GetType() + << ", and " << p_step.GetType() << " for low, high, and step respectively."; + throw ParseError(err.str(), GetSourceLocation()); + } + + if (p_precondition.GetType().Primitive() != Type::Bool + || p_loopCondition.GetType().Primitive() != Type::Bool) + { + std::ostringstream err; + err << "Loop conditions must evaluate to boolean types. Got " + << p_precondition.GetType() << " and " << p_loopCondition.GetType() + << " for the precondition and loop condition respectively."; + throw ParseError(err.str(), GetSourceLocation()); + } +} + + +const FreeForm2::ComplexRangeLoopExpression& +FreeForm2::ComplexRangeLoopExpression::Create( + const Annotations& p_annotations, + const std::pair& p_range, + const Expression& p_step, + const Expression& p_body, + const Expression& p_loopVar, + VariableID p_stepId, + size_t p_version, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) +{ + FF2_ASSERT(p_range.first && p_range.second); + const TypeImpl* stepType + = &TypeUtil::Unify(p_range.first->GetType(), p_range.second->GetType(), p_typeManager, false, true); + stepType = &TypeUtil::Unify(*stepType, p_step.GetType(), p_typeManager, false, true); + + if (!p_range.first->GetType().IsIntegerType() || !p_range.second->GetType().IsIntegerType() + || !p_step.GetType().IsIntegerType() || !p_loopVar.GetType().IsIntegerType() + || !TypeUtil::IsAssignable(p_loopVar.GetType(), *stepType)) + { + std::ostringstream err; + err << "Range bounds, step value, and loop variable must all be integral types. Got " + << p_range.first->GetType() << ", " << p_range.second->GetType() + << ", " << p_step.GetType() << ", and " << p_loopVar.GetType() + << " for low, high, and step respectively."; + throw ParseError(err.str(), p_annotations.m_sourceLocation); + } + + const Expression* precondition = nullptr; + const Expression* condition = nullptr; + if (p_step.IsConstant()) + { + boost::tie(precondition, condition) + = CreateConditionsForKnownStep(p_range, p_step, p_loopVar, p_owner, p_typeManager); + } + else + { + boost::tie(precondition, condition) + = CreateGenericLoopConditions(p_range, p_step, p_loopVar, p_owner, p_typeManager); + } + + FF2_ASSERT(precondition && condition); + auto loop = boost::make_shared( + p_annotations, p_range, p_step, p_body, *precondition, *condition, p_loopVar.GetType(), p_stepId, p_version); + p_owner.AddExpression(loop); + return *loop; +} + + +size_t +FreeForm2::ComplexRangeLoopExpression::GetNumChildren() const +{ + return 6; +} + + +void +FreeForm2::ComplexRangeLoopExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_precondition.Accept(p_visitor); + m_low.Accept(p_visitor); + m_high.Accept(p_visitor); + m_step.Accept(p_visitor); + m_body.Accept(p_visitor); + m_loopCondition.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::ComplexRangeLoopExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +const FreeForm2::Expression& +FreeForm2::ComplexRangeLoopExpression::GetPrecondition() const +{ + return m_precondition; +} + + +const FreeForm2::Expression& +FreeForm2::ComplexRangeLoopExpression::GetLow() const +{ + return m_low; +} + + +const FreeForm2::Expression& +FreeForm2::ComplexRangeLoopExpression::GetHigh() const +{ + return m_high; +} + + +const FreeForm2::Expression& +FreeForm2::ComplexRangeLoopExpression::GetStep() const +{ + return m_step; +} + + +const FreeForm2::Expression& +FreeForm2::ComplexRangeLoopExpression::GetBody() const +{ + return m_body; +} + + +const FreeForm2::Expression& +FreeForm2::ComplexRangeLoopExpression::GetLoopCondition() const +{ + return m_loopCondition; +} + + +const FreeForm2::TypeImpl& +FreeForm2::ComplexRangeLoopExpression::GetStepType() const +{ + return m_stepType; +} + + +FreeForm2::VariableID +FreeForm2::ComplexRangeLoopExpression::GetStepId() const +{ + return m_stepId; +} + + +size_t +FreeForm2::ComplexRangeLoopExpression::GetVersion() const +{ + return m_version; +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RangeReduceExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RangeReduceExpression.h new file mode 100644 index 000000000000..43eef5bb4fb9 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RangeReduceExpression.h @@ -0,0 +1,229 @@ +#pragma once + +#ifndef FREEFORM2_RANGE_REDUCE_EXPRESSION_H +#define FREEFORM2_RANGE_REDUCE_EXPRESSION_H + +#include "Expression.h" + +// (range-reduce curr 0 10 prev 0.0 (+ prev curr)) + +namespace FreeForm2 +{ + class SimpleExpressionOwner; + class TypeManager; + + // A range-reduce expression generates a loop that reduces integer values + // in a given range to a final quantity. + class RangeReduceExpression : public Expression + { + public: + // Construct a range-reduce expression to loop over a range specified + // as an Expression pair; the + RangeReduceExpression(const Annotations& p_annotations, + const Expression& p_low, + const Expression& p_high, + const Expression& p_initial, + const Expression& p_reduce, + VariableID p_stepId, + VariableID p_reduceId); + + // Methods inherited from Expression. + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + + // Expose all children nodes for the various AlternativeVisit methods. + const Expression& GetLow() const; + const Expression& GetHigh() const; + const Expression& GetInitial() const; + const Expression& GetReduceExpression() const; + VariableID GetReduceId() const; + VariableID GetStepId() const; + + private: + // Infer the resulting type of the range-reduce expression. + const TypeImpl& InferType() const; + + // Low range expression. + const Expression& m_low; + + // High range expression. + const Expression& m_high; + + // ID of the step variable. + VariableID m_stepId; + + // Initial reduction value. + const Expression& m_initial; + + // Reduction expression. + const Expression& m_reduce; + + // ID of the reduction variable + VariableID m_reduceId; + + // The type of this expression. + const TypeImpl& m_type; + }; + + // A for-each loop expression represents an iterative loop with a beginning + // value, an ending value, a next expression, and a loop body. As a + // required precondition, successive evaluations of the next expression + // will eventually result in the expression [current == end] being true. + class ForEachLoopExpression : public Expression + { + public: + // Loop hints allow the backend to optimize its implementation of the + // loop. + enum LoopHint + { + NoHint, + HintStepIncreasing, + HintStepDecreasing + }; + + // Create a for-each loop over a set of bounds. At the end of each + // evaluation of body, the iterator variable is assigned to the result + // of the p_next expression. The loop breaks when the iterator variable + // is equal to the second member of the bounds pair. + ForEachLoopExpression(const Annotations& p_annotations, + const std::pair& p_bounds, + const Expression& p_next, + const Expression& p_body, + VariableID p_iteratorId, + size_t p_version, + LoopHint p_hint, + TypeManager& p_typeManager); + + // Methods inherited from Expression. + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + + // Accessor methods + const Expression& GetBegin() const; + const Expression& GetEnd() const; + const Expression& GetNext() const; + const Expression& GetBody() const; + const TypeImpl& GetIteratorType() const; + VariableID GetIteratorId() const; + size_t GetVersion() const; + LoopHint GetHint() const; + + private: + // The beginning and ending bounds of the loop. + const Expression& m_begin; + const Expression& m_end; + + // This expression is evaluated after each iteration to get the next + // iterator value. + const Expression& m_next; + + // The loop body. + const Expression& m_body; + + // The type of the iterator variable. This type is compatible with the + // bounds expressions and the next expression. + const TypeImpl* m_iteratorType; + + // The variable ID of the iterator variable. + VariableID m_iteratorId; + + // A unique version number associated with a particular + // value for this variable. + const size_t m_version; + + // Implementation loop hint. + LoopHint m_hint; + }; + + // A complex range loop is a loop with the following properties: + // - The loop has a range [low, high) and a step, which are all integers. + // - The loop includes preconditions that assert the safety of the loop, + // specifically testing under/overflow. + // - The following is an expected precondition of the loop: + // step > 0 == high > low && step != 0. + // - If all preconditions are met, the loop will execute at least once. + // - At each iteration of the loop, there exists a variable i such that + // i = low + step * j, where j is the number of times the loop body has + // executed. + // - The loop condition follows post-test execution pattern. + // This is a more complex loop than the above structures, as it is not the + // case that the iterative variable ever be equal to the high value; the + // loop will break before passing the high value of the range. + class ComplexRangeLoopExpression : public Expression + { + public: + // Create a complex range loop expression with correct loop conditions + // according to the properties above. This method derives the + // precondition and loop condition from the given data. + static const ComplexRangeLoopExpression& + Create(const Annotations& p_annotations, + const std::pair& p_range, + const Expression& p_step, + const Expression& p_body, + const Expression& p_loopVar, + VariableID p_loopVarId, + size_t p_version, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager); + + // Create a complex range loop specifying all properties of the loop. + ComplexRangeLoopExpression(const Annotations& p_annotations, + const std::pair& p_range, + const Expression& p_step, + const Expression& p_body, + const Expression& p_precondition, + const Expression& p_loopCondition, + const TypeImpl& p_stepType, + VariableID p_stepId, + size_t p_version); + + // Methods inherited from Expression. + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + + // Accessor methods + const Expression& GetPrecondition() const; + const Expression& GetLow() const; + const Expression& GetHigh() const; + const Expression& GetStep() const; + const Expression& GetBody() const; + const Expression& GetLoopCondition() const; + const TypeImpl& GetStepType() const; + VariableID GetStepId() const; + size_t GetVersion() const; + + private: + // The low and high bounds of the range. + const Expression& m_low; + const Expression& m_high; + + // This expression is evaluated after each iteration to get the next + // iterator value. + const Expression& m_step; + + // The loop body. + const Expression& m_body; + + // These expressions are generated by the loop expression. The + // precondition is evaluated before the loop, and the loop condition + // is evaluated at every + const Expression& m_precondition; + const Expression& m_loopCondition; + + // The type of the iterator variable. + const TypeImpl& m_stepType; + + // The variable ID of the iterator variable. + VariableID m_stepId; + + // A unique version number associated with a particular + // value for this variable. + const size_t m_version; + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RefExpression.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RefExpression.cpp new file mode 100644 index 000000000000..3fa2057a3190 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RefExpression.cpp @@ -0,0 +1,156 @@ +#include "RefExpression.h" + +#include +#include "CompoundType.h" +#include "ConvertExpression.h" +#include "Declaration.h" +#include "FreeForm2Assert.h" +#include "TypeUtil.h" +#include "Visitor.h" +#include +#include + +using namespace FreeForm2; +FreeForm2::FeatureRefExpression::FeatureRefExpression(const FreeForm2::Annotations& p_annotations, + UInt32 p_index) + : FreeForm2::Expression(p_annotations), + m_index(p_index) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::FeatureRefExpression::GetType() const +{ + return TypeImpl::GetIntInstance(true); +} + + +size_t +FreeForm2::FeatureRefExpression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::FeatureRefExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +FreeForm2::VariableRefExpression::VariableRefExpression(const FreeForm2::Annotations& p_annotations, + VariableID p_id, + size_t p_version, + const TypeImpl& p_type) + : Expression(p_annotations), + m_id(p_id), + m_version(p_version), + m_type(p_type) +{ +} + + +FreeForm2::VariableRefExpression::~VariableRefExpression() +{ +} + + +size_t +FreeForm2::VariableRefExpression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::VariableRefExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +void +FreeForm2::VariableRefExpression::AcceptReference(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + p_visitor.VisitReference(*this); + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::VariableRefExpression::GetType() const +{ + return m_type; +} + + +FreeForm2::VariableID +FreeForm2::VariableRefExpression::GetId() const +{ + return m_id; +} + + +size_t +FreeForm2::VariableRefExpression::GetVersion() const +{ + return m_version; +} + + +FreeForm2::ThisExpression::ThisExpression(const Annotations& p_annotations, + const TypeImpl& p_type) + : Expression(p_annotations), + m_type(p_type) +{ +} + + +size_t +FreeForm2::ThisExpression::GetNumChildren() const +{ + return 0; +} + + +void +FreeForm2::ThisExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + p_visitor.Visit(*this); + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +void +FreeForm2::ThisExpression::AcceptReference(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + p_visitor.VisitReference(*this); + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::ThisExpression::GetType() const +{ + FF2_ASSERT(m_type.Primitive() == Type::Unknown || CompoundType::IsCompoundType(m_type)); + return m_type; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RefExpression.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RefExpression.h new file mode 100644 index 000000000000..e1ba0fed5e7e --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/RefExpression.h @@ -0,0 +1,85 @@ +#pragma once + +#ifndef FREEFORM2_REFEXPRESSION_H +#define FREEFORM2_REFEXPRESSION_H + +#include "Expression.h" + +namespace FreeForm2 +{ + // Class representing a reference to a feature. + class FeatureRefExpression : public Expression + { + public: + FeatureRefExpression(const Annotations& p_annotations, + UInt32 p_index); + + virtual const TypeImpl& GetType() const override; + + virtual size_t GetNumChildren() const override; + + virtual void Accept(Visitor& p_visitor) const override; + + UInt32 m_index; + }; + + // Class representing a reference to a stack location. We keep track of only + // a stack slot, so that the value of the expression can be generated + // during compilation, rather than parsing, and looked up using the stack + // slot as identifier. + class VariableRefExpression : public Expression + { + public: + // Construct a stack expression from a stack slot. + VariableRefExpression(const Annotations& p_annotations, + VariableID p_id, + size_t p_version, + const TypeImpl& p_type); + virtual ~VariableRefExpression(); + + // Methods inherited from Expression (note that VariableRefExpression + // can generate references). + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual void AcceptReference(Visitor&) const override; + virtual const TypeImpl& GetType() const override; + + VariableID GetId() const; + size_t GetVersion() const; + + private: + // ID assigned to the value. + VariableID m_id; + + // A unique version number associated with a particular + // value for this variable. + const size_t m_version; + + // Type of the value. + const TypeImpl& m_type; + }; + + // This class represents an expression which refers to the not-yet- + // instantiated object for the scope being compiled (for instance, a state + // machine will use a ThisExpression to refer to a member inside the state + // machine, as the machine is not instantiated at this point). + class ThisExpression : public Expression + { + public: + // Construct a ThisExpression from a compound type. + ThisExpression(const Annotations& p_annotations, + const TypeImpl& p_type); + + // Methods inherited from Visitor. + virtual size_t GetNumChildren() const override; + virtual void Accept(Visitor& p_visitor) const override; + virtual void AcceptReference(Visitor&) const override; + virtual const TypeImpl& GetType() const override; + + private: + // The type of the object. + const TypeImpl& m_type; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SelectNth.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SelectNth.cpp new file mode 100644 index 000000000000..6565b71aaf76 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SelectNth.cpp @@ -0,0 +1,210 @@ +#include "SelectNth.h" + +#include "ArrayType.h" +#include "Expression.h" +#include "FreeForm2Assert.h" +#include "SimpleExpressionOwner.h" +#include "Visitor.h" +#include "TypeManager.h" +#include "TypeUtil.h" +#include + +FreeForm2::SelectNthExpression::SelectNthExpression( + const Annotations& p_annotations, + const std::vector& p_children) + : Expression(p_annotations), + m_index(*p_children[0]), + m_numChildren(static_cast(p_children.size()) - 1), + m_type(NULL) +{ + FF2_ASSERT(p_children.size() >= 2); + for (unsigned int i = 0; i + 1 < p_children.size(); i++) + { + m_children[i] = p_children[i + 1]; + } + + m_type = &InferType(); +} + + +void +FreeForm2::SelectNthExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (unsigned int i = 0; i < m_numChildren; i++) + { + m_children[i]->Accept(p_visitor); + } + + m_index.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::SelectNthExpression::GetType() const +{ + return *m_type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::SelectNthExpression::InferType() const +{ + FF2_ASSERT(m_numChildren > 0); + + if (!m_index.GetType().IsIntegerType()) + { + std::ostringstream err; + err << "First argument to selection expression must be " + << " an integer (got type '" + << m_index.GetType() << "')"; + throw ParseError(err.str(), GetSourceLocation()); + } + + const TypeImpl& unifiedType = m_children[0]->GetType().AsConstType(); + for (unsigned int i = 1; i < m_numChildren; i++) + { + if (!unifiedType.IsSameAs(m_children[i]->GetType(), true)) + { + std::ostringstream err; + err << "All arguments to selection expression (except index) " + << "are expected to be of the same type. The first " + << i - 1 << " elements are of type '" << unifiedType + << "', element " << i << " is of type '" + << m_children[i]->GetType() << "'"; + throw ParseError(err.str(), GetSourceLocation()); + } + } + + return unifiedType; +} + + +size_t +FreeForm2::SelectNthExpression::GetNumChildren() const +{ + return m_numChildren + 1; +} + + +const FreeForm2::Expression& +FreeForm2::SelectNthExpression::GetIndex() const +{ + return m_index; +} + + +const FreeForm2::Expression& +FreeForm2::SelectNthExpression::GetChild(size_t p_index) const +{ + FF2_ASSERT(p_index < m_numChildren); + return *m_children[p_index]; +} + + +boost::shared_ptr +FreeForm2::SelectNthExpression::Alloc(const Annotations& p_annotations, + const std::vector& p_children) +{ + size_t bytes = sizeof(SelectNthExpression) + + (p_children.size() - 2) * sizeof(Expression*); + + // Allocate a shared_ptr that deletes an SelectNthExpression + // allocated in a char[]. + boost::shared_ptr exp(new (new char[bytes]) + SelectNthExpression(p_annotations, p_children), DeleteAlloc); + return exp; +} + + +void +FreeForm2::SelectNthExpression::DeleteAlloc(SelectNthExpression* p_allocated) +{ + // Manually call dtor for arith expression. + p_allocated->~SelectNthExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + + +FreeForm2::SelectRangeExpression::SelectRangeExpression(const Annotations& p_annotations, + const Expression& p_start, + const Expression& p_count, + const Expression& p_array, + TypeManager& p_typeManager) + : Expression(p_annotations), + m_type(nullptr), + m_start(p_start), + m_count(p_count), + m_array(p_array) +{ + FF2_ASSERT(m_array.GetType().Primitive() == Type::Array); + const ArrayType& type = static_cast(m_array.GetType()); + m_type = &p_typeManager.GetArrayType(type.GetChildType(), + true, + type.GetDimensionCount(), + type.GetMaxElements()); +} + + +void +FreeForm2::SelectRangeExpression::Accept(Visitor& p_visitor) const +{ + const size_t startSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_start.Accept(p_visitor); + m_count.Accept(p_visitor); + m_array.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == startSize + p_visitor.StackIncrement()); +} + + +size_t +FreeForm2::SelectRangeExpression::GetNumChildren() const +{ + return 3; +} + + +const FreeForm2::TypeImpl& +FreeForm2::SelectRangeExpression::GetType() const +{ + return *m_type; +} + + +const FreeForm2::Expression& +FreeForm2::SelectRangeExpression::GetStart() const +{ + return m_start; +} + + +const FreeForm2::Expression& +FreeForm2::SelectRangeExpression::GetCount() const +{ + return m_count; +} + + +const FreeForm2::Expression& +FreeForm2::SelectRangeExpression::GetArray() const +{ + return m_array; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SelectNth.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SelectNth.h new file mode 100644 index 000000000000..267b2665b5a7 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SelectNth.h @@ -0,0 +1,107 @@ +#pragma once + +#ifndef FREEFORM2_SELECT_NTH_H +#define FREEFORM2_SELECT_NTH_H + +#include +#include +#include "Expression.h" + +namespace DynamicRank +{ + class IFeatureMap; + class INeuralNetFeatures; +} + +namespace FreeForm2 +{ + class TypeManager; + class ArrayType; + + // SelectNth is a simple expression to which the programmer provides an + // index (first arg), and then a series of values. The index selects a + // value based on ordinal position (0 is first, 1 second, etc). If the + // programmer provides an out-of-bounds index, we provide the value from the + // nearest available index (either lowest or highest, depending on which + // side the programmer fell off the end of the expression list). + class SelectNthExpression : public Expression + { + public: + // Methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + const Expression& GetIndex() const; + const Expression& GetChild(size_t p_index) const; + + static boost::shared_ptr Alloc( + const Annotations& p_annotations, + const std::vector& p_children); + + private: + SelectNthExpression(const Annotations& p_annotations, + const std::vector& p_children); + + static void DeleteAlloc(SelectNthExpression* p_allocated); + + // Infer the result type of the select-nth expression. + const TypeImpl& InferType() const; + + // The type of this expression. + const TypeImpl* m_type; + + // Integer expression giving the selection index. + const Expression& m_index; + + // Number of children stored in m_numChildren. + unsigned int m_numChildren; + + // Array of children of this node, allocated using struct hack. + const Expression* m_children[1]; + }; + + // SelectRange is similar in concept to SelectNth: it selects a slice of + // an array. The expression takes a start index, a count, and an array. It + // evaluates to an array which has the same content as the source array + // slice starting at the start index, with count elements. If the start of + // the slice is negative, an empty array is returned (regardless of size). + // The count will be limited such that this expression will produce the + // largest array possible containing elements from the source. A negative + // or zero count will return an empty array. + class SelectRangeExpression : public Expression + { + public: + SelectRangeExpression(const Annotations& p_annotations, + const Expression& p_start, + const Expression& p_count, + const Expression& p_array, + TypeManager& p_typeManager); + + // Methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual size_t GetNumChildren() const override; + virtual const TypeImpl& GetType() const override; + + // Accessor methods for the properties of this class. + const Expression& GetStart() const; + const Expression& GetCount() const; + const Expression& GetArray() const; + + private: + // The type of this expression. + const ArrayType* m_type; + + // Integer expression giving the start of the slice. + const Expression& m_start; + + // Integer expression giving the count of elements in the slice. + const Expression& m_count; + + // The array from which a slice is being taken. + const Expression& m_array; + }; +}; + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SimpleExpressionOwner.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SimpleExpressionOwner.h new file mode 100644 index 000000000000..700d7891245d --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SimpleExpressionOwner.h @@ -0,0 +1,29 @@ +#pragma once + +#ifndef FREEFORM2_SIMPLEEXPRESSIONOWNER +#define FREEFORM2_SIMPLEEXPRESSIONOWNER + +#include "Expression.h" +#include + +namespace FreeForm2 +{ + // Straight-forward expression owner, that keeps shared pointers + // to given expressions. + class SimpleExpressionOwner : public ExpressionOwner + { + public: + // Transfer ownership of the given expression to the expression owner. + const Expression* AddExpression(const boost::shared_ptr& p_expr) + { + m_exp.push_back(p_expr); + return m_exp.back().get(); + } + + private: + // Vector of managed expressions. + std::vector> m_exp; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StateMachine.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StateMachine.cpp new file mode 100644 index 000000000000..96e560441fe0 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StateMachine.cpp @@ -0,0 +1,717 @@ +#include "StateMachine.h" + +#include +#include "FreeForm2Assert.h" +#include "Mutation.h" +#include +#include "TypeManager.h" +#include "Visitor.h" + +namespace +{ + // Custom deleter for StateMachineExpressions. This is required because + // these objects are allocated using the struct hack. + void + DeleteStateMachine(FreeForm2::StateMachineExpression* p_ptr) + { + // Explicitly call the destructor. + p_ptr->~StateMachineExpression(); + + // Delete the memory, which is allocated as a char[]. + char* mem = reinterpret_cast(p_ptr); + delete[] mem; + } +} + + +FreeForm2::StateExpression::StateExpression(const Annotations& p_annotations) + : Expression(p_annotations) +{ +} + + +void +FreeForm2::StateExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::StateExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::StateExpression::GetNumChildren() const +{ + return 0; +} + + +boost::shared_ptr +FreeForm2::StateMachineExpression::Alloc(const Annotations& p_annotations, + const TypeInitializerExpression& p_initializer, + const StateExpression* const* p_states, + size_t p_numStates, + size_t p_startStateId) +{ + const size_t allocSize = sizeof(StateMachineExpression) + + (p_numStates > 0 ? (p_numStates - 1) * sizeof(const StateExpression*) : 0); + boost::shared_ptr expr; + char* mem = NULL; + + try + { + mem = new char[allocSize]; + expr.reset(new (mem) StateMachineExpression(p_annotations, + p_initializer, + p_states, + p_numStates, + p_startStateId), + DeleteStateMachine); + } + catch (...) + { + delete[] mem; + throw; + } + return expr; +} + + +boost::shared_ptr +FreeForm2::StateMachineExpression::Alloc(const Annotations& p_annotations, + const StateMachineType& p_type, + const TypeInitializerExpression& p_initializer, + const StateExpression* const* p_states, + size_t p_numStates, + size_t p_startStateId) +{ + boost::shared_ptr expr( + Alloc(p_annotations, p_initializer, p_states, p_numStates, p_startStateId)); + + // Link the type and expression. + expr->m_type = &p_type; + FF2_ASSERT(p_type.m_expr.expired()); + p_type.m_expr = expr; + + return expr; +} + + +FreeForm2::StateMachineExpression::StateMachineExpression(const Annotations& p_annotations, + const TypeInitializerExpression& p_initializer, + const StateExpression* const* p_states, + size_t p_numStates, + size_t p_startStateId) + : Expression(p_annotations), + m_startStateId(p_startStateId), + m_type(NULL), + m_initializer(p_initializer), + m_numStates(p_numStates) +{ + memcpy(m_states, p_states, sizeof(const StateExpression*) * p_numStates); +} + + +FreeForm2::StateMachineExpression::~StateMachineExpression() +{ +} + + +void +FreeForm2::StateMachineExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (size_t i = 0; i < m_numStates; i++) + { + m_states[i]->Accept(p_visitor); + } + + m_initializer.Accept(p_visitor); + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::StateMachineExpression::GetType() const +{ + FF2_ASSERT(m_type != NULL); + return *m_type; +} + + +size_t +FreeForm2::StateMachineExpression::GetNumChildren() const +{ + return m_numStates + 1; +} + + +const FreeForm2::TypeInitializerExpression& +FreeForm2::StateMachineExpression::GetInitializer() const +{ + return m_initializer; +} + + +const FreeForm2::StateExpression* const* +FreeForm2::StateMachineExpression::GetChildren() const +{ + return m_states; +} + + +size_t +FreeForm2::StateMachineExpression::GetStartStateId() const +{ + return m_startStateId; +} + + +std::string +FreeForm2::StateMachineExpression::GetAugmentedMemberName( + const std::string& p_machineName, + const std::string& p_memberName) +{ + std::ostringstream out; + out << "__" << p_machineName << "_" << p_memberName; + return out.str(); +} + + +boost::shared_ptr +FreeForm2::ExecuteMachineExpression::Alloc(const Annotations& p_annotations, + const Expression& p_stream, + const Expression& p_machine, + const std::pair* p_yieldActions, + const size_t p_numYieldActions) +{ + size_t bytes = sizeof(ExecuteMachineExpression) + + (std::max(p_numYieldActions, 1) - 1) + * sizeof(std::pair); + + // Allocate a shared_ptr that deletes an ExecuteMachineExpression + // allocated in a char[]. + boost::shared_ptr exp; + exp.reset(new (new char[bytes]) ExecuteMachineExpression(p_annotations, + p_stream, + p_machine, + p_yieldActions, + p_numYieldActions), + DeleteAlloc); + return exp; +} + + +void +FreeForm2::ExecuteMachineExpression::DeleteAlloc(ExecuteMachineExpression* p_allocated) +{ + for (size_t i = 1; i < p_allocated->m_numYieldActions; ++i) + { + p_allocated->m_yieldActions[i].first.~basic_string(); + } + + // Manually call dtor for expression. + p_allocated->~ExecuteMachineExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + + +FreeForm2::ExecuteMachineExpression::ExecuteMachineExpression( + const Annotations& p_annotations, + const Expression& p_stream, + const Expression& p_machine, + const std::pair* p_yieldActions, + size_t p_numYieldActions) + : Expression(p_annotations), + m_machine(p_machine), + m_stream(p_stream), + m_numYieldActions(p_numYieldActions) +{ + for (size_t i = 0; i < p_numYieldActions; ++i) + { + if (i > 0) + { + new (&m_yieldActions[i].first) std::string(); + } + + m_yieldActions[i] = p_yieldActions[i]; + } +} + + +void +FreeForm2::ExecuteMachineExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + m_stream.Accept(p_visitor); + m_machine.Accept(p_visitor); + + for (size_t i = 0; i < m_numYieldActions; ++i) + { + m_yieldActions[i].second->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::ExecuteMachineExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::ExecuteMachineExpression::GetNumChildren() const +{ + return 2 + m_numYieldActions; +} + + +const FreeForm2::Expression& +FreeForm2::ExecuteMachineExpression::GetStream() const +{ + return m_stream; +} + + +const FreeForm2::Expression& +FreeForm2::ExecuteMachineExpression::GetMachine() const +{ + return m_machine; +} + + +const std::pair* +FreeForm2::ExecuteMachineExpression::GetYieldActions() const +{ + return m_yieldActions; +} + + +size_t +FreeForm2::ExecuteMachineExpression::GetNumYieldActions() const +{ + return m_numYieldActions; +} + + +FreeForm2::ExecuteMachineGroupExpression::ExecuteMachineGroupExpression(const Annotations& p_annotations, + const FreeForm2::ExecuteMachineGroupExpression::MachineInstance* p_machineInstances, + unsigned int p_numMachineInstances, + unsigned int p_numBound) + : Expression(p_annotations), + m_numMachineInstances(p_numMachineInstances), + m_numBound(p_numBound) +{ + for (unsigned int i = 0; i < m_numMachineInstances; ++i) + { + m_machineInstances[i] = p_machineInstances[i]; + } +} + + +boost::shared_ptr +FreeForm2::ExecuteMachineGroupExpression::Alloc(const Annotations& p_annotations, + const FreeForm2::ExecuteMachineGroupExpression::MachineInstance* p_machineInstances, + unsigned int p_numMachineInstances, + unsigned int p_numBound) +{ + FF2_ASSERT(p_numMachineInstances > 0); + + size_t bytes = sizeof(ExecuteMachineGroupExpression) + (p_numMachineInstances - 1) * sizeof(MachineInstance); + + // Allocate a shared_ptr that deletes an ExecuteMachineExpression + // allocated in a char[]. + boost::shared_ptr exp; + exp.reset(new (new char[bytes]) ExecuteMachineGroupExpression(p_annotations, + p_machineInstances, + p_numMachineInstances, + p_numBound), + DeleteAlloc); + return exp; +} + + +void +FreeForm2::ExecuteMachineGroupExpression::DeleteAlloc(ExecuteMachineGroupExpression* p_allocated) +{ + // Manually call dtor for expression. + p_allocated->~ExecuteMachineGroupExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + + +void +FreeForm2::ExecuteMachineGroupExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (size_t i = 0; i < m_numMachineInstances; ++i) + { + m_machineInstances[i].m_machineDeclaration->Accept(p_visitor); + m_machineInstances[i].m_machineExpression->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::ExecuteMachineGroupExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::ExecuteMachineGroupExpression::GetNumChildren() const +{ + return 2 * m_numMachineInstances; +} + + +unsigned int +FreeForm2::ExecuteMachineGroupExpression::GetNumBound() const +{ + return m_numBound; +} + + +const FreeForm2::ExecuteMachineGroupExpression::MachineInstance* +FreeForm2::ExecuteMachineGroupExpression::GetMachineInstances() const +{ + return m_machineInstances; +} + + +unsigned int +FreeForm2::ExecuteMachineGroupExpression::GetNumMachineInstances() const +{ + return m_numMachineInstances; +} + + +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::ExecuteStreamRewritingStateMachineGroupExpression( + const Annotations& p_annotations, + const FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::MachineInstance* p_machineInstances, + unsigned int p_numMachineInstances, + unsigned int p_numBound, + VariableID p_machineIndexID, + unsigned int p_machineArraySize, + StreamRewritingType p_streamRewritingType, + const Expression* p_duplicateTermInformation, + const Expression* p_numQueryPaths, + const Expression* p_queryPathCandidates, + const Expression* p_queryLength, + const Expression* p_tupleOfInterestCount, + const Expression* p_tuplesOfInterest, + bool p_isNearChunk, + unsigned int p_minChunkNumber) + : Expression(p_annotations), + m_numMachineInstances(p_numMachineInstances), + m_numBound(p_numBound), + m_machineIndexID(p_machineIndexID), + m_machineArraySize(p_machineArraySize), + m_streamRewritingType(p_streamRewritingType), + m_duplicateTermInformation(p_duplicateTermInformation), + m_numQueryPaths(p_numQueryPaths), + m_queryPathCandidates(p_queryPathCandidates), + m_queryLength(p_queryLength), + m_tupleOfInterestCount(p_tupleOfInterestCount), + m_tuplesOfInterest(p_tuplesOfInterest), + m_isNearChunk(p_isNearChunk), + m_minChunkNumber(p_minChunkNumber) +{ + for (unsigned int i = 0; i < m_numMachineInstances; ++i) + { + m_machineInstances[i] = p_machineInstances[i]; + } +} + + +boost::shared_ptr +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::Alloc(const Annotations& p_annotations, + const FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::MachineInstance* p_machineInstances, + unsigned int p_numMachineInstances, + unsigned int p_numBound, + VariableID p_machineIndexID, + unsigned int p_machineArraySize, + StreamRewritingType p_streamRewritingType, + const Expression* p_duplicateTermInformation, + const Expression* p_numQueryPaths, + const Expression* p_queryPathCandidates, + const Expression* p_queryLength, + const Expression* p_tupleOfInterestCount, + const Expression* p_tuplesOfInterest, + bool p_isNearChunk, + unsigned int p_minChunkNumber) +{ + FF2_ASSERT(p_numMachineInstances > 0 && p_machineArraySize > 0); + + // Check that the machine array size is appropriately matched with the StreamRewritingType. + //FF2_ASSERT(//(p_streamRewritingType == BodyBlock && p_machineArraySize == MetaWords::BBWM_Max) || + // (p_streamRewritingType == QueryPath && p_machineArraySize == FeatureData::c_maxNumberOfQueryPaths) + // || (p_streamRewritingType == Chunk && p_machineArraySize == (FeatureData::c_chunkTypeEndIndex - FeatureData::c_chunkTypeStartIndex + 1))); + + size_t bytes = sizeof(ExecuteStreamRewritingStateMachineGroupExpression) + (p_numMachineInstances - 1) * sizeof(MachineInstance); + + // Allocate a shared_ptr that deletes an ExecuteMachineExpression + // allocated in a char[]. + boost::shared_ptr exp; + exp.reset(new (new char[bytes]) ExecuteStreamRewritingStateMachineGroupExpression(p_annotations, + p_machineInstances, + p_numMachineInstances, + p_numBound, + p_machineIndexID, + p_machineArraySize, + p_streamRewritingType, + p_duplicateTermInformation, + p_numQueryPaths, + p_queryPathCandidates, + p_queryLength, + p_tupleOfInterestCount, + p_tuplesOfInterest, + p_isNearChunk, + p_minChunkNumber), + DeleteAlloc); + return exp; +} + + +void +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::DeleteAlloc(ExecuteStreamRewritingStateMachineGroupExpression* p_allocated) +{ + // Manually call dtor for expression. + p_allocated->~ExecuteStreamRewritingStateMachineGroupExpression(); + + // Dispose of memory, which we allocated in a char[]. + char* mem = reinterpret_cast(p_allocated); + delete[] mem; +} + + +void +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + for (size_t i = 0; i < m_numMachineInstances; ++i) + { + m_machineInstances[i].m_machineDeclaration->Accept(p_visitor); + m_machineInstances[i].m_machineExpression->Accept(p_visitor); + } + + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetNumChildren() const +{ + return 2 * m_numMachineInstances; +} + + +unsigned int +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetNumBound() const +{ + return m_numBound; +} + + +const FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::MachineInstance* +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetMachineInstances() const +{ + return m_machineInstances; +} + + +unsigned int +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetNumMachineInstances() const +{ + return m_numMachineInstances; +} + + +FreeForm2::VariableID +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetMachineIndexID() const +{ + return m_machineIndexID; +} + + +unsigned int +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetMachineArraySize() const +{ + return m_machineArraySize; +} + + +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::StreamRewritingType +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetStreamRewritingType() const +{ + return m_streamRewritingType; +} + + +const FreeForm2::Expression* +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetDuplicateTermInformation() const +{ + return m_duplicateTermInformation; +} + + +const FreeForm2::Expression* +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetNumQueryPaths() const +{ + return m_numQueryPaths; +} + + +const FreeForm2::Expression* +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetQueryPathCandidates() const +{ + return m_queryPathCandidates; +} + + +const FreeForm2::Expression* +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetQueryLength() const +{ + return m_queryLength; +} + + +const FreeForm2::Expression* +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetTupleOfInterestCount() const +{ + return m_tupleOfInterestCount; +} + + +const FreeForm2::Expression* +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetTuplesOfInterest() const +{ + return m_tuplesOfInterest; +} + + +bool +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::IsNearChunk() const +{ + return m_isNearChunk; +} + + +unsigned int +FreeForm2::ExecuteStreamRewritingStateMachineGroupExpression::GetMinChunkNumber() const +{ + return m_minChunkNumber; +} + + +FreeForm2::YieldExpression::YieldExpression(const Annotations& p_annotations, + const std::string& p_machineName, + const std::string& p_name) + : Expression(p_annotations), + m_name(p_name), + m_machineName(p_machineName), + m_fullName(p_machineName + "::" + p_name) +{ +} + + +void +FreeForm2::YieldExpression::Accept(Visitor& p_visitor) const +{ + const size_t stackSize = p_visitor.StackSize(); + + p_visitor.Visit(*this); + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::YieldExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::YieldExpression::GetNumChildren() const +{ + return 0; +} + + +const std::string& +FreeForm2::YieldExpression::GetName() const +{ + return m_name; +} + + +const std::string& +FreeForm2::YieldExpression::GetMachineName() const +{ + return m_machineName; +} + + +const std::string& +FreeForm2::YieldExpression::GetFullName() const +{ + return m_fullName; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StateMachine.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StateMachine.h new file mode 100644 index 000000000000..645de57ac917 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StateMachine.h @@ -0,0 +1,418 @@ +#pragma once + +#ifndef FREEFORM2_STATEMACHINE_H +#define FREEFORM2_STATEMACHINE_H + +#include +#include "StateMachineType.h" +#include "Expression.h" +#include +#include "Declaration.h" + +namespace FreeForm2 +{ + class SimpleExpressionOwner; + class TypeInitializerExpression; + class TypeManager; + + // A StateExpression represents a single state of a finite state machine. + // States contain actions and transitions, which are described below. + class StateExpression : public Expression + { + public: + // All actions and matches have a match type, which corresponds to a + // possible type of word occurrences at the current stream location. + // An unconstrained match type does not take into account the current + // word. + enum MatchType + { + Unconstrained, + MatchWord, + MatchInstanceHeader, + MatchBodyBlockHeader, + EndStream + }; + + // Actions modify the state of a state machine, and are optionally + // executed when entering or leaving a state. Leaving actions are + // contained within the Transition struct; this struct is for entering + // actions only. + struct Action + { + // Type of the current word for the action. + MatchType m_matchType; + + // Expression tree for the action. + const Expression* m_action; + }; + + // Transitions make the edges of the state machine graph. Because the + // finite state machines are digraphs, each state contains the + // transitions originating from that state. + struct Transition + { + // Type of the current word required for this transition. + MatchType m_matchType; + + // Optional condition for this transition. This should be NULL if + // not used. + const Expression* m_condition; + + // Destination state ID in the context of the parent machine. + size_t m_destinationId; + + // Optional action to execute when leaving the state. + const Expression* m_leavingAction; + }; + + StateExpression(const Annotations& p_annotations); + + // Methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // An identifier which is unique within the same parsing context. + size_t m_id; + + // Actions to perform on this state. + std::list m_actions; + + // Transitions leaving this state. + std::list m_transitions; + }; + + // A StateMachineExpression is a machine node, which contains the digraph + // state machine. Note that even though the machine is a graph, the syntax + // tree does not contain physical cycles; transitions are referenced by ID + // as oppose to pointers. + class StateMachineExpression : public Expression + { + public: + // Allocate an instance of a StateMachineExpression from an existing + // StateMachineType. The StateMachineType must not have an associated + // definition expression. + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const StateMachineType& p_type, + const TypeInitializerExpression& p_initializer, + const StateExpression* const* p_states, + size_t p_numStates, + size_t p_startStateId); + + // Empty destructor for use with custom allocation. + virtual ~StateMachineExpression(); + + // Methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Getters for member variables. + const TypeInitializerExpression& GetInitializer() const; + const StateExpression* const* GetChildren() const; + size_t GetStartStateId() const; + + // Generates the augmented name for state machine member names. + static std::string GetAugmentedMemberName(const std::string& p_machineName, + const std::string& p_memberName); + + private: + // Private constructor due to struct hack allocation. + StateMachineExpression(const Annotations& p_annotations, + const TypeInitializerExpression& p_initializer, + const StateExpression* const* p_states, + size_t p_numStates, + size_t p_startStateId); + + // Allocate the StateMachineExpression independent of the TypeImpl. + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const TypeInitializerExpression& p_initializer, + const StateExpression* const* p_states, + size_t p_numStates, + size_t p_startStateId); + + // Start state of the machine. + size_t m_startStateId; + + // The type of the state machine type. + const StateMachineType* m_type; + + // The type initializer for the state machine. + const TypeInitializerExpression& m_initializer; + + // Number of children in m_children. + size_t m_numStates; + + // Children of the machine, allocated using the struct hack. This + // should be either StateExpressions or DeclarationExpressions. + const StateExpression* m_states[1]; + }; + + // This expression executes an instantiated state machine on a stream. + class ExecuteMachineExpression : public Expression + { + public: + + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const Expression& p_stream, + const Expression& p_machine, + const std::pair* p_yieldActions, + size_t p_numYieldActions); + + // Methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Getters for member variables. + const Expression& GetStream() const; + const Expression& GetMachine() const; + const std::pair* GetYieldActions() const; + size_t GetNumYieldActions() const; + + private: + ExecuteMachineExpression(const Annotations& p_annotations, + const Expression& p_stream, + const Expression& p_machine, + const std::pair* p_yieldActions, + size_t p_numYieldActions); + + static void DeleteAlloc(ExecuteMachineExpression* p_allocated); + + // The stream object to be mached by the state machine + const Expression& m_stream; + + // The machine variable to execute. + const Expression& m_machine; + + // The number of available yield actions. + size_t m_numYieldActions; + + // Action code to be executed on each named yield, + // allocated using the struct hack. + std::pair m_yieldActions[1]; + }; + + // This expression groups several state machine declarations and ExecuteMachineExpressions. + // This is a wrapper class to help state machine composition. + class ExecuteMachineGroupExpression : public Expression + { + public: + struct MachineInstance + { + // The VariableRefExpression that points to the + // state machine declaration. + const Expression* m_machineDeclaration; + + // The ExecuteMachineExpression associated with this + // group. + const ExecuteMachineExpression* m_machineExpression; + }; + + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const MachineInstance* p_machines, + unsigned int p_numMachines, + unsigned int p_numBound); + + // Return the number of symbols bound by immediate children of this + // expression, and left open. + unsigned int GetNumBound() const; + + // Methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Getters for member variables. + const MachineInstance* GetMachineInstances() const; + unsigned int GetNumMachineInstances() const; + + private: + ExecuteMachineGroupExpression(const Annotations& p_annotations, + const MachineInstance* p_machines, + unsigned int p_numMachines, + unsigned int p_numBound); + + static void DeleteAlloc(ExecuteMachineGroupExpression* p_allocated); + + // Number of machine instances. + unsigned int m_numMachineInstances; + + // Number of symbols left bound by the machine instances. + unsigned int m_numBound; + + // Action code to be executed on each named yield, + // allocated using the struct hack. + MachineInstance m_machineInstances[1]; + }; + + // This expression executes a state machine group where the state machines + // act on a modified version of the stream. The original stream + // given to the feature group is "rewritten" before the state machines process it. + class ExecuteStreamRewritingStateMachineGroupExpression : public Expression + { + public: + struct MachineInstance + { + // The DeclarationExpression that points to the + // state machine declaration. + const DeclarationExpression* m_machineDeclaration; + + // The ExecuteMachineExpression associated with this + // group. + const ExecuteMachineExpression* m_machineExpression; + }; + + // The type of stream rewriting mechanism that will be used for this state machine group. + enum StreamRewritingType + { + BodyBlock, + QueryPath, + Chunk + }; + + static boost::shared_ptr + Alloc(const Annotations& p_annotations, + const MachineInstance* p_machines, + unsigned int p_numMachines, + unsigned int p_numBound, + VariableID p_machineIndexID, + unsigned int p_machineArraySize, + StreamRewritingType p_streamRewritingType, + const Expression* p_duplicateTermInformation, + const Expression* p_numQueryPaths, + const Expression* p_queryPathCandidates, + const Expression* p_queryLength, + const Expression* p_tupleOfInterestCount, + const Expression* p_tuplesOfInterest, + bool p_isNearChunk, + unsigned int p_minChunkNumber); + + // Return the number of symbols bound by immediate children of this + // expression, and left open. + unsigned int GetNumBound() const; + + // Methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Getters for member variables. + const MachineInstance* GetMachineInstances() const; + unsigned int GetNumMachineInstances() const; + VariableID GetMachineIndexID() const; + unsigned int GetMachineArraySize() const; + StreamRewritingType GetStreamRewritingType() const; + const Expression* GetDuplicateTermInformation() const; + const Expression* GetNumQueryPaths() const; + const Expression* GetQueryPathCandidates() const; + const Expression* GetQueryLength() const; + const Expression* GetTupleOfInterestCount() const; + const Expression* GetTuplesOfInterest() const; + bool IsNearChunk() const; + unsigned int GetMinChunkNumber() const; + + private: + ExecuteStreamRewritingStateMachineGroupExpression(const Annotations& p_annotations, + const MachineInstance* p_machines, + unsigned int p_numMachines, + unsigned int p_numBound, + VariableID p_machineIndexID, + unsigned int p_machineArraySize, + StreamRewritingType p_streamRewritingType, + const Expression* p_duplicateTermInformation, + const Expression* p_numQueryPaths, + const Expression* p_queryPathCandidates, + const Expression* p_queryLength, + const Expression* p_tupleOfInterestCount, + const Expression* p_tuplesOfInterest, + bool p_isNearChunk, + unsigned int p_minChunkNumber); + + static void DeleteAlloc(ExecuteStreamRewritingStateMachineGroupExpression* p_allocated); + + // Number of machine instances. + unsigned int m_numMachineInstances; + + // Number of symbols left bound by the machine instances. + unsigned int m_numBound; + + // Variable ID of state machine index. + VariableID m_machineIndexID; + + // Size of State Machine array. + unsigned int m_machineArraySize; + + // Stream rewriting type. + StreamRewritingType m_streamRewritingType; + + // An (optional) extern expression representing a raw array with information about + // duplicate terms. + const Expression* m_duplicateTermInformation; + + // An (optional) extern expression representing the number of query path candidates. + const Expression* m_numQueryPaths; + + // An (optional) extern expression representing the query path candidates. + const Expression* m_queryPathCandidates; + + // An (optional) extern expression representing the length of the query. + const Expression* m_queryLength; + + // An (optional) extern expression representing the number of tuples of interest. + const Expression* m_tupleOfInterestCount; + + // An (optional) extern expression representing the tuples of interest. + const Expression* m_tuplesOfInterest; + + // A flag whether or not the chunk match only considers near occurrences or not. + bool m_isNearChunk; + + // The minimum number of chunks required to process a chunk feature. + unsigned int m_minChunkNumber; + + // Action code to be executed on each named yield, + // allocated using the struct hack. + MachineInstance m_machineInstances[1]; + }; + + // Yield expressions are used to call back to the caller of a state + // machine. A yield expression may only appear in a state machine. + class YieldExpression : public Expression + { + public: + YieldExpression(const Annotations& p_annotations, + const std::string& p_machineName, + const std::string& p_name); + + // Methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Get the name of the yield. + const std::string& GetName() const; + + // Get the name of machine where this yield expression is originally executed. + const std::string& GetMachineName() const; + + // Get the name of the yield block that should be invoked. + const std::string& GetFullName() const; + + private: + // The name of the yield. + const std::string m_name; + + // The name of the machine where this yield expression is originally executed. + const std::string m_machineName; + + // The name of the yield block that should be invoked. + const std::string m_fullName; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StreamData.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StreamData.cpp new file mode 100644 index 000000000000..fad185bb6dd8 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StreamData.cpp @@ -0,0 +1,82 @@ +#include "StreamData.h" + +#include "FreeForm2Assert.h" +#include "Visitor.h" + +FreeForm2::StreamDataExpression::StreamDataExpression(const FreeForm2::Annotations& p_annotations, + bool p_requestsLength) + : Expression(p_annotations), + m_requestsLength(p_requestsLength) +{ +} + + +void +FreeForm2::StreamDataExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::StreamDataExpression::GetType() const +{ + return TypeImpl::GetIntInstance(true); +} + + +size_t +FreeForm2::StreamDataExpression::GetNumChildren() const +{ + return 0; +} + + +const FreeForm2::UpdateStreamDataExpression& +FreeForm2::UpdateStreamDataExpression::GetInstance() +{ + static const Annotations emptyAnnotations; + static const UpdateStreamDataExpression exp(emptyAnnotations); + return exp; +} + + +void +FreeForm2::UpdateStreamDataExpression::Accept(Visitor& p_visitor) const +{ + size_t stackSize = p_visitor.StackSize(); + + if (!p_visitor.AlternativeVisit(*this)) + { + p_visitor.Visit(*this); + } + + FF2_ASSERT(p_visitor.StackSize() == stackSize + p_visitor.StackIncrement()); +} + + +const FreeForm2::TypeImpl& +FreeForm2::UpdateStreamDataExpression::GetType() const +{ + return TypeImpl::GetVoidInstance(); +} + + +size_t +FreeForm2::UpdateStreamDataExpression::GetNumChildren() const +{ + return 0; +} + + +FreeForm2::UpdateStreamDataExpression::UpdateStreamDataExpression(const FreeForm2::Annotations& p_annotations) + : Expression(p_annotations) +{ +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StreamData.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StreamData.h new file mode 100644 index 000000000000..8c215b39314f --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/StreamData.h @@ -0,0 +1,53 @@ +#pragma once + +#ifndef FREEFORM2_STREAM_DATA_H +#define FREEFORM2_STREAM_DATA_H + +#include "Expression.h" + +namespace FreeForm2 +{ + // StreamDataExpression pulls either the stream count (opaque integer data + // field) or the length (calculated length of the stream instance) from + // currently matched stream instance. These have traditionally been called + // PhraseCount/Length, or 'click phrase' data. However, i've avoided those + // names because they're inaccurate and incomplete (as metastreams are a + // broader abstraction than click phrases, and the data in the metastream + // instance has no connection to phrases outside of click metastreams). + class StreamDataExpression : public Expression + { + public: + StreamDataExpression(const Annotations& p_annotations, + bool p_requestsLength); + + // Virtual methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + // Whether this expression requests the stream length, with the + // alternative being the stream count. + bool m_requestsLength; + }; + + // Singleton expression class to issue side-effect code to update the + // stream data (length, count) during matching. This requires an + // expression to leverage the FSM architecture to decide when this + // must occur. If we accumulate other similar expressions, they + // can be aggregated into a SystemEffectExpression, or similar. + class UpdateStreamDataExpression : public Expression + { + public: + static const UpdateStreamDataExpression& GetInstance(); + + // Virtual methods inherited from Expression. + virtual void Accept(Visitor& p_visitor) const override; + virtual const TypeImpl& GetType() const override; + virtual size_t GetNumChildren() const override; + + private: + UpdateStreamDataExpression(const Annotations& p_annotations); + }; +}; + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SymbolTable.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SymbolTable.cpp new file mode 100644 index 000000000000..000fca488475 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SymbolTable.cpp @@ -0,0 +1,241 @@ +#include "SymbolTable.h" + +#include "FreeForm2Assert.h" +#include "FreeForm2Utils.h" +#include "LiteralExpression.h" +#include "RefExpression.h" +#include +#include "SimpleExpressionOwner.h" +#include "CsHash.h" +#include + +namespace +{ + static const FreeForm2::LiteralBoolExpression c_falseExpression(FreeForm2::Annotations(), false); + static const FreeForm2::LiteralBoolExpression c_trueExpression(FreeForm2::Annotations(), true); +} + + +FreeForm2::SymbolTable::SymbolTable(SimpleExpressionOwner& p_owner, + DynamicRank::IFeatureMap* p_featureMap) + : m_owner(p_owner), m_featureMap(p_featureMap), m_localStackStart(0), m_allowFeatures(true) +{ + // Push true and false symbols onto stack. + static const char* c_false = "false"; + Bind(FreeForm2::SymbolTable::Symbol(CStackSizedString(c_false)), &c_falseExpression); + + static const char* c_true = "true"; + Bind(FreeForm2::SymbolTable::Symbol(CStackSizedString(c_true)), &c_trueExpression); + + // Measure the size of the local stack, to ignore effects of pushing on + // 'system' symbols. + m_localStackStart = m_localStack.size(); +} + + +void +FreeForm2::SymbolTable::Bind(const Symbol& p_symbol, const Expression* p_expr) +{ + if (IsSimpleName(p_symbol.GetSymbolName())) + { + m_localStack.push_back(std::make_pair(p_symbol, p_expr)); + } + else + { + // Prevent binding of odd names, consisting of punctuation. This allows + // us to change tokenisation of punctuation in the future without + // breaking backward compatibility (because we don't have to worry about + // how the tokenisation of '$%)*%$[]' will occur, and we know the pool + // of names bound by primitive operations). + std::ostringstream err; + err << "Failed to bind name '" << p_symbol + << "'. Bound names can only contain alphanumeric characters, hyphens, and underscores."; + throw ParseError(err.str(), p_expr->GetSourceLocation()); + } +} + + +std::pair +FreeForm2::SymbolTable::Unbind() +{ + FF2_ASSERT(!m_localStack.empty()); + + SIZED_STRING name = m_localStack.back().first.GetSymbolName(); + const Expression* expr = m_localStack.back().second; + + m_localStack.pop_back(); + return std::make_pair(name, expr); +} + + +std::pair +FreeForm2::SymbolTable::Unbind(const Symbol& p_symbol) +{ + FF2_ASSERT(!m_localStack.empty()); + FF2_ASSERT(m_localStack.back().first == p_symbol); + + return Unbind(); +} + + +const FreeForm2::Expression* +FreeForm2::SymbolTable::FindLocal(const Symbol& p_symbol) const +{ + // Search through the stack of local variables, from most recent + // to least recent. + for (LocalStack::const_reverse_iterator iter = m_localStack.rbegin(); + iter != m_localStack.rend(); + ++iter) + { + if (iter->first == p_symbol) + { + return iter->second; + } + } + return NULL; +} + + +bool +FreeForm2::SymbolTable::FindFeatureIndex(SIZED_STRING p_str, UInt32& p_index) const +{ + if (m_featureMap != NULL && m_allowFeatures) + { + if (m_featureMap->ObtainFeatureIndex(p_str, p_index)) + { + return true; + } + } + return false; +} + + +const FreeForm2::Expression& +FreeForm2::SymbolTable::Lookup(const Symbol& p_symbol) const +{ + const Expression* expr = FindLocal(p_symbol); + if (expr != NULL) + { + return *expr; + } + + // Failed to find symbol in local stack, check feature map. + UInt32 index = 0; + if (FindFeatureIndex(p_symbol.GetSymbolName(), index)) + { + // Found it in feature map, create an expression for it. + boost::shared_ptr exp(new FeatureRefExpression(Annotations(), index)); + m_owner.AddExpression(exp); + return *exp.get(); + } + + // Failed to resolve symbol. + std::ostringstream err; + err << "Failed to find '" << p_symbol << "' in local variables and features."; + throw std::runtime_error(err.str()); +} + + +bool +FreeForm2::SymbolTable::IsBound(const Symbol& p_symbol) const +{ + UInt32 index; + return (FindLocal(p_symbol) != NULL || FindFeatureIndex(p_symbol.GetSymbolName(), index)); +} + + +size_t +FreeForm2::SymbolTable::GetNumLocal() const +{ + return m_localStack.size() - m_localStackStart; +} + + +void +FreeForm2::SymbolTable::SetAllowFeatures(bool p_allowFeatures) +{ + m_allowFeatures = p_allowFeatures; +} + + +bool +FreeForm2::SymbolTable::GetAllowFeatures() const +{ + return m_allowFeatures; +} + + +FreeForm2::SymbolTable::Symbol::Symbol(SIZED_STRING p_str) + : m_str(p_str), + m_hash(CsHash64::Compute(p_str.pcData, p_str.cbData)), + m_paramHash(0), + m_isParameterized(false) +{ +} + + +FreeForm2::SymbolTable::Symbol::Symbol(SIZED_STRING p_str, SIZED_STRING p_param) + : m_str(p_str), + m_param(p_param), + m_isParameterized(true), + m_hash(CsHash64::Compute(p_str.pcData, p_str.cbData)), + m_paramHash(CsHash64::Compute(p_param.pcData, p_param.cbData)) +{ +} + + +bool +FreeForm2::SymbolTable::Symbol::operator==(const Symbol& p_other) const +{ + return m_isParameterized == p_other.m_isParameterized + && m_hash == p_other.m_hash + && m_paramHash == p_other.m_paramHash + && m_str.cbData == p_other.m_str.cbData + && memcmp(m_str.pcData, p_other.m_str.pcData, m_str.cbData) == 0 + && (m_isParameterized ? m_param.cbData == p_other.m_param.cbData + && memcmp(m_param.pcData, p_other.m_param.pcData, m_param.cbData) == 0 + : true); +} + + +SIZED_STRING +FreeForm2::SymbolTable::Symbol::GetSymbolName() const +{ + return m_str; +} + + +bool +FreeForm2::SymbolTable::Symbol::IsParameterized() const +{ + return m_isParameterized; +} + + +SIZED_STRING +FreeForm2::SymbolTable::Symbol::GetSymbolParameter() const +{ + FF2_ASSERT(m_isParameterized); + return m_param; +} + + +std::string +FreeForm2::SymbolTable::Symbol::ToString() const +{ + std::ostringstream buffer; + buffer << *this; + return buffer.str(); +} + + +std::ostream& +FreeForm2::operator<<(std::ostream& p_out, const FreeForm2::SymbolTable::Symbol& p_symbol) +{ + p_out << p_symbol.GetSymbolName(); + if (p_symbol.IsParameterized()) + { + p_out << "<" << p_symbol.GetSymbolParameter() << ">"; + } + return p_out; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SymbolTable.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SymbolTable.h new file mode 100644 index 000000000000..7ca83e4d7c44 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/SymbolTable.h @@ -0,0 +1,126 @@ +#pragma once + +#ifndef FREEFORM2_SYMBOL_TABLE_H +#define FREEFORM2_SYMBOL_TABLE_H + +#include +#include +#include +#include +#include + +namespace DynamicRank +{ + class IFeatureMap; +} + +namespace FreeForm2 +{ + class Expression; + class SimpleExpressionOwner; + + class SymbolTable : boost::noncopyable + { + public: + // Symbol represents a symbol string, which is hashed for + // faster linear search. + class Symbol : boost::equality_comparable + { + public: + Symbol(SIZED_STRING p_str); + Symbol(SIZED_STRING p_str, SIZED_STRING p_parameter); + + // Comparison operator. + bool operator==(const Symbol& p_other) const; + + // Get the name of this symbol. + SIZED_STRING GetSymbolName() const; + + // Whether this feature is parameterized. + bool IsParameterized() const; + + // Get the parametrization string of this symbol. + // This function may only be called if the symbol is parameterized. + SIZED_STRING GetSymbolParameter() const; + + // Get a std::string representation of this symbol. + std::string ToString() const; + + private: + // Symbol string. + SIZED_STRING m_str; + + // Hash of symbol string. + UInt64 m_hash; + + // Whether this feature is parameterized. + bool m_isParameterized; + + // Parameter string. + SIZED_STRING m_param; + + // Hash of symbol string. + UInt64 m_paramHash; + }; + + // Create a symbol table, that can (optionally) refer to an + // underlying feature map to supply + SymbolTable(SimpleExpressionOwner& p_owner, DynamicRank::IFeatureMap* p_featureMap); + + // Bind a symbol to an expression. + void Bind(const Symbol& p_symbol, const Expression* p_expr); + + // Unbind the top symbol. + std::pair Unbind(); + + // Unbind the top symbol. It must match the p_symbol + // parameter. + std::pair Unbind(const Symbol& p_symbol); + + // Look up a string in the symbol table. + const Expression& Lookup(const Symbol& p_symbol) const; + + // Test if a name has been bound. + bool IsBound(const Symbol& p_symbol) const; + + // Get the number of local symbols currently bound. + size_t GetNumLocal() const; + + // Turn on or off feature symbols. + void SetAllowFeatures(bool p_allowFeatures); + bool GetAllowFeatures() const; + + private: + + // Find a name in the local stack. Return NULL if p_str does not name + // a local. + const Expression* FindLocal(const Symbol& p_str) const; + + // Find a feature index by name. Returns true if p_str names a feature + // and features are enabled; otherwise, returns false. + bool FindFeatureIndex(SIZED_STRING p_str, UInt32& p_index) const; + + // Stack of local variables, in order of creation. Since this + // number is expected to be fairly low, we linearly search + // these variables. + typedef std::vector> LocalStack; + LocalStack m_localStack; + + // Offset at which bound symbols start in m_localStack. + size_t m_localStackStart; + + // Owner of expressions produced by this table. + SimpleExpressionOwner& m_owner; + + // Feature map which is searched after local variables. + DynamicRank::IFeatureMap* m_featureMap; + + // Flag to indicate whether features should be looked up in the + // feature map. + bool m_allowFeatures; + }; + + std::ostream& operator<<(std::ostream& p_out, const SymbolTable::Symbol& p_symbol); +}; + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/TypeUtil.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/TypeUtil.cpp new file mode 100644 index 000000000000..ae7909523422 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/TypeUtil.cpp @@ -0,0 +1,360 @@ +#include "TypeUtil.h" + +#include "ArrayType.h" +#include +#include +#include "ConvertExpression.h" +#include "FreeForm2Assert.h" +#include "LiteralExpression.h" +#include +#include +#include "StructType.h" +#include "TypeImpl.h" +#include "TypeManager.h" + + +namespace +{ + using namespace FreeForm2; + + template + boost::shared_ptr + ConvertConstant(const Annotations& p_annotations, F p_value, Type::TypePrimitive p_to) + { + switch (p_to) + { + case Type::Int: + { + return boost::make_shared(p_annotations, static_cast(p_value)); + } + case Type::UInt64: + { + return boost::make_shared(p_annotations, static_cast(p_value)); + } + case Type::Int32: + { + return boost::make_shared(p_annotations, static_cast(p_value)); + } + case Type::UInt32: + { + return boost::make_shared(p_annotations, static_cast(p_value)); + } + case Type::Float: + { + return boost::make_shared(p_annotations, static_cast(p_value)); + } + case Type::Bool: + { + return boost::make_shared(p_annotations, p_value != 0); + } + default: + { + throw std::bad_cast(); + } + } + } +} + + +bool +FreeForm2::TypeUtil::IsConvertible(const TypeImpl& p_from, const TypeImpl& p_to) +{ + if (p_from.Primitive() == Type::Unknown + || p_from.Primitive() == Type::Invalid + || p_to.Primitive() == Type::Unknown + || p_to.Primitive() == Type::Invalid) + { + return false; + } + else if (p_from.Primitive() == Type::Array && p_to.Primitive() == Type::Array) + { + return false; + } + else if (p_from.IsLeafType() && p_to.IsLeafType()) + { + return p_to.IsConst() || !p_from.IsConst(); + } + else if (p_to.Primitive() == Type::Void) + { + return true; + } + + return false; +} + + +boost::shared_ptr +FreeForm2::TypeUtil::Convert(const Expression& p_expr, Type::TypePrimitive p_type) +{ + if (p_type != Type::Void && p_expr.IsConstant()) + { + return TypeUtil::ConvertConstant(p_expr.GetAnnotations(), p_expr.GetConstantValue(), p_expr.GetType().Primitive(), p_type); + } + + switch (p_type) + { + case Type::Int: + { + return boost::make_shared(p_expr.GetAnnotations(), p_expr); + } + case Type::UInt64: + { + return boost::make_shared(p_expr.GetAnnotations(), p_expr); + } + case Type::Int32: + { + return boost::make_shared(p_expr.GetAnnotations(), p_expr); + } + case Type::UInt32: + { + return boost::make_shared(p_expr.GetAnnotations(), p_expr); + } + case Type::Float: + { + return boost::make_shared(p_expr.GetAnnotations(), p_expr); + } + case Type::Bool: + { + return boost::make_shared(p_expr.GetAnnotations(), p_expr); + } + case Type::Void: + { + return boost::make_shared(p_expr.GetAnnotations(), p_expr); + } + default: + { + std::ostringstream err; + err << "Unable to convert from " << p_expr.GetType() + << " to " << Type::Name(p_type); + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + } +} + + +boost::shared_ptr +FreeForm2::TypeUtil::ConvertConstant(const Annotations& p_annotations, ConstantValue p_value, Type::TypePrimitive p_from, Type::TypePrimitive p_to) +{ + try + { + switch (p_from) + { + case Type::Int: + { + return ::ConvertConstant(p_annotations, p_value.m_int, p_to); + } + case Type::UInt64: + { + return ::ConvertConstant(p_annotations, p_value.m_uint64, p_to); + } + case Type::Int32: + { + return ::ConvertConstant(p_annotations, p_value.m_int32, p_to); + } + case Type::UInt32: + { + return ::ConvertConstant(p_annotations, p_value.m_uint32, p_to); + } + case Type::Float: + { + return ::ConvertConstant(p_annotations, p_value.m_float, p_to); + } + case Type::Bool: + { + return ::ConvertConstant(p_annotations, p_value.m_bool, p_to); + } + default: + { + throw std::bad_cast(); + } + } + } + catch (std::bad_cast&) + { + std::ostringstream err; + err << "Unable to convert from " << Type::Name(p_from) + << " to " << Type::Name(p_to); + throw ParseError(err.str(), p_annotations.m_sourceLocation); + } +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeUtil::Unify(const TypeImpl& p_type1, + const TypeImpl& p_type2, + TypeManager& p_typeManager, + bool p_allowArray, + bool p_allowPromotion) +{ + if (!p_type1.IsValid() || !p_type2.IsValid()) + { + return TypeImpl::GetInvalidType(); + } + + const bool isConst = p_type1.IsConst() || p_type2.IsConst(); + + if (!p_allowArray && (p_type1.Primitive() == Type::Array + || p_type2.Primitive() == Type::Array)) + { + std::ostringstream err; + err << "We don't currently allow unification of array types, pending " + "a decision on whether to track array bounds statically (TFS 62552)."; + throw std::runtime_error(err.str()); + } + + if (p_type1.IsSameAs(p_type2, true)) + { + return p_type1.IsConst() ? p_type1 : p_type2; + } + + if (p_type1.Primitive() == Type::Array && p_type2.Primitive() == Type::Array) + { + const ArrayType& left = static_cast(p_type1); + const ArrayType& right = static_cast(p_type2); + + if (left.GetChildType().Primitive() == Type::Unknown + && left.GetDimensionCount() <= right.GetDimensionCount()) + { + FF2_ASSERT(left.GetMaxElements() == 0); + + // Left is unknown, use right. + return isConst ? p_type2.AsConstType() : p_type2; + } + else if (right.GetChildType().Primitive() == Type::Unknown + && right.GetDimensionCount() <= left.GetDimensionCount()) + { + FF2_ASSERT(right.GetMaxElements() == 0); + + // Right is unknown, use left. + return isConst ? p_type1.AsConstType() : p_type1; + } + else if (left.GetDimensionCount() == right.GetDimensionCount() + && left.GetChildType() == right.GetChildType()) + { + // Return the type with more information. A fixed-size array + // provides the sizes of each dimension; we prefer to return the + // fixed-size array. Note that the case of two fix-sized arrays + // with the same dimensions is covered in the IsSameAs test. + if (left.IsFixedSize() && !right.IsFixedSize()) + { + return isConst ? left.AsConstType() : left; + } + else if (!left.IsFixedSize() && right.IsFixedSize()) + { + return isConst ? right.AsConstType() : right; + } + // Otherwise, fall through to the error case. + } + } + else if (p_type1.Primitive() == p_type2.Primitive()) + { + return isConst ? p_type1.AsConstType() : p_type1; + } + else if (p_type1.Primitive() == Type::Unknown) + { + return isConst ? p_type2.AsConstType() : p_type2; + } + else if (p_type2.Primitive() == Type::Unknown) + { + return isConst ? p_type1.AsConstType() : p_type1; + } + else if (p_type1.IsLeafType() && p_type2.IsLeafType()) + { + if (p_allowPromotion + && (p_type1.Primitive() == Type::Float || p_type2.Primitive() == Type::Float) + && (p_type1.IsIntegerType() || p_type2.IsIntegerType())) + { + return TypeImpl::GetFloatInstance(isConst); + } + + // Note that no unification exists for uint64 and int64. + if (p_allowPromotion + && p_type1.IsIntegerType() + && p_type2.IsIntegerType() + && p_type1.Primitive() != Type::UInt64 + && p_type2.Primitive() != Type::UInt64) + { + return TypeImpl::GetIntInstance(isConst); + } + + if (p_allowPromotion + && p_type1.IsIntegerType() + && p_type2.IsIntegerType() + && (p_type1.Primitive() == Type::UInt64 || p_type2.Primitive() == Type::UInt64) + && p_type1.Primitive() != Type::Int + && p_type2.Primitive() != Type::Int) + { + return TypeImpl::GetUInt64Instance(isConst); + } + } + + return TypeImpl::GetInvalidType(); +} + + +bool +FreeForm2::TypeUtil::IsAssignable(const TypeImpl& p_dest, const TypeImpl& p_source) +{ + if (p_dest.Primitive() == Type::Array && p_source.Primitive() == Type::Array) + { + const ArrayType& source = static_cast(p_source); + const ArrayType& dest = static_cast(p_dest); + if (!source.GetChildType().IsSameAs(dest.GetChildType(), true)) + { + return false; + } + + if (source.GetDimensionCount() != dest.GetDimensionCount()) + { + return false; + } + + if (source.IsFixedSize() && dest.IsFixedSize()) + { + if (memcmp(source.GetDimensions(), + dest.GetDimensions(), + sizeof(unsigned int) * source.GetDimensionCount()) != 0) + { + return false; + } + else + { + return true; + } + } + else + { + return true; + } + } + else + { + if (p_source.Primitive() == Type::Unknown || p_dest.Primitive() == Type::Unknown) + { + return true; + } + else if ((p_source.Primitive() == Type::Int32 || p_source.Primitive() == Type::UInt32) + && p_dest.Primitive() == Type::Int) + { + return true; + } + else if (p_source.Primitive() == Type::UInt32 + && p_dest.Primitive() == Type::UInt64) + { + return true; + } + else + { + return p_source.IsSameAs(p_dest, true); + } + } +} + + + +const FreeForm2::TypeImpl& +FreeForm2::TypeUtil::SetConstness(const TypeImpl& p_type, bool p_isConst) +{ + return p_isConst ? p_type.AsConstType() : p_type.AsMutableType(); +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/TypeUtil.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/TypeUtil.h new file mode 100644 index 000000000000..711645ba29e1 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/TypeUtil.h @@ -0,0 +1,55 @@ +#pragma once + +#ifndef FREEFORM2_TYPE_UTIL_H +#define FREEFORM2_TYPE_UTIL_H + +#include +#include "Expression.h" +#include "TypeImpl.h" + +namespace FreeForm2 +{ + class ArrayType; + class ConversionExpression; + class TypeManager; + + class TypeUtil + { + public: + // Check if one type is convertible to another type. Returns true if + // a ConversionExpression exists; false otherwise. + static bool IsConvertible(const TypeImpl& p_from, const TypeImpl& p_to); + + // Create a new Expression of the specified type for a child + // expression. + static boost::shared_ptr Convert(const Expression& p_expr, + Type::TypePrimitive p_type); + + // Create a new literal Expression of the specified type for a constant + // value. + static boost::shared_ptr ConvertConstant(const Annotations& p_annotations, + ConstantValue p_value, + Type::TypePrimitive p_from, + Type::TypePrimitive p_to); + + // Return a type that's compatible with two types, or return + // Type::Invalid. If the types differ in const-ness, the resulting + // type will be constant. p_allowArray specifies whether we are + // allowed to unify array types. + static const TypeImpl& Unify(const TypeImpl& p_type1, + const TypeImpl& p_type2, + TypeManager& p_typeManager, + bool p_allowArray, + bool p_allowPromotion); + + // Determine if a source type is assignable to a destination type. + static bool IsAssignable(const TypeImpl& p_dest, + const TypeImpl& p_source); + + // Return a type which matches the type argument, but with const-ness + // specified by a flag. + static const TypeImpl& SetConstness(const TypeImpl& p_type, bool p_isConst); + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/UnaryOperator.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/UnaryOperator.cpp new file mode 100644 index 000000000000..8ab14200a08f --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/UnaryOperator.cpp @@ -0,0 +1,105 @@ +#include "UnaryOperator.h" + +#include "FreeForm2Assert.h" +#include "FreeForm2Type.h" +#include "TypeImpl.h" +#include "TypeUtil.h" +#include + +static +std::vector +GetOperandTypes(FreeForm2::UnaryOperator::Operation p_op) +{ + using FreeForm2::UnaryOperator; + std::vector types; + types.reserve(2); + + switch (p_op) + { + case UnaryOperator::_not: + types.push_back(FreeForm2::Type::Bool); + break; + + case UnaryOperator::abs: __attribute__((__fallthrough__)); + case UnaryOperator::minus: __attribute__((__fallthrough__)); + case UnaryOperator::round: __attribute__((__fallthrough__)); + case UnaryOperator::trunc: __attribute__((__fallthrough__)); + case UnaryOperator::log: __attribute__((__fallthrough__)); + case UnaryOperator::log1: __attribute__((__fallthrough__)); + case UnaryOperator::tanh: + types.push_back(FreeForm2::Type::Int); + types.push_back(FreeForm2::Type::Int32); + types.push_back(FreeForm2::Type::UInt32); + types.push_back(FreeForm2::Type::UInt64); + types.push_back(FreeForm2::Type::Float); + break; + + case UnaryOperator::bitnot: + types.push_back(FreeForm2::Type::Int); + types.push_back(FreeForm2::Type::Int32); + types.push_back(FreeForm2::Type::UInt32); + types.push_back(FreeForm2::Type::UInt64); + break; + + default: + FreeForm2::Unreachable(__FILE__, __LINE__); + } + + return types; +} + +const FreeForm2::TypeImpl& +FreeForm2::UnaryOperator::GetBestOperandType(Operation p_operator, + const TypeImpl& p_operand) +{ + if (p_operand.Primitive() == Type::Unknown) + { + return p_operand; + } + + const std::vector types = GetOperandTypes(p_operator); + + if (std::find(types.begin(), types.end(), p_operand.Primitive()) != types.end()) + { + return p_operand; + } + else + { + return TypeImpl::GetInvalidType(); + } +} + +const FreeForm2::TypeImpl& +FreeForm2::UnaryOperator::GetReturnType(Operation p_operator, + const TypeImpl& p_operand) +{ + switch (p_operator) + { + case UnaryOperator::minus: + { + if (p_operand.Primitive() == Type::UInt32) + { + return TypeImpl::GetIntInstance(true); + } + else + { + return p_operand; + } + } + case UnaryOperator::trunc: __attribute__((__fallthrough__)); + case UnaryOperator::round: + { + return TypeImpl::GetIntInstance(true); + } + case UnaryOperator::log: __attribute__((__fallthrough__)); + case UnaryOperator::log1: __attribute__((__fallthrough__)); + case UnaryOperator::tanh: + { + return TypeImpl::GetFloatInstance(true); + } + default: + { + return p_operand; + } + } +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/UnaryOperator.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/UnaryOperator.h new file mode 100644 index 000000000000..3df50bcf362f --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/UnaryOperator.h @@ -0,0 +1,42 @@ +#pragma once + +#ifndef FREEFORM2_UNARY_OPERATOR_H +#define FREEFORM2_UNARY_OPERATOR_H + +namespace FreeForm2 +{ + class TypeImpl; + + class UnaryOperator + { + public: + enum Operation + { + minus, + log, + log1, + abs, + round, + trunc, + _not, + bitnot, + tanh, + + invalid + }; + + // Select the best operand type for an operator. Best is defined in + // terms of TypeUtil::SelectBestType. If no valid type is found, an + // invalid TypeImpl is returned. + static const TypeImpl& GetBestOperandType(Operation p_operator, const TypeImpl& p_operand); + + // Return the type of a unary operator result given an operator and + // an operand type. If the operand type is not a valid operand type for + // the operator, the return type is undefined. + static const TypeImpl& GetReturnType(Operation p_operator, const TypeImpl& p_operand); + + }; +} + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Visitor.h b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Visitor.h new file mode 100644 index 000000000000..dc59ee7aab45 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Expression/Visitor.h @@ -0,0 +1,565 @@ +#pragma once + +#ifndef FREEFORM2_VISITOR_H +#define FREEFORM2_VISITOR_H + +#include +#include +#include "FreeForm2Type.h" + +namespace FreeForm2 +{ + class Expression; + class Allocation; + class AggregateContextExpression; + class ArrayDereferenceExpression; + class ArrayLengthExpression; + class ArrayLiteralExpression; + class BinaryOperatorExpression; + class BlockExpression; + class ComplexRangeLoopExpression; + class ConditionalExpression; + class ConvertToBoolExpression; + class ConvertToFloatExpression; + class ConvertToImperativeExpression; + class ConvertToIntExpression; + class ConvertToUInt64Expression; + class ConvertToInt32Expression; + class ConvertToUInt32Expression; + class DebugExpression; + class DeclarationExpression; + class DirectPublishExpression; + class ExecuteStreamRewritingStateMachineGroupExpression; + class ExecuteMachineExpression; + class ExecuteMachineGroupExpression; + class ExternExpression; + class FeatureRefExpression; + class FeatureSpecExpression; + class FeatureGroupSpecExpression; + class ForEachLoopExpression; + class FunctionExpression; + class FunctionCallExpression; + class ImportFeatureExpression; + class LetExpression; + class LiteralBoolExpression; + class LiteralFloatExpression; + class LiteralIntExpression; + class LiteralUInt64Expression; + class LiteralInt32Expression; + class LiteralUInt32Expression; + class LiteralStreamExpression; + class LiteralVoidExpression; + class LiteralWordExpression; + class LiteralInstanceHeaderExpression; + class MatchExpression; + class MatchOperatorExpression; + class MatchGuardExpression; + class MatchBindExpression; + class MemberAccessExpression; + class MutationExpression; + class PhiNodeExpression; + class PublishExpression; + class RandFloatExpression; + class RandIntExpression; + class RangeReduceExpression; + class ReturnExpression; + class SelectNthExpression; + class SelectRangeExpression; + class StateExpression; + class StateMachineExpression; + class StreamDataExpression; + class ThisExpression; + class TypeInitializerExpression; + class UnaryOperatorExpression; + class UnresolvedAccessExpression; + class UpdateStreamDataExpression; + class VariableRefExpression; + class YieldExpression; + + // Visitor, an interface to implement the visitor pattern to generate code. + class Visitor : boost::noncopyable + { + public: + // The visit method of the visitor pattern for each expression type. + // The AlternativeVisit function for each type should return false if + // a normal visitation is intended, and manually do the visitation return true if it will manage the + // visitation itself. + + // The order of the Accept calls is: + // * Children (0..number of children) + // * Index + virtual void Visit(const SelectNthExpression&) = 0; + virtual bool AlternativeVisit(const SelectNthExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * Start index + // * Element count + // * Source array + virtual void Visit(const SelectRangeExpression&) = 0; + virtual bool AlternativeVisit(const SelectRangeExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * else + // * then + // * condition + virtual void Visit(const ConditionalExpression&) = 0; + virtual bool AlternativeVisit(const ConditionalExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * Children (number of children-1..0) + virtual void Visit(const ArrayLiteralExpression&) = 0; + virtual bool AlternativeVisit(const ArrayLiteralExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * Expressions (0..number of bound variables) + // * The let value + virtual void Visit(const LetExpression&) = 0; + virtual bool AlternativeVisit(const LetExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * Expressions (0..number of child expressions) + virtual void Visit(const BlockExpression&) = 0; + virtual bool AlternativeVisit(const BlockExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * Children (number of children-1..0) + virtual void Visit(const BinaryOperatorExpression&) = 0; + virtual bool AlternativeVisit(const BinaryOperatorExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * Initial value + // * High value + // * Low value + // * Reduce value + virtual void Visit(const RangeReduceExpression&) = 0; + virtual bool AlternativeVisit(const RangeReduceExpression&) + { + return false; + } + + virtual void Visit(const ForEachLoopExpression&) = 0; + virtual bool AlternativeVisit(const ForEachLoopExpression&) + { + return false; + } + + virtual void Visit(const ComplexRangeLoopExpression&) = 0; + virtual bool AlternativeVisit(const ComplexRangeLoopExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * l-value + // * r-value + virtual void Visit(const MutationExpression&) = 0; + virtual bool AlternativeVisit(const MutationExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * value to be matched + // * pattern + // * action + virtual void Visit(const MatchExpression&) = 0; + virtual bool AlternativeVisit(const MatchExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * operands + virtual void Visit(const MatchOperatorExpression&) = 0; + virtual bool AlternativeVisit(const MatchOperatorExpression&) + { + return false; + } + + virtual void Visit(const MatchGuardExpression&) = 0; + virtual bool AlternativeVisit(const MatchGuardExpression&) + { + return false; + } + + virtual void Visit(const MatchBindExpression&) = 0; + virtual bool AlternativeVisit(const MatchBindExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * constraints + // virtual void Visit(const MatchWordExpression&) = 0; + // virtual bool AlternativeVisit(const MatchWordExpression&) + // { + // return false; + // } + + // The order of the Accept calls is: + // * struct + virtual void Visit(const MemberAccessExpression&) = 0; + virtual bool AlternativeVisit(const MemberAccessExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * parameters (0..number of parameters) + // * import feature parameters (0..number of import feature parameters) + // * function body + virtual void Visit(const FunctionExpression&) = 0; + virtual bool AlternativeVisit(const FunctionExpression&) + { + return false; + } + + // The order of the Accept calls is: + // * function expression + // * parameters (0..number of parameters) + virtual void Visit(const FunctionCallExpression&) = 0; + virtual bool AlternativeVisit(const FunctionCallExpression&) + { + return false; + } + + // For all the unary expressions, the child accepts the visitor + // before the call to Visit. + virtual void Visit(const LiteralIntExpression&) = 0; + virtual bool AlternativeVisit(const LiteralIntExpression&) + { + return false; + } + + virtual void Visit(const LiteralUInt64Expression&) = 0; + virtual bool AlternativeVisit(const LiteralUInt64Expression&) + { + return false; + } + + virtual void Visit(const LiteralInt32Expression&) = 0; + virtual bool AlternativeVisit(const LiteralInt32Expression&) + { + return false; + } + + virtual void Visit(const LiteralUInt32Expression&) = 0; + virtual bool AlternativeVisit(const LiteralUInt32Expression&) + { + return false; + } + + virtual void Visit(const ArrayLengthExpression&) = 0; + virtual bool AlternativeVisit(const ArrayLengthExpression&) + { + return false; + } + + virtual void Visit(const ArrayDereferenceExpression&) = 0; + virtual bool AlternativeVisit(const ArrayDereferenceExpression&) + { + return false; + } + + virtual void Visit(const ConvertToFloatExpression&) = 0; + virtual bool AlternativeVisit(const ConvertToFloatExpression&) + { + return false; + } + + virtual void Visit(const ConvertToIntExpression&) = 0; + virtual bool AlternativeVisit(const ConvertToIntExpression&) + { + return false; + } + + virtual void Visit(const ConvertToUInt64Expression&) = 0; + virtual bool AlternativeVisit(const ConvertToUInt64Expression&) + { + return false; + } + + virtual void Visit(const ConvertToInt32Expression&) = 0; + virtual bool AlternativeVisit(const ConvertToInt32Expression&) + { + return false; + } + + virtual void Visit(const ConvertToUInt32Expression&) = 0; + virtual bool AlternativeVisit(const ConvertToUInt32Expression&) + { + return false; + } + + virtual void Visit(const ConvertToBoolExpression&) = 0; + virtual bool AlternativeVisit(const ConvertToBoolExpression&) + { + return false; + } + + virtual void Visit(const ConvertToImperativeExpression&) = 0; + virtual bool AlternativeVisit(const ConvertToImperativeExpression&) + { + return false; + } + + virtual void Visit(const DeclarationExpression&) = 0; + virtual bool AlternativeVisit(const DeclarationExpression&) + { + return false; + } + + virtual void Visit(const DirectPublishExpression&) = 0; + virtual bool AlternativeVisit(const DirectPublishExpression&) + { + return false; + } + + virtual void Visit(const ExternExpression&) = 0; + virtual bool AlternativeVisit(const ExternExpression&) + { + return false; + } + + virtual void Visit(const LiteralFloatExpression&) = 0; + virtual bool AlternativeVisit(const LiteralFloatExpression&) + { + return false; + } + + virtual void Visit(const LiteralBoolExpression&) = 0; + virtual bool AlternativeVisit(const LiteralBoolExpression&) + { + return false; + } + + virtual void Visit(const LiteralVoidExpression&) = 0; + virtual bool AlternativeVisit(const LiteralVoidExpression&) + { + return false; + } + + virtual void Visit(const LiteralStreamExpression&) = 0; + virtual bool AlternativeVisit(const LiteralStreamExpression&) + { + return false; + } + + virtual void Visit(const LiteralWordExpression&) = 0; + virtual bool AlternativeVisit(const LiteralWordExpression&) + { + return false; + } + + virtual void Visit(const LiteralInstanceHeaderExpression&) = 0; + virtual bool AlternativeVisit(const LiteralInstanceHeaderExpression&) + { + return false; + } + + virtual void Visit(const FeatureRefExpression&) = 0; + virtual bool AlternativeVisit(const FeatureRefExpression&) + { + return false; + } + + virtual void Visit(const UnaryOperatorExpression&) = 0; + virtual bool AlternativeVisit(const UnaryOperatorExpression&) + { + return false; + } + + virtual void Visit(const PhiNodeExpression&) = 0; + virtual bool AlternativeVisit(const PhiNodeExpression&) + { + return false; + } + + virtual void Visit(const PublishExpression&) = 0; + virtual bool AlternativeVisit(const PublishExpression&) + { + return false; + } + + virtual void Visit(const ReturnExpression&) = 0; + virtual bool AlternativeVisit(const ReturnExpression&) + { + return false; + } + + virtual void Visit(const FeatureSpecExpression&) = 0; + virtual bool AlternativeVisit(const FeatureSpecExpression&) + { + return false; + } + + virtual void Visit(const FeatureGroupSpecExpression&) = 0; + virtual bool AlternativeVisit(const FeatureGroupSpecExpression&) + { + return false; + } + + virtual void Visit(const StreamDataExpression&) = 0; + virtual bool AlternativeVisit(const StreamDataExpression&) + { + return false; + } + + virtual void Visit(const UpdateStreamDataExpression&) = 0; + virtual bool AlternativeVisit(const UpdateStreamDataExpression&) + { + return false; + } + + virtual void Visit(const VariableRefExpression&) = 0; + virtual bool AlternativeVisit(const VariableRefExpression&) + { + return false; + } + + virtual void Visit(const ImportFeatureExpression&) = 0; + virtual bool AlternativeVisit(const ImportFeatureExpression&) + { + return false; + } + + virtual void Visit(const StateExpression&) = 0; + virtual bool AlternativeVisit(const StateExpression&) + { + return false; + } + + virtual void Visit(const StateMachineExpression&) = 0; + virtual bool AlternativeVisit(const StateMachineExpression&) + { + return false; + } + + virtual void Visit(const ExecuteStreamRewritingStateMachineGroupExpression&) = 0; + virtual bool AlternativeVisit(const ExecuteStreamRewritingStateMachineGroupExpression&) + { + return false; + } + + virtual void Visit(const ExecuteMachineExpression&) = 0; + virtual bool AlternativeVisit(const ExecuteMachineExpression&) + { + return false; + } + + virtual void Visit(const ExecuteMachineGroupExpression&) = 0; + virtual bool AlternativeVisit(const ExecuteMachineGroupExpression&) + { + return false; + } + + virtual void Visit(const YieldExpression&) = 0; + virtual bool AlternativeVisit(const YieldExpression&) + { + return false; + } + + virtual void Visit(const RandFloatExpression&) = 0; + virtual bool AlternativeVisit(const RandFloatExpression&) + { + return false; + } + + virtual void Visit(const RandIntExpression&) = 0; + virtual bool AlternativeVisit(const RandIntExpression&) + { + return false; + } + + virtual void Visit(const ThisExpression&) = 0; + virtual bool AlternativeVisit(const ThisExpression&) + { + return false; + } + + virtual void Visit(const UnresolvedAccessExpression&) = 0; + virtual bool AlternativeVisit(const UnresolvedAccessExpression&) + { + return false; + } + + virtual void Visit(const TypeInitializerExpression&) = 0; + virtual bool AlternativeVisit(const TypeInitializerExpression&) + { + return false; + } + + virtual void Visit(const AggregateContextExpression&) = 0; + virtual bool AlternativeVisit(const AggregateContextExpression&) + { + return false; + } + + virtual void Visit(const DebugExpression&) = 0; + virtual bool AlternativeVisit(const DebugExpression&) + { + return false; + } + + // VisitReference functions act like regular visit functions, except + // that the quantity generated should reference the expressed value, + // instead of the value itself. This arrangement, of strongly + // separating reference visitation from value vistation, forces us to be + // careful about reference/value distinctions, which seems to be wise. + // + // Note that we drawing a distinction between reference and value return + // at the class level would involve more duplication (duplicating + // classes, as well as Visit methods). + // + // Another alternative (with less duplication, but seems + // dangerously unsafe) is to set flags during iteration to indicate + // whether reference/value return is expected. + virtual void VisitReference(const ArrayDereferenceExpression&) = 0; + virtual bool AlternativeVisitReference(const ArrayDereferenceExpression&) + { + return false; + } + + virtual void VisitReference(const VariableRefExpression&) = 0; + virtual void VisitReference(const MemberAccessExpression&) = 0; + virtual void VisitReference(const ThisExpression&) = 0; + virtual void VisitReference(const UnresolvedAccessExpression&) = 0; + + // These two methods allow us to check on the correctness of a common + // idiom, which is to use a stack to keep results from subexpressions. + // Bugs affecting the size of this stack are irritating to track, since + // they aren't caught until significantly later than they occur. + // Override these to benefit from stack size checking during visitation. + virtual size_t StackSize() const + { + return 0; + } + virtual size_t StackIncrement() const + { + return 0; + } + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/ArrayResult.h b/src/transform/DynamicRank.FreeForm.Library/libs/External/ArrayResult.h new file mode 100644 index 000000000000..6abc35625231 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/ArrayResult.h @@ -0,0 +1,396 @@ +#pragma once + +#ifndef FREEFORM2_ARRAY_RESULT_H +#define FREEFORM2_ARRAY_RESULT_H + +#include "ArrayType.h" +#include +#include +#include "Expression.h" +#include "FreeForm2Result.h" +#include "FreeForm2Type.h" +#include "FreeForm2Tokenizer.h" +#include "FreeForm2Assert.h" +#include "ResultIteratorImpl.h" +#include "ValueResult.h" +#include + +namespace FreeForm2 +{ + // Class to iterate through elements of an array, returning basic (i.e. + // no arrays) values from it. + template + class BaseResultIterator : public ResultIteratorImpl + { + public: + BaseResultIterator(const T* p_pos, + unsigned int p_idx, + const boost::shared_array& p_space) + : m_pos(p_pos), + m_idx(p_idx), + m_space(p_space), + m_result(NULL) + { + } + + + virtual void + increment() + { + m_result.reset(NULL); + m_pos++; + m_idx++; + } + + + virtual void + decrement() + { + m_result.reset(NULL); + m_pos--; + m_idx--; + } + + + virtual const Result& + dereference() const + { + // Note that we need to initialise the result here, because we + // aren't sure it is valid at any other point. + if (m_result.get() == NULL) + { + m_result.reset(new ValueResult(*m_pos)); + } + return *m_result; + } + + + virtual void + advance(std::ptrdiff_t p_distance) + { + m_result.reset(NULL); + m_pos += p_distance; + m_idx += static_cast(p_distance); + } + + + virtual std::auto_ptr + Clone() const + { + return std::auto_ptr( + new BaseResultIterator(m_pos, m_idx, m_space)); + } + + + virtual std::pair + Position() const + { + return std::make_pair(reinterpret_cast(m_pos), + m_idx); + } + + + virtual unsigned int + ElementSize() const + { + return sizeof(T); + } + + private: + // Current position in array. + const T* m_pos; + + // Currend index in array. + unsigned int m_idx; + + // Shared pointer to space allocated to hold the array. + boost::shared_array m_space; + + // Current result. + mutable boost::scoped_ptr m_result; + }; + + + typedef boost::shared_ptr> SharedDimensions; + template class ArrayResultIterator; + + + template + class ArrayResult : public Result + { + public: + ArrayResult(const TypeImpl& p_type, + unsigned int p_dimensionPos, + const SharedDimensions& p_dimensions, + const T* p_pos, + const boost::shared_array& p_space) + : m_arrayType(NULL), + m_type(p_type), + m_dimensionPos(p_dimensionPos), + m_dimensions(p_dimensions), + m_pos(p_pos), + m_end(p_pos + ArrayResult::CalculateArrayStep(p_dimensionPos, *p_dimensions)), + m_space(p_space) + { + FF2_ASSERT(p_type.Primitive() == Type::Array); + m_arrayType = static_cast(&p_type); + FF2_ASSERT(m_arrayType->GetDimensionCount() == p_dimensions->size() - p_dimensionPos); + } + + + virtual const Type& + GetType() const + { + return m_type; + } + + + virtual IntType + GetInt() const + { + // Can't retrieve an int from an array. + Unreachable(__FILE__, __LINE__); + } + + + virtual UInt64Type + GetUInt64() const + { + // Can't retrieve a uint64 from an array. + Unreachable(__FILE__, __LINE__); + } + + + virtual int + GetInt32() const + { + // Can't retrieve an int32 from an array. + Unreachable(__FILE__, __LINE__); + } + + + virtual unsigned int + GetUInt32() const + { + // Can't retrieve an uint32 from an array. + Unreachable(__FILE__, __LINE__); + } + + + virtual FloatType + GetFloat() const + { + // Can't retrieve a float from an array. + Unreachable(__FILE__, __LINE__); + } + + + virtual bool + GetBool() const + { + // Can't retrieve a bool from an array. + Unreachable(__FILE__, __LINE__); + } + + + virtual ResultIterator + BeginArray() const + { + if (m_arrayType->GetDimensionCount() > 1) + { + // Step type down a dimension. + return ResultIterator(std::auto_ptr( + new ArrayResultIterator(m_arrayType->GetDerefType(), + m_dimensionPos + 1, + m_dimensions, + m_pos, + 0, + m_space))); + } + else + { + return ResultIterator(std::auto_ptr( + new BaseResultIterator(m_pos, 0, m_space))); + } + } + + + virtual ResultIterator + EndArray() const + { + if (m_arrayType->GetDimensionCount() > 1) + { + // Step type down a dimension. + return ResultIterator(std::auto_ptr( + new ArrayResultIterator(m_arrayType->GetDerefType(), + m_dimensionPos + 1, + m_dimensions, + m_end, + (*m_dimensions)[m_dimensions->size() - m_dimensionPos - 1], + m_space))); + } + else + { + return ResultIterator(std::auto_ptr( + new BaseResultIterator(m_end, (*m_dimensions)[0], m_space))); + } + } + + // Calculate the number of elements at a given array level, providing the + // step size across the array. + static unsigned int CalculateArrayStep(unsigned int p_dimensionPos, + const std::vector& p_dimensions) + { + unsigned int step = 1; + for (unsigned int i = p_dimensionPos; i < p_dimensions.size(); i++) + { + step *= p_dimensions[p_dimensions.size() - 1 - i]; + } + return step; + } + + + private: + // Type and TypeImpl (need both as we need to pass the impl to type). + const ArrayType* m_arrayType; + Type m_type; + + // Position in the dimension array. For example, if at 0 we are at the + // highest dimension, 1 is next down (one level of dereference), etc. + unsigned int m_dimensionPos; + + // Shared dimension array. + SharedDimensions m_dimensions; + + // Position in the array space. + const T* m_pos; + + // End position in the array space. + const T* m_end; + + // Shared pointer to space allocated to hold array. + boost::shared_array m_space; + }; + + + // Class to iterate through array elements of an array, returning subarrays. + // no arrays) values from it. + template + class ArrayResultIterator : public ResultIteratorImpl + { + public: + ArrayResultIterator(const TypeImpl& p_type, + unsigned int p_dimensionPos, + const SharedDimensions& p_dimensions, + const T* p_pos, + unsigned int p_idx, + const boost::shared_array& p_space) + : m_type(p_type), + m_dimensionPos(p_dimensionPos), + m_dimensions(p_dimensions), + m_pos(p_pos), + m_idx(p_idx), + m_space(p_space), + m_step(ArrayResult::CalculateArrayStep(p_dimensionPos, *p_dimensions)), + m_result(NULL) + { + FF2_ASSERT(m_type.Primitive() == Type::Array); + } + + + virtual void + increment() + { + m_result.reset(NULL); + m_pos += m_step; + m_idx++; + } + + + virtual void + decrement() + { + m_result.reset(NULL); + m_pos -= m_step; + m_idx--; + } + + + virtual const Result& + dereference() const + { + // Note that we need to initialise the result here, because we + // aren't sure it is valid at any other point. + if (m_result.get() == NULL) + { + m_result.reset(new ArrayResult(m_type, + m_dimensionPos, + m_dimensions, + m_pos, + m_space)); + } + return *m_result; + } + + + virtual void + advance(std::ptrdiff_t p_distance) + { + m_result.reset(NULL); + m_pos += p_distance * m_step; + m_idx += static_cast(p_distance); + } + + + virtual std::auto_ptr + Clone() const + { + return std::auto_ptr( + new ArrayResultIterator(m_type, m_dimensionPos, m_dimensions, m_pos, m_idx, m_space)); + } + + + virtual std::pair + Position() const + { + return std::make_pair(reinterpret_cast(m_pos), + m_idx); + } + + + virtual unsigned int + ElementSize() const + { + return sizeof(T) * m_step; + } + + + private: + // Child type of the array + const TypeImpl& m_type; + + // Position in the dimension array. For example, if at 0 we are at the + // highest dimension, 1 is next down (one level of dereference), etc. + unsigned int m_dimensionPos; + + // Shared dimension array. + SharedDimensions m_dimensions; + + // Position in the array space. + const T* m_pos; + + // Index in the array. + unsigned int m_idx; + + // Shared pointer to space allocated to hold array. + boost::shared_array m_space; + + // Step size that we're taking for each element. + unsigned int m_step; + + // Current result. + mutable boost::scoped_ptr> m_result; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/CMakeLists.txt b/src/transform/DynamicRank.FreeForm.Library/libs/External/CMakeLists.txt new file mode 100644 index 000000000000..8d3638ee79a4 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/CMakeLists.txt @@ -0,0 +1,34 @@ +cmake_minimum_required(VERSION 3.15) + +set(PROJECT_NAME DRFreeFormLibrary) + +project(${PROJECT_NAME}) + +set(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS} -fpermissive") + +add_library(${PROJECT_NAME} STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/Compiler.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Executable.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/FreeForm2ExternalData.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/FreeForm2Result.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/FreeForm2Type.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/NeuralInputFreeForm2.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Program.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ValueResult.cpp +) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../../inc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../NeuralTree.Library/inc + ${CMAKE_CURRENT_SOURCE_DIR}/../Shared + ${CMAKE_CURRENT_SOURCE_DIR}/../Expression + ${CMAKE_CURRENT_SOURCE_DIR}/../Backend/llvm + ${CMAKE_CURRENT_SOURCE_DIR}/../Transform + ${CMAKE_CURRENT_SOURCE_DIR}/../Parse/SExpression/inc + ) + +install(TARGETS ${PROJECT_NAME} + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + ) \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/Compiler.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/External/Compiler.cpp new file mode 100644 index 000000000000..2acc5d172a2b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/Compiler.cpp @@ -0,0 +1,33 @@ +#include "Compiler.h" +#include "FreeForm2Compiler.h" + +#include "FreeForm2Program.h" +#include + + +FreeForm2::CompilerImpl::~CompilerImpl() +{ +} + + +FreeForm2::CompilerResults::~CompilerResults() +{ +} + + +FreeForm2::Compiler::Compiler(std::auto_ptr p_impl) + : m_impl(p_impl.release()) +{ +} + + +FreeForm2::Compiler::~Compiler() +{ +} + + +std::unique_ptr +FreeForm2::Compiler::Compile(const Program& p_program, bool p_debugOutput) +{ + return m_impl->Compile(p_program.GetImplementation(), p_debugOutput); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/Compiler.h b/src/transform/DynamicRank.FreeForm.Library/libs/External/Compiler.h new file mode 100644 index 000000000000..8ea07998ea0f --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/Compiler.h @@ -0,0 +1,25 @@ +#pragma once + +#ifndef FREEFORM2_COMPILER +#define FREEFORM2_COMPILER + +#include +#include + +namespace FreeForm2 +{ + class CompilerResults; + class ProgramImpl; + + class CompilerImpl : boost::noncopyable + { + public: + virtual ~CompilerImpl(); + + // Compile the given program. + virtual std::unique_ptr Compile(const ProgramImpl& p_program, + bool p_debugOutput) = 0; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/Executable.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/External/Executable.cpp new file mode 100644 index 000000000000..4b4e5ade4d8e --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/Executable.cpp @@ -0,0 +1,99 @@ +#include "Executable.h" +#include "FreeForm2Executable.h" + +#include "Compiler.h" +#include "FreeForm2Assert.h" +#include "FreeForm2Compiler.h" +#include "FreeForm2CompilerFactory.h" +#include "FreeForm2Result.h" +#include "LlvmCompiler.h" +#include + +FreeForm2::Executable::Executable(std::auto_ptr p_impl) + : m_impl(p_impl.release()) +{ +} + + +boost::shared_ptr +FreeForm2::Executable::Evaluate(StreamFeatureInput* p_input, + const FeatureType p_features[]) const +{ + return m_impl->Evaluate(p_input, p_features); +} + + +boost::shared_ptr +FreeForm2::Executable::Evaluate(const Executable::FeatureType* const* p_features, + UInt32 p_currentDocument, + UInt32 p_documentCount, + Int64* p_cache) const +{ + return m_impl->Evaluate(p_features, p_currentDocument, p_documentCount, p_cache); +} + + +FreeForm2::Executable::DirectEvalFun +FreeForm2::Executable::EvaluationFunction() const +{ + return m_impl->EvaluationFunction(); +} + + +FreeForm2::Executable::AggregatedEvalFun +FreeForm2::Executable::AggregatedEvaluationFunction() const +{ + return m_impl->AggregatedEvaluationFunction(); +} + + +const FreeForm2::Type& +FreeForm2::Executable::GetType() const +{ + return m_impl->GetType(); +} + + +size_t +FreeForm2::Executable::GetExternalSize() const +{ + return sizeof(FreeForm2::ExecutableImpl) + m_impl->GetExternalSize(); +} + + +const FreeForm2::ExecutableImpl& +FreeForm2::Executable::GetImplementation() const +{ + return *m_impl; +} + + +FreeForm2::ExecutableCompilerResults::ExecutableCompilerResults( + const boost::shared_ptr& p_executable) + : m_executable(p_executable) +{ +} + + +const boost::shared_ptr& +FreeForm2::ExecutableCompilerResults::GetExecutable() const +{ + return m_executable; +} + + +FreeForm2::ExecutableImpl::~ExecutableImpl() +{ +} + + +std::unique_ptr +FreeForm2::CompilerFactory::CreateExecutableCompiler( + unsigned int p_optimizationLevel, + FreeForm2::CompilerFactory::DestinationFunctionType p_destinationFunctionType) +{ + std::auto_ptr impl; + impl.reset(new LlvmCompilerImpl(p_optimizationLevel, p_destinationFunctionType)); + return std::unique_ptr(new Compiler(impl)); +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/Executable.h b/src/transform/DynamicRank.FreeForm.Library/libs/External/Executable.h new file mode 100644 index 000000000000..b577034d22e7 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/Executable.h @@ -0,0 +1,48 @@ +#pragma once + +#ifndef FREEFORM2_EXECUTABLE_H +#define FREEFORM2_EXECUTABLE_H + +#include +#include +#include "FreeForm2CompilerFactory.h" +#include "FreeForm2Executable.h" + +class StreamFeatureInput; + +namespace FreeForm2 +{ + class Result; + class Type; + + // An ExectuableImpl is the implementation class for exectuables, which + // currently compiles and runs via one of the backends. + class ExecutableImpl : boost::noncopyable + { + public: + virtual ~ExecutableImpl(); + + virtual boost::shared_ptr + Evaluate(StreamFeatureInput* p_input, + const Executable::FeatureType p_features[]) const = 0; + + // List based evaluation. + virtual boost::shared_ptr + Evaluate(const Executable::FeatureType* const* p_features, + UInt32 p_currentDocument, + UInt32 p_documentCount, + Int64* p_cache) const = 0; + + virtual Executable::DirectEvalFun EvaluationFunction() const = 0; + + // Get list based evaluation function. + virtual Executable::AggregatedEvalFun AggregatedEvaluationFunction() const = 0; + + virtual const Type& GetType() const = 0; + + // Get the size of external memory. + virtual size_t GetExternalSize() const = 0; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2ExternalData.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2ExternalData.cpp new file mode 100644 index 000000000000..ed37a668fbf1 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2ExternalData.cpp @@ -0,0 +1,70 @@ +#include "FreeForm2ExternalData.h" + +#include "FreeForm2Type.h" +#include +#include "TypeImpl.h" +#include "TypeManager.h" + +FreeForm2::ExternalData::ExternalData(const std::string& p_name, const FreeForm2::TypeImpl& p_typeImpl) + : m_name(p_name), m_type(&p_typeImpl), m_isCompileTimeConst(false) +{ +} + + +FreeForm2::ExternalData::ExternalData(const std::string& p_name, + const TypeImpl& p_typeImpl, + ConstantValue p_value) + : m_name(p_name), m_type(&p_typeImpl), m_isCompileTimeConst(true), m_constantValue(p_value) +{ +} + + +FreeForm2::ExternalData::~ExternalData() +{ +} + + +const std::string& +FreeForm2::ExternalData::GetName() const +{ + return m_name; +} + + +const FreeForm2::TypeImpl& +FreeForm2::ExternalData::GetType() const +{ + return *m_type; +} + + +bool +FreeForm2::ExternalData::IsCompileTimeConstant() const +{ + return m_isCompileTimeConst; +} + + +FreeForm2::ConstantValue +FreeForm2::ExternalData::GetCompileTimeValue() const +{ + return m_constantValue; +} + + +FreeForm2::ExternalDataManager::ExternalDataManager() + : m_typeFactory(new TypeFactory(TypeManager::CreateTypeManager())) +{ +} + + +FreeForm2::ExternalDataManager::~ExternalDataManager() +{ +} + + +FreeForm2::TypeFactory& +FreeForm2::ExternalDataManager::GetTypeFactory() +{ + return *m_typeFactory; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2Result.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2Result.cpp new file mode 100644 index 000000000000..62ce0d089c78 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2Result.cpp @@ -0,0 +1,365 @@ +#include "FreeForm2Result.h" + +#include +#include "FreeForm2Assert.h" +#include "FreeForm2Tokenizer.h" +#include "ResultIteratorImpl.h" +#include "TypeImpl.h" +#include "TypeUtil.h" +#include +#include + +using namespace FreeForm2; +using namespace boost; + + +FreeForm2::Result::~Result() +{ +} + + +int +FreeForm2::Result::Compare(const Result& p_other) const +{ + if (!GetType().GetImplementation().IsSameAs(p_other.GetType().GetImplementation(), true)) + { + std::ostringstream err; + err << "Mismatched compare between " + << GetType() << " and " << p_other.GetType(); + throw std::runtime_error(err.str()); + } + + switch (GetType().Primitive()) + { + case Type::Bool: + { + bool left = GetBool(); + bool right = p_other.GetBool(); + + if (left == right) + { + return 0; + } + else if (right) + { + return -1; + } + else + { + return 1; + } + break; + } + + case Type::Int: + { + Result::IntType left = GetInt(); + Result::IntType right = p_other.GetInt(); + + if (left < right) + { + return -1; + } + else if (left > right) + { + return 1; + } + else + { + return 0; + } + break; + } + + case Type::UInt64: + { + Result::UInt64Type left = GetUInt64(); + Result::UInt64Type right = p_other.GetUInt64(); + + if (left < right) + { + return -1; + } + else if (left > right) + { + return 1; + } + else + { + return 0; + } + break; + } + + case Type::Int32: + { + int left = GetInt32(); + int right = p_other.GetInt32(); + + if (left < right) + { + return -1; + } + else if (left > right) + { + return 1; + } + else + { + return 0; + } + break; + } + + case Type::UInt32: + { + unsigned int left = GetUInt32(); + unsigned int right = p_other.GetUInt32(); + + if (left < right) + { + return -1; + } + else if (left > right) + { + return 1; + } + else + { + return 0; + } + break; + } + + case Type::Float: + { + return CompareFloat(GetFloat(), p_other.GetFloat()); + } + + case Type::Array: + { + ResultIterator leftPos = BeginArray(); + ResultIterator leftEnd = EndArray(); + ResultIterator rightPos = p_other.BeginArray(); + ResultIterator rightEnd = p_other.EndArray(); + + while (leftPos != leftEnd && rightPos != rightEnd) + { + int cmp = leftPos->Compare(*rightPos); + if (cmp != 0) + { + return cmp; + } + + ++leftPos; + ++rightPos; + } + + if (leftPos == leftEnd && rightPos == rightEnd) + { + return 0; + } + else if (leftPos == leftEnd) + { + return -1; + } + else + { + return 1; + } + break; + } + + default: + { + std::ostringstream err; + err << "Comparison of unknown type '" << GetType() << "'"; + throw std::runtime_error(err.str()); + } + } +} + + +void +FreeForm2::Result::Print(std::ostream& p_out) const +{ + switch (GetType().Primitive()) + { + case Type::Bool: + { + p_out << (GetBool() ? "true" : "false"); + break; + } + + case Type::Int: + { + p_out << GetInt(); + break; + } + + case Type::UInt64: + { + p_out << GetUInt64(); + break; + } + + case Type::Int32: + { + p_out << GetInt32(); + break; + } + + case Type::UInt32: + { + p_out << GetUInt32(); + break; + } + + case Type::Float: + { + const std::streamsize savePrecision = p_out.precision(); + + p_out << std::setprecision(9) << GetFloat(); + + // Restore precision. + p_out << std::setprecision(savePrecision); + break; + } + + case Type::Array: + { + p_out << "["; + ResultIterator end = EndArray(); + bool first = true; + for (ResultIterator iter = BeginArray(); iter != end; ++iter) + { + p_out << (first ? "" : " "); + first = false; + iter->Print(p_out); + } + p_out << "]"; + break; + } + + default: + { + std::ostringstream err; + err << "Printing unknown type '" << GetType() << "'"; + throw std::runtime_error(err.str()); + } + } +} + + +std::ostream& +FreeForm2::operator<<(std::ostream& p_out, const Result& p_result) +{ + p_result.Print(p_out); + return p_out; +} + + +int +FreeForm2::Result::CompareFloat(FloatType p_left, FloatType p_right) +{ + // This value was chosen to be compatible with the old freeforms. + const Result::FloatType relativeError = 1E-6F; + bool equal = false; + + // Check for identical values (this is needed to compare infinity). + if (p_left == p_right) + { + return 0; + } + + // Check whether right operand is small. + if (p_right < relativeError && p_right > -relativeError) + { + // Right is small, so they're equal iff left is small. + equal = (p_left < relativeError && p_left > -relativeError); + } + else + { + // Right isn't small, so check the difference between the two + // (related to right operand). They're equal iff the difference is + // small. + const Result::FloatType diff = (p_left - p_right) / p_right; + equal = (diff < relativeError && diff > -relativeError); + } + + if (equal) + { + return 0; + } + else if (p_left < p_right) + { + return -1; + } + else + { + return 1; + } +} + + +FreeForm2::ResultIterator::ResultIterator(std::auto_ptr p_impl) + : m_impl(p_impl) +{ +} + + +FreeForm2::ResultIterator::ResultIterator(const ResultIterator& p_other) + : m_impl(p_other.m_impl->Clone()) +{ +} + + +FreeForm2::ResultIterator::~ResultIterator() +{ +} + + +void +FreeForm2::ResultIterator::increment() +{ + return m_impl->increment(); +} + + +void +FreeForm2::ResultIterator::decrement() +{ + return m_impl->decrement(); +} + + +bool +FreeForm2::ResultIterator::equal(const ResultIterator& p_other) const +{ + return m_impl->Position() == p_other.m_impl->Position() + && m_impl->ElementSize() == p_other.m_impl->ElementSize(); +} + + +const Result& +FreeForm2::ResultIterator::dereference() const +{ + return m_impl->dereference(); +} + + +void +FreeForm2::ResultIterator::advance(std::ptrdiff_t p_distance) +{ + m_impl->advance(p_distance); +} + + +std::ptrdiff_t +FreeForm2::ResultIterator::distance_to(const ResultIterator& p_other) const +{ + FF2_ASSERT(m_impl->ElementSize() == p_other.m_impl->ElementSize()); + return p_other.m_impl->Position().second - m_impl->Position().second; +} + + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2Type.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2Type.cpp new file mode 100644 index 000000000000..a8be506ea030 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/FreeForm2Type.cpp @@ -0,0 +1,226 @@ +#include "FreeForm2Type.h" + +#include "ArrayType.h" +#include "FreeForm2Assert.h" +#include "FunctionType.h" +#include +#include +#include +#include "StructType.h" +#include "TypeImpl.h" +#include "TypeManager.h" + +FreeForm2::Type::Type(const TypeImpl& p_impl) + : m_impl(p_impl) +{ +} + + +FreeForm2::Type::TypePrimitive +FreeForm2::Type::Primitive() const +{ + return m_impl.Primitive(); +} + + +const char* +FreeForm2::Type::Name(TypePrimitive p_type) +{ + switch (p_type) + { + case Float: return "float"; + case Int: return "int"; + case UInt64: return "uint64"; + case Int32: return "int32"; + case UInt32: return "uint32"; + case Bool: return "bool"; + case Array: return "array"; + case Void: return "void"; + case Stream: return "stream"; + case Word: return "word"; + case InstanceHeader: return "instanceHeader"; + case BodyBlockHeader: return "bodyBlockHeader"; + case Unknown: return "unknown"; + default: return ""; + }; +} + + +bool +FreeForm2::Type::operator==(const Type& p_other) const +{ + return GetImplementation() == p_other.GetImplementation(); +} + + +const FreeForm2::TypeImpl& +FreeForm2::Type::GetImplementation() const +{ + return m_impl; +} + + +FreeForm2::Type::TypePrimitive +FreeForm2::Type::ParsePrimitive(SIZED_STRING p_string) +{ + for (int i = 0; i < Type::Invalid; i++) + { + Type::TypePrimitive prim = static_cast(i); + const char* name = Name(prim); + + if (name != NULL) + { + size_t len = strlen(name); + + if (p_string.cbData == len && memcmp(p_string.pcData, name, len) == 0) + { + return prim; + } + } + } + + return Type::Invalid; +} + + +std::ostream& +FreeForm2::operator<<(std::ostream& p_out, const Type& p_type) +{ + return p_out << p_type.GetImplementation(); +} + + +FreeForm2::TypeFactory::TypeFactory(std::auto_ptr p_typeManager) + : m_typeManager(p_typeManager.release()) +{ +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetFloatType() +{ + return TypeImpl::GetFloatInstance(true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetIntType() +{ + return TypeImpl::GetIntInstance(true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetUInt64Type() +{ + return TypeImpl::GetUInt64Instance(true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetInt32Type() +{ + return TypeImpl::GetInt32Instance(true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetUInt32Type() +{ + return TypeImpl::GetUInt32Instance(true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetBoolType() +{ + return TypeImpl::GetBoolInstance(true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetVoidType() +{ + return TypeImpl::GetVoidInstance(); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetStreamType() +{ + return TypeImpl::GetStreamInstance(true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetWordType() +{ + return TypeImpl::GetWordInstance(true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetInstanceHeaderType() +{ + return TypeImpl::GetInstanceHeaderInstance(true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetBodyBlockHeaderType() +{ + return TypeImpl::GetBodyBlockHeaderInstance(true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetArrayType(const TypeImpl& p_child, + const UInt32* p_dimensions, + UInt32 p_dimensionCount) +{ + const UInt32 maxElements + = std::accumulate(p_dimensions, p_dimensions + p_dimensionCount, 1, std::multiplies()); + return m_typeManager->GetArrayType(p_child, true, p_dimensionCount, p_dimensions, maxElements); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetStructType(const std::string& p_name, + const StructMember* p_members, + UInt32 p_memberCount) +{ + std::vector members; + members.reserve(p_memberCount); + for (UInt32 i = 0; i < p_memberCount; i++) + { + members.push_back(StructType::MemberInfo(p_members[i].first, + p_members[i].second, + p_members[i].first, + 0, + 0)); + } + return m_typeManager->GetStructType(p_name, p_name, members, true); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeFactory::GetFunctionType(const TypeImpl& p_returnType, + const TypeImpl* const* p_parameters, + UInt32 p_parameterCount) +{ + return m_typeManager->GetFunctionType(p_returnType, p_parameters, p_parameterCount); +} + + +const FreeForm2::TypeImpl* +FreeForm2::TypeFactory::FindType(const std::string& p_name) const +{ + return m_typeManager->GetTypeInfo(p_name); +} + + +const FreeForm2::TypeManager& +FreeForm2::TypeFactory::GetTypeManager() const +{ + return *m_typeManager; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/NeuralInputFreeForm2.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/External/NeuralInputFreeForm2.cpp new file mode 100644 index 000000000000..12fe2db0268e --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/NeuralInputFreeForm2.cpp @@ -0,0 +1,459 @@ +#include "NeuralInputFreeForm2.h" + +#include +#include "BinaryOperator.h" +#include +#include +#include +#include "Compiler.h" +#include "Conditional.h" +#include "ConvertExpression.h" +#include "Executable.h" +#include "Expression.h" +#include "FreeForm2Assert.h" +#include "FreeForm2Compiler.h" +#include "FreeForm2CompilerFactory.h" +#include "FreeForm2Executable.h" +#include "FreeForm2Program.h" +#include "FreeForm2Result.h" +#include "FreeForm2Utils.h" +#include +#include "LiteralExpression.h" +#include "LlvmCompiler.h" +#include "OperatorExpression.h" +#include "Program.h" +#include "RefExpression.h" +#include "SimpleExpressionOwner.h" +#include +#include "TypeManager.h" +#include +#include + +using namespace LightGBM; + + +namespace +{ + // Gets the feature index for an input which transforms a single feature. + bool ReadAssociatedFeature(DynamicRank::Config& p_config, + const char *p_section, + DynamicRank::IFeatureMap& p_featureMap, + UInt32 *p_feature) + { + *p_feature = static_cast(-1); + + if (p_section == NULL) + { + return false; + } + + char szName[256]; + if (p_config.GetStringParameter(p_section, "Name", szName, sizeof(szName))) + { + if (!p_featureMap.ObtainFeatureIndex(szName, *p_feature)) + { + Log::Warning("DR:ReadAssociatedFeature Could not find index for feature name: %s in section: %s", szName, p_section); + return false; // input error, invalid feature name + } + return true; + } + + Log::Warning("DR:ReadAssociatedFeature Could not find Name of the feature for section: %s", p_section); + return false; // feature is not specified + } +} + + +FreeForm2::NeuralInputFreeForm2::NeuralInputFreeForm2() + : m_transform("FreeForm2") +{ +} + + +FreeForm2::NeuralInputFreeForm2::NeuralInputFreeForm2(const std::string& p_input, + const char* p_transform, + DynamicRank::IFeatureMap& p_map) + : m_input(p_input), + m_transform(p_transform), + m_map(&p_map), + m_fun(NULL), + m_program(Program::Parse(CStackSizedString(p_input.c_str(), p_input.size()), + p_map, + true, + nullptr, + NULL)) +{ + Init(); +} + + +FreeForm2::NeuralInputFreeForm2::NeuralInputFreeForm2(const std::string& p_input, + const char* p_transform, + DynamicRank::IFeatureMap& p_map, + boost::shared_ptr p_program) + : m_input(p_input), + m_transform(p_transform), + m_map(&p_map), + m_fun(NULL), + m_program(p_program) +{ + Init(); +} + + +void +FreeForm2::NeuralInputFreeForm2::Init() +{ + // Infer actual features used by this expression. + std::set features; + SetFromNeuralNetFeatures actualFeatures(features); + m_program->ProcessFeaturesUsed(actualFeatures); + + // Copy the features into m_features. + m_features.resize(features.size()); + std::copy(features.begin(), features.end(), m_features.begin()); +} + + +double +FreeForm2::NeuralInputFreeForm2::Evaluate(UInt32 p_input[]) const +{ + try + { + return m_fun(NULL, p_input, NULL); + } + catch(const std::exception& e) + { + + } + Unreachable(__FILE__, __LINE__); +} + + +bool +FreeForm2::NeuralInputFreeForm2::Train(double p_learningRate, + double p_outputHigh, + double p_outputLow, + double p_outputDelta, + UInt32 p_inputHigh[], + UInt32 p_inputLow[]) +{ + return false; +} + + +void +FreeForm2::NeuralInputFreeForm2::GetAllAssociatedFeatures( + std::vector& p_associatedFeaturesList) const +{ + std::copy(m_features.begin(), m_features.end(), std::back_inserter(p_associatedFeaturesList)); +} + + +void +FreeForm2::NeuralInputFreeForm2::Compile(Compiler* p_compiler) +{ + // Only compile once. + if (m_program == NULL) + { + return; + } + + try + { + if (p_compiler != NULL) + { + std::unique_ptr results = p_compiler->Compile(*m_program, false); + const ExecutableCompilerResults* exec + = boost::polymorphic_downcast(results.get()); + m_exec = exec->GetExecutable(); + } + else + { + std::unique_ptr compiler(CompilerFactory::CreateExecutableCompiler( + Compiler::c_defaultOptimizationLevel, + CompilerFactory::SingleDocumentEvaluation)); + std::unique_ptr results = compiler->Compile(*m_program, false); + const ExecutableCompilerResults* exec + = boost::polymorphic_downcast(results.get()); + m_exec = exec->GetExecutable(); + } + + m_fun = m_exec->EvaluationFunction(); + + // Free the expression tree from memory. + m_program.reset(); + } + catch (const std::exception& p_except) + { + Log::Fatal("NeuralInputFreeForm2::Compile,Failed to compile %s: %s", m_transform, p_except.what()); + Log::Fatal("NeuralInputFreeForm2::Compile,Failed to compile %s: %s", m_transform, GetStringRepresentation().c_str()); + throw; + } +} + + +const FreeForm2::Program& +FreeForm2::NeuralInputFreeForm2::GetProgram() const +{ + return *m_program.get(); +} + + +double +FreeForm2::NeuralInputFreeForm2::GetMin() const +{ + return -DBL_MAX; +} + + +double +FreeForm2::NeuralInputFreeForm2::GetMax() const +{ + return DBL_MAX; +} + + +size_t +FreeForm2::NeuralInputFreeForm2::GetSize() const +{ + return sizeof(NeuralInputFreeForm2) + GetExternalSize(); +} + + +size_t +FreeForm2::NeuralInputFreeForm2::GetExternalSize() const +{ + size_t externalSize =sizeof(UInt32) * m_features.size(); + + if (m_exec.get()) + { + externalSize += sizeof(FreeForm2::Executable) + m_exec->GetExternalSize(); + } + + externalSize += DynamicRank::NeuralInput::GetExternalSize(); + externalSize += m_input.capacity() * sizeof(std::string::value_type); + return externalSize; +} + + +bool +FreeForm2::NeuralInputFreeForm2::Save(FILE *p_out, size_t p_input, const DynamicRank::IFeatureMap& p_map) const +{ + // Write header. + bool success = DynamicRank::NeuralInput::Save(p_out, p_input, p_map); + success = success + && fprintf(p_out, "Transform=%s\n", (m_transform != NULL ? m_transform : "")); + typedef boost::split_iterator SplitIter; + + unsigned int numLine = 1; + for (SplitIter iter = boost::make_split_iterator(m_input, + boost::token_finder(boost::is_any_of("\r\n"))); + iter != SplitIter(); + ++iter) + { + if (iter->size() > 0) + { + const char* str = &(*iter->begin()); + success = success && fprintf(p_out, "Line%u=%.*s\n", numLine, static_cast(iter->size()), str); + numLine++; + } + } + + return success; +} + + +std::string +FreeForm2::NeuralInputFreeForm2::LoadProgram(DynamicRank::Config& p_config, + const char* p_section, + const DynamicRank::IFeatureMap* p_featureMap, + const char* p_transform) +{ + // Read multiple lines from config, and assemble them into a program. + unsigned int numLine = 1; + bool found = true; + std::ostringstream program; + do + { + std::ostringstream lineName; + lineName << "Line" << numLine; + std::string lineStr = lineName.str(); + std::string lineValue; + found = p_config.GetStringParameter(p_section, lineStr.c_str(), lineValue); + numLine++; + program << lineValue << std::endl; + } + while (found); + + if (numLine == 2 && !found) + { + Log::Warning("NeuralInputFreeForm2::Load NeuralInputFreeForm2::Load %s", p_section); + return ""; + } + + // Guard against skipped line numbers by refusing to load when we find + // something that looks like one. + for (unsigned int i = 0; i < 10; i++) + { + std::ostringstream lineName; + lineName << "Line" << numLine + i; + std::string lineStr = lineName.str(); + std::string lineValue; + found = p_config.GetStringParameter(p_section, lineStr.c_str(), lineValue); + + if (found) + { + Log::Warning("NeuralInputFreeForm2::Found ignored parameter %s in section %s: did you skip line number %s?", lineStr.c_str(), p_section, to_string(numLine + i).c_str()); + return ""; + } + } + + return program.str(); +} + + +FreeForm2::NeuralInputFreeForm2* +FreeForm2::NeuralInputFreeForm2::Load(DynamicRank::Config& p_config, + const char* p_section, + DynamicRank::IFeatureMap* p_featureMap, + const char* p_transform) +{ + std::string programStr = LoadProgram(p_config, p_section, p_featureMap, p_transform); + if (programStr.empty()) + { + return NULL; + } + + try + { + return new NeuralInputFreeForm2(programStr, p_transform, *p_featureMap); + } + catch (const std::exception& p_except) + { + Log::Warning("NeuralInputFreeForm2::Load Failed to load freeform2: %s (program is '%s')", p_except.what(), programStr.c_str()); + return NULL; + } +} + + +std::string +FreeForm2::NeuralInputFreeForm2::GetStringRepresentation() const +{ + return m_input; +} + + +bool +FreeForm2::NeuralInputFreeForm2::IsFreeForm2() const +{ + // Only this one says true!!! + // We will call BatchSerialize on Inputs that said yes to IsFreeForm2() + // When do bond serialization. + return true; +} + + +bool +FreeForm2::NeuralInputFreeForm2::Equal(const DynamicRank::NeuralInput* p_other) const +{ + const NeuralInputFreeForm2* t + = dynamic_cast(p_other); + + if (t == nullptr + || m_input != t->m_input + || m_features.size() != t->m_features.size()) + { + return false; + } + + for (size_t i = 0; i< m_features.size(); ++i) + { + if (m_features[i] != t->m_features[i]) + { + return false; + } + } + + if ((m_exec == nullptr) != (t->m_exec == nullptr)) + { + return false; + } + if (m_exec == nullptr) + { + return true; + } + + const LlvmExecutableImpl* left + = dynamic_cast(&m_exec->GetImplementation()); + const LlvmExecutableImpl* right + = dynamic_cast(&t->m_exec->GetImplementation()); + + if ((left == nullptr) != (right == nullptr)) + { + return false; + } + + if (left && right) + { + return (*left == *right); + } + + return true; +} + + +void +FreeForm2::NeuralInputFreeForm2::FillBond(DynamicRank::UnionBondInput& p_data) const +{ + // Selector. + p_data.m_inputType = c_freeform2_tranform; + + DynamicRank::NeuralInputFreeForm2BondData data; + data.m_input = m_input; + for (size_t i = 0; i< m_features.size(); ++i) + { + data.m_features.push_back(m_features[i]); + } + + p_data.m_freeform2.set(data); + + // Other members will use BatchSerialize. +} + + +FreeForm2::NeuralInputFreeForm2::NeuralInputFreeForm2(const DynamicRank::UnionBondInput& p_data) +{ + if("freeform2" != p_data.m_inputType) + { + Log::Fatal("Input type '%s' is not supported. Accepted type is freeform2.", p_data.m_inputType.c_str()); + } + const DynamicRank::NeuralInputFreeForm2BondData& data = p_data.m_freeform2.value(); + for (size_t i = 0; i< data.m_features.size(); ++i) + { + m_features.push_back(data.m_features[i]); + } + m_input = data.m_input; + + // Other members will use BatchUnSerialize. +} + + +void +FreeForm2::NeuralInputCompiler::Compile(const std::vector& p_inputs, + FreeForm2::Compiler& p_compiler) +{ + if (p_inputs.empty()) + { + return; + } + + for (unsigned int i = 0; i < p_inputs.size(); i++) + { + if(p_inputs[i] == NULL) + { + Log::Fatal("Input freeform is null."); + } + p_inputs[i]->Compile(&p_compiler); + } +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/Program.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/External/Program.cpp new file mode 100644 index 000000000000..0f1396e69f2d --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/Program.cpp @@ -0,0 +1,262 @@ +#include "Program.h" +#include "FreeForm2Program.h" + +#include "Allocation.h" +#include "AllocationVisitor.h" +#include + +#include "FeatureSpec.h" +#include "FreeForm2Assert.h" +#include +#include "ObjectResolutionVisitor.h" +#include "OperandPromotionVisitor.h" +#include "ProcessFeaturesUsed.h" +#include "SExpressionParse.h" +#include +#include "TypeCheckingVisitor.h" +#include "TypeImpl.h" + +using namespace FreeForm2; + +namespace +{ + template + boost::tuples::tuple, + boost::shared_ptr> + ParseInternal(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + bool p_mustProduceFloat, + const ExternalDataManager* p_externalData, + std::ostream* p_debugOutput); + + + template <> + boost::tuples::tuple, + boost::shared_ptr> + ParseInternal(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + bool p_mustProduceFloat, + const ExternalDataManager* p_externalData, + std::ostream* p_debugOutput) + { + return SExpressionParse::Parse(p_input, p_map, p_mustProduceFloat, false); + } + + + template <> + boost::tuples::tuple, + boost::shared_ptr> + ParseInternal(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + bool p_mustProduceFloat, + const ExternalDataManager* p_externalData, + std::ostream* p_debugOutput) + { + return SExpressionParse::Parse(p_input, p_map, p_mustProduceFloat, true); + } + + std::string + ConstructParseErrorMessage( + const char* p_message, + unsigned int p_line, + unsigned int p_lineChar) + { + std::ostringstream err; + err << "Parse error: " << p_message + << " at line " << p_line + << ", char " << p_lineChar; + return err.str(); + } +} + + +FreeForm2::SourceLocation::SourceLocation() + : m_lineNo(0), + m_lineOffset(0) +{ +} + + +FreeForm2::SourceLocation::SourceLocation(unsigned int p_lineNo, unsigned int p_lineOffset) + : m_lineNo(p_lineNo), + m_lineOffset(p_lineOffset) +{ +} + + +FreeForm2::ParseError::ParseError( + const std::string& p_message, + const SourceLocation& p_location) + : runtime_error(ConstructParseErrorMessage(p_message.c_str(), p_location.m_lineNo, p_location.m_lineOffset)), + m_message(p_message), + m_sourceLocation(p_location) +{ +} + + +FreeForm2::ParseError::ParseError( + const std::exception& p_inner, + const SourceLocation& p_location) + : runtime_error(ConstructParseErrorMessage(p_inner.what(), p_location.m_lineNo, p_location.m_lineOffset)), + m_message(p_inner.what()), + m_sourceLocation(p_location) +{ +} + + +const std::string& +FreeForm2::ParseError::GetMessage() const +{ + return m_message; +} + + +const FreeForm2::SourceLocation& +FreeForm2::ParseError::GetSourceLocation() const +{ + return m_sourceLocation; +} + + +FreeForm2::ProgramImpl::ProgramImpl(const Expression& p_exp, + boost::shared_ptr p_owner, + boost::shared_ptr p_typeManager, + DynamicRank::IFeatureMap& p_map) + : m_typeImpl(p_exp.GetType()), + m_type(m_typeImpl), + m_exp(&p_exp), + m_owner(p_owner), + m_typeManager(p_typeManager), + m_map(p_map), + m_allocationVisitor(p_exp) +{ +} + + +const Type& +FreeForm2::ProgramImpl::GetType() const +{ + return m_type; +} + + +void +FreeForm2::ProgramImpl::ProcessFeaturesUsed(DynamicRank::INeuralNetFeatures& p_features) const +{ + ProcessFeaturesUsedVisitor visitor(p_features); + m_exp->Accept(visitor); +} + + +const Expression& +FreeForm2::ProgramImpl::GetExpression() const +{ + return *m_exp; +} + + +DynamicRank::IFeatureMap& +FreeForm2::ProgramImpl::GetFeatureMap() const +{ + return m_map; +} + + +const std::vector>& +FreeForm2::ProgramImpl::GetAllocations() const +{ + return m_allocationVisitor.GetAllocations(); +} + + +template +boost::shared_ptr +FreeForm2::Program::Parse(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + bool p_mustProduceFloat, + const ExternalDataManager* p_externalData, + std::ostream* p_debugOutput) +{ + using boost::tuples::get; + boost::tuples::tuple, + boost::shared_ptr> ret + = ::ParseInternal(p_input, p_map, p_mustProduceFloat, p_externalData, p_debugOutput); + + const Expression* syntaxTree = get<0>(ret); + FF2_ASSERT(syntaxTree != NULL); + boost::shared_ptr owner; + boost::shared_ptr typeManager; + + { + // Resolve unknown object types. + ObjectResolutionVisitor resolve; + get<0>(ret)->Accept(resolve); + syntaxTree = resolve.GetSyntaxTree(); + + // Ensure that all type information has been filled out. + TypeCheckingVisitor typeCheck; + syntaxTree->Accept(typeCheck); + + // Infer all missing type information. + OperandPromotionVisitor promotion; + syntaxTree->Accept(promotion); + + syntaxTree = promotion.GetSyntaxTree(); + owner = promotion.GetExpressionOwner(); + typeManager = promotion.GetTypeManager(); + } + + std::auto_ptr ptr(new ProgramImpl(*syntaxTree, owner, typeManager, p_map)); + return boost::shared_ptr(new Program(ptr)); +} + + +FreeForm2::Program::Program(std::auto_ptr p_impl) + : m_impl(p_impl.release()) +{ +} + + +const FreeForm2::Type& +FreeForm2::Program::GetType() const +{ + return m_impl->GetType(); +} + + +const FreeForm2::Expression& +FreeForm2::Program::GetExpression() const +{ + return m_impl->GetExpression(); +} + + +void +FreeForm2::Program::ProcessFeaturesUsed(DynamicRank::INeuralNetFeatures& p_features) const +{ + return m_impl->ProcessFeaturesUsed(p_features); +} + + +const std::vector>& +FreeForm2::Program::GetAllocations() const +{ + return m_impl->GetAllocations(); +} + +FreeForm2::ProgramImpl& +FreeForm2::Program::GetImplementation() +{ + return *m_impl; +} + + +const FreeForm2::ProgramImpl& +FreeForm2::Program::GetImplementation() const +{ + return *m_impl; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/Program.h b/src/transform/DynamicRank.FreeForm.Library/libs/External/Program.h new file mode 100644 index 000000000000..e034c56e7f0b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/Program.h @@ -0,0 +1,72 @@ +#pragma once + +#ifndef FREEFORM2_PROGRAM_H +#define FREEFORM2_PROGRAM_H + +#include "AllocationVisitor.h" +#include +#include +#include "FreeForm2Type.h" +#include + +namespace DynamicRank +{ + class IFeatureMap; + class INeuralNetFeatures; +} + +namespace FreeForm2 +{ + class Allocation; + class Expression; + class ExpressionOwner; + class TypeManager; + + // A ProgramImpl is essentially just the concrete instantiation of the + // Program class, into which we parse expressions. + class ProgramImpl : boost::noncopyable + { + public: + ProgramImpl(const Expression& p_exp, + boost::shared_ptr p_owner, + boost::shared_ptr p_typeManager, + DynamicRank::IFeatureMap& p_map); + + const Type& GetType() const; + + void + ProcessFeaturesUsed(DynamicRank::INeuralNetFeatures& p_features) const; + + const Expression& GetExpression() const; + + DynamicRank::IFeatureMap& + GetFeatureMap() const; + + const std::vector>& GetAllocations() const; + + private: + // Top-level type implementation of this program. + const TypeImpl& m_typeImpl; + + // Top-level type of this program. + Type m_type; + + // Pointer to root of parsed expression tree. + const Expression* m_exp; + + // Expression owner. + boost::shared_ptr m_owner; + + // Type manager. + boost::shared_ptr m_typeManager; + + // Feature map used to compile expression, used to print program + // information if exceptions occur. + DynamicRank::IFeatureMap& m_map; + + // A visitor to extract all the allocations from the program. + const AllocationVisitor m_allocationVisitor; + }; +} + +#endif \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/ResultIteratorImpl.h b/src/transform/DynamicRank.FreeForm.Library/libs/External/ResultIteratorImpl.h new file mode 100644 index 000000000000..3d494f24f01b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/ResultIteratorImpl.h @@ -0,0 +1,50 @@ +#pragma once + +#ifndef FREEFORM2_RESULTITERATORIMPL_H +#define FREEFORM2_RESULTITERATORIMPL_H + +#include "FreeForm2Result.h" + +namespace FreeForm2 +{ + class ResultIteratorImpl + { + public: + virtual ~ResultIteratorImpl() + { + } + + // Delegated iterator_facade function to increment the iterator. + virtual void increment() = 0; + + // Delegated iterator_facade function to decrement the iterator. + virtual void decrement() = 0; + + // Delegated iterator_facade function to get the current element. + virtual const Result& dereference() const = 0; + + // Delegated iterator_facade function to get the current element. + virtual void advance(std::ptrdiff_t p_distance) = 0; + + // Virtual copy constructor. + virtual std::auto_ptr Clone() const = 0; + + // Having an abstract iterator puts us in a tricky position, because + // some of the iterator methods (like equal) accept another iterator + // as arg. Since we may have any number of subclasses, equal must + // be able to compare iterators that have nothing to do with each + // other. As such, we use a couple of (somewhat hacky) methods + // below to return enough information from each iterator to compare + // and calculate the difference between them without further + // knowledge. + + // Returns a pointer indicating current position, plus the element index. + virtual std::pair Position() const = 0; + + // Returns number of bytes per element. + virtual unsigned int ElementSize() const = 0; + }; +} + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/ValueResult.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/External/ValueResult.cpp new file mode 100644 index 000000000000..bec3764d8969 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/ValueResult.cpp @@ -0,0 +1,448 @@ +#include "ValueResult.h" + +#include "ArrayType.h" +#include +#include "FreeForm2Assert.h" +#include "FreeForm2Tokenizer.h" +#include "FreeForm2Type.h" +#include "FreeForm2Utils.h" +#include +#include "ResultIteratorImpl.h" +#include "TypeManager.h" + +using namespace FreeForm2; + +namespace +{ + // __declspec(noreturn) + void FailParseResult(SIZED_STRING p_result) + { + std::ostringstream err; + err << "Failed to parse value from result '" + << p_result << "'"; + throw std::runtime_error(err.str()); + } + + bool ParseBoolean(SIZED_STRING p_str) + { + std::string value(p_str.pcData, p_str.cbData); + + if (value == "true") + { + return true; + } + else if (value == "false") + { + return false; + } + else + { + FailParseResult(p_str); + } + } + + + class ValueResultIteratorImpl : public ResultIteratorImpl + { + public: + ValueResultIteratorImpl(const boost::shared_ptr* p_pos, + unsigned int p_idx) + : m_pos(p_pos), + m_idx(p_idx) + { + } + + virtual ~ValueResultIteratorImpl() + { + } + + private: + virtual void + increment() + { + m_pos++; + m_idx++; + } + + + virtual void decrement() + { + m_pos--; + m_idx--; + } + + + virtual const Result& + dereference() const + { + return **m_pos; + } + + + virtual void + advance(std::ptrdiff_t p_distance) + { + m_pos += p_distance; + m_idx += static_cast(p_distance); + } + + + virtual std::auto_ptr + Clone() const + { + return std::auto_ptr(new ValueResultIteratorImpl(m_pos, m_idx)); + } + + + virtual std::pair + Position() const + { + return std::make_pair(reinterpret_cast(m_pos), + m_idx); + } + + + virtual unsigned int + ElementSize() const + { + return sizeof(*m_pos); + } + + + const boost::shared_ptr* m_pos; + + + unsigned int m_idx; + }; +} + + +FreeForm2::ValueResult::ValueResult(FloatType p_float) + : m_typeImpl(&TypeImpl::GetFloatInstance(true)), m_type(*m_typeImpl) +{ + m_val.m_float = p_float; +} + + +FreeForm2::ValueResult::ValueResult(IntType p_int) + : m_typeImpl(&TypeImpl::GetIntInstance(true)), m_type(*m_typeImpl) +{ + m_val.m_int = p_int; +} + + +FreeForm2::ValueResult::ValueResult(UInt64Type p_int) + : m_typeImpl(&TypeImpl::GetUInt64Instance(true)), m_type(*m_typeImpl) +{ + m_val.m_uint64 = p_int; +} + + +FreeForm2::ValueResult::ValueResult(int p_int) + : m_typeImpl(&TypeImpl::GetInt32Instance(true)), m_type(*m_typeImpl) +{ + m_val.m_int32 = p_int; +} + + +FreeForm2::ValueResult::ValueResult(unsigned int p_int) + : m_typeImpl(&TypeImpl::GetUInt32Instance(true)), m_type(*m_typeImpl) +{ + m_val.m_uint32 = p_int; +} + + +FreeForm2::ValueResult::ValueResult(bool p_bool) + : m_typeImpl(&TypeImpl::GetBoolInstance(true)), m_type(*m_typeImpl) +{ + m_val.m_bool = p_bool; +} + + +FreeForm2::ValueResult::ValueResult(const ArrayType& p_arrayType, + const std::vector>& p_elements, + unsigned int p_numElements) + : m_typeImpl(&p_arrayType), m_type(*m_typeImpl), m_array(p_elements) +{ + m_val.m_array.m_elements = m_array.empty() ? NULL : &m_array[0]; + m_val.m_array.m_numElements = p_numElements; +} + + +boost::shared_ptr +FreeForm2::ValueResult::Parse(SIZED_STRING p_result, TypeManager& p_typeManager) +{ + Tokenizer tok(p_result); + boost::shared_ptr result = Parse(tok, p_result, p_typeManager); + + // Check that there's no trailing junk. + if (tok.GetToken() != TOKEN_END) + { + FailParseResult(p_result); + } + + return boost::static_pointer_cast(result); +} + + +FreeForm2::ValueResult::~ValueResult() +{ +} + + +const FreeForm2::Type& +FreeForm2::ValueResult::GetType() const +{ + return m_type; +} + + +FreeForm2::Result::IntType +FreeForm2::ValueResult::GetInt() const +{ + FF2_ASSERT(GetType().Primitive() == Type::Int); + return m_val.m_int; +} + + +FreeForm2::Result::UInt64Type +FreeForm2::ValueResult::GetUInt64() const +{ + FF2_ASSERT(GetType().Primitive() == Type::UInt64); + return m_val.m_uint64; +} + + +int +FreeForm2::ValueResult::GetInt32() const +{ + FF2_ASSERT(GetType().Primitive() == Type::Int32); + return m_val.m_int32; +} + + +unsigned int +FreeForm2::ValueResult::GetUInt32() const +{ + FF2_ASSERT(GetType().Primitive() == Type::UInt32); + return m_val.m_uint32; +} + + +FreeForm2::Result::FloatType +FreeForm2::ValueResult::GetFloat() const +{ + FF2_ASSERT(GetType().Primitive() == Type::Float); + return m_val.m_float; +} + + +bool +FreeForm2::ValueResult::GetBool() const +{ + FF2_ASSERT(GetType().Primitive() == Type::Bool); + return m_val.m_bool; +} + + +FreeForm2::ResultIterator +FreeForm2::ValueResult::BeginArray() const +{ + FF2_ASSERT(GetType().Primitive() == Type::Array); + return ResultIterator(std::auto_ptr( + new ValueResultIteratorImpl(m_val.m_array.m_elements, + 0))); +} + + +FreeForm2::ResultIterator +FreeForm2::ValueResult::EndArray() const +{ + FF2_ASSERT(GetType().Primitive() == Type::Array); + return ResultIterator(std::auto_ptr( + new ValueResultIteratorImpl(m_val.m_array.m_elements + m_val.m_array.m_numElements, + m_val.m_array.m_numElements))); +} + + +boost::shared_ptr +FreeForm2::ValueResult::Parse(Tokenizer& p_tok, SIZED_STRING p_original, TypeManager& p_typeManager) +{ + Token token = p_tok.GetToken(); + boost::shared_ptr result; + + switch (token) + { + case TOKEN_INT: + { + std::string value(p_tok.GetValue().pcData, p_tok.GetValue().cbData); + result.reset(new ValueResult(boost::lexical_cast(value))); + + // Consume token. + p_tok.Advance(); + break; + } + + case TOKEN_FLOAT: + { + std::string value(p_tok.GetValue().pcData, p_tok.GetValue().cbData); + result.reset(new ValueResult(boost::lexical_cast(value))); + + // Consume token. + p_tok.Advance(); + break; + } + + case TOKEN_ATOM: + { + std::string value(p_tok.GetValue().pcData, p_tok.GetValue().cbData); + + if (value == "infinity") + { + result.reset(new ValueResult(std::numeric_limits::infinity())); + } + else if (value == "-") + { + // The only valid possibility here would be -infinity. + p_tok.Advance(); + + value = std::string(p_tok.GetValue().pcData, p_tok.GetValue().cbData); + + if (p_tok.GetToken() == TOKEN_ATOM && value == "infinity") + { + result.reset(new ValueResult(-std::numeric_limits::infinity())); + } + else + { + std::ostringstream err; + err << "Failed to parse value from result '-'."; + throw std::runtime_error(err.str()); + } + } + else + { + // Freeforms handle literal booleans as atoms, but we need to + // recognise them here so we can produce literal boolean values. + result.reset(new ValueResult(ParseBoolean(p_tok.GetValue()))); + } + + // Consume token. + p_tok.Advance(); + break; + } + + case TOKEN_OPEN_ARRAY: + { + unsigned int numElements = 0; + std::vector> array; + p_tok.Advance(); + + // Parse first element, if any. + unsigned int sumElements = 0; + unsigned int subDimensions = 0; + Type::TypePrimitive basePrimitive = Type::Unknown; + if (p_tok.GetToken() != TOKEN_CLOSE_ARRAY) + { + // Parse array element. + array.push_back(boost::shared_ptr( + Parse(p_tok, p_original, p_typeManager))); + ValueResult& curr = *array.back(); + + if (curr.GetType().Primitive() != Type::Array) + { + sumElements += 1; + subDimensions = 0; + basePrimitive = curr.GetType().Primitive(); + } + else + { + const ArrayType& arrayType = static_cast(curr.GetType().GetImplementation()); + sumElements += arrayType.GetMaxElements(); + basePrimitive = arrayType.GetChildType().Primitive(); + subDimensions = arrayType.GetDimensionCount(); + } + + numElements++; + } + + // Parse subsequent elements, if any. + while (p_tok.GetToken() != TOKEN_CLOSE_ARRAY) + { + // Parse array element. + array.push_back(boost::shared_ptr( + Parse(p_tok, p_original, p_typeManager))); + ValueResult& curr = *array.back(); + + if (curr.GetType().Primitive() != Type::Array) + { + sumElements += 1; + if (subDimensions != 0) + { + std::ostringstream err; + err << "Previous array element was an array, current is not."; + throw std::runtime_error(err.str()); + } + + if (basePrimitive != curr.GetType().Primitive()) + { + std::ostringstream err; + err << "Previous array element was a " << basePrimitive + << ", current is a " << curr.GetType().Primitive() << "."; + throw std::runtime_error(err.str()); + } + } + else + { + const ArrayType& arrayType = static_cast(curr.GetType().GetImplementation()); + sumElements += arrayType.GetMaxElements(); + + if (subDimensions != arrayType.GetDimensionCount()) + { + std::ostringstream err; + err << "Previous array element was a " << subDimensions + << "-dimensional array, current is " + << arrayType.GetDimensionCount() + << "-dimensional"; + throw std::runtime_error(err.str()); + } + + if (basePrimitive != arrayType.GetChildType().Primitive()) + { + std::ostringstream err; + err << "Previous array element contained " << basePrimitive + << ", current contains " << curr.GetType().Primitive() << "."; + throw std::runtime_error(err.str()); + } + } + + numElements++; + } + + const ArrayType& type = p_typeManager.GetArrayType(TypeImpl::GetCommonType(basePrimitive, true), + true, + subDimensions + 1, + sumElements); + + result.reset(new ValueResult(type, array, numElements)); + + // Remove array close before proceeding. + p_tok.Advance(); + break; + } + + case TOKEN_CLOSE_ARRAY: + case TOKEN_END: + case TOKEN_OPEN: + case TOKEN_CLOSE: + { + FailParseResult(p_original); + } + + default: + { + Unreachable(__FILE__, __LINE__); + } + } + return result; +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/External/ValueResult.h b/src/transform/DynamicRank.FreeForm.Library/libs/External/ValueResult.h new file mode 100644 index 000000000000..ea784ea355d7 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/External/ValueResult.h @@ -0,0 +1,94 @@ +#pragma once + +#ifndef FREEFORM2_VALUE_RESULT_H +#define FREEFORM2_VALUE_RESULT_H + +#include +#include "Expression.h" +#include "FreeForm2Result.h" +#include "FreeForm2Type.h" +#include "FreeForm2Tokenizer.h" +#include + +namespace FreeForm2 +{ + class ArrayType; + class TypeManager; + + class ValueResult : public Result + { + public: + ValueResult(FloatType p_float); + ValueResult(IntType p_int); + ValueResult(UInt64Type p_int); + ValueResult(Int32Type p_int); + ValueResult(UInt32Type p_int); + ValueResult(BoolType p_bool); + + // Parse a result from a string. + static boost::shared_ptr + Parse(SIZED_STRING p_result, TypeManager& p_typeManager); + + virtual ~ValueResult(); + + virtual const Type& GetType() const override; + virtual IntType GetInt() const override; + virtual UInt64Type GetUInt64() const override; + virtual Int32Type GetInt32() const override; + virtual UInt32Type GetUInt32() const override; + virtual FloatType GetFloat() const override; + virtual BoolType GetBool() const override; + virtual ResultIterator BeginArray() const override; + virtual ResultIterator EndArray() const override; + + private: + // Construct an array ValueResult; only called internally from Parse. + ValueResult(const ArrayType& p_arrayType, + const std::vector>& p_elements, + unsigned int p_numElements); + + // Internal method to parse a ValueResult from a string tokenizer. + // p_original gives the original string that is being parsed, for use in + // error messages. + static boost::shared_ptr + Parse(Tokenizer& p_tok, SIZED_STRING p_original, TypeManager& p_typeManager); + + // Type of result. + const TypeImpl* m_typeImpl; + Type m_type; + + // Structure (no constructor, destructor so that it can be used in below + // union) to represent array values. Note that an array at this level + // is simply a series of elements with a length (and has none of the + // restrictions that our implementation might, such as requiring + // 'square' arrays, or limiting dimensions). + struct ArrayVal + { + const boost::shared_ptr* m_elements; + unsigned int m_numElements; + }; + + // Val defines the different value types used by freeform expressions. + union Val + { + Result::FloatType m_float; + Result::IntType m_int; + Result::UInt64Type m_uint64; + Result::Int32Type m_int32; + Result::UInt32Type m_uint32; + Result::BoolType m_bool; + ArrayVal m_array; + }; + + // Result value. + Val m_val; + + // Container to hold array elements, if any. We separate this from the + // value above to avoid union issues. + std::vector> m_array; + }; +}; + +#endif + + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/inc/FreeForm2Tokenizer.h b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/inc/FreeForm2Tokenizer.h new file mode 100644 index 000000000000..f2ab8ec04938 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/inc/FreeForm2Tokenizer.h @@ -0,0 +1,201 @@ +#pragma once + +#ifndef FREEFORM2_TOKENIZER_H +#define FREEFORM2_TOKENIZER_H + +#include +#include +#include +#include +#include + +namespace FreeForm2 +{ + enum Token + { + // End-of-stream token. + TOKEN_END, + + // Open parenthesis. + TOKEN_OPEN, + + // Close parenthesis. + TOKEN_CLOSE, + + // Open array. + TOKEN_OPEN_ARRAY, + + // Close array. + TOKEN_CLOSE_ARRAY, + + // An atom (a name). + TOKEN_ATOM, + + // Integer. + TOKEN_INT, + + // Floating point number. + TOKEN_FLOAT, + }; + + // Class that turns a stream of characters into tokens. + class Tokenizer + { + public: + // Construct a tokeniser over the given input. + explicit Tokenizer(SIZED_STRING p_input); + + // Advance to the next token, which is returned. + Token Advance(); + + // Get the current token type. + Token GetToken() const; + + // Get the text that produced the current token. + SIZED_STRING GetValue() const; + + // Gets the offset of the current token with respect to + // the input. + unsigned int GetPosition() const; + + // Returns the name of the given token. + static const char* TokenName(Token p_token); + + // Signal the tokenizer to start recording tokens to be associated + // with a macro. The current token is included in the macro. + void StartMacro(SIZED_STRING p_macroName); + + // Signal the tokenizer that macro recording should end. The current + // token will not be included in the macro. All subsequent ATOM tokens + // will be compared against the macro name. Any matching atom will be + // expanded into the recorded macro tokens. Macros will be expanded + // during playback, not during recording. + void EndMacro(); + + // Delete a macro previously recorded with calls to Start/EndMacro. + // This method returns true if the macro with the specified name was + // successfully deleted; otherwise, returns false. + bool DeleteMacro(SIZED_STRING p_macroName); + + // Return a flag to determine if the tokenizer is currently expanding + // a macro. + bool IsExpandingMacro() const; + + // Return a flag to determine if the tokenizer is currently recording + // a macro. + bool IsRecordingMacro() const; + + private: + // Consume one character from the input. + void AdvanceChar(); + + // Advance the token being read from the input stream. + Token AdvanceInput(); + + // The original input. + SIZED_STRING m_originalInput; + + // Remaining input. + SIZED_STRING m_input; + + // Current token type. + Token m_current; + + // Text that produced the current token. + SIZED_STRING m_value; + + // This struct contains state data related to the recording of macros. + // The macro state object can essentially be in one of three modes: + // 1. No action is needed by the state. + // 2. Macro expansion is in progress. The expansion of a macro is + // referred to as 'playback' in this class. + // 3. A macro is being recorded. + struct MacroState + { + // Constructor to initialize flags. + MacroState(); + + // This method manages the macro recording state to start recording + // a macro. + void StartMacro(SIZED_STRING p_macroName); + + // Record a token to the current macro. + void RecordToken(Token p_token, SIZED_STRING p_value); + + // This method manages the macro recording state to end recording a + // macro. It also pops the last token off the stream to exclude it + // from the macro. + void EndMacro(); + + // Signal if macro recording is in progress. + bool IsRecording() const; + + // Signal if the macro playback is in progress. + bool IsInPlayback() const; + + // A token recorded during macro recording. + struct RecordedToken + { + // The type of the recorded token. + Token m_token; + + // The string associated with the token. + SIZED_STRING m_value; + }; + + // Return the current token. If not in playback, this will return + // a TOKEN_END. + RecordedToken GetCurrentToken() const; + + // Start the playback of a macro. Playback here means the stateful + // expansion of a macro. If a macro does not exist for the given + // name, this function returns false. + bool PlayMacro(SIZED_STRING p_name); + + // Advance playback by one token. + void Advance(); + + // Delete a macro by name. This function returns true if the macro + // was successfully deleted; otherwise, returns false. + bool DeleteMacro(SIZED_STRING p_name); + + private: + // This type represents a stream of recorded tokens which can be + // used to play back a macro. + typedef std::vector MacroStream; + + // This type is the iterator type for macro streams. + typedef MacroStream::const_iterator MacroStreamIter; + + // This type represents the playback state of a macro stream. The + // beginning of the range is the current playback location, which + // is advanced during playback until the range is empty. + typedef std::pair PlaybackState; + + // A comparison functor to check less-than equality for + // SIZED_STRING objects. + struct SizedStringLess + { + bool operator()(SIZED_STRING p_left, SIZED_STRING p_right) const; + }; + + // A map of macro names to the token stream with which they were + // recorded. + std::map m_macros; + + // The stream to which macro recording is currently writing. + MacroStream* m_recordingStream; + + // This list contains all playback states currently in progress. + // The back element is the state of the macro currently being + // expanded. + std::list m_playbackStack; + }; + + MacroState m_macroState; + }; +} + +#endif + + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/inc/SExpressionParse.h b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/inc/SExpressionParse.h new file mode 100644 index 000000000000..02f171fe7560 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/inc/SExpressionParse.h @@ -0,0 +1,67 @@ +#pragma once + +#ifndef FREEFORM2_SEXPRESSIONPARSE_H +#define FREEFORM2_SEXPRESSIONPARSE_H + +#include +#include +#include "FreeForm2.h" +#include "FreeForm2Tokenizer.h" + +namespace FreeForm2 +{ + class ProgramParseState; + class Expression; + class ExpressionOwner; + class ArrayLiteralExpression; + class TypeManager; + + class SExpressionParse + { + public: + // Main driver function for parsing. p_expressionLimit can be used to limit + // the number of expressions parsed into the current expression, with zero + // reserved to indicate no limit. + static Token ParseTokens(ProgramParseState& p_state, + unsigned int p_expressionLimit); + + typedef boost::tuples::tuple, + boost::shared_ptr> ParserResults; + + // Parse an expression from a string. + static ParserResults + Parse(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + bool p_mustProduceFloat, + bool p_parsingAggregatedExpression); + + // Parse an array dereference expression. + static Token ParseArrayDereference(ProgramParseState& p_state); + + // Function to parse a let expression. + static Token ParseLet(ProgramParseState& p_state); + + // Function to parse a macro-let expression. + static Token ParseMacroLet(ProgramParseState& p_state); + + // Parse a RangeReduceExpression. + static Token ParseRangeReduce(ProgramParseState& p_state); + + // Parse an ArrayLiteralExpression. + static Token ParseArrayLiteral(ProgramParseState& p_state); + + // Parse a lambda expression. + static Token ParseLambda(ProgramParseState& p_state); + + // Parse an invoke expression. + static Token ParseInvoke(ProgramParseState& p_state); + + private: + // Recursively parse array literal expressions. + static const ArrayLiteralExpression& ParseArrayLiteralRecurse(ProgramParseState& p_state); + }; +} + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Arithmetic.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Arithmetic.cpp new file mode 100644 index 000000000000..bc42de96c9ba --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Arithmetic.cpp @@ -0,0 +1,223 @@ +#include "Arithmetic.h" + +#include "BinaryOperator.h" +#include "OperatorExpressionFactory.h" +#include "UnaryOperator.h" +#include "TypeUtil.h" + +using namespace FreeForm2; + +namespace +{ + // This class adds ConvertToIntExpressions where appropriate to truncate + // all operands before passing them to the OperatorExpressionFactory. + class TruncatingOperatorFactory : public ExpressionFactory + { + public: + TruncatingOperatorFactory(UnaryOperator::Operation p_unaryOp, + BinaryOperator::Operation p_binaryOp) + : m_factory(p_unaryOp, p_binaryOp, false) + { + } + + private: + // Internal factory to create the operator expression. + OperatorExpressionFactory m_factory; + + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + ProgramParseState::ExpressionParseState opState(m_factory, p_state.m_atom, p_state.m_offset); + + for (size_t i = 0; i < p_state.m_children.size(); i++) + { + if (p_state.m_children[i]->GetType().Primitive() == Type::Float) + { + boost::shared_ptr expr( + TypeUtil::Convert(*p_state.m_children[i], Type::Int)); + p_owner.AddExpression(expr); + opState.Add(*expr); + } + else + { + opState.Add(*p_state.m_children[i]); + } + } + opState.m_variableIds.insert(opState.m_variableIds.begin(), + p_state.m_variableIds.begin(), + p_state.m_variableIds.end()); + return opState.Finish(p_owner, p_typeManager); + }; + + + virtual + std::pair + Arity() const override + { + return std::make_pair(1, 2); + } + }; + + typedef OperatorExpressionFactory OperatorFactory; + static const OperatorFactory c_plusFactory(UnaryOperator::invalid, + BinaryOperator::plus, + true); + static const OperatorFactory c_minusFactory(UnaryOperator::minus, + BinaryOperator::minus, + false); + static const OperatorFactory c_mulFactory(UnaryOperator::invalid, + BinaryOperator::multiply, + false); + static const OperatorFactory c_divFactory(UnaryOperator::invalid, + BinaryOperator::divides, + false); + static const OperatorFactory c_modFactory(UnaryOperator::invalid, + BinaryOperator::mod, + false); + static const OperatorFactory c_maxFactory(UnaryOperator::invalid, + BinaryOperator::max, + false); + static const OperatorFactory c_minFactory(UnaryOperator::invalid, + BinaryOperator::min, + false); + static const OperatorFactory c_powFactory(UnaryOperator::invalid, + BinaryOperator::pow, + false); + static const OperatorFactory c_unaryLogFactory(UnaryOperator::log, + BinaryOperator::invalid, + false); + static const OperatorFactory c_binaryLogFactory(UnaryOperator::invalid, + BinaryOperator::log, + false); + static const OperatorFactory c_log1Factory(UnaryOperator::log1, + BinaryOperator::invalid, + false); + static const OperatorFactory c_absFactory(UnaryOperator::abs, + BinaryOperator::invalid, + false); + static const OperatorFactory c_roundFactory(UnaryOperator::round, + BinaryOperator::invalid, + false); + static const OperatorFactory c_truncFactory(UnaryOperator::trunc, + BinaryOperator::invalid, + false); + static const TruncatingOperatorFactory c_intDivFactory(UnaryOperator::invalid, + BinaryOperator::divides); + static const TruncatingOperatorFactory c_intModFactory(UnaryOperator::invalid, + BinaryOperator::mod); +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetPlusInstance() +{ + return c_plusFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetMinusInstance() +{ + return c_minusFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetMultiplyInstance() +{ + return c_mulFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetDividesInstance() +{ + return c_divFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetIntegerDivInstance() +{ + return c_intDivFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetIntegerModInstance() +{ + return c_intModFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetModInstance() +{ + return c_modFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetMaxInstance() +{ + return c_maxFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetMinInstance() +{ + return c_minFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetPowInstance() +{ + return c_powFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetUnaryLogInstance() +{ + return c_unaryLogFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetBinaryLogInstance() +{ + return c_binaryLogFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetLog1Instance() +{ + return c_log1Factory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetAbsInstance() +{ + return c_absFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetRoundInstance() +{ + return c_roundFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Arithmetic::GetTruncInstance() +{ + return c_truncFactory; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Arithmetic.h b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Arithmetic.h new file mode 100644 index 000000000000..a329e2a3243a --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Arithmetic.h @@ -0,0 +1,33 @@ +#pragma once + +#ifndef FREEFORM2_ARITHMETIC_H +#define FREEFORM2_ARITHMETIC_H + +namespace FreeForm2 +{ + class ExpressionFactory; + + namespace Arithmetic + { + const ExpressionFactory& GetPlusInstance(); + const ExpressionFactory& GetMinusInstance(); + const ExpressionFactory& GetMultiplyInstance(); + const ExpressionFactory& GetDividesInstance(); + const ExpressionFactory& GetIntegerDivInstance(); + const ExpressionFactory& GetIntegerModInstance(); + const ExpressionFactory& GetModInstance(); + const ExpressionFactory& GetMaxInstance(); + const ExpressionFactory& GetMinInstance(); + const ExpressionFactory& GetPowInstance(); + const ExpressionFactory& GetUnaryLogInstance(); + const ExpressionFactory& GetBinaryLogInstance(); + const ExpressionFactory& GetLog1Instance(); + const ExpressionFactory& GetAbsInstance(); + const ExpressionFactory& GetRoundInstance(); + const ExpressionFactory& GetTruncInstance(); + } +}; + + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Bitwise.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Bitwise.cpp new file mode 100644 index 000000000000..4b4e89dd1ab7 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Bitwise.cpp @@ -0,0 +1,37 @@ +#include "Bitwise.h" + +#include "BinaryOperator.h" +#include "OperatorExpressionFactory.h" +#include "UnaryOperator.h" + +using namespace FreeForm2; + +namespace +{ + typedef OperatorExpressionFactory OperatorExpression; + static const OperatorExpression c_and(UnaryOperator::invalid, BinaryOperator::_bitand, false); + static const OperatorExpression c_or(UnaryOperator::invalid, BinaryOperator::_bitor, false); + static const OperatorExpression c_not(UnaryOperator::bitnot, BinaryOperator::invalid, false); +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Bitwise::GetAndInstance() +{ + return c_and; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Bitwise::GetOrInstance() +{ + return c_or; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Bitwise::GetNotInstance() +{ + return c_not; +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Bitwise.h b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Bitwise.h new file mode 100644 index 000000000000..76f156659474 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Bitwise.h @@ -0,0 +1,20 @@ +#pragma once + +#ifndef FREEFORM2_BITWISE_H +#define FREEFORM2_BITWISE_H + +namespace FreeForm2 +{ + class ExpressionFactory; + + namespace Bitwise + { + const ExpressionFactory& GetAndInstance(); + const ExpressionFactory& GetOrInstance(); + const ExpressionFactory& GetNotInstance(); + } +}; + + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/CMakeLists.txt b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/CMakeLists.txt new file mode 100644 index 000000000000..b1cad96ced8a --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/CMakeLists.txt @@ -0,0 +1,31 @@ +cmake_minimum_required(VERSION 3.15) + +set(PROJECT_NAME DRFreeFormSExpressionLibrary) + + +Project(${PROJECT_NAME}) + +SET(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS} -fpermissive") + + + +add_library(${PROJECT_NAME} STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/Arithmetic.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Bitwise.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ExpressionFactory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/FreeForm2Tokenizer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/Logic.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/MiscFactory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ProgramParseState.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/SExpressionParse.cpp +) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../inc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../../inc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../NeuralTree.Library/inc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../Transform + ${CMAKE_CURRENT_SOURCE_DIR}/../../../Shared + ${CMAKE_CURRENT_SOURCE_DIR}/../../../Expression + ) diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ExpressionFactory.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ExpressionFactory.cpp new file mode 100644 index 000000000000..be19dd7d54ff --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ExpressionFactory.cpp @@ -0,0 +1,26 @@ +#include "ExpressionFactory.h" + +#include + +const FreeForm2::Expression& +FreeForm2::ExpressionFactory::Create(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const +{ + std::pair arity = Arity(); + if (p_state.m_children.size() >= arity.first + && p_state.m_children.size() <= arity.second) + { + return CreateExpression(p_state, p_owner, p_typeManager); + } + else + { + // Incorrect arity, throw exception. + std::ostringstream err; + err << "Arity of " << std::string(SIZED_STR(p_state.m_atom)) << " was " + << p_state.m_children.size() << " but was expected to be in range [" + << Arity().first << ", " + << Arity().second << "]"; + throw std::runtime_error(err.str()); + } +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ExpressionFactory.h b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ExpressionFactory.h new file mode 100644 index 000000000000..a7c13b1fd62d --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ExpressionFactory.h @@ -0,0 +1,51 @@ +#pragma once + +#ifndef FREEFORM2_EXPRESSIONFACTORY_H +#define FREEFORM2_EXPRESSIONFACTORY_H + +#include +#include +#include +#include "ProgramParseState.h" +#include "TypeImpl.h" +#include + +namespace FreeForm2 +{ + class Expression; + class SimpleExpressionOwner; + class TypeManager; + + // Base class that assists parsing by creating an expression from a given + // set of children. + class ExpressionFactory : boost::noncopyable + { + public: + typedef boost::shared_ptr Ptr; + + typedef std::vector ChildVec; + + // Creates an expression from the given children, with the returned + // expression being owned by the given owner. p_atom specifies the atom + // with which this expression factory was identified during parsing, + // which allows us to provide decent error messages. + const Expression& Create(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const; + + private: + // Creates an expression from the given parse state. + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const = 0; + + // Indicates the allowed arity of expressions produced from + // this factory, in a [min, max] pair (both ends inclusive). + virtual std::pair Arity() const = 0; + }; +} + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/FreeForm2Tokenizer.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/FreeForm2Tokenizer.cpp new file mode 100644 index 000000000000..11eb863c9776 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/FreeForm2Tokenizer.cpp @@ -0,0 +1,489 @@ +#include "FreeForm2Tokenizer.h" + +#include "FreeForm2Assert.h" +#include +#include +#include + +FreeForm2::Tokenizer::Tokenizer(SIZED_STRING p_input) + : m_input(p_input), + m_originalInput(p_input) +{ + Advance(); +} + + +FreeForm2::Token +FreeForm2::Tokenizer::GetToken() const +{ + return m_current; +} + + +FreeForm2::Token +FreeForm2::Tokenizer::Advance() +{ + // The Tokenizer should never be both expanding and recording a macro at + // the same time. + FF2_ASSERT(!(m_macroState.IsInPlayback() && m_macroState.IsRecording())); + + // Check for macro playback first. + if (m_macroState.IsInPlayback()) + { + MacroState::RecordedToken token = m_macroState.GetCurrentToken(); + m_value = token.m_value; + m_current = token.m_token; + } + else + { + Token token = AdvanceInput(); + FF2_ASSERT(m_current == token); + } + + if (m_macroState.IsRecording()) + { + m_macroState.RecordToken(m_current, m_value); + } + else + { + // Check for macro expansion. + while (m_current == TOKEN_ATOM && m_macroState.PlayMacro(m_value)) + { + MacroState::RecordedToken token = m_macroState.GetCurrentToken(); + m_current = token.m_token; + m_value = token.m_value; + } + } + + if (m_macroState.IsInPlayback()) + { + m_macroState.Advance(); + } + + return m_current; +} + + +FreeForm2::Token +FreeForm2::Tokenizer::AdvanceInput() +{ + // Remove all comments and whitespace from the front of the input. Note + // that we have to loop, as comments and whitespace can be arbitrarily long. + bool reduced = false; + do + { + reduced = false; + + // Discard comments. + if (m_input.cbData > 0 && m_input.pbData[0] == '#') + { + reduced = true; + AdvanceChar(); + while (m_input.cbData > 0 + && m_input.pcData[0] != '\r' && m_input.pcData[0] != '\n') + { + AdvanceChar(); + } + } + + // Discard whitespace. + while (m_input.cbData > 0 && isspace(m_input.pbData[0])) + { + reduced = true; + AdvanceChar(); + } + } + while (reduced); + + if (m_input.cbData == 0) + { + m_value.cbData = 0; + return (m_current = TOKEN_END); + } + + if (m_input.pbData[0] == '(') + { + m_value = m_input; + m_value.cbData = 1; + AdvanceChar(); + return (m_current = TOKEN_OPEN); + } + else if (m_input.pbData[0] == ')') + { + m_value = m_input; + m_value.cbData = 1; + AdvanceChar(); + return (m_current = TOKEN_CLOSE); + } + else if (m_input.pbData[0] == '[') + { + m_value = m_input; + m_value.cbData = 1; + AdvanceChar(); + return (m_current = TOKEN_OPEN_ARRAY); + } + else if (m_input.pbData[0] == ']') + { + m_value = m_input; + m_value.cbData = 1; + AdvanceChar(); + return (m_current = TOKEN_CLOSE_ARRAY); + } + else if (isdigit(m_input.pbData[0]) || (m_input.pbData[0] == '-' + && m_input.cbData > 1 && isdigit(m_input.pbData[1]))) + { + // Parse literal numeric value. Note that we had to use a lookahead to + // tell the difference between '-1.0' and '-' (the atom). + m_value.pbData = m_input.pbData; + AdvanceChar(); + + while (m_input.cbData > 0 && isdigit(m_input.pbData[0])) + { + AdvanceChar(); + } + + Token tok = TOKEN_INT; + + // Parse decimal in float. + if (m_input.cbData > 0 && m_input.pcData[0] == '.') + { + tok = TOKEN_FLOAT; + AdvanceChar(); + + while (m_input.cbData > 0 && isdigit(m_input.pbData[0])) + { + AdvanceChar(); + } + } + + // Allow exponents on floating point numbers. + if (m_input.cbData > 0 && (m_input.pcData[0] == 'e' || m_input.pcData[0] == 'E')) + { + tok = TOKEN_FLOAT; + AdvanceChar(); + + if (m_input.cbData > 0 && (m_input.pcData[0] == '-' || m_input.pcData[0] == '+')) + { + // Allow negative exponent. + AdvanceChar(); + } + + while (m_input.cbData > 0 && isdigit(m_input.pbData[0])) + { + AdvanceChar(); + } + } + + m_value.cbData = m_input.pbData - m_value.pbData; + return (m_current = tok); + } + else if (isalpha(m_input.pbData[0])) + { + // Parse atom. + m_value.pbData = m_input.pbData; + AdvanceChar(); + + while (m_input.cbData > 0 + && (isalnum(m_input.pbData[0]) + || m_input.pbData[0] == '_' + || m_input.pbData[0] == '-' + || m_input.pbData[0] == ':' + || m_input.pbData[0] == '@') + || m_input.pbData[0] == '.' + || m_input.pbData[0] == '|') + { + AdvanceChar(); + } + + m_value.cbData = m_input.pbData - m_value.pbData; + return (m_current = TOKEN_ATOM); + } + else if (ispunct(m_input.pbData[0]) && m_input.pcData[0] != '#' && m_input.pcData[0] != '@') + { + // Parse atom. + m_value.pbData = m_input.pbData; + + for (AdvanceChar(); + m_input.cbData > 0 && ispunct(m_input.pbData[0]) + && m_input.pbData[0] != '#' && m_input.pbData[0] != '@' + && m_input.pbData[0] != ')' && m_input.pbData[0] != ']' + && m_input.pbData[0] != '(' && m_input.pbData[0] != '['; + AdvanceChar()); + + m_value.cbData = m_input.pbData - m_value.pbData; + return (m_current = TOKEN_ATOM); + } + else + { + std::ostringstream err; + err << "Invalid character '" << m_input.pcData[0] << "' (ascii " + << static_cast(m_input.pbData[0]) + << " in decimal) found in input."; + throw std::runtime_error(err.str()); + } +} + + +SIZED_STRING +FreeForm2::Tokenizer::GetValue() const +{ + return m_value; +} + + +unsigned int +FreeForm2::Tokenizer::GetPosition() const +{ + return static_cast(m_value.pbData - m_originalInput.pbData); +} + + +void +FreeForm2::Tokenizer::AdvanceChar() +{ + m_input.pbData++; + m_input.cbData--; +} + + +const char* +FreeForm2::Tokenizer::TokenName(Token p_token) +{ + switch (p_token) + { + case TOKEN_END: return "end"; + case TOKEN_OPEN: return "open"; + case TOKEN_CLOSE: return "close"; + case TOKEN_OPEN_ARRAY: return "open array"; + case TOKEN_CLOSE_ARRAY: return "close array"; + case TOKEN_ATOM: return "atom"; + case TOKEN_INT: return "int"; + case TOKEN_FLOAT: return "float"; + + default: + { + Unreachable(__FILE__, __LINE__); + break; + } + } +} + + +void +FreeForm2::Tokenizer::StartMacro(SIZED_STRING p_macroName) +{ + if (m_macroState.IsInPlayback()) + { + std::ostringstream err; + err << "Cannot define a macro while another macro is being expanded " + << "(additional macro definition at offset " << GetPosition() << ")"; + throw std::runtime_error(err.str()); + } + m_macroState.StartMacro(p_macroName); +} + + +void +FreeForm2::Tokenizer::EndMacro() +{ + m_macroState.EndMacro(); +} + + +bool +FreeForm2::Tokenizer::DeleteMacro(SIZED_STRING p_name) +{ + return m_macroState.DeleteMacro(p_name); +} + + +bool +FreeForm2::Tokenizer::IsExpandingMacro() const +{ + return m_macroState.IsInPlayback(); +} + + +bool +FreeForm2::Tokenizer::IsRecordingMacro() const +{ + return m_macroState.IsRecording(); +} + + +FreeForm2::Tokenizer::MacroState::MacroState() + : m_recordingStream(nullptr) +{ +} + + +void +FreeForm2::Tokenizer::MacroState::StartMacro(SIZED_STRING p_macroName) +{ + FF2_ASSERT(!IsRecording() && "Cannot nest macro definitions"); + + FF2_ASSERT(p_macroName.cbData > 0 && p_macroName.pcData != nullptr + && "Macro name cannot be empty"); + + if (m_macros.find(p_macroName) != m_macros.end()) + { + std::ostringstream err; + err << "Macro already defined: " << SIZED_STR_STL(p_macroName); + throw std::runtime_error(err.str()); + } + + const auto ret = m_macros.insert(std::make_pair(p_macroName, MacroStream())); + FF2_ASSERT(ret.second && ret.first != m_macros.end()); + m_recordingStream = &ret.first->second; +} + + +void +FreeForm2::Tokenizer::MacroState::RecordToken(Token p_token, SIZED_STRING p_value) +{ + FF2_ASSERT(IsRecording() && "Macro recording not started"); + RecordedToken token; + token.m_token = p_token; + token.m_value = p_value; + m_recordingStream->push_back(token); +} + + +void +FreeForm2::Tokenizer::MacroState::EndMacro() +{ + FF2_ASSERT(IsRecording() && "Macro recording not started"); + FF2_ASSERT(!m_recordingStream->empty() && "Macro token stream cannot be empty"); + + m_recordingStream->pop_back(); + + m_recordingStream = nullptr; +} + + +bool +FreeForm2::Tokenizer::MacroState::IsRecording() const +{ + return m_recordingStream != nullptr; +} + + +bool +FreeForm2::Tokenizer::MacroState::IsInPlayback() const +{ + return !m_playbackStack.empty(); +} + + +FreeForm2::Tokenizer::MacroState::RecordedToken +FreeForm2::Tokenizer::MacroState::GetCurrentToken() const +{ + if (IsInPlayback()) + { + const PlaybackState& state = m_playbackStack.back(); + FF2_ASSERT(state.second != state.first->cend()); + return *state.second; + } + else + { + const RecordedToken token = { TOKEN_END, { nullptr, 0 } }; + return token; + } +} + + +bool +FreeForm2::Tokenizer::MacroState::PlayMacro(SIZED_STRING p_name) +{ + const auto find = m_macros.find(p_name); + if (find != m_macros.end()) + { + const MacroStream& stream = find->second; + PlaybackState range(&stream, stream.cbegin()); + FF2_ASSERT(!range.first->empty() && "Empty macros are not allowed."); + + for (auto iter = m_playbackStack.cbegin(); iter != m_playbackStack.cend(); ++iter) + { + if (&stream == iter->first) + { + std::ostringstream err; + err << "Macro definition is malformed: recursive macros are not allowed " + << "for macro: " << SIZED_STR_STL(p_name); + throw std::runtime_error(err.str()); + } + } + + m_playbackStack.push_back(range); + return true; + } + else + { + return false; + } +} + + +void +FreeForm2::Tokenizer::MacroState::Advance() +{ + FF2_ASSERT(IsInPlayback() && "Must be in playback to advance token"); + + PlaybackState& range = m_playbackStack.back(); + FF2_ASSERT(range.second != range.first->cend()); + ++range.second; + + if (range.second == range.first->cend()) + { + // This invalidates the reference contained by range. Range may not be + // accessed after the pop call. + m_playbackStack.pop_back(); + if (IsInPlayback()) + { + Advance(); + } + } +} + + +bool +FreeForm2::Tokenizer::MacroState::DeleteMacro(SIZED_STRING p_name) +{ + const auto find = m_macros.find(p_name); + if (find != m_macros.cend()) + { + const MacroStream& stream = find->second; + for (auto iter = m_playbackStack.cbegin(); iter != m_playbackStack.cend(); ++iter) + { + if (iter->first == &stream) + { + std::ostringstream err; + err << "Macro definition is malformed or contains more closing " + << "tokens than opening. Name: " << SIZED_STR_STL(p_name); + throw std::runtime_error(err.str()); + } + } + + m_macros.erase(find); + return true; + } + else + { + return false; + } +} + + +bool +FreeForm2::Tokenizer::MacroState::SizedStringLess::operator()( + SIZED_STRING p_left, + SIZED_STRING p_right) const +{ + if (p_left.cbData < p_right.cbData) + { + return std::char_traits::compare(p_left.pcData, p_right.pcData, p_left.cbData) <= 0; + } + else + { + return std::char_traits::compare(p_left.pcData, p_right.pcData, p_right.cbData) < 0; + } +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Logic.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Logic.cpp new file mode 100644 index 000000000000..e832d4cbfbe3 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Logic.cpp @@ -0,0 +1,86 @@ +#include "Logic.h" + +#include "BinaryOperator.h" +#include "OperatorExpressionFactory.h" +#include "UnaryOperator.h" + +using namespace FreeForm2; + +namespace +{ + typedef OperatorExpressionFactory OperatorFactory; + static const OperatorFactory c_eqFactory(UnaryOperator::invalid, BinaryOperator::eq, false); + static const OperatorFactory c_notEqFactory(UnaryOperator::invalid, BinaryOperator::neq, false); + static const OperatorFactory c_ltFactory(UnaryOperator::invalid, BinaryOperator::lt, false); + static const OperatorFactory c_lteFactory(UnaryOperator::invalid, BinaryOperator::lte, false); + static const OperatorFactory c_gtFactory(UnaryOperator::invalid, BinaryOperator::gt, false); + static const OperatorFactory c_gteFactory(UnaryOperator::invalid, BinaryOperator::gte, false); + static const OperatorFactory c_andFactory(UnaryOperator::invalid, BinaryOperator::_and, true); + static const OperatorFactory c_orFactory(UnaryOperator::invalid, BinaryOperator::_or, true); + static const OperatorFactory c_notFactory(UnaryOperator::_not, BinaryOperator::invalid, false); +} + + + +const FreeForm2::ExpressionFactory& +FreeForm2::Logic::GetCmpEqInstance() +{ + return c_eqFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Logic::GetCmpNotEqInstance() +{ + return c_notEqFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Logic::GetCmpLTInstance() +{ + return c_ltFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Logic::GetCmpLTEInstance() +{ + return c_lteFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Logic::GetCmpGTInstance() +{ + return c_gtFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Logic::GetCmpGTEInstance() +{ + return c_gteFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Logic::GetAndInstance() +{ + return c_andFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Logic::GetOrInstance() +{ + return c_orFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Logic::GetNotInstance() +{ + return c_notFactory; +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Logic.h b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Logic.h new file mode 100644 index 000000000000..fd70ed5106fc --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/Logic.h @@ -0,0 +1,27 @@ +#pragma once + +#ifndef FREEFORM2_LOGIC_H +#define FREEFORM2_LOGIC_H + +namespace FreeForm2 +{ + class ExpressionFactory; + + namespace Logic + { + const ExpressionFactory& GetCmpEqInstance(); + const ExpressionFactory& GetCmpNotEqInstance(); + const ExpressionFactory& GetCmpLTInstance(); + const ExpressionFactory& GetCmpLTEInstance(); + const ExpressionFactory& GetCmpGTInstance(); + const ExpressionFactory& GetCmpGTEInstance(); + + const ExpressionFactory& GetAndInstance(); + const ExpressionFactory& GetOrInstance(); + const ExpressionFactory& GetNotInstance(); + } +}; + + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/MiscFactory.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/MiscFactory.cpp new file mode 100644 index 000000000000..e8f80a449801 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/MiscFactory.cpp @@ -0,0 +1,445 @@ +#include "MiscFactory.h" + +#include +#include "ArrayLength.h" +#include "Conditional.h" +#include "ConvertExpression.h" +#include "Expression.h" +#include "ExpressionFactory.h" +#include "FeatureSpec.h" +#include "FreeForm2Assert.h" +#include "SimpleExpressionOwner.h" +#include "SelectNth.h" +#include "TypeUtil.h" +#include "RandExpression.h" +#include + +using namespace FreeForm2; + +namespace FreeForm2 +{ + class ConditionalExpressionFactory : public ExpressionFactory + { + public: + ConditionalExpressionFactory() + { + } + + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + boost::shared_ptr expr( + new ConditionalExpression(Annotations(SourceLocation(1, p_state.m_offset)), + *p_state.m_children[0], + *p_state.m_children[1], + *p_state.m_children[2])); + p_owner.AddExpression(expr); + return *expr; + } + + + virtual std::pair Arity() const override + { + return std::make_pair(3, 3); + } + }; + + static const ConditionalExpressionFactory c_cond; + + class ArrayLengthExpressionFactory : public ExpressionFactory + { + public: + ArrayLengthExpressionFactory() + { + } + + private: + // Create array-length expression from results of type-checking and + // accumulated children. + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + boost::shared_ptr expr(new ArrayLengthExpression(Annotations(SourceLocation(1, p_state.m_offset)), + *p_state.m_children[0])); + p_owner.AddExpression(expr); + return *expr; + } + + + // Indicate what the min/max arity (number of arguments) to the + // array-length expression is (both are one, as array-length takes a + // single array). + virtual std::pair + Arity() const override + { + return std::make_pair(1, 1); + } + }; + + // Singleton instance of the array-length expression factory. + static const ArrayLengthExpressionFactory c_arrayLengthFactory; + + // Base class for numeric conversions, which take a single numeric argument. + class NumericConversionFactory : public ExpressionFactory + { + private: + virtual std::pair Arity() const override + { + return std::make_pair(1, 1); + } + }; + + // Factory that accepts a single child expresion and upconverts + // it to a float: useful as a root for parsing, as the freeform language + // implicitly returns a float. + class FloatConversionFactory : public NumericConversionFactory + { + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + // Convert expression to floating point. + boost::shared_ptr convert(new ConvertToFloatExpression(Annotations(SourceLocation(1, p_state.m_offset)), + *p_state.m_children[0])); + p_owner.AddExpression(convert); + return *convert; + } + }; + + // Factory that accepts a single child expresion and implicitly upconverts + // it to a float: useful as a root for parsing, as the freeform language + // implicitly returns a float. + class IntConversionFactory : public NumericConversionFactory + { + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + // Convert expression to int. + boost::shared_ptr convert(new ConvertToIntExpression(Annotations(SourceLocation(1, p_state.m_offset)), + *p_state.m_children[0])); + p_owner.AddExpression(convert); + + return *convert; + } + }; + + // Factory that accepts a single child expresion and truncates + // it to a bool. + class BoolConversionFactory : public NumericConversionFactory + { + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + // Convert expression to int. + boost::shared_ptr convert(new ConvertToBoolExpression(Annotations(SourceLocation(1, p_state.m_offset)), + *p_state.m_children[0])); + p_owner.AddExpression(convert); + + return *convert; + } + }; + + // Factory that accepts a single child expression and then simply + // returns it - useful for parsing. + class IdentityExpressionFactory : public ExpressionFactory + { + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + return *p_state.m_children[0]; + } + + virtual std::pair Arity() const override + { + return std::make_pair(1, 1); + } + }; + + static const FloatConversionFactory c_floatFactory; + static const IntConversionFactory c_intFactory; + static const BoolConversionFactory c_boolFactory; + static const IdentityExpressionFactory c_identityFactory; + + class SelectNthExpressionFactory : public ExpressionFactory + { + public: + SelectNthExpressionFactory() + { + } + + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + boost::shared_ptr expr + = SelectNthExpression::Alloc(Annotations(SourceLocation(1, p_state.m_offset)), + p_state.m_children); + p_owner.AddExpression(expr); + return *expr; + } + + + virtual std::pair Arity() const override + { + return std::make_pair(2, UINT_MAX); + } + }; + + class SelectRangeExpressionFactory : public ExpressionFactory + { + public: + SelectRangeExpressionFactory() + { + } + + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + boost::shared_ptr expr( + new SelectRangeExpression(Annotations(SourceLocation(1, p_state.m_offset)), + *p_state.m_children[0], + *p_state.m_children[1], + *p_state.m_children[2], + p_typeManager)); + p_owner.AddExpression(expr); + return *expr; + } + + + virtual std::pair Arity() const override + { + return std::make_pair(3, 3); + } + }; + + static const SelectNthExpressionFactory c_selectNth; + static const SelectRangeExpressionFactory c_selectRange; + + // This class creates a derived feature specification expression. It will + // optionally convert the result value to float before wrapping. The + // feature specification expression should be the root of the expression + // tree. + class FeatureSpecExpressionFactory : public ExpressionFactory + { + public: + FeatureSpecExpressionFactory(bool p_mustProduceFloat) + : m_produceFloat(p_mustProduceFloat) + { + } + + private: + const bool m_produceFloat; + + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + const Expression* body = p_state.m_children[0]; + + if (m_produceFloat) + { + boost::shared_ptr convert( + TypeUtil::Convert(*p_state.m_children[0], Type::Float)); + p_owner.AddExpression(convert); + body = convert.get(); + } + + // Create the DerivedFeatureSpecExpression. "Feature" is a generic + // name (FreeForm2 expressions are anonymous). The (NULL, 0) pair + // signifies that the expression does not take special parameters. + boost::shared_ptr featureMap = + boost::make_shared(); + featureMap->emplace(FeatureSpecExpression::FeatureName("Feature"), body->GetType()); + + boost::shared_ptr expr( + new FeatureSpecExpression(Annotations(SourceLocation(1, p_state.m_offset)), + featureMap, + *body, + FeatureSpecExpression::AggregatedDerivedFeature, + true)); + p_owner.AddExpression(expr); + return *expr; + } + + + virtual std::pair Arity() const override + { + return std::make_pair(1, 1); + } + }; + + static const FeatureSpecExpressionFactory c_featureSpec(false); + static const FeatureSpecExpressionFactory c_floatFeatureSpec(true); + + // Factory that returns the instance of the RandFloatExpression singleton + // for the S-Expression language parser. + class RandomFloatFactory : public ExpressionFactory + { + public: + RandomFloatFactory() + { + } + + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + return RandFloatExpression::GetInstance(); + } + + virtual std::pair Arity() const override + { + return std::make_pair(0, 0); + } + }; + + + // Factory that returns an instance of the RandIntExpression. + class RandomIntFactory : public ExpressionFactory + { + public: + RandomIntFactory() + { + } + + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + boost::shared_ptr randomInteger(new RandIntExpression(Annotations(SourceLocation(1, p_state.m_offset)), + *p_state.m_children[0], + *p_state.m_children[1])); + p_owner.AddExpression(randomInteger); + return *randomInteger; + } + + virtual std::pair Arity() const override + { + return std::make_pair(2, 2); + } + }; + + static const RandomFloatFactory c_randFloatFactory; + static const RandomIntFactory c_randIntFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::GetArrayLengthInstance() +{ + return c_arrayLengthFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Conditional::GetIfInstance() +{ + return c_cond; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Convert::GetFloatConvertFactory() +{ + return c_floatFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Convert::GetIntConvertFactory() +{ + return c_intFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Convert::GetBoolConversionFactory() +{ + return c_boolFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Convert::GetIdentityFactory() +{ + return c_identityFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Select::GetSelectNthInstance() +{ + return c_selectNth; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Select::GetSelectRangeInstance() +{ + return c_selectRange; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::GetFeatureSpecInstance(bool p_mustConvertToFloat) +{ + if (p_mustConvertToFloat) + { + return c_floatFeatureSpec; + } + else + { + return c_featureSpec; + } +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Random::GetRandomFloatInstance() +{ + return c_randFloatFactory; +} + + +const FreeForm2::ExpressionFactory& +FreeForm2::Random::GetRandomIntInstance() +{ + return c_randIntFactory; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/MiscFactory.h b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/MiscFactory.h new file mode 100644 index 000000000000..0b4c80e86074 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/MiscFactory.h @@ -0,0 +1,43 @@ +#pragma once + +#ifndef FREEFORM2_MISC_FACTORY_H +#define FREEFORM2_MISC_FACTORY_H + +namespace FreeForm2 +{ + class ExpressionFactory; + + namespace Conditional + { + const ExpressionFactory& GetIfInstance(); + } + + namespace Random + { + const ExpressionFactory& GetRandomFloatInstance(); + const ExpressionFactory& GetRandomIntInstance(); + } + + // Returns an expression factory for the array-length primitive. + const ExpressionFactory& GetArrayLengthInstance(); + + namespace Convert + { + const ExpressionFactory& GetFloatConvertFactory(); + const ExpressionFactory& GetIntConvertFactory(); + const ExpressionFactory& GetBoolConversionFactory(); + const ExpressionFactory& GetIdentityFactory(); + } + + namespace Select + { + const ExpressionFactory& GetSelectNthInstance(); + const ExpressionFactory& GetSelectRangeInstance(); + } + + const ExpressionFactory& GetFeatureSpecInstance(bool p_mustConvertToFloat); +}; + + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/OperatorExpressionFactory.h b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/OperatorExpressionFactory.h new file mode 100644 index 000000000000..402bdaafd4be --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/OperatorExpressionFactory.h @@ -0,0 +1,86 @@ +#pragma once + +#ifndef FREEFORM2_OPERATOR_EXPRESSION_FACTORY_H +#define FREEFORM2_OPERATOR_EXPRESSION_FACTORY_H + +#include "ConvertExpression.h" +#include "OperatorExpression.h" +#include "ExpressionFactory.h" +#include "FreeForm2Assert.h" +#include "SimpleExpressionOwner.h" + +namespace FreeForm2 +{ + // An OperatorExpressionFactory creates operator expressions over arguments. + class OperatorExpressionFactory : public ExpressionFactory + { + public: + // Constructor, taking an optional unary operator, and an optional + // binary operator (though it is not sensible to provide neither). + // p_multiArity indicates whether this expression factory allows + // arbitrary numbers of parameters, which are combined with multiple + // application of the binary operator. + OperatorExpressionFactory(UnaryOperator::Operation p_unaryOp, + BinaryOperator::Operation p_binaryOp, + bool p_multiArity) + : m_unaryOp(p_unaryOp), m_binaryOp(p_binaryOp), m_multiArity(p_multiArity) + { + } + + + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + if (p_state.m_children.size() == 1) + { + // Handle unary expressions. + FF2_ASSERT(m_unaryOp != UnaryOperator::invalid); + boost::shared_ptr expr( + new UnaryOperatorExpression(Annotations(SourceLocation(1, p_state.m_offset)), + *p_state.m_children[0], + m_unaryOp)); + p_owner.AddExpression(expr); + return *expr; + } + else + { + // Handle n-ary expressions + FF2_ASSERT(m_binaryOp != BinaryOperator::invalid); + boost::shared_ptr expr + = BinaryOperatorExpression::Alloc(Annotations(SourceLocation(1, p_state.m_offset)), + p_state.m_children, + m_binaryOp, + p_typeManager); + p_owner.AddExpression(expr); + return *expr; + } + } + + + virtual std::pair + Arity() const override + { + unsigned int upper = (m_binaryOp != BinaryOperator::invalid) + ? (m_multiArity ? UINT_MAX : 2) : 1; + unsigned int lower = (m_unaryOp != UnaryOperator::invalid) ? 1 : 2; + return std::make_pair(lower, upper); + } + + + // Unary operator used by created expressions to compile arithmetic. + const UnaryOperator::Operation m_unaryOp; + + // Binary operator used by created expressions to compile arithmetic. + const BinaryOperator::Operation m_binaryOp; + + // Whether this expression combines more than two arguments using + // multiple application of the binary operator. + bool m_multiArity; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ProgramParseState.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ProgramParseState.cpp new file mode 100644 index 000000000000..a1e0ef15231e --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ProgramParseState.cpp @@ -0,0 +1,152 @@ +#include "ProgramParseState.h" + +#include "ExpressionFactory.h" +#include "ConvertExpression.h" +#include "FreeForm2Assert.h" +#include "FreeForm2Utils.h" +#include "FunctionInlineVisitor.h" +#include "MiscFactory.h" +#include +#include "TypeManager.h" + +FreeForm2::ProgramParseState::ExpressionParseState::ExpressionParseState( + const ExpressionFactory& p_factory, + SIZED_STRING p_atom, + unsigned int p_offset) + : m_factory(&p_factory), m_atom(p_atom), m_offset(p_offset) +{ +} + + +void +FreeForm2::ProgramParseState::ExpressionParseState::Add(const Expression& p_expr) +{ + m_children.push_back(&p_expr); +} + + +const FreeForm2::Expression& +FreeForm2::ProgramParseState::ExpressionParseState::Finish(SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const +{ + const FreeForm2::Expression& ret = m_factory->Create(*this, p_owner, p_typeManager); + + return ret; +} + + +FreeForm2::ProgramParseState::ProgramParseState(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + const OpMap& p_operators, + bool p_mustProduceFloat, + bool p_parsingAggregatedExpression) + : m_owner(new SimpleExpressionOwner()), + m_typeManager(TypeManager::CreateTypeManager().release()), + m_tokenizer(p_input), + m_symbols(*m_owner, &p_map), + m_operators(p_operators), + m_parsingLambdaBody(false), + m_parsingAggregatedExpression(p_parsingAggregatedExpression) +{ + // If the expression must be a float, then we create a base factory + // to handle that. + const ExpressionFactory* rootFactory = &GetFeatureSpecInstance(p_mustProduceFloat); + ExpressionParseState initialState(*rootFactory, CStackSizedString(""), 0); + m_parseStack.push_back(initialState); + + m_variableIdCounter.m_value = 0; +} + + +FreeForm2::SExpressionParse::ParserResults +FreeForm2::ProgramParseState::Finish() +{ + // An expression has now been parsed. We need to check a few things + // to ensure that it's valid. + + if (m_parseStack.size() > 1) + { + // Left expression(s) open. + std::ostringstream err; + err << "After parsing expression, " << m_parseStack.size() - 1 + << " expressions remain open."; + throw std::runtime_error(err.str()); + } + + Token tok = m_tokenizer.Advance(); + if (tok != TOKEN_END) + { + // Trailing junk after expression. + std::ostringstream err; + err << "Trailing " << Tokenizer::TokenName(tok) + << " token found after expression."; + throw std::runtime_error(err.str()); + } + + // Use the FunctionInlineVisitor to replace FunctionCallExpressions with LetExpressions. + // This needs to be done now since the parsed expression is then + // wrapped with a FeatureSpecExpression that uses the type of the parsed expression. + boost::shared_ptr functionVisitorOwner(new SimpleExpressionOwner()); + boost::shared_ptr functionVisitorTypeManager(TypeManager::CreateTypeManager().release()); + FunctionInlineVisitor functionInlineVisitor(functionVisitorOwner, + functionVisitorTypeManager, + GetNextVariableId()); + + FF2_ASSERT(m_parseStack.back().m_children.size() == 1); + const Expression* syntaxTree = m_parseStack.back().m_children.back(); + m_parseStack.back().m_children.pop_back(); + syntaxTree->Accept(functionInlineVisitor); + syntaxTree = functionInlineVisitor.GetSyntaxTree(); + m_parseStack.back().m_children.push_back(syntaxTree); + + m_owner.swap(functionVisitorOwner); + m_typeManager.swap(functionVisitorTypeManager); + + // The FunctionInlineVisitor may assign variable ids to ensure unique variable ids for each + // function call. Update the id of the ProgramParseState to ensure unique variable ids. + m_variableIdCounter = functionInlineVisitor.GetVariableId(); + + FF2_ASSERT(m_parseStack.size() == 1); + const Expression& root = m_parseStack.back().Finish(*m_owner, *m_typeManager); + + m_parseStack.pop_back(); + FF2_ASSERT(m_parseStack.empty()); + return boost::tuples::make_tuple(&root, + boost::shared_ptr(m_owner), + boost::shared_ptr(m_typeManager)); +} + + +FreeForm2::VariableID +FreeForm2::ProgramParseState::GetNextVariableId() +{ + VariableID id = m_variableIdCounter; + ++m_variableIdCounter.m_value; + return id; +} + + +const FreeForm2::Expression& +FreeForm2::ProgramParseState::GetLastParsed() const +{ + return *m_parseStack.back().m_children.back(); +} + + +void +FreeForm2::ProgramParseState::CloseExpression() +{ + const Expression& finished = m_parseStack.back().Finish(*m_owner, *m_typeManager); + m_parseStack.pop_back(); + + if (m_parseStack.empty()) + { + // Too many close parens. + std::ostringstream err; + err << "Mismatched close parenthesis"; + throw std::runtime_error(err.str()); + } + + m_parseStack.back().Add(finished); +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ProgramParseState.h b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ProgramParseState.h new file mode 100644 index 000000000000..27668ce61efc --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/ProgramParseState.h @@ -0,0 +1,129 @@ +#pragma once + +#ifndef FREEFORM2_PROGRAM_PARSE_STATE_H +#define FREEFORM2_PROGRAM_PARSE_STATE_H + +#include +#include +#include +#include "Expression.h" +#include "FreeForm2Tokenizer.h" +#include +#include +#include "SExpressionParse.h" +#include "SimpleExpressionOwner.h" +#include "SymbolTable.h" + +namespace FreeForm2 +{ + class ExpressionFactory; + class TypeManager; + + // ProgramParseState represents the parsing state of a program. + class ProgramParseState + { + public: + // Class that represents an s-expression as it's being parsed. We + // accumulate children, and then create a root expression + class ExpressionParseState + { + public: + // Construct an ExpressionParseState object, from the + // expression factory used for parsing, and the atom that + // caused this object to be constructed (only used in error messages). + ExpressionParseState(const ExpressionFactory& p_factory, + SIZED_STRING p_atom, + unsigned int p_offset); + + // Add a subexpression to this expression. + void Add(const Expression& p_expr); + + // Finish parsing and return this expression. + const Expression& Finish(SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const; + + // Accumulated children. + std::vector m_children; + + // Variable IDs allocated by the parser for this expression. + std::vector m_variableIds; + + // Expression factory that will produce the finished expression. + const ExpressionFactory* m_factory; + + // Atom that caused this parsestate to be created. + SIZED_STRING m_atom; + + // The offset of the current token with respect to the beginning of the expression. + unsigned int m_offset; + }; + + // Function that parses a special form (such as 'let'). + typedef Token (*ParsingFunction)(ProgramParseState& p_state); + + // Variant that provides either an expression factory for standard + // parsing, or a ParsingFunction for special form parsing, depending on + // the operator. + typedef boost::variant OperatorInfo; + + // Map of operator names to parsing method. + typedef std::map OpMap; + + // Construct a parse state object from the input program, the feature + // map used for this program, the available set of operators, and an + // indication as to whether this program must produce a float (by + // conversion, if necessary - this also means things that can't be + // converted to a float produce errors during parsing). + ProgramParseState(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + const OpMap& p_operators, + bool p_mustProduceFloat, + bool p_parsingAggregatedExpression); + + // Get the next variable ID. + VariableID GetNextVariableId(); + + // Return the last expression parsed. + const Expression& GetLastParsed() const; + + // Finish parsing, producing the final expression and owner. + SExpressionParse::ParserResults Finish(); + + // Close the expression currently being parsed. + void CloseExpression(); + + // Object to own produced expressions. + boost::shared_ptr m_owner; + + // Object to own/manage the types. + boost::shared_ptr m_typeManager; + + // Stack of expressions being parsed. + std::list m_parseStack; + + // List of available operators. + const OpMap& m_operators; + + // Tokenizer, to turn textual input into tokens of different types. + Tokenizer m_tokenizer; + + // Mapping from strings to bound values. + SymbolTable m_symbols; + + // Whether a lambda expression is currently being parsed. + bool m_parsingLambdaBody; + + // Whether an aggregated expression is being parsed. + const bool m_parsingAggregatedExpression; + + private: + // Counter to keep track of the next variable ID. This tracks the + // number of values that have been allocated for a let statement, or + // other special forms, and allows us to assign IDs to each value for + // later use. + VariableID m_variableIdCounter; + }; +} + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/SExpressionParse.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/SExpressionParse.cpp new file mode 100644 index 000000000000..cfdb56e3d886 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Parse/SExpression/libs/SExpressionParse.cpp @@ -0,0 +1,1382 @@ +#include "SExpressionParse.h" + +#include "Arithmetic.h" +#include "ArrayLength.h" +#include "ArrayLiteralExpression.h" +#include "ArrayDereferenceExpression.h" +#include "Bitwise.h" +#include "Conditional.h" +#include "ConvertExpression.h" +#include "ExpressionFactory.h" +#include "FreeForm2Assert.h" +#include "FreeForm2Tokenizer.h" +#include "FreeForm2Utils.h" +#include "Function.h" +#include "FunctionType.h" +#include "LetExpression.h" +#include "LiteralExpression.h" +#include "Logic.h" +#include "MiscFactory.h" +#include "ProgramParseState.h" +#include "RangeReduceExpression.h" +#include "RefExpression.h" +#include "SelectNth.h" +#include "SimpleExpressionOwner.h" +#include "SymbolTable.h" +#include "TypeManager.h" +#include "TypeUtil.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace FreeForm2; + +namespace +{ + boost::shared_ptr ParseLiteralInt(const Annotations& p_annotations, + SIZED_STRING p_value) + { + std::string source(SIZED_STR(p_value)); + boost::shared_ptr expr( + boost::make_shared(p_annotations, + boost::lexical_cast(source))); + return expr; + } + + boost::shared_ptr ParseLiteralFloat(const Annotations& p_annotations, + SIZED_STRING p_value) + { + std::string source(SIZED_STR(p_value)); + boost::shared_ptr expr( + boost::make_shared(p_annotations, + boost::lexical_cast(source))); + return expr; + } + + class StaticOperatorMap : public ProgramParseState::OpMap + { + public: + StaticOperatorMap() + { + (*this)["+"] = ProgramParseState::OperatorInfo(&Arithmetic::GetPlusInstance()); + (*this)["-"] = ProgramParseState::OperatorInfo(&Arithmetic::GetMinusInstance()); + (*this)["*"] = ProgramParseState::OperatorInfo(&Arithmetic::GetMultiplyInstance()); + (*this)["/"] = ProgramParseState::OperatorInfo(&Arithmetic::GetDividesInstance()); + (*this)["trunc-div"] = ProgramParseState::OperatorInfo(&Arithmetic::GetIntegerDivInstance()); + (*this)["trunc-mod"] = ProgramParseState::OperatorInfo(&Arithmetic::GetIntegerModInstance()); + (*this)["mod"] = ProgramParseState::OperatorInfo(&Arithmetic::GetModInstance()); + (*this)["max"] = ProgramParseState::OperatorInfo(&Arithmetic::GetMaxInstance()); + (*this)["min"] = ProgramParseState::OperatorInfo(&Arithmetic::GetMinInstance()); + (*this)["**"] = ProgramParseState::OperatorInfo(&Arithmetic::GetPowInstance()); + (*this)["^"] = ProgramParseState::OperatorInfo(&Arithmetic::GetPowInstance()); + (*this)["ln"] = ProgramParseState::OperatorInfo(&Arithmetic::GetUnaryLogInstance()); + (*this)["log"] + = ProgramParseState::OperatorInfo(&Arithmetic::GetBinaryLogInstance()); + (*this)["ln1"] = ProgramParseState::OperatorInfo(&Arithmetic::GetLog1Instance()); + (*this)["abs"] = ProgramParseState::OperatorInfo(&Arithmetic::GetAbsInstance()); + (*this)["truncate"] + = ProgramParseState::OperatorInfo(&Arithmetic::GetTruncInstance()); + (*this)["round"] + = ProgramParseState::OperatorInfo(&Arithmetic::GetRoundInstance()); + (*this)["float"] + = ProgramParseState::OperatorInfo(&Convert::GetFloatConvertFactory()); + (*this)["int"] + = ProgramParseState::OperatorInfo(&Convert::GetIntConvertFactory()); + (*this)["bool"] + = ProgramParseState::OperatorInfo(&Convert::GetBoolConversionFactory()); + + (*this)["=="] = ProgramParseState::OperatorInfo(&Logic::GetCmpEqInstance()); + (*this)["!="] = ProgramParseState::OperatorInfo(&Logic::GetCmpNotEqInstance()); + (*this)["<"] = ProgramParseState::OperatorInfo(&Logic::GetCmpLTInstance()); + (*this)["<="] = ProgramParseState::OperatorInfo(&Logic::GetCmpLTEInstance()); + (*this)[">"] = ProgramParseState::OperatorInfo(&Logic::GetCmpGTInstance()); + (*this)[">="] = ProgramParseState::OperatorInfo(&Logic::GetCmpGTEInstance()); + + (*this)["and"] = ProgramParseState::OperatorInfo(&Logic::GetAndInstance()); + (*this)["&&"] = ProgramParseState::OperatorInfo(&Logic::GetAndInstance()); + (*this)["or"] = ProgramParseState::OperatorInfo(&Logic::GetOrInstance()); + (*this)["||"] = ProgramParseState::OperatorInfo(&Logic::GetOrInstance()); + (*this)["not"] = ProgramParseState::OperatorInfo(&Logic::GetNotInstance()); + + (*this)["bitand"] = ProgramParseState::OperatorInfo(&Bitwise::GetAndInstance()); + (*this)["bitor"] = ProgramParseState::OperatorInfo(&Bitwise::GetOrInstance()); + (*this)["bitnot"] = ProgramParseState::OperatorInfo(&Bitwise::GetNotInstance()); + + (*this)["if"] = ProgramParseState::OperatorInfo(&Conditional::GetIfInstance()); + (*this)["select-nth"] = ProgramParseState::OperatorInfo(&Select::GetSelectNthInstance()); + (*this)["select-range"] = ProgramParseState::OperatorInfo(&Select::GetSelectRangeInstance()); + + (*this)["let"] = ProgramParseState::OperatorInfo(&SExpressionParse::ParseLet); + (*this)["macro-let"] = ProgramParseState::OperatorInfo(&SExpressionParse::ParseMacroLet); + (*this)["range-reduce"] = ProgramParseState::OperatorInfo(&SExpressionParse::ParseRangeReduce); + (*this)["array-literal"] + = ProgramParseState::OperatorInfo(&SExpressionParse::ParseArrayLiteral); + (*this)["array-length"] = ProgramParseState::OperatorInfo(&GetArrayLengthInstance()); + (*this)["random-float"] = ProgramParseState::OperatorInfo(&Random::GetRandomFloatInstance()); + (*this)["random-int"] = ProgramParseState::OperatorInfo(&Random::GetRandomIntInstance()); + (*this)["lambda"] = ProgramParseState::OperatorInfo(&SExpressionParse::ParseLambda); + (*this)["invoke"] = ProgramParseState::OperatorInfo(&SExpressionParse::ParseInvoke); + } + }; + + static const StaticOperatorMap c_operators; + + // Indicates whether the current parse stack has hit the expression limit, + // (in which case parsing should return to caller for + bool HitExpressionLimit(const std::list& p_stack, + size_t p_depth, + size_t p_limit) + { + return p_stack.size() == p_depth && (p_stack.back().m_children.size() >= p_limit); + } + + class ArrayDereferenceFactory : public ExpressionFactory + { + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + boost::shared_ptr ptr(boost::make_shared( + Annotations(SourceLocation(1, p_state.m_offset)), + *p_state.m_children[0], + *p_state.m_children[1], + 0)); + p_owner.AddExpression(ptr); + return *ptr; + } + + + virtual std::pair Arity() const override + { + return std::make_pair(2, 2); + } + }; + + // Global instance of the array dereference expression factory. + const static ArrayDereferenceFactory c_arrayDereferenceFactory; + + // Factory to create recursively nested array literal expressions. + class ArrayLiteralExpressionFactory : public ExpressionFactory + { + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + FF2_ASSERT(p_state.m_variableIds.size() == 1); + boost::shared_ptr ptr(ArrayLiteralExpression::Alloc(Annotations(SourceLocation(1, p_state.m_offset)), + TypeImpl::GetUnknownType(), + p_state.m_children, + p_state.m_variableIds[0], + p_typeManager)); + p_owner.AddExpression(ptr); + return *ptr; + } + + + virtual std::pair Arity() const override + { + return std::make_pair(0, UINT_MAX); + } + }; + + const static ArrayLiteralExpressionFactory c_arrayLiteralFactory; + + class LetExpressionFactory : public ExpressionFactory + { + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + // Count the number of non-function parameters that there are in the let statement. + int numNonFunctionParameters = 0; + + std::vector children; + for (size_t i = 0; i < p_state.m_variableIds.size(); ++i) + { + // Do not create a binding for lambdas. + if (p_state.m_children[i]->GetType().Primitive() != Type::Function) + { + children.push_back(std::make_pair(p_state.m_variableIds[i], p_state.m_children[i])); + ++numNonFunctionParameters; + } + } + + FF2_ASSERT(p_state.m_variableIds.size() == numNonFunctionParameters); + + if (children.size() > 0) + { + boost::shared_ptr expr + = LetExpression::Alloc(Annotations(SourceLocation(1, p_state.m_offset)), + children, + p_state.m_children.back()); + p_owner.AddExpression(expr); + return *expr; + } + else + { + // If all let bindings were lambdas, do not produce a let + // expression; no variable bindings are needed. + return *p_state.m_children.back(); + } + } + + + virtual std::pair Arity() const override + { + return std::make_pair(1, UINT_MAX); + } + }; + + // Global instance of the let expression factory. + const static LetExpressionFactory c_letFactory; + + class RangeReduceExpressionFactory : public ExpressionFactory + { + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + boost::shared_ptr one(new LiteralIntExpression(Annotations(SourceLocation(1, p_state.m_offset)), + 1)); + p_owner.AddExpression(one); + + FF2_ASSERT(p_state.m_variableIds.size() == 2); + const ChildVec& children = p_state.m_children; + boost::shared_ptr ptr( + boost::make_shared(Annotations(SourceLocation(1, p_state.m_offset)), + *children[0], + *children[1], + *children[2], + *children[3], + p_state.m_variableIds[0], + p_state.m_variableIds[1])); + + p_owner.AddExpression(ptr); + return *ptr; + } + + + virtual std::pair Arity() const override + { + return std::make_pair(4, 4); + } + }; + + // Global instance of the rangereduce expression factory. + const static RangeReduceExpressionFactory c_rangeFactory; + + class LambdaExpressionFactory : public ExpressionFactory + { + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + // Children of lambda ExpressionParseState are: + // All except last: Parameters of lambda expression. + // Last: The body of the lambda expression. + FF2_ASSERT(p_state.m_variableIds.size() == p_state.m_children.size() - 1); + std::vector params(p_state.m_variableIds.size()); + std::vector paramTypes(p_state.m_variableIds.size()); + for (size_t i = 0; i < p_state.m_variableIds.size(); i++) + { + params[i].m_parameter + = boost::polymorphic_downcast(p_state.m_children[i]); + params[i].m_isFeatureParameter = false; + paramTypes[i] = ¶ms[i].m_parameter->GetType(); + } + const Expression& body = *p_state.m_children.back(); + const FunctionType& type + = p_typeManager.GetFunctionType(body.GetType(), paramTypes.data(), paramTypes.size()); + + // The name of the lambda is unimportant to evaluation, but we will name it anyways. + static int id = 1; + std::ostringstream lambdaNameStream; + lambdaNameStream << "lambda<" << id++ << ">"; + std::string lambdaName = lambdaNameStream.str(); + + boost::shared_ptr expr + = boost::make_shared(Annotations(SourceLocation(1, p_state.m_offset)), + type, + lambdaName, + params, + body); + p_owner.AddExpression(expr); + return *expr; + } + + + virtual std::pair Arity() const override + { + return std::make_pair(1, UINT_MAX); + } + }; + + // Global instance of the lambda expression factory. + const static LambdaExpressionFactory c_lambdaFactory; + + class InvokeFactory : public ExpressionFactory + { + public: + InvokeFactory() + { + } + + private: + virtual + const Expression& + CreateExpression(const ProgramParseState::ExpressionParseState& p_state, + SimpleExpressionOwner& p_owner, + TypeManager& p_typeManager) const override + { + // The children of the inoke ExpressionParseState should be: + // 1st child: Function expression + // All others: Parameters in the invoke expression. (Must have at least one) + FF2_ASSERT(p_state.m_children.size() >= 2); + + if (p_state.m_children[0]->GetType().Primitive() != Type::Function) + { + std::ostringstream err; + err << "Invoke may only be called on a lambda or " + << "an expression bound to a lambda " + << "(called on expression of type " + << p_state.m_children[0]->GetType() + << ")"; + throw std::runtime_error(err.str()); + } + + const FunctionType& type = static_cast(p_state.m_children[0]->GetType()); + if (type.GetParameterCount() != p_state.m_children.size() - 1) + { + std::ostringstream err; + err << "Parameter count mismatch: expected " + << type.GetParameterCount() << " parameters, " + << "got " << p_state.m_children.size() - 1; + throw std::runtime_error(err.str()); + } + + std::vector params(p_state.m_children.size() - 1); + for (size_t i = 1; i < p_state.m_children.size(); i++) + { + const Expression& param = *p_state.m_children[i]; + + // Since the lambda is the first child in the Invoke expression, + // subtract 1 from the child index to get the correct parameter. + const TypeImpl& expected = *type.BeginParameters()[i - 1]; + + // Check that parameters with explicit types in the lambda + // definition are correctly specified. + if (!TypeUtil::IsAssignable(expected, param.GetType())) + { + // Unknown types should always be valid assignment destinations. + FF2_ASSERT(expected.Primitive() != Type::Unknown); + std::ostringstream err; + err << "Parameter type mismatch: expected type " + << expected << ", " + << "got " << param.GetType(); + throw std::runtime_error(err.str()); + } + params[i - 1] = ¶m; + } + + boost::shared_ptr expr( + FunctionCallExpression::Alloc(Annotations(SourceLocation(1, p_state.m_offset)), + *p_state.m_children[0], + params)); + p_owner.AddExpression(expr); + return *expr; + } + + virtual std::pair Arity() const override + { + // Invoke must have a function and at least one parameter. + return std::make_pair(2, MAX_UINT32); + } + }; + + // Global instance of the invoke factory. + static const InvokeFactory c_invokeFactory; + + // Continue advancing the tokenizer until a matched closing state is found. + // This method matches nested open-close pairs correctly. + FreeForm2::Token ParseUntilClosed(FreeForm2::ProgramParseState& p_state, + FreeForm2::Token p_open, + FreeForm2::Token p_close) + { + // Read tokens until we have a matched open- and close-parentheses. + size_t depth = 1; + FreeForm2::Token tok = p_open; + while (depth > 0 && tok != TOKEN_END) + { + tok = p_state.m_tokenizer.Advance(); + if (tok == p_open) + { + ++depth; + } + else if (tok == p_close) + { + --depth; + } + } + + return tok; + } +}; + + +FreeForm2::SExpressionParse::ParserResults +FreeForm2::SExpressionParse::Parse(SIZED_STRING p_input, + DynamicRank::IFeatureMap& p_map, + bool p_mustProduceFloat, + bool p_parsingAggregatedExpression) +{ + ProgramParseState parseState(p_input, p_map, c_operators, p_mustProduceFloat, p_parsingAggregatedExpression); + SExpressionParse::ParseTokens(parseState, 0); + + return parseState.Finish(); +} + + +FreeForm2::Token +FreeForm2::SExpressionParse::ParseTokens(ProgramParseState& p_state, + unsigned int p_expressionLimit) +{ + // Keep track of the depth of the parse stack at entry. This allows us + // to use this function recursively in a sensible fashion, by having it + // exit once we close back to this depth. + const size_t parseDepth = p_state.m_parseStack.size(); + const size_t limit = p_expressionLimit == 0 + ? static_cast(-1) + : p_state.m_parseStack.back().m_children.size() + p_expressionLimit; + + Token tok = p_state.m_tokenizer.GetToken(); + while (tok != TOKEN_END) + { + FF2_ASSERT(!p_state.m_parseStack.empty()); + + switch (tok) + { + case TOKEN_OPEN: + { + tok = p_state.m_tokenizer.Advance(); + + if (tok != TOKEN_ATOM) + { + std::ostringstream err; + err << "Expected atom (a name) after open parenthesis, " + << "got something else (" + << Tokenizer::TokenName(tok) << ")"; + throw std::runtime_error(err.str()); + } + + std::string atom(SIZED_STR(p_state.m_tokenizer.GetValue())); + ProgramParseState::OpMap::const_iterator iter + = p_state.m_operators.find(atom); + + if (iter == p_state.m_operators.end()) + { + // Couldn't find operator. + std::ostringstream err; + err << "Failed to find operator '" << atom << "'."; + throw std::runtime_error(err.str()); + } + + if (boost::get(&iter->second) == NULL) + { + // Continue parsing this operator. + const ExpressionFactory* const* factory + = boost::get(&iter->second); + FF2_ASSERT(factory != NULL && *factory != NULL); + ProgramParseState::ExpressionParseState + state(**factory, p_state.m_tokenizer.GetValue(), p_state.m_tokenizer.GetPosition()); + p_state.m_parseStack.push_back(state); + + // Advance to next token. + tok = p_state.m_tokenizer.Advance(); + } + else + { + // Handle special form. + ProgramParseState::ParsingFunction const* specialForm + = boost::get(&iter->second); + FF2_ASSERT(specialForm != NULL && *specialForm != NULL); + tok = (*specialForm)(p_state); + FF2_ASSERT(tok == TOKEN_CLOSE || tok == TOKEN_END); + } + + // Continue back up to the top of the loop, rather than going + // through the processing that follows other tokens. + continue; + } + + case TOKEN_CLOSE: + { + p_state.CloseExpression(); + + if (p_state.m_parseStack.size() < parseDepth) + { + // Have closed the expression we began parsing, consume + // token and return. + return p_state.m_tokenizer.Advance(); + } + break; + } + + case TOKEN_ATOM: + { + // Look up name in symbol table. + p_state.m_parseStack.back().Add(p_state.m_symbols.Lookup( + SymbolTable::Symbol(p_state.m_tokenizer.GetValue()))); + break; + } + + case TOKEN_INT: + { + boost::shared_ptr expr + = ParseLiteralInt(Annotations(SourceLocation(1, p_state.m_tokenizer.GetPosition())), + p_state.m_tokenizer.GetValue()); + p_state.m_owner->AddExpression(expr); + p_state.m_parseStack.back().Add(*expr); + break; + } + + case TOKEN_FLOAT: + { + boost::shared_ptr expr + = ParseLiteralFloat(Annotations(SourceLocation(1, p_state.m_tokenizer.GetPosition())), + p_state.m_tokenizer.GetValue()); + p_state.m_owner->AddExpression(expr); + p_state.m_parseStack.back().Add(*expr); + break; + } + + case TOKEN_OPEN_ARRAY: + { + // Parse the array dereference expression. Note that we can't + // handle this through the generic special form method, because + // they all start with '(', and this doesn't. + tok = ParseArrayDereference(p_state); + break; + } + + default: + { + std::ostringstream err; + err << "Unexpected token type '" << tok << "'"; + throw std::runtime_error(err.str()); + break; + } + }; + + // Advance to next token. + tok = p_state.m_tokenizer.Advance(); + + // We check to see if the next token is an open array, because it adds + // to the last expression (no other construct currently does), instead + // of creating more expressions. Thus, the expression limit hasn't + // really been hit if the next token is an open array. + if (tok != TOKEN_OPEN_ARRAY + && HitExpressionLimit(p_state.m_parseStack, parseDepth, limit)) + { + // Have closed the expression we began parsing, or ran into + // a limit on the number of sub-expressions, end now. + return tok; + } + } + + return tok; +} + + +FreeForm2::Token +FreeForm2::SExpressionParse::ParseArrayDereference(ProgramParseState& p_state) +{ + if (p_state.m_parseStack.back().m_children.empty()) + { + // User tried to dereference without providing an expression + // to dereference. + std::ostringstream err; + err << "Received " << Tokenizer::TokenName(p_state.m_tokenizer.GetToken()) + << " token, which starts an " + << "array dereference, but there is nothing to dereference"; + throw std::runtime_error(err.str()); + } + + // Pop the last expression off of the stack. + const Expression& last = p_state.GetLastParsed(); + p_state.m_parseStack.back().m_children.pop_back(); + + // Create an array dereference factory, and push the last + // expression parsed into it as the first expression. + ProgramParseState::ExpressionParseState + state(c_arrayDereferenceFactory, p_state.m_tokenizer.GetValue(), p_state.m_tokenizer.GetPosition()); + p_state.m_parseStack.push_back(state); + p_state.m_parseStack.back().m_children.push_back(&last); + + // Parse index expression. + p_state.m_tokenizer.Advance(); + Token tok = SExpressionParse::ParseTokens(p_state, 1); + + if (tok == TOKEN_END) + { + return tok; + } + else if (tok != TOKEN_CLOSE_ARRAY) + { + // User tried to dereference without providing an expression + // to dereference. + std::ostringstream err; + err << "Expected a " << Tokenizer::TokenName(TOKEN_CLOSE_ARRAY) + << " token after array dereference index, but got a " + << Tokenizer::TokenName(tok) << " token instead."; + throw std::runtime_error(err.str()); + } + + // Finish off the expression. + p_state.CloseExpression(); + + return tok; +} + + +const FreeForm2::ArrayLiteralExpression& +FreeForm2::SExpressionParse::ParseArrayLiteralRecurse(ProgramParseState& p_state) +{ + Token tok = p_state.m_tokenizer.GetToken(); + if (tok != TOKEN_OPEN_ARRAY) + { + std::ostringstream err; + err << "Expected open array token, " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + const size_t parseDepth = p_state.m_parseStack.size(); + ProgramParseState::ExpressionParseState + arrayState(c_arrayLiteralFactory, p_state.m_tokenizer.GetValue(), p_state.m_tokenizer.GetPosition()); + arrayState.m_variableIds.push_back(p_state.GetNextVariableId()); + p_state.m_parseStack.push_back(arrayState); + + tok = p_state.m_tokenizer.Advance(); + while (tok != TOKEN_CLOSE_ARRAY) + { + if (tok == TOKEN_OPEN_ARRAY) + { + p_state.m_parseStack.back().Add(ParseArrayLiteralRecurse(p_state)); + tok = p_state.m_tokenizer.Advance(); + } + else if (tok == TOKEN_END) + { + std::ostringstream err; + err << "Unexpected end to program with array literal still open"; + throw std::runtime_error(err.str()); + } + else + { + tok = SExpressionParse::ParseTokens(p_state, 1); + } + } + + FF2_ASSERT(p_state.m_tokenizer.GetToken() == TOKEN_CLOSE_ARRAY); + const ArrayLiteralExpression* result + = boost::polymorphic_downcast( + &p_state.m_parseStack.back().Finish(*p_state.m_owner, *p_state.m_typeManager)); + p_state.m_parseStack.pop_back(); + FF2_ASSERT(p_state.m_parseStack.size() == parseDepth); + return *result; +} + + +FreeForm2::Token +FreeForm2::SExpressionParse::ParseArrayLiteral(ProgramParseState& p_state) +{ + // Advance past token that started this special form. + p_state.m_tokenizer.Advance(); + + const size_t parseDepth = p_state.m_parseStack.size(); + const ArrayLiteralExpression& array = ParseArrayLiteralRecurse(p_state); + FF2_ASSERT(p_state.m_tokenizer.GetToken() == TOKEN_CLOSE_ARRAY); + Token tok = p_state.m_tokenizer.Advance(); + const ArrayLiteralExpression* flat = NULL; + if (tok == TOKEN_ATOM) + { + Type::TypePrimitive primitive = Type::ParsePrimitive(p_state.m_tokenizer.GetValue()); + if (primitive == Type::Invalid) + { + std::ostringstream err; + err << "Couldn't parse name of array element type from '" + << p_state.m_tokenizer.GetValue() + << "'"; + throw std::runtime_error(err.str()); + } + else if (!TypeImpl::IsLeafType(primitive)) + { + std::ostringstream err; + err << "Array elements must be of fixed size (such as int, float), " + << "not " << Type::Name(primitive); + throw std::runtime_error(err.str()); + } + + const TypeImpl& child = TypeImpl::GetCommonType(primitive, true); + flat = &array.Flatten(*p_state.m_owner, &child, p_state.m_typeManager.get()); + + tok = p_state.m_tokenizer.Advance(); + } + else + { + flat = &array.Flatten(*p_state.m_owner); + } + FF2_ASSERT(flat != NULL); + + if (tok != TOKEN_CLOSE && tok != TOKEN_END) + { + std::ostringstream err; + err << "Trailing junk (" << Tokenizer::TokenName(tok) << ") after array literal."; + throw std::runtime_error(err.str()); + } + + // Arrange for flattened array to be popped off the stack once the + // top-level parser receives the TOKEN_CLOSE. + ProgramParseState::ExpressionParseState + flatState(Convert::GetIdentityFactory(), p_state.m_tokenizer.GetValue(), p_state.m_tokenizer.GetPosition()); + p_state.m_parseStack.push_back(flatState); + p_state.m_parseStack.back().Add(*flat); + FF2_ASSERT(p_state.m_parseStack.size() == parseDepth + 1); + FF2_ASSERT(tok == TOKEN_CLOSE || tok == TOKEN_END); + return tok; +} + + +FreeForm2::Token +FreeForm2::SExpressionParse::ParseLet(ProgramParseState& p_state) +{ + if (p_state.m_parsingLambdaBody) + { + std::ostringstream err; + err << "Cannot include a let expression within a lambda function body."; + throw std::runtime_error(err.str()); + } + + size_t parseDepth = p_state.m_parseStack.size(); + std::vector bindings; + + p_state.m_parseStack.push_back( + ProgramParseState::ExpressionParseState(c_letFactory, + p_state.m_tokenizer.GetValue(), + p_state.m_tokenizer.GetPosition())); + + ProgramParseState::ExpressionParseState& letState = p_state.m_parseStack.back(); + + Token tok = p_state.m_tokenizer.Advance(); + if (tok != TOKEN_OPEN) + { + std::ostringstream err; + err << "Expected open parenthesis after 'let', " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + // Iterate through binding pairs. + tok = p_state.m_tokenizer.Advance(); + do + { + if (tok != TOKEN_OPEN) + { + std::ostringstream err; + err << "Expected parenthesised pairs of bindings after " + << "let open parentheses, " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + tok = p_state.m_tokenizer.Advance(); + bindings.push_back(p_state.m_tokenizer.GetValue()); + if (tok != TOKEN_ATOM) + { + std::ostringstream err; + err << "Expected atom to be bound in let binding pair, " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + // Now parse the remaining expression using the main parser. + p_state.m_tokenizer.Advance(); + tok = SExpressionParse::ParseTokens(p_state, 1); + + if (tok != TOKEN_CLOSE) + { + std::ostringstream err; + err << "Expected binding of name to a single expression: " + << "got trailing junk (" << Tokenizer::TokenName(tok) << ")."; + throw std::runtime_error(err.str()); + } + + FF2_ASSERT(bindings.size() == p_state.m_parseStack.back().m_children.size()); + FF2_ASSERT(p_state.m_parseStack.size() == parseDepth + 1); + const SymbolTable::Symbol symbol(bindings.back()); + // If the variable is a function, create a binding from it directly to the function expression + // so the invoke ExpressionParseState can access the FunctionExpression. + if (p_state.GetLastParsed().GetType().Primitive() == Type::Function) + { + p_state.m_symbols.Bind(symbol, p_state.m_parseStack.back().m_children.back()); + } + else + { + // Bind variable (we do this incrementally, so each successive + // binding can refer to the previously bound expressions as well). + const VariableID id = p_state.GetNextVariableId(); + boost::shared_ptr exp( + new VariableRefExpression(Annotations(SourceLocation(1, p_state.m_tokenizer.GetPosition())), + id, + 0, + p_state.GetLastParsed().GetType())); + p_state.m_owner->AddExpression(exp); + p_state.m_symbols.Bind(symbol, exp.get()); + letState.m_variableIds.push_back(id); + } + + tok = p_state.m_tokenizer.Advance(); + } + while (tok != TOKEN_END && tok != TOKEN_CLOSE); + + // Make sure they closed with a close parens. + if (tok != TOKEN_CLOSE) + { + if (tok == TOKEN_END) + { + return tok; + } + else + { + std::ostringstream err; + err << "Expected parenthesised pairs of bindings to finish " + << "with a close parenthesis, " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + } + + // Make sure they provided at least one binding (there's nothing explicitly + // wrong with not providing any bindings, but it seems weird, so i'm + // banning it for now). + if (bindings.size() == 0) + { + std::ostringstream err; + err << "Let expressions that bind no variables are not currently allowed."; + throw std::runtime_error(err.str()); + } + + // Parse the bound expression. + tok = p_state.m_tokenizer.Advance(); + if (tok == TOKEN_CLOSE || tok == TOKEN_END) + { + std::ostringstream err; + err << "Expected an expression after bindings in let " + << "expression, but let has no additional arguments."; + throw std::runtime_error(err.str()); + } + tok = SExpressionParse::ParseTokens(p_state, 1); + + // Make sure they closed with a close parens. + if (tok != TOKEN_CLOSE) + { + if (tok == TOKEN_END) + { + return tok; + } + else + { + std::ostringstream err; + err << "Expected a single expression after bindings in let " + << "expression, but let has additional arguments (" + << Tokenizer::TokenName(tok) << ")"; + throw std::runtime_error(err.str()); + } + } + + // Remove bindings (note reverse iteration order to pop things off local + // variable stack in the order they were pushed on). + for (std::vector::const_reverse_iterator rnameIter = bindings.rbegin(); + rnameIter != bindings.rend(); + ++rnameIter) + { + p_state.m_symbols.Unbind(SymbolTable::Symbol(*rnameIter)); + } + + FF2_ASSERT(p_state.m_parseStack.size() == parseDepth + 1); + return tok; +} + + +FreeForm2::Token +FreeForm2::SExpressionParse::ParseMacroLet(ProgramParseState& p_state) +{ + // Ideally, disallowing macro-let while recording a macro should preclude + // macro-let expressions during macro playback, but if string concatenation + // or other complex macro expansion mechanisms are ever added, macro + // expansion could potentially introduce new operator calls which were not + // present in the original recording. + if (p_state.m_tokenizer.IsRecordingMacro() || p_state.m_tokenizer.IsExpandingMacro()) + { + throw std::runtime_error("Macro definitions may not contain macro-let expressions."); + } + + const size_t parseDepth = p_state.m_parseStack.size(); + + // Push an identity factory to allow this ExpressionParseState to evaluate + // to the macro-let body expression. + p_state.m_parseStack.push_back( + ProgramParseState::ExpressionParseState(Convert::GetIdentityFactory(), + p_state.m_tokenizer.GetValue(), + p_state.m_tokenizer.GetPosition())); + + ProgramParseState::ExpressionParseState& state = p_state.m_parseStack.back(); + + Token tok = p_state.m_tokenizer.Advance(); + if (tok != TOKEN_OPEN) + { + std::ostringstream err; + err << "Expected open parenthesis after 'macro-let', " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + std::vector macros; + + // Iterate through binding pairs. + tok = p_state.m_tokenizer.Advance(); + do + { + if (tok != TOKEN_OPEN) + { + std::ostringstream err; + err << "Expected parenthesised pairs of bindings after " + << "macro-let open parentheses, " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + tok = p_state.m_tokenizer.Advance(); + SIZED_STRING name = p_state.m_tokenizer.GetValue(); + macros.push_back(name); + + if (tok != TOKEN_ATOM) + { + std::ostringstream err; + err << "Expected atom to be bound in macro-let binding pair, " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + // Now parse the remaining expression using the main parser. + p_state.m_tokenizer.StartMacro(name); + tok = p_state.m_tokenizer.Advance(); + + if (tok == TOKEN_OPEN) + { + tok = ParseUntilClosed(p_state, TOKEN_OPEN, TOKEN_CLOSE); + if (tok == TOKEN_END) + { + p_state.m_tokenizer.EndMacro(); + return tok; + } + } + else if (tok == TOKEN_OPEN_ARRAY) + { + tok = ParseUntilClosed(p_state, TOKEN_OPEN_ARRAY, TOKEN_CLOSE_ARRAY); + if (tok == TOKEN_END) + { + p_state.m_tokenizer.EndMacro(); + return tok; + } + } + + // Always consume at least one token. + tok = p_state.m_tokenizer.Advance(); + p_state.m_tokenizer.EndMacro(); + + if (tok != TOKEN_CLOSE) + { + if (tok == TOKEN_END) + { + return tok; + } + else + { + std::ostringstream err; + err << "Expected binding of name to a single expression: " + << "got trailing junk (" + << Tokenizer::TokenName(tok) << ")."; + throw std::runtime_error(err.str()); + } + } + + FF2_ASSERT(p_state.m_parseStack.size() == parseDepth + 1); + state.m_children.clear(); + + tok = p_state.m_tokenizer.Advance(); + } + while (tok != TOKEN_END && tok != TOKEN_CLOSE); + + // Make sure they closed with a close parens. + if (tok != TOKEN_CLOSE) + { + if (tok == TOKEN_END) + { + return tok; + } + else + { + std::ostringstream err; + err << "Expected parenthesised pairs of bindings to finish " + << "with a close parenthesis, " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + } + + // Make sure they provided at least one binding (there's nothing explicitly + // wrong with not providing any bindings, but it seems weird, so it's + // banned it for now). + if (macros.size() == 0) + { + std::ostringstream err; + err << "Macro-let expressions that bind no macros are not currently allowed."; + throw std::runtime_error(err.str()); + } + + // Parse the bound expression. Since this will be added to an + // IdentityExpressionFactory, the ExpressionParseState which will be + // finished by the caller will produce the body of the macro-let. + p_state.m_tokenizer.Advance(); + tok = SExpressionParse::ParseTokens(p_state, 1); + + // Make sure they closed with a close parens. + if (tok != TOKEN_CLOSE) + { + if (tok == TOKEN_END) + { + return tok; + } + else + { + std::ostringstream err; + err << "Expected a single expression after bindings in macro-let " + << "expression, but macro-let has additional arguments (" + << Tokenizer::TokenName(tok) << ")"; + throw std::runtime_error(err.str()); + } + } + + // Remove bindings (note reverse iteration order to pop things off local + // variable stack in the order they were pushed on). + for (auto iter = macros.crbegin(); iter != macros.crend(); ++iter) + { + p_state.m_tokenizer.DeleteMacro(*iter); + } + + FF2_ASSERT(p_state.m_parseStack.size() == parseDepth + 1); + return tok; +} + + +FreeForm2::Token +FreeForm2::SExpressionParse::ParseRangeReduce(ProgramParseState& p_state) +{ + const size_t parseDepth = p_state.m_parseStack.size(); + const VariableID stepId = p_state.GetNextVariableId(); + const VariableID reduceId = p_state.GetNextVariableId(); + + ProgramParseState::ExpressionParseState + rangeState(c_rangeFactory, p_state.m_tokenizer.GetValue(), p_state.m_tokenizer.GetPosition()); + rangeState.m_variableIds.push_back(stepId); + rangeState.m_variableIds.push_back(reduceId); + + p_state.m_parseStack.push_back(rangeState); + + // Get next token (should be an atom). + Token tok = p_state.m_tokenizer.Advance(); + if (tok != TOKEN_ATOM) + { + std::ostringstream err; + err << "Expected variable name 'range-reduce', " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + const SymbolTable::Symbol rangeSymbol(p_state.m_tokenizer.GetValue()); + + // Parse range limits. + tok = p_state.m_tokenizer.Advance(); + tok = SExpressionParse::ParseTokens(p_state, 2); + if (tok == TOKEN_END) + { + return tok; + } + + // Next token should be an atom. + if (tok != TOKEN_ATOM) + { + std::ostringstream err; + err << "Expected variable name 'range-reduce', " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + const SymbolTable::Symbol accSymbol(p_state.m_tokenizer.GetValue()); + + // Parse initial value expression. + tok = p_state.m_tokenizer.Advance(); + tok = SExpressionParse::ParseTokens(p_state, 1); + if (tok == TOKEN_END) + { + return tok; + } + + // Bind previous and current value variables. + boost::shared_ptr + boundRange(new VariableRefExpression(Annotations(SourceLocation(1, p_state.m_tokenizer.GetPosition())), + stepId, + 0, + TypeImpl::GetIntInstance(true))); + p_state.m_owner->AddExpression(boundRange); + boost::shared_ptr accExp( + new VariableRefExpression(Annotations(SourceLocation(1, p_state.m_tokenizer.GetPosition())), + reduceId, + 0, + p_state.GetLastParsed().GetType())); + p_state.m_owner->AddExpression(accExp); + p_state.m_symbols.Bind(accSymbol, accExp.get()); + p_state.m_symbols.Bind(rangeSymbol, boundRange.get()); + + // Parse reduction expression. + tok = SExpressionParse::ParseTokens(p_state, 1); + if (tok == TOKEN_END) + { + return tok; + } + + // Unbind the variables, so they can't be used anymore. + p_state.m_symbols.Unbind(rangeSymbol); + p_state.m_symbols.Unbind(accSymbol); + + // Next token should be a close. + if (tok != TOKEN_CLOSE) + { + std::ostringstream err; + err << "Expected close of 'range-reduce', " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + // Return close token to main parser, it will take care of the rest. + FF2_ASSERT(p_state.m_parseStack.size() == parseDepth + 1); + return tok; +} + + +FreeForm2::Token +FreeForm2::SExpressionParse::ParseLambda(ProgramParseState& p_state) +{ + if (p_state.m_parsingAggregatedExpression) + { + std::ostringstream err; + err << "Cannot include a lambda expression within an aggregated expression."; + throw std::runtime_error(err.str()); + } + + if (p_state.m_parsingLambdaBody) + { + std::ostringstream err; + err << "Cannot include a lambda expression within a lambda function body."; + throw std::runtime_error(err.str()); + } + p_state.m_parsingLambdaBody = true; + + const size_t parseDepth = p_state.m_parseStack.size(); + std::vector bindings; + + // Push a factory that will produce the function expression. + p_state.m_parseStack.push_back( + ProgramParseState::ExpressionParseState(c_lambdaFactory, + p_state.m_tokenizer.GetValue(), + p_state.m_tokenizer.GetPosition())); + + ProgramParseState::ExpressionParseState& lambdaState = p_state.m_parseStack.back(); + + Token tok = p_state.m_tokenizer.Advance(); + if (tok != TOKEN_OPEN) + { + std::ostringstream err; + err << "Expected open parenthesis after 'lambda', " + << "got something else (" << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + // Iterate through lambda parameters + tok = p_state.m_tokenizer.Advance(); + while (tok != TOKEN_CLOSE && tok != TOKEN_END) + { + if (tok != TOKEN_ATOM && tok != TOKEN_OPEN) + { + std::ostringstream err; + err << "Expected formal declaration, got something else (" + << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + const TypeImpl* type = &TypeImpl::GetUnknownType().AsConstType(); + bool matchOpenParen = false; + if (tok == TOKEN_OPEN) + { + matchOpenParen = true; + tok = p_state.m_tokenizer.Advance(); + if (tok != TOKEN_ATOM) + { + std::ostringstream err; + err << "Expected atom declaration, got something else (" + << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + const Type::TypePrimitive prim = Type::ParsePrimitive(p_state.m_tokenizer.GetValue()); + if (prim != Type::Float && prim != Type::Int && prim != Type::Bool) + { + std::ostringstream err; + err << "Expected type in lambda formals, got something else (" + << p_state.m_tokenizer.GetValue() << ")."; + throw std::runtime_error(err.str()); + } + + type = &TypeImpl::GetCommonType(prim, true); + tok = p_state.m_tokenizer.Advance(); + } + + bindings.push_back(p_state.m_tokenizer.GetValue()); + const VariableID id = p_state.GetNextVariableId(); + boost::shared_ptr exp( + new VariableRefExpression(Annotations(SourceLocation(1, p_state.m_tokenizer.GetPosition())), + id, + 0, + *type)); + p_state.m_owner->AddExpression(exp); + p_state.m_symbols.Bind(SymbolTable::Symbol(bindings.back()), exp.get()); + lambdaState.m_variableIds.push_back(id); + lambdaState.m_children.push_back(exp.get()); + + tok = p_state.m_tokenizer.Advance(); + + if (matchOpenParen) + { + if (tok != TOKEN_CLOSE) + { + std::ostringstream err; + err << "Expected close parenthesis, got something else (" + << Tokenizer::TokenName(tok) << " token)."; + throw std::runtime_error(err.str()); + } + + tok = p_state.m_tokenizer.Advance(); + } + } + FF2_ASSERT(lambdaState.m_variableIds.size() == lambdaState.m_children.size()); + + if (bindings.size() == 0) + { + std::ostringstream err; + err << "lambdas must have formals (use a regular let binding " + << "for computations without parameters)"; + throw std::runtime_error(err.str()); + } + + // Parse the bound expression. + p_state.m_tokenizer.Advance(); + tok = SExpressionParse::ParseTokens(p_state, 1); + + // Make sure they closed with a close parens. + if (tok != TOKEN_CLOSE) + { + if (tok == TOKEN_END) + { + return tok; + } + else + { + std::ostringstream err; + err << "Expected a single expression after formals in lambda " + << "expression, but lambda has additional arguments (" + << Tokenizer::TokenName(tok) << ")"; + throw std::runtime_error(err.str()); + } + } + + // Remove bindings (note reverse iteration order to pop things off local + // variable stack in the order they were pushed on). + for (auto rnameIter = bindings.crbegin(); rnameIter != bindings.crend(); ++rnameIter) + { + p_state.m_symbols.Unbind(SymbolTable::Symbol(*rnameIter)); + } + + p_state.m_parsingLambdaBody = false; + + FF2_ASSERT(p_state.m_parseStack.size() == parseDepth + 1); + return tok; +} + + +FreeForm2::Token +FreeForm2::SExpressionParse::ParseInvoke(ProgramParseState& p_state) +{ + if (p_state.m_parsingAggregatedExpression) + { + std::ostringstream err; + err << "Cannot include an invoke expression within an aggregated expression."; + throw std::runtime_error(err.str()); + } + + if (p_state.m_parsingLambdaBody) + { + std::ostringstream err; + err << "Cannot include a invoke expression within a lambda function body."; + throw std::runtime_error(err.str()); + } + + const size_t parseDepth = p_state.m_parseStack.size(); + + // Push a factory that will produce the function expression. + p_state.m_parseStack.push_back( + ProgramParseState::ExpressionParseState(c_invokeFactory, + p_state.m_tokenizer.GetValue(), + p_state.m_tokenizer.GetPosition())); + + // Process the parameters within the Invoke expression. + Token tok = p_state.m_tokenizer.Advance(); + while (tok != TOKEN_CLOSE) + { + tok = SExpressionParse::ParseTokens(p_state, 1); + } + + FF2_ASSERT(p_state.m_parseStack.size() == parseDepth + 1); + return tok; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ArrayType.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ArrayType.cpp new file mode 100644 index 000000000000..dab7d1af0c99 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ArrayType.cpp @@ -0,0 +1,282 @@ +#include "ArrayType.h" + +#include +#include "FreeForm2Assert.h" +#include +#include +#include +#include +#include "TypeManager.h" + +FreeForm2::ArrayType::ArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + unsigned int p_maxElements, + TypeManager& p_typeManager) + : TypeImpl(Type::Array, p_isConst, &p_typeManager), + m_typeManager(p_typeManager), + m_derefType(NULL), + m_oppositeConstnessType(NULL), + m_isFixedSize(false), + m_child(p_child), + m_dimensionCount(p_dimensions), + m_maxElements(p_maxElements) +{ + FF2_ASSERT(m_maxElements <= c_maxElements); + FF2_ASSERT(m_child.Primitive() != Type::Array && m_child.Primitive() != Type::Invalid); + FF2_ASSERT(p_dimensions > 0); + + if (p_dimensions > c_maxDimensions) + { + std::ostringstream err; + err << "The FreeForm2 language doesn't currently support more than " + << c_maxDimensions << " dimensions per array."; + throw std::runtime_error(err.str()); + } + + m_dimensions[0] = 0; +} + + +FreeForm2::ArrayType::ArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + const unsigned int p_elementCounts[], + unsigned int p_maxElements, + TypeManager& p_typeManager) + : TypeImpl(Type::Array, p_isConst, &p_typeManager), + m_typeManager(p_typeManager), + m_derefType(NULL), + m_oppositeConstnessType(NULL), + m_isFixedSize(true), + m_child(p_child), + m_dimensionCount(p_dimensions), + m_maxElements(p_maxElements) +{ + FF2_ASSERT(m_maxElements <= c_maxElements); + FF2_ASSERT(m_child.Primitive() != Type::Array && m_child.Primitive() != Type::Invalid); + FF2_ASSERT(p_dimensions > 0); + + if (p_dimensions > c_maxDimensions) + { + std::ostringstream err; + err << "The FreeForm2 language doesn't currently support more than " + << c_maxDimensions << " dimensions per array."; + throw std::runtime_error(err.str()); + } + + memcpy(m_dimensions, p_elementCounts, sizeof(unsigned int) * m_dimensionCount); +} + + +const FreeForm2::TypeImpl& +FreeForm2::ArrayType::GetChildType() const +{ + return m_child; +} + + +const unsigned int* +FreeForm2::ArrayType::GetDimensions() const +{ + FF2_ASSERT(IsFixedSize()); + return m_dimensions; +} + + +unsigned int +FreeForm2::ArrayType::GetDimensionCount() const +{ + return m_dimensionCount; +} + + +unsigned int +FreeForm2::ArrayType::GetMaxElements() const +{ + return m_maxElements; +} + + +bool +FreeForm2::ArrayType::IsFixedSize() const +{ + return m_isFixedSize; +} + + +bool +FreeForm2::ArrayType::IsSameSubType(const TypeImpl& p_other, bool p_ignoreConst) const +{ + FF2_ASSERT(p_other.Primitive() == Type::Array); + + const ArrayType& other = static_cast(p_other); + if ((IsFixedSize() != other.IsFixedSize()) || !GetChildType().IsSameAs(other.GetChildType(), p_ignoreConst)) + { + return false; + } + + if (GetDimensionCount() == other.GetDimensionCount()) + { + if (IsFixedSize()) + { + return memcmp(GetDimensions(), + other.GetDimensions(), + sizeof(unsigned int) * GetDimensionCount()) == 0; + } + else + { + return true; + } + } + else + { + return false; + } +} + + +std::string +FreeForm2::ArrayType::GetName(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensionCount, + const unsigned int* p_dimensions, + unsigned int p_maxElements) +{ + // Ignore the const-ness on arrays, because the const-ness of the child and + // the constness of the array should be the same. + std::ostringstream out; + if (!p_isConst) + { + out << "mutable "; + } + + out << p_child; + + for (unsigned int i = 0; i < p_dimensionCount; i++) + { + out << "["; + if (p_dimensions != NULL) + { + out << p_dimensions[i]; + } + out << "]"; + } + return out.str(); +} + + +const std::string& +FreeForm2::ArrayType::GetName() const +{ + if (m_name.empty()) + { + m_name = GetName(GetChildType(), + IsConst(), + GetDimensionCount(), + IsFixedSize() ? GetDimensions() : NULL, + GetMaxElements()); + } + return m_name; +} + + +const FreeForm2::TypeImpl& +FreeForm2::ArrayType::AsConstType() const +{ + if (IsConst()) + { + return *this; + } + else + { + if (m_oppositeConstnessType == NULL) + { + if (IsFixedSize()) + { + m_oppositeConstnessType = &m_typeManager.GetArrayType(GetChildType().AsConstType(), + true, + GetDimensionCount(), + GetDimensions(), + GetMaxElements()); + } + else + { + m_oppositeConstnessType = &m_typeManager.GetArrayType(GetChildType().AsConstType(), + true, + GetDimensionCount(), + GetMaxElements()); + } + } + return *m_oppositeConstnessType; + } +} + + +const FreeForm2::TypeImpl& +FreeForm2::ArrayType::AsMutableType() const +{ + if (!IsConst()) + { + return *this; + } + else + { + if (m_oppositeConstnessType == NULL) + { + if (IsFixedSize()) + { + m_oppositeConstnessType = &m_typeManager.GetArrayType(GetChildType().AsMutableType(), + false, + GetDimensionCount(), + GetDimensions(), + GetMaxElements()); + } + else + { + m_oppositeConstnessType = &m_typeManager.GetArrayType(GetChildType().AsMutableType(), + false, + GetDimensionCount(), + GetMaxElements()); + } + } + return *m_oppositeConstnessType; + } +} + + +const FreeForm2::TypeImpl& +FreeForm2::ArrayType::GetDerefType() const +{ + if (m_derefType == NULL) + { + if (GetDimensionCount() > 1) + { + if (IsFixedSize()) + { + const unsigned int newMaxElements + = GetMaxElements() / GetDimensions()[0]; + + FF2_ASSERT(GetTypeManager() != NULL); + m_derefType = &GetTypeManager()->GetArrayType( + GetChildType(), IsConst(), GetDimensionCount() - 1, GetDimensions() + 1, newMaxElements); + } + else + { + // Note that we lose a lot of information here about the + // possible number of children, since we don't know how to + // allocate them between the dimensions. + FF2_ASSERT(GetTypeManager() != NULL); + m_derefType = &GetTypeManager()->GetArrayType( + GetChildType(), IsConst(), GetDimensionCount() - 1, GetMaxElements()); + } + } + else + { + FF2_ASSERT(GetDimensionCount() == 1); + m_derefType = &GetChildType(); + } + } + return *m_derefType; +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ArrayType.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ArrayType.h new file mode 100644 index 000000000000..a0446b7d9d03 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ArrayType.h @@ -0,0 +1,105 @@ +#pragma once + +#include +#include "TypeImpl.h" +#include + +namespace FreeForm2 +{ + class TypeFactory; + + class ArrayType : public TypeImpl + { + public: + // Maximum number of dimensions per array. + const static unsigned int c_maxDimensions = 7; + + // Maximum number of leaf elements per array. + const static unsigned int c_maxElements = 16384; + + // Maximum number of elements per array dimension. + const static unsigned int c_maxElementsPerDimension = 255; + + // Get the child type. + const TypeImpl& GetChildType() const; + + // Get a vector containing the sizes of each dimension. + const unsigned int* GetDimensions() const; + + // Get the number of dimensions. This is a convenience method for + // getting the size of the vector returned from GetDimensions. + unsigned int GetDimensionCount() const; + + // The the maximum number of elements in the array. + unsigned int GetMaxElements() const; + + // Tests whether this array is fixed-size. + bool IsFixedSize() const; + + // Get the name of an array type for the type desribed by the + // parameters. + static std::string + GetName(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensionCount, + const unsigned int* p_dimensions, + unsigned int p_maxElements); + + // Get a string representation of the type. + virtual const std::string& GetName() const override; + + // Methods to get types derived from this type. + virtual const TypeImpl& AsConstType() const override; + virtual const TypeImpl& AsMutableType() const override; + const TypeImpl& GetDerefType() const; + + private: + // Construct a variable-size array type. + ArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + unsigned int p_maxElements, + TypeManager& p_typeManager); + + // Construct a fixed-size array type. + ArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + const unsigned int p_elementCounts[], + unsigned int p_maxElements, + TypeManager& p_typeManager); + + virtual bool IsSameSubType(const TypeImpl& p_other, bool p_ignoreConst) const override; + + // Give the TypeManager access to the ArrayType constructor. + friend class TypeManager; + + // The type manager that created this type. + TypeManager& m_typeManager; + + // The name of this type. + mutable std::string m_name; + + // Derived type references, stored for efficiency. + mutable const TypeImpl* m_derefType; + mutable const TypeImpl* m_oppositeConstnessType; + + // This flag indicates whether or not this array has a fixed size. + bool m_isFixedSize; + + // Type of array. Note that we only allow basic types as child of + // an array, because we're using old-school c-style + // multi-dimensional arrays, rather than something compositional. + const TypeImpl& m_child; + + // Maximum number of elements of this array. We use this to track + // the maximum size of the array, and statically allocate space for it. + unsigned int m_maxElements; + + // The number of dimensions contained in this array. + unsigned int m_dimensionCount; + + // Number of elements in each dimension, allocated using the struct hack. + unsigned int m_dimensions[1]; + }; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/Attributes.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/Attributes.h new file mode 100644 index 000000000000..20af82f86753 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/Attributes.h @@ -0,0 +1,34 @@ +#pragma once + +// A simple means for extracting the attribute bits and word offset +// fields from a "location" as recorded in the index. + +enum WordAttribute +{ + ZeroAttributeState = 0, + + // UrlWord attribute values + + ServiceUrlWord = 0, + SubDomainUrlWord = 1, + BaseDomainUrlWord = 2, + TopDomainUrlWord = 3, + PortUrlWord = 4, + PathUrlWord = 5, + QueryUrlWord = 6, + AnyUrlWord = 7, + + NumberOfUrlAttributeStates = 8, + + // BodyText attribute values + + NormalBodyText = 0, + NavigationBodyText = 1, + Reserved1BodyText = 2, + Reserved2BodyText = 3, + AnyBodyText = 4, + + NumberOfTextAttributeStates = 5, + + MaximumNumberOfAttributeStates = 8 +}; diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/BFBFile.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/BFBFile.h new file mode 100644 index 000000000000..7767b4dc7c66 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/BFBFile.h @@ -0,0 +1,202 @@ +#pragma once + +#include +#include +#include + +namespace FreeForm2 +{ + namespace BFB + { + // Masks for the first word of a BFB stream definition. + static const UInt32 c_streamTypeMask = 0x80000000; + static const UInt32 c_streamTupleCountMask = 0x7FFFFFFF; + + // Values of m_decodeType in the various tuple types. + static const UInt64 c_decodeSmallTuple = 0; + static const UInt64 c_decodeLargeTuple = 1; + static const UInt64 c_decodeExtraLargeTuple = 2; + static const UInt64 c_decodeMetadataToken = 3; + + // Values of m_bodyBlockFlag in the various tuple types. + static const UInt64 c_inBothStreams = 0; + static const UInt64 c_inBodyStream = 1; + static const UInt64 c_inBodyBlockStream = 2; + +#pragma pack(push, 1) + union SmallTuple + { + struct + { + UInt16 m_decodeType : 2; + UInt16 m_bodyBlockFlag : 2; + UInt16 m_relativeOffset : 5; + UInt16 m_wordAttribute : 3; + UInt16 m_candidateID : 2; + UInt16 m_wordID : 2; + } m_tuple; + struct + { + UInt16 m_data1; + } m_data; + }; + BOOST_STATIC_ASSERT(sizeof(SmallTuple) * 8 == 16); + + // Large tuples are specified as 32-bit integers. + union LargeTuple + { + struct + { + UInt32 m_decodeType : 2; + UInt32 m_bodyBlockFlag : 2; + UInt32 m_relativeOffset : 14; + UInt32 m_wordLength : 4; + UInt32 m_wordAttribute : 3; + UInt32 m_candidateID : 2; + UInt32 m_wordID : 5; + } m_tuple; + struct + { + UInt32 m_data1; + } m_data; + }; + BOOST_STATIC_ASSERT(sizeof(LargeTuple) * 8 == 32); + + // Extra large tuples are specified as 48-bit integers. + union ExtraLargeTuple + { + // Because there is no 48-bit integer type, the m_tuple member of + // ExtraLargeTuple is actually the size of a 64-bit integer, though + // only the lower 48-bits are used. + struct + { + UInt64 m_decodeType : 2; + UInt64 m_bodyBlockFlag : 2; + UInt64 m_relativeOffset : 19; + UInt64 m_wordLength : 5; + UInt64 m_wordAttribute : 3; + UInt64 m_candidateID : 4; + UInt64 m_wordID : 5; + UInt64 m_reserved : 8; + } m_tuple; + struct + { + UInt32 m_data1; + UInt16 m_data2; + } m_data; + }; + BOOST_STATIC_ASSERT(sizeof(ExtraLargeTuple) * 8 == 64); + + // Metadata tokens currently only act as body block headers. + struct MetadataToken + { + UInt32 m_decodeType : 2; + UInt32 m_bodyBlockType : 3; + UInt32 m_bodyBlockLength : 27; + }; + BOOST_STATIC_ASSERT(sizeof(MetadataToken) * 8 == 32); + + // BFB tuple formats. + union Tuple + { + SmallTuple m_smallTuple; + LargeTuple m_largeTuple; + ExtraLargeTuple m_xlargeTuple; + MetadataToken m_metaTuple; + UInt64 m_value; + }; + BOOST_STATIC_ASSERT(sizeof(Tuple) * 8 == 64); + + // BFB InterestingTuple data. + union InterestingTupleData + { + struct + { + UInt32 m_tupleID : 3; + UInt32 m_tupleIndex : 3; + UInt32 m_firstWord : 4; + UInt32 m_lastWord : 4; + UInt32 m_tupleWeight : 18; + } m_struct; + UInt32 m_value; + }; + BOOST_STATIC_ASSERT(sizeof(InterestingTupleData) * 8 == 32); + + // BFB QueryPath data format. + union QueryPathData + { + struct + { + UInt32 m_pathIndex : 3; + UInt32 m_candidate0 : 2; + UInt32 m_candidate1 : 2; + UInt32 m_candidate2 : 2; + UInt32 m_candidate3 : 2; + UInt32 m_candidate4 : 2; + UInt32 m_candidate5 : 2; + UInt32 m_candidate6 : 2; + UInt32 m_candidate7 : 2; + UInt32 m_candidate8 : 2; + UInt32 m_candidate9 : 2; + UInt32 m_pathWeight : 9; + } m_struct; + UInt32 m_value; + }; + BOOST_STATIC_ASSERT(sizeof(QueryPathData) * 8 == 32); + + // Phrase normalize: a fixed-point 48-bit decimal number. + union PhraseNormalizer + { + UInt64 m_value; + struct + { + UInt32 m_value1; + UInt16 m_value2; + } m_data; + }; + BOOST_STATIC_ASSERT(sizeof(PhraseNormalizer) * 8 == 64); + + // The number of UInt32s required to hold the configuration bitmask. + static const size_t c_numberOfConfigBlocks + = (ExtractorConfig::ExtractorConfigCount - 1) / (sizeof(UInt32) * 8) + 1; + + struct PerStreamFeatures + { + UInt32 m_wordFound; + UInt32 m_wordsFound; + UInt32 m_bm25f; + UInt32 m_bm25fNorm; + UInt32 m_originalQueryBM25F; + UInt32 m_originalQueryBM25FNorm; + UInt32 m_proxBM25F; + UInt32 m_proxBM25FNorm; + UInt32 m_perStreamLMScore; + UInt32 m_parametersPresent[c_numberOfConfigBlocks]; + }; + BOOST_STATIC_ASSERT(sizeof(PerStreamFeatures) == sizeof(UInt32) * (9 + c_numberOfConfigBlocks)); + + union QueryDataBits + { + UInt32 m_value; + struct + { + UInt32 m_extractWordCandidateDate : 1; + UInt32 m_termWeightEnabled : 1; + UInt32 m_alterationWeightEnabled : 1; + UInt32 m_calculationMethod : 2; + UInt32 m_newFeatureFlags : 5; + UInt32 m_pbsTupleTypes : 6; + UInt32 m_unused : 16; + } m_data; + }; + static_assert(sizeof(QueryDataBits) == sizeof(UInt32), "QueryDataBits has bad size"); + + // Custom data structure to hold the max size for each feature + struct FeatureDefinition + { + UInt32 m_nameIndex; + unsigned char m_size[4]; + }; +#pragma pack(pop) + } +} \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/CMakeLists.txt b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/CMakeLists.txt new file mode 100644 index 000000000000..9fae63dc647c --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/CMakeLists.txt @@ -0,0 +1,34 @@ +cmake_minimum_required(VERSION 3.15) + +set(PROJECT_NAME DRFreeFormSharedLibrary) + + +Project(${PROJECT_NAME}) + +SET(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS} -fpermissive") + + + +add_library(${PROJECT_NAME} STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/ArrayType.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/CompoundType.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/FreeForm2Assert.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/FreeForm2Utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/FunctionType.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ObjectType.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/StateMachineType.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/StructType.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/TypeImpl.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/TypeManager.cpp +) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../../inc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../NeuralTree.Library/inc + ) + +install(TARGETS ${PROJECT_NAME} + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + ) \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/CompoundType.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/CompoundType.cpp new file mode 100644 index 000000000000..e7af5a86415d --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/CompoundType.cpp @@ -0,0 +1,30 @@ +#include "CompoundType.h" + +FreeForm2::CompoundType::Member::Member(const std::string& p_name, + const TypeImpl& p_type) + : m_name(p_name), m_type(&p_type) +{ +} + + +FreeForm2::CompoundType::Member::Member() + : m_type(NULL) +{ +} + + +FreeForm2::CompoundType::CompoundType(Type::TypePrimitive p_prim, + bool p_isConst, + TypeManager* p_typeManager) + : TypeImpl(p_prim, p_isConst, p_typeManager) +{ +} + + +bool +FreeForm2::CompoundType::IsCompoundType(const TypeImpl& p_type) +{ + return p_type.Primitive() == Type::Struct + || p_type.Primitive() == Type::StateMachine + || p_type.Primitive() == Type::Object; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/CompoundType.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/CompoundType.h new file mode 100644 index 000000000000..6483ef3980c8 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/CompoundType.h @@ -0,0 +1,52 @@ +#pragma once + +#ifndef FREEFORM2_COMPOUND_TYPE_H +#define FREEFORM2_COMPOUND_TYPE_H + +#include "FreeForm2Type.h" +#include +#include "TypeImpl.h" + +namespace FreeForm2 +{ + class TypeManager; + + // A compound type is any type which contains named instantiations of other + // types. This class provides a mechanism for looking up member information + // by name. + class CompoundType : public TypeImpl + { + public: + // A member represents a single named entity contained within a + // compound type. + struct Member + { + public: + // Construct a member with a name and type. + Member(const std::string& p_name, const TypeImpl& p_type); + + // Default constructor to initialize members to empty values. + Member(); + + // The name of this member. + std::string m_name; + + // The type of this member. + const TypeImpl* m_type; + }; + + // Create a compound type of the give type, constness, and type + // manager. + CompoundType(Type::TypePrimitive p_prim, bool p_isConst, TypeManager* p_typeManager); + + // Find a member by name. Returns a pointer to the member object + // associated with a name; if this compound type does not contain a + // member of the specified name, this function returns NULL. + virtual const Member* FindMember(const std::string& p_name) const = 0; + + // Determine whether a TypePrimitive is a compound type. + static bool IsCompoundType(const TypeImpl& p_type); + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Assert.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Assert.cpp new file mode 100644 index 000000000000..9302b4074d05 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Assert.cpp @@ -0,0 +1,38 @@ +#include "FreeForm2Assert.h" + +#include +#include + +void +FreeForm2::ThrowAssert(bool p_condition, const char* p_file, unsigned int p_line) +{ + if (!p_condition) + { + std::ostringstream err; + err << "Assertion error at " << p_file << ":" << p_line; + throw std::runtime_error(err.str()); + } +} + + +void +FreeForm2::ThrowAssert(bool p_condition, const char* p_expression, const char* p_file, unsigned int p_line) +{ + if (!p_condition) + { + std::ostringstream err; + err << "Assertion error: \"" << p_expression << "\" failed at " << p_file << ":" << p_line; + throw std::runtime_error(err.str()); + } +} + + +void +FreeForm2::Unreachable(const char* p_file, unsigned int p_line) +{ + std::ostringstream err; + err << "Unreachable code reached at " << p_file << ":" << p_line; + throw std::runtime_error(err.str()); +} + + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Utils.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Utils.cpp new file mode 100644 index 000000000000..927fa8f86076 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Utils.cpp @@ -0,0 +1,150 @@ +#include "FreeForm2Utils.h" + +#include "FreeForm2Assert.h" +// #include +#include +#include +#include +// #include + +void +FreeForm2::VectorFromNeuralNetFeatures::ProcessFeature(UInt32 p_featureIndex) +{ + m_associatedFeaturesList.push_back(p_featureIndex); +} + + +void +FreeForm2::VectorFromNeuralNetFeatures::ProcessFeature(UInt32 p_featureIndex, + const std::vector& p_segments) +{ + // Ignore segment information. + ProcessFeature(p_featureIndex); +} + + +void +FreeForm2::SetFromNeuralNetFeatures::ProcessFeature(UInt32 p_featureIndex) +{ + m_associatedFeaturesList.insert(p_featureIndex); +} + + +void +FreeForm2::SetFromNeuralNetFeatures::ProcessFeature(UInt32 p_featureIndex, + const std::vector& p_segments) +{ + // Ignore segment information. + ProcessFeature(p_featureIndex); +} + + +std::ostream& +FreeForm2::operator<<(std::ostream& p_out, SIZED_STRING p_str) +{ + return p_out.write(p_str.pcData, static_cast(p_str.cbData)); +} + + +void +FreeForm2::LogHardwareException(DWORD p_exceptionCode, + const Executable::FeatureType p_features[], + const DynamicRank::IFeatureMap& p_map, + const char* p_sourceFile, + unsigned int p_sourceLine) +{ + // Blech, windows programming. Use FormatMessage and some LoadLibrary + // trickery to get windows to format our exception code into text. +} + + +bool +FreeForm2::IsSimpleName(SIZED_STRING p_name) +{ + for (unsigned int i = 0; i < p_name.cbData; i++) + { + if (!isalnum(p_name.pbData[i]) + && p_name.pcData[i] != '-' + && p_name.pcData[i] != '_') + { + return false; + } + } + + return true; +} + + +void +FreeForm2::WriteCompressedVectorRLE(const UInt32* p_data, size_t p_numElements, std::ostream& p_out) +{ + FF2_ASSERT(p_data != NULL); + enum + { + MatchingNulls, + MatchingNonNulls + } state = MatchingNonNulls; + UInt32 numNulls = 0; + const UInt32 null = 0; + for (size_t i = 0; i < p_numElements; i++) + { + switch (state) + { + case MatchingNulls: + { + if (p_data[i] == null) + { + FF2_ASSERT(numNulls < MAX_UINT32); + numNulls++; + } + else + { + state = MatchingNonNulls; + p_out.write(reinterpret_cast(&numNulls), sizeof(UInt32)); + p_out.write(reinterpret_cast(&p_data[i]), sizeof(UInt32)); + } + break; + } + + case MatchingNonNulls: + { + if (p_data[i] == null) + { + state = MatchingNulls; + p_out.write(reinterpret_cast(&null), sizeof(UInt32)); + numNulls = 1; + } + else + { + p_out.write(reinterpret_cast(&p_data[i]), sizeof(UInt32)); + } + break; + } + } + } +} + + +void +FreeForm2::ReadCompressedVectorRLE(UInt32* p_data, size_t p_numElements, std::istream& p_in) +{ + FF2_ASSERT(p_data != NULL); + size_t i = 0; + while (i < p_numElements) + { + UInt32 value = 0; + p_in.read(reinterpret_cast(&value), sizeof(UInt32)); + if (value == 0) + { + p_in.read(reinterpret_cast(&value), sizeof(UInt32)); + memset(&p_data[i], 0, sizeof(UInt32) * value); + i += value; + } + else + { + p_data[i] = value; + i++; + } + } +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Utils.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Utils.h new file mode 100644 index 000000000000..088bc8a36573 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FreeForm2Utils.h @@ -0,0 +1,88 @@ +#pragma once + +#ifndef FREEFORM2_UTILS_H +#define FREEFORM2_UTILS_H + +#include +#include +#include "FreeForm2.h" +#include +#include +#include +#include + +namespace DynamicRank +{ + class IFeatureMap; +} + +namespace FreeForm2 +{ + // Utility class that populates a vector of features used using the + // INeuralNetFeatures interface. + class VectorFromNeuralNetFeatures + : public DynamicRank::INeuralNetFeatures, boost::noncopyable + { + public: + VectorFromNeuralNetFeatures(std::vector& p_associatedFeaturesList) + : m_associatedFeaturesList(p_associatedFeaturesList) + { + } + + + virtual void ProcessFeature(UInt32 p_featureIndex) override; + + virtual void ProcessFeature(UInt32 p_featureIndex, const std::vector& p_segments) override; + + private: + std::vector& m_associatedFeaturesList; + }; + + + // Utility class that populates a vector of features used using the + // INeuralNetFeatures interface. + class SetFromNeuralNetFeatures + : public DynamicRank::INeuralNetFeatures, boost::noncopyable + { + public: + SetFromNeuralNetFeatures(std::set& p_associatedFeaturesList) + : m_associatedFeaturesList(p_associatedFeaturesList) + { + } + + + virtual void ProcessFeature(UInt32 p_featureIndex) override; + + virtual void ProcessFeature(UInt32 p_featureIndex, const std::vector& p_segments) override; + + private: + std::set& m_associatedFeaturesList; + }; + + + // Print a SIZED_STRING to an output stream. + std::ostream& operator<<(std::ostream& p_out, SIZED_STRING p_str); + + // Log errors after a crash, suitable for use as a structured exception + // handling test. Note that this function always returns false, as it does + // not handle exceptions, it simply logs information regarding that exception. + void LogHardwareException(DWORD p_exceptionCode, + const Executable::FeatureType p_features[], + const DynamicRank::IFeatureMap& p_map, + const char* p_sourceFile, + unsigned int p_sourceLine); + + // Indicates whether the given name is composed of alphanumeric + // characters and '-' only. + bool IsSimpleName(SIZED_STRING p_name); + + // Write a sparse vector to the output stream. This algorithm uses a simple + // run-length encoding to compress ranges of 0's. + void WriteCompressedVectorRLE(const UInt32* p_data, size_t p_numElements, std::ostream& p_out); + + // Read a sparse integer vector incoded using the above method. + void ReadCompressedVectorRLE(UInt32* p_data, size_t p_numElements, std::istream& p_in); +} + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FunctionType.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FunctionType.cpp new file mode 100644 index 000000000000..afc7972769c4 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FunctionType.cpp @@ -0,0 +1,134 @@ +#include "FunctionType.h" + +#include +#include "FreeForm2Assert.h" +#include +#include "TypeManager.h" + + +FreeForm2::FunctionType::FunctionType(TypeManager& p_typeManager, + const TypeImpl& p_returnType, + const TypeImpl* const* p_parameterTypes, + size_t p_numParameters) + : TypeImpl(Type::Function, true, &p_typeManager), + m_returnType(&p_returnType), + m_numParameters(p_numParameters) +{ + if (m_numParameters > 0) + { + m_parameterTypes[0] = p_parameterTypes[0]; + + for (size_t i = 1; i < m_numParameters; i++) + { + m_parameterTypes[i] = p_parameterTypes[i]; + } + } +} + + +const FreeForm2::TypeImpl& +FreeForm2::FunctionType::GetReturnType() const +{ + return *m_returnType; +} + + +FreeForm2::FunctionType::ParameterIterator +FreeForm2::FunctionType::BeginParameters() const +{ + return const_cast(&m_parameterTypes[0]); +} + + +FreeForm2::FunctionType::ParameterIterator +FreeForm2::FunctionType::EndParameters() const +{ + return BeginParameters() + m_numParameters; +} + + +size_t +FreeForm2::FunctionType::GetParameterCount() const +{ + return m_numParameters; +} + + +const std::string& +FreeForm2::FunctionType::GetName() const +{ + if (m_name.size() == 0) + { + m_name = FunctionType::GetName(GetReturnType(), BeginParameters(), GetParameterCount()); + } + + return m_name; +} + + +std::string +FreeForm2::FunctionType::GetName(const TypeImpl& p_returnType, + const TypeImpl* const* p_parameterTypes, + size_t p_numParams) +{ + std::ostringstream out; + out << p_returnType << "("; + + bool first = true; + + for (int i = 0; i < p_numParams; i++) + { + if (!first) + { + out << ", "; + } + + out << *p_parameterTypes[i]; + + first = false; + } + + out << ")"; + + return out.str(); +} + + +const FreeForm2::TypeImpl& +FreeForm2::FunctionType::AsConstType() const +{ + return *this; +} + + +const FreeForm2::TypeImpl& +FreeForm2::FunctionType::AsMutableType() const +{ + FF2_ASSERT("Functions cannot be mutable." && false); + Unreachable(__FILE__, __LINE__); +} + + +bool +FreeForm2::FunctionType::IsSameSubType(const TypeImpl& p_other, bool p_ignoreConst) const +{ + FF2_ASSERT(p_other.Primitive() == Type::Function); + const FunctionType& other = static_cast(p_other); + + if (m_numParameters != other.m_numParameters) + { + return false; + } + + for (size_t i = 0; i < m_numParameters; i++) + { + const TypeImpl& m1 = *m_parameterTypes[i]; + const TypeImpl& m2 = *other.m_parameterTypes[i]; + if (!m1.IsSameAs(m2, p_ignoreConst)) + { + return false; + } + } + + return GetReturnType().IsSameAs(other.GetReturnType(), p_ignoreConst); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FunctionType.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FunctionType.h new file mode 100644 index 000000000000..e495df0e5982 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/FunctionType.h @@ -0,0 +1,69 @@ +#pragma once + +#ifndef FREEFORM2_FUNCTION_TYPE_H +#define FREEFORM2_FUNCTION_TYPE_H + +#include "FreeForm2Type.h" +#include "TypeImpl.h" +#include + +namespace FreeForm2 +{ + class TypeManager; + + // A function type is any type which can be called with certain parameters and + // returns a value of another type. + class FunctionType : public TypeImpl + { + public: + virtual ~FunctionType() {} + + // Gets the return type of the function. + const TypeImpl& GetReturnType() const; + + // Iterate over parameters. + typedef const TypeImpl** ParameterIterator; + ParameterIterator BeginParameters() const; + ParameterIterator EndParameters() const; + + // Get the number of parameters. + size_t GetParameterCount() const; + + // Get a string representation of the type. + virtual const std::string& GetName() const override; + static std::string GetName(const TypeImpl& p_returnType, + const TypeImpl* const* p_parameterTypes, + size_t p_numParams); + + // Create derived types based on this type. + virtual const TypeImpl& AsConstType() const override; + virtual const TypeImpl& AsMutableType() const override; + + private: + // Create a compound type of the give type, constness, and type + // manager. + FunctionType(TypeManager& p_typeManager, + const TypeImpl& p_returnType, + const TypeImpl* const* p_parameterTypes, + size_t p_numParams); + + // Compare subclass type data. + virtual bool IsSameSubType(const TypeImpl& p_type, bool p_ignoreConst) const override; + + friend class TypeManager; + + // The return type of the function. + const TypeImpl* m_returnType; + + // The string representation of this type. + mutable std::string m_name; + + // The number of parameters. + size_t m_numParameters; + + // The types of the parameters. This is allocated using the struct hack. + const TypeImpl* m_parameterTypes[1]; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ObjectType.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ObjectType.cpp new file mode 100644 index 000000000000..560a730f529b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ObjectType.cpp @@ -0,0 +1,119 @@ +#include "ObjectType.h" + +#include +#include "FreeForm2Assert.h" +#include "TypeManager.h" + +FreeForm2::ObjectType::ObjectMember::ObjectMember(const std::string& p_name, + const TypeImpl& p_type, + const std::string& p_externName) + : CompoundType::Member(p_name, p_type), + m_externName(p_externName) +{ +} + +FreeForm2::ObjectType::ObjectMember::ObjectMember(const std::string& p_name, + const TypeImpl& p_type) + : CompoundType::Member(p_name, p_type), + m_externName(p_name) +{ +} + + +FreeForm2::ObjectType::ObjectType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst, + TypeManager& p_typeManager) + : CompoundType(Type::Object, p_isConst, &p_typeManager), + m_name(p_name), + m_externName(p_externName) +{ + BOOST_FOREACH (ObjectType::ObjectMember member, p_members) + { + // Verify that all names are unique. + FF2_ASSERT(m_members.find(member.m_name) == m_members.end()); + m_members.insert(std::make_pair(member.m_name, member)); + } +} + + +const std::string& +FreeForm2::ObjectType::GetName() const +{ + return m_name; +} + + +const std::string& +FreeForm2::ObjectType::GetExternName() const +{ + return m_externName; +} + + +const FreeForm2::ObjectType::ObjectMember* +FreeForm2::ObjectType::FindMember(const std::string& p_name) const +{ + std::map::const_iterator member = m_members.find(p_name); + return member != m_members.end() ? &member->second : NULL; +} + + +const FreeForm2::TypeImpl& +FreeForm2::ObjectType::AsConstType() const +{ + if (IsConst()) + { + return *this; + } + else + { + FF2_ASSERT(GetTypeManager() != NULL); + + std::vector members; + for (std::map::const_iterator memberIterator = m_members.begin(); + memberIterator != m_members.end(); + ++memberIterator) + { + members.push_back(memberIterator->second); + } + + return GetTypeManager()->GetObjectType(GetName(), GetExternName(), members, true); + } +} + + +const FreeForm2::TypeImpl& +FreeForm2::ObjectType::AsMutableType() const +{ + if (!IsConst()) + { + return *this; + } + else + { + FF2_ASSERT(GetTypeManager() != NULL); + + std::vector members; + for (std::map::const_iterator memberIterator = m_members.begin(); + memberIterator != m_members.end(); + ++memberIterator) + { + members.push_back(memberIterator->second); + } + + return GetTypeManager()->GetObjectType(GetName(), GetExternName(), members, false); + } +} + + +bool +FreeForm2::ObjectType::IsSameSubType(const TypeImpl& p_other, bool p_ignoreConst) const +{ + FF2_ASSERT(p_other.Primitive() == Type::Object); + const ObjectType& other = static_cast(p_other); + + return GetName() == other.GetName() + && GetExternName() == other.GetExternName(); +} \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ObjectType.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ObjectType.h new file mode 100644 index 000000000000..bcf155da992b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/ObjectType.h @@ -0,0 +1,70 @@ +#pragma once + +#ifndef FREEFORM2_OBJECT_TYPE_H +#define FREEFORM2_OBJECT_TYPE_H + +#include "CompoundType.h" +#include +#include +#include + +namespace FreeForm2 +{ + // Object types are used to store information about external objects. + class ObjectType : public CompoundType + { + public: + // A struct that holds the information about members of an object. + struct ObjectMember : public CompoundType::Member + { + // Constructor to initialize all members of the class. + ObjectMember(const std::string& p_name, + const TypeImpl& p_type, + const std::string& p_externName); + + // Constructor used when frontend name matches the external name. + ObjectMember(const std::string& p_name, + const TypeImpl& p_type); + + // The C++ name of the member. + std::string m_externName; + }; + + // Get the name of this object type. + virtual const std::string& GetName() const override; + + // Get the name of this object type. + const std::string& GetExternName() const; + + // Find a member by name. Return value is NULL if the object + // does not contain the specified member. + virtual const ObjectType::ObjectMember* FindMember(const std::string& p_name) const override; + + // Create derived types based on this type. + virtual const TypeImpl& AsConstType() const override; + virtual const TypeImpl& AsMutableType() const override; + private: + // Create an ObjectType of the given function information. + ObjectType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst, + TypeManager& p_typeManager); + + friend class TypeManager; + + // Compare subclass type data. + virtual bool IsSameSubType(const TypeImpl& p_type, bool p_ignoreConst) const override; + + // Functions associated with this type. + std::map m_members; + + // Name of this object in the frontend. + std::string m_name; + + // Name of this object in the backend. + std::string m_externName; + }; +} + +#endif \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/QuietMetaStreams.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/QuietMetaStreams.h new file mode 100644 index 000000000000..cca9256c2713 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/QuietMetaStreams.h @@ -0,0 +1,97 @@ +#pragma once + +#include +#include "FreeForm2Assert.h" +#include +#include +#include + +namespace FreeForm2 +{ + // This class redirects an output file to a NUL output stream and restores + // the original output mechanism when destroyed. If the class fails to + // restore the output stream for some reason, it is redirected to console + // output. + class IORedirectGuard + { + public: + static const int c_badFileDescriptor = -1; + + // Redirect the file parameter to a NUL output stream. Do nothing if + // any I/O call fails. + explicit IORedirectGuard(FILE* p_out) + : m_out(p_out), + m_redirect(nullptr), + m_duplicatedFd(c_badFileDescriptor) + { + m_redirect = fopen("NUL", "w"); + if (m_redirect) + { + m_duplicatedFd = _dup(_fileno(m_out)); + if (m_duplicatedFd == c_badFileDescriptor) + { + // On failure, close the redirect file handle. + fclose(m_redirect); + m_redirect = nullptr; + } + else + { + if (_dup2(_fileno(m_redirect), _fileno(m_out)) != 0) + { + // On failure, close the opened handles. + fclose(m_redirect); + m_redirect = nullptr; + _close(m_duplicatedFd); + m_duplicatedFd = c_badFileDescriptor; + } + } + } + } + + // Restore the output file descriptor. + ~IORedirectGuard() + { + if (m_redirect) + { + fflush(m_out); + fclose(m_redirect); + if (_dup2(m_duplicatedFd, _fileno(m_out)) != 0) + { + // On failure, redirect the output stream to console + // output. This could cause issues with scripts, but it's + // better than silencing the output. + FILE* out = nullptr; + const errno_t err = freopen_s(&out, "CONOUT$", "w", m_out); + FF2_ASSERT(err == 0); + } + } + } + + private: + // The file being redirected. + FILE* m_out; + + // The newly created output stream to which m_out is redirected. + FILE* m_redirect; + + // The duplicated file descriptor of the original output source. + int m_duplicatedFd; + }; + + + // Load a metastream definition list from a file, returning a unique_ptr + // to the object. Because the MetaStreams constructor produces (incredibly) + // versbose output, this method will silence stdout and stderr during the + // loading process. + template + std::unique_ptr QuietLoadMSDL(const String& p_path) + { + std::unique_ptr msdl; + { + IORedirectGuard stdOutGuard(stdout); + IORedirectGuard stdErrGuard(stderr); + msdl.reset(new MetaStreams(p_path)); + } + return msdl; + } +} \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StateMachineType.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StateMachineType.cpp new file mode 100644 index 000000000000..04acfc78afd2 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StateMachineType.cpp @@ -0,0 +1,119 @@ +#include "StateMachineType.h" + +#include +#include "FreeForm2Assert.h" + + +FreeForm2::StateMachineType::StateMachineType(TypeManager& p_typeManager, + const std::string& p_name, + const Member* p_members, + size_t p_numMembers, + boost::weak_ptr p_expr) + : CompoundType(Type::StateMachine, false, &p_typeManager), + m_name(p_name), + m_expr(p_expr), + m_numMembers(p_numMembers) +{ + if (m_numMembers > 0) + { + m_members[0] = p_members[0]; + + // All subsequent members must be constructed. + for (size_t i = 1; i < m_numMembers; i++) + { + m_members[i] = Member(p_members[i]); + } + } +} + + +FreeForm2::StateMachineType::~StateMachineType() +{ + // Only members > 1 need to be destructed; member 0 will be destructed + // automatically. + for (size_t i = 1; i < m_numMembers; i++) + { + m_members[i].Member::~Member(); + } +} + + +const std::string& +FreeForm2::StateMachineType::GetName() const +{ + return m_name; +} + + +const FreeForm2::CompoundType::Member* +FreeForm2::StateMachineType::FindMember(const std::string& p_name) const +{ + const MemberIterator end = EndMembers(); + for (MemberIterator iter = BeginMembers(); iter != end; ++iter) + { + if (iter->m_name == p_name) + { + return iter; + } + } + return NULL; +} + + +FreeForm2::StateMachineType::MemberIterator +FreeForm2::StateMachineType::BeginMembers() const +{ + return m_members; +} + + +FreeForm2::StateMachineType::MemberIterator +FreeForm2::StateMachineType::EndMembers() const +{ + return BeginMembers() + m_numMembers; +} + + +size_t +FreeForm2::StateMachineType::GetMemberCount() const +{ + return m_numMembers; +} + + +const FreeForm2::TypeImpl& +FreeForm2::StateMachineType::AsConstType() const +{ + return *this; +} + + +const FreeForm2::TypeImpl& +FreeForm2::StateMachineType::AsMutableType() const +{ + return *this; +} + + +bool +FreeForm2::StateMachineType::HasDefinition() const +{ + return !m_expr.expired(); +} + + +boost::shared_ptr +FreeForm2::StateMachineType::GetDefinition() const +{ + return m_expr.lock(); +} + + +bool +FreeForm2::StateMachineType::IsSameSubType(const TypeImpl& p_type, bool p_ignoreConst) const +{ + FF2_ASSERT(p_type.Primitive() == Type::StateMachine); + const StateMachineType& other = static_cast(p_type); + return StateMachineType::GetName() == other.StateMachineType::GetName(); +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StateMachineType.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StateMachineType.h new file mode 100644 index 000000000000..acfbb8dc714e --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StateMachineType.h @@ -0,0 +1,80 @@ +#pragma once + +#ifndef FREEFORM2_STATE_MACHINE_TYPE_H +#define FREEFORM2_STATE_MACHINE_TYPE_H + +#include +#include +#include "CompoundType.h" +#include + +namespace FreeForm2 +{ + class TypeManager; + class StateMachineExpression; + + // This class represents the type of a state machine. State machine types + // contain the state variables which constitute the members of the type, + // as well as a weak reference to the definition of the state machine. + class StateMachineType : public CompoundType + { + public: + // Destructor to correctly dispose of members. + ~StateMachineType(); + + // Get the name of the state machine. + virtual const std::string& GetName() const override; + + // Find a member by name. + virtual const Member* FindMember(const std::string& p_name) const override; + + // Iterate over members. + typedef const Member* MemberIterator; + MemberIterator BeginMembers() const; + MemberIterator EndMembers() const; + + // Get the number of members in the type. + size_t GetMemberCount() const; + + // Create derived types based on this type. + virtual const TypeImpl& AsConstType() const override; + virtual const TypeImpl& AsMutableType() const override; + + // Manipulate the definition of this StateMacineType. + bool HasDefinition() const; + boost::shared_ptr GetDefinition() const; + + private: + // This private constructor for StateMachineType requires memory + // allocated for the size of the StateMachineType, plus the size of + // the number of state variables and the size of the StateExpression + // pointer, less the size of one char. This should be done + // by the TypeManager. + StateMachineType(TypeManager& p_typeManager, + const std::string& p_name, + const Member* p_members, + size_t p_numMembers, + boost::weak_ptr p_expr); + + friend class TypeManager; + friend class StateMachineExpression; + + // Test if this function type is the same as another. + virtual bool IsSameSubType(const TypeImpl& p_type, bool p_ignoreConst) const override; + + // The state machine name. + std::string m_name; + + // The state machine definition associated with this type. + mutable boost::weak_ptr m_expr; + + // The number of state variables in the data blob. + size_t m_numMembers; + + // A blob of data holding both the state variables followed by state + // expression pointers, allocated using the struct hack. + Member m_members[1]; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StructType.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StructType.cpp new file mode 100644 index 000000000000..416e7ce737e7 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StructType.cpp @@ -0,0 +1,186 @@ +#include "StructType.h" + +#include +#include "FreeForm2Assert.h" +#include +#include "TypeManager.h" + +FreeForm2::StructType::MemberInfo::MemberInfo(const std::string& p_name, + const TypeImpl& p_type, + const std::string& p_externName, + size_t p_offset, + size_t p_size) + : Member(p_name, p_type), + m_externName(p_externName), + m_offset(p_offset), + m_size(p_size) +{ +} + + +bool +FreeForm2::StructType::MemberInfo::operator==(const MemberInfo& p_other) const +{ + return (m_name == p_other.m_name + && m_externName == p_other.m_externName + && m_offset == p_other.m_offset + && m_size == p_other.m_size + && (*m_type == *p_other.m_type)); +} + + +bool +FreeForm2::StructType::MemberInfo::operator!=(const MemberInfo& p_other) const +{ + return !(*this == p_other); +} + + +FreeForm2::StructType::StructType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst, + TypeManager& p_typeManager) + : CompoundType(Type::Struct, p_isConst, &p_typeManager), + m_name(p_name), + m_externName(p_externName) +{ + m_members.reserve(p_members.size()); + BOOST_FOREACH (StructType::MemberInfo member, p_members) + { + const TypeImpl& memberType + = p_isConst ? member.m_type->AsConstType() : member.m_type->AsMutableType(); + m_members.push_back(MemberInfo(member.m_name, + memberType, + member.m_externName, + member.m_offset, + member.m_size)); + + // Verify that all names are unique. + FF2_ASSERT(m_memberMapping.find(m_members.back().m_name) == m_memberMapping.end()); + m_memberMapping.insert(std::make_pair(m_members.back().m_name, &m_members.back())); + } +} + + +const FreeForm2::CompoundType::Member* +FreeForm2::StructType::FindMember(const std::string& p_name) const +{ + return FindStructMember(p_name); +} + + +const FreeForm2::StructType::MemberInfo* +FreeForm2::StructType::FindStructMember(const std::string& p_name) const +{ + std::map::const_iterator member + = m_memberMapping.find(p_name); + return (member != m_memberMapping.end()) ? member->second : NULL; +} + + +const std::vector& +FreeForm2::StructType::GetMembers() const +{ + return m_members; +} + + +const std::string& +FreeForm2::StructType::GetName() const +{ + return m_name; +} + + +const std::string& +FreeForm2::StructType::GetExternName() const +{ + return m_externName; +} + + +std::string +FreeForm2::StructType::GetString() const +{ + std::ostringstream out; + out << "struct " << GetName() << "{"; + + bool first = true; + + BOOST_FOREACH (StructType::MemberInfo member, GetMembers()) + { + if (!first) + { + out << ", "; + } + + out << *member.m_type << " " << member.m_name; + + first = false; + } + + out << "}"; + + return out.str(); +} + +const FreeForm2::TypeImpl& +FreeForm2::StructType::AsConstType() const +{ + if (IsConst()) + { + return *this; + } + else + { + FF2_ASSERT(GetTypeManager() != NULL); + return GetTypeManager()->GetStructType(GetName(), GetExternName(), GetMembers(), true); + } +} + + +const FreeForm2::TypeImpl& +FreeForm2::StructType::AsMutableType() const +{ + if (!IsConst()) + { + return *this; + } + else + { + FF2_ASSERT(GetTypeManager() != NULL); + return GetTypeManager()->GetStructType(GetName(), GetExternName(), GetMembers(), false); + } +} + + +bool +FreeForm2::StructType::IsSameSubType(const TypeImpl& p_other, bool p_ignoreConst) const +{ + FF2_ASSERT(p_other.Primitive() == Type::Struct); + const StructType& other = static_cast(p_other); + + if (GetMembers().size() != other.GetMembers().size()) + { + return false; + } + + for (size_t i = 0; i < GetMembers().size(); i++) + { + const MemberInfo& m1 = GetMembers()[i]; + const MemberInfo& m2 = other.GetMembers()[i]; + if (!m1.m_type->IsSameAs(*m2.m_type, p_ignoreConst) + || m1.m_externName != m2.m_externName + || m1.m_name != m2.m_name + || m1.m_offset != m2.m_offset + || m1.m_size != m2.m_size) + { + return false; + } + } + + return GetName() == other.GetName() + && GetExternName() == other.GetExternName(); +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StructType.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StructType.h new file mode 100644 index 000000000000..f70dcbd26d42 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/StructType.h @@ -0,0 +1,96 @@ +#pragma once + +#ifndef FREEFORM2_STRUCT_TYPE_H +#define FREEFORM2_STRUCT_TYPE_H + +#include "CompoundType.h" +#include +#include +#include + +namespace FreeForm2 +{ + // Structure types are C-like structs, which are named and contain named + // members. + class StructType : public CompoundType + { + public: + // A struct that holds the information about members of a struct. + class MemberInfo : public CompoundType::Member + { + public: + // Constructor to initialize all members of the class. + MemberInfo(const std::string& p_name, + const TypeImpl& p_type, + const std::string& p_externName, + size_t p_offset, + size_t p_size); + + // Default constructor to initialize members to empty values. + MemberInfo(); + + // The C++ name of the member. + std::string m_externName; + + // The offset (in bytes) of the member from the beginning of the struct. + size_t m_offset; + + // The size (in bytes) of this member. + size_t m_size; + + // Equality operators. + bool operator==(const MemberInfo& p_other) const; + bool operator!=(const MemberInfo& p_other) const; + }; + + // Find a member of the given name within this struct. + virtual const Member* FindMember(const std::string& p_name) const override; + + // Find a member info object of the given name within the struct. + // Behaves similarly to FindMember. + const MemberInfo* FindStructMember(const std::string& p_name) const; + + // Gets a vector of member information objects. + const std::vector& GetMembers() const; + + // Gets the name of this struct. + virtual const std::string& GetName() const override; + + // Gets the C++ name of this struct. + const std::string& GetExternName() const; + + // Get a string representation of the type. + virtual std::string GetString() const; + + // Create derived types based on this type. + virtual const TypeImpl& AsConstType() const override; + virtual const TypeImpl& AsMutableType() const override; + + private: + // Creates a new StructInfo based on the list of members. + StructType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst, + TypeManager& p_typeManager); + + friend class TypeManager; + + // Compare subclass type data. + virtual bool IsSameSubType(const TypeImpl& p_type, bool p_ignoreConst) const override; + + // The (ordered) list of members of this struct. + std::vector m_members; + + // A mapping between names and MemberInfo structures. + std::map m_memberMapping; + + // The name exposed to Visage. + std::string m_name; + + // The C++ name of the struct. + std::string m_externName; + }; +} + +#endif \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeImpl.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeImpl.cpp new file mode 100644 index 000000000000..3c2862560e66 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeImpl.cpp @@ -0,0 +1,373 @@ +#include "TypeImpl.h" + +#include "ArrayType.h" +#include "FreeForm2Assert.h" +#include +#include "StructType.h" + +namespace +{ + std::string GetTypeName(FreeForm2::Type::TypePrimitive p_type, bool p_isConst) + { + std::ostringstream out; + if (!p_isConst) + { + out << "mutable "; + } + out << FreeForm2::Type::Name(p_type); + return out.str(); + } + + // A class representing the builtin primitive types. + class PrimitiveType : public FreeForm2::TypeImpl + { + public: + PrimitiveType(FreeForm2::Type::TypePrimitive p_prim, bool p_isConst) + : TypeImpl(p_prim, p_isConst, NULL), m_name(GetTypeName(p_prim, p_isConst)) + { + } + + virtual const std::string& + GetName() const override + { + return m_name; + } + + virtual const TypeImpl& AsConstType() const override + { + using FreeForm2::Type; + switch (Primitive()) + { + case Type::Float: return TypeImpl::GetFloatInstance(true); + case Type::Int: return TypeImpl::GetIntInstance(true); + case Type::UInt64: return TypeImpl::GetUInt64Instance(true); + case Type::Int32: return TypeImpl::GetInt32Instance(true); + case Type::UInt32: return TypeImpl::GetUInt32Instance(true); + case Type::Bool: return TypeImpl::GetBoolInstance(true); + case Type::Stream: return TypeImpl::GetStreamInstance(true); + case Type::Word: return TypeImpl::GetWordInstance(true); + case Type::InstanceHeader: return TypeImpl::GetInstanceHeaderInstance(true); + case Type::BodyBlockHeader: return TypeImpl::GetBodyBlockHeaderInstance(true); + + // For Void, Unknown, and Invalid, constness is not definable. + case Type::Void: __attribute__((__fallthrough__)); + case Type::Unknown: __attribute__((__fallthrough__)); + case Type::Invalid: return *this; + + default: FreeForm2::Unreachable(__FILE__, __LINE__); + } + } + + virtual const TypeImpl& AsMutableType() const override + { + using FreeForm2::Type; + switch (Primitive()) + { + case Type::Float: return TypeImpl::GetFloatInstance(false); + case Type::Int: return TypeImpl::GetIntInstance(false); + case Type::UInt64: return TypeImpl::GetUInt64Instance(false); + case Type::Int32: return TypeImpl::GetInt32Instance(false); + case Type::UInt32: return TypeImpl::GetUInt32Instance(false); + case Type::Bool: return TypeImpl::GetBoolInstance(false); + case Type::Stream: return TypeImpl::GetStreamInstance(false); + case Type::Word: return TypeImpl::GetWordInstance(false); + case Type::InstanceHeader: return TypeImpl::GetInstanceHeaderInstance(false); + case Type::BodyBlockHeader: return TypeImpl::GetBodyBlockHeaderInstance(false); + + // For Void, Unknown, and Invalid, constness is not definable. + case Type::Void: __attribute__((__fallthrough__)); + case Type::Unknown: __attribute__((__fallthrough__)); + case Type::Invalid: return *this; + + default: FreeForm2::Unreachable(__FILE__, __LINE__); + } + } + + private: + virtual bool IsSameSubType(const TypeImpl& p_other, bool p_ignoreConst) const override + { + FF2_ASSERT(Primitive() == p_other.Primitive()); + return true; + } + + std::string m_name; + }; +} + + +FreeForm2::TypeImpl::TypeImpl(Type::TypePrimitive p_prim, bool p_isConst, TypeManager* p_typeManager) + : m_prim(p_prim), + m_isConst(p_isConst), + m_typeManager(p_typeManager) +{ +} + + +FreeForm2::Type::TypePrimitive +FreeForm2::TypeImpl::Primitive() const +{ + return m_prim; +} + + +bool +FreeForm2::TypeImpl::IsSameAs(const TypeImpl& p_other, bool p_ignoreConst) const +{ + return (Primitive() == p_other.Primitive() + && (p_ignoreConst || IsConst() == p_other.IsConst()) + && IsSameSubType(p_other, p_ignoreConst)); +} + + +bool +FreeForm2::TypeImpl::operator==(const TypeImpl& p_other) const +{ + return IsSameAs(p_other, false); +} + + +bool +FreeForm2::TypeImpl::operator!=(const TypeImpl& p_other) const +{ + return !(*this == p_other); +} + + +bool +FreeForm2::TypeImpl::IsLeafType(Type::TypePrimitive p_prim) +{ + switch (p_prim) + { + case Type::Bool: + case Type::Int: + case Type::UInt64: + case Type::Int32: + case Type::UInt32: + case Type::Float: + { + return true; + } + + // Note that Type::Void does not have fixed size (it effectively has no + // size). + case Type::Void: + case Type::Unknown: + case Type::Array: + case Type::Struct: + case Type::Stream: + case Type::Word: + case Type::InstanceHeader: + case Type::BodyBlockHeader: + case Type::StateMachine: + case Type::Function: + case Type::Object: + { + return false; + } + + default: + { + Unreachable(__FILE__, __LINE__); + break; + } + }; +} + + +bool +FreeForm2::TypeImpl::IsLeafType() const +{ + return IsLeafType(m_prim); +} + + +bool +FreeForm2::TypeImpl::IsIntegerType() const +{ + return (m_prim == Type::Int) + || (m_prim == Type::UInt64) + || (m_prim == Type::Int32) + || (m_prim == Type::UInt32); +} + + +bool +FreeForm2::TypeImpl::IsFloatingPointType() const +{ + return m_prim == Type::Float; +} + + +bool +FreeForm2::TypeImpl::IsSigned() const +{ + return (m_prim == Type::Int) + || (m_prim == Type::Int32) + || (m_prim == Type::Float); +} + + +bool +FreeForm2::TypeImpl::IsValid() const +{ + return Primitive() != Type::Invalid; +} + + +bool +FreeForm2::TypeImpl::IsConst() const +{ + return m_isConst; +} + + +FreeForm2::TypeManager* +FreeForm2::TypeImpl::GetTypeManager() const +{ + return m_typeManager; +} + + +std::ostream& +FreeForm2::operator<<(std::ostream& p_out, const TypeImpl& p_type) +{ + return p_out << p_type.GetName(); +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetFloatInstance(bool p_isConst) +{ + static const PrimitiveType type(Type::Float, false); + static const PrimitiveType constType(Type::Float, true); + return p_isConst ? constType : type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetIntInstance(bool p_isConst) +{ + static const PrimitiveType type(Type::Int, false); + static const PrimitiveType constType(Type::Int, true); + return p_isConst ? constType : type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetUInt64Instance(bool p_isConst) +{ + static const PrimitiveType type(Type::UInt64, false); + static const PrimitiveType constType(Type::UInt64, true); + return p_isConst ? constType : type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetInt32Instance(bool p_isConst) +{ + static const PrimitiveType type(Type::Int32, false); + static const PrimitiveType constType(Type::Int32, true); + return p_isConst ? constType : type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetUInt32Instance(bool p_isConst) +{ + static const PrimitiveType type(Type::UInt32, false); + static const PrimitiveType constType(Type::UInt32, true); + return p_isConst ? constType : type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetBoolInstance(bool p_isConst) +{ + static const PrimitiveType type(Type::Bool, false); + static const PrimitiveType constType(Type::Bool, true); + return p_isConst ? constType : type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetVoidInstance() +{ + static const PrimitiveType type(Type::Void, true); + return type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetWordInstance(bool p_isConst) +{ + static const PrimitiveType type(Type::Word, false); + static const PrimitiveType constType(Type::Word, true); + return p_isConst ? constType : type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetInstanceHeaderInstance(bool p_isConst) +{ + static const PrimitiveType type(Type::InstanceHeader, false); + static const PrimitiveType constType(Type::InstanceHeader, true); + return p_isConst ? constType : type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetBodyBlockHeaderInstance(bool p_isConst) +{ + static const PrimitiveType type(Type::BodyBlockHeader, false); + static const PrimitiveType constType(Type::BodyBlockHeader, true); + return p_isConst ? constType : type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetStreamInstance(bool p_isConst) +{ + static const PrimitiveType type(Type::Stream, false); + static const PrimitiveType constType(Type::Stream, true); + return p_isConst ? constType : type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetUnknownType() +{ + static const PrimitiveType type(Type::Unknown, true); + return type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetInvalidType() +{ + static const PrimitiveType type(Type::Invalid, true); + return type; +} + + +const FreeForm2::TypeImpl& +FreeForm2::TypeImpl::GetCommonType(Type::TypePrimitive p_prim, bool p_isConst) +{ + switch (p_prim) + { + case Type::Float: return TypeImpl::GetFloatInstance(p_isConst); + case Type::Int: return TypeImpl::GetIntInstance(p_isConst); + case Type::UInt64: return TypeImpl::GetUInt64Instance(p_isConst); + case Type::Int32: return TypeImpl::GetInt32Instance(p_isConst); + case Type::UInt32: return TypeImpl::GetUInt32Instance(p_isConst); + case Type::Bool: return TypeImpl::GetBoolInstance(p_isConst); + case Type::Void: return TypeImpl::GetVoidInstance(); + case Type::Stream: return TypeImpl::GetStreamInstance(p_isConst); + case Type::Word: return TypeImpl::GetWordInstance(p_isConst); + case Type::InstanceHeader: return TypeImpl::GetInstanceHeaderInstance(p_isConst); + case Type::BodyBlockHeader: return TypeImpl::GetBodyBlockHeaderInstance(p_isConst); + case Type::Unknown: return TypeImpl::GetUnknownType(); + case Type::Invalid: return TypeImpl::GetInvalidType(); + default: Unreachable(__FILE__, __LINE__); + } +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeImpl.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeImpl.h new file mode 100644 index 000000000000..d61e3b85d245 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeImpl.h @@ -0,0 +1,111 @@ +#pragma once + +#ifndef FREEFORM2_TYPEIMPL_H +#define FREEFORM2_TYPEIMPL_H + +#include +#include "FreeForm2Type.h" +#include +#include +#include + +namespace FreeForm2 +{ + class TypeManager; + + class TypeImpl : boost::noncopyable + { + public: + // Construct a type implementation from the given type primitive + // and const-ness. + TypeImpl(Type::TypePrimitive p_prim, bool p_isConst, TypeManager* p_typeManager); + + // Get the type primitive of this type. + Type::TypePrimitive Primitive() const; + + // Equality comparison for types. This function returns true if two + // types are equivalent; otherwise returns false. This function will + // ignore constness if p_ignoreConst is true. + bool IsSameAs(const TypeImpl& p_type, bool p_ignoreConst) const; + + // Equality operator. Implemented in terms of IsSameAs, taking + // constness into account. + bool operator==(const TypeImpl& p_other) const; + + // Inequality operator. Implemented in terms of operator ==. + bool operator!=(const TypeImpl& p_other) const; + + // Get a string representation of the type. + virtual const std::string& GetName() const = 0; + + // A static and member function to determine whether a given or the + // current type is of fixed size. Types that are not of fixed size + // include arrays (variable-sized) and void (doesn't have a size). + static bool IsLeafType(Type::TypePrimitive p_prim); + bool IsLeafType() const; + + // Check if this type is an integer type. + bool IsIntegerType() const; + + // Check if this type is a floating-point type. + bool IsFloatingPointType() const; + + // Check if this type is signed. Signed types include signed integer + // types and the float type. + bool IsSigned() const; + + // Check if this type is valid. Unknown types are valid. + bool IsValid() const; + + // Check if this type is const. + bool IsConst() const; + + // Create derived types based on this type. + virtual const TypeImpl& AsConstType() const = 0; + virtual const TypeImpl& AsMutableType() const = 0; + + // Return convenience single instances of some common, leaf types. + static const TypeImpl& GetFloatInstance(bool p_isConst); + static const TypeImpl& GetIntInstance(bool p_isConst); + static const TypeImpl& GetUInt64Instance(bool p_isConst); + static const TypeImpl& GetInt32Instance(bool p_isConst); + static const TypeImpl& GetUInt32Instance(bool p_isConst); + static const TypeImpl& GetBoolInstance(bool p_isConst); + static const TypeImpl& GetVoidInstance(); + static const TypeImpl& GetStreamInstance(bool p_isConst); + static const TypeImpl& GetWordInstance(bool p_isConst); + static const TypeImpl& GetInstanceHeaderInstance(bool p_isConst); + static const TypeImpl& GetBodyBlockHeaderInstance(bool p_isConst); + static const TypeImpl& GetUnknownType(); + static const TypeImpl& GetInvalidType(); + + // Get the type for the singleton types above. + static const TypeImpl& GetCommonType(Type::TypePrimitive p_prim, bool p_isConst); + + protected: + // For derived classes, return the TypeManager passed to the + // constructor of this TypeImpl. + TypeManager* GetTypeManager() const; + + private: + // This method determines if the subclass data for this TypeImpl is the + // same as the paramter. This method may assume that Primitive() == + // p_type.Primitive(). + virtual bool IsSameSubType(const TypeImpl& p_type, bool p_ignoreConst) const = 0; + + // Type primitive of this type. + Type::TypePrimitive m_prim; + + // Constness of the type. + bool m_isConst; + + // Owning TypeManager of this type. + TypeManager* m_typeManager; + }; + + std::ostream& + operator<<(std::ostream& p_out, const TypeImpl& p_type); +} + +#endif + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeManager.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeManager.cpp new file mode 100644 index 000000000000..5e1e63fabc44 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeManager.cpp @@ -0,0 +1,1097 @@ +#include "TypeManager.h" + +#include "ArrayType.h" +#include +#include +#include +#include +#include +#include +#include "FreeForm2Assert.h" +#include "FreeForm2ExternalData.h" +#include +//#include "MetaStreams.h" +#include "ObjectType.h" +//#include "RankerFeatures.h" +#include +#include "StateMachineType.h" +#include "StructType.h" +#include "TypeImpl.h" + +#include "MigratedApi.h" + +using namespace FreeForm2; + +namespace +{ + template + void ByteArrayDeleter(T* p_delete) + { + // Cast to char*, as this structure was allocated as char[]. + p_delete->~T(); + char* data = reinterpret_cast(p_delete); + delete[] data; + } + + // A class that manages type information. For now it only contains a + // std::string-to-StructInfo mapping. + class NamedTypeManager : public TypeManager + { + public: + // Create an empty type manager with the specified parent. + NamedTypeManager(const TypeManager& p_parent); + + // Constructor to create the global TypeManager. + NamedTypeManager(); + + // Gets the type information for the provided name. Returns NULL if the + // type is not found. + virtual const TypeImpl* GetTypeInfo(const std::string& p_name) const override; + + // Create a variable sized array type owned by this TypeManager. + virtual const ArrayType& + GetArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + unsigned int p_maxElements) override; + + // Create a fixed-sized array type owned by this TypeManager. + virtual const ArrayType& + GetArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + const unsigned int p_elementCounts[], + unsigned int p_maxElements) override; + + // Returns an array type owned by this TypeManager which has the same + // properties as another ArrayType. + virtual + const ArrayType& + GetArrayType(const ArrayType& p_type) override; + + // Create a struct type owned by this TypeManager. If the specified + // name exists, the existing structure is returned iff it is exactly + // the same as the parameters; otherwise, an the function asserts. + virtual const StructType& + GetStructType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) override; + + // Returns a struct type owned by this TypeManager which has the same + // properties as another StructType. + virtual + const StructType& + GetStructType(const StructType& p_type) override; + + // Get an object type owned by this TypeManager. The TypeManager is not + // required to allow multiple non-unique names exist in the context of + // its owned types. + virtual + const ObjectType& + GetObjectType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) override; + + // Returns an object type owned by this TypeManager which has the same + // properties as another StructType. + virtual + const ObjectType& + GetObjectType(const ObjectType& p_type) override; + + // Returns a state machine type owned by this TypeManager with the same + // semantics as GetStructType. + virtual + const StateMachineType& + GetStateMachineType(const std::string& p_name, + const CompoundType::Member* p_members, + size_t p_numMembers, + boost::weak_ptr p_expr) override; + + // Returns a state machine type owned by this TypeManager which has the + // same properties as another state machine type. + virtual + const StateMachineType& + GetStateMachineType(const StateMachineType& p_type) override; + + // Get a function type owned by this TypeManager. The TypeManager will just store one + // function type per signature. + virtual + const FunctionType& + GetFunctionType(const TypeImpl& p_returnType, + const TypeImpl* const* p_parameters, + size_t p_numParameters) override; + + // Returns a function type owned by this TypeManager which has the same + // properties as another FunctionType. + virtual + const FunctionType& + GetFunctionType(const FunctionType& p_type) override; + + private: + template + const T& RegisterType(boost::shared_ptr p_type, const std::string& p_name) + { + BOOST_STATIC_ASSERT((boost::is_base_of::value == true)); + boost::shared_ptr ptr = boost::static_pointer_cast(p_type); + m_typeMap.insert(std::make_pair(p_name, ptr)); + return *p_type; + } + + + template + const T& RegisterType(boost::shared_ptr p_type) + { + BOOST_STATIC_ASSERT((boost::is_base_of::value == true)); + return RegisterType(p_type, p_type->GetName()); + } + + // A mapping from names to types. + std::map> m_typeMap; + }; +} + + +TypeManager::TypeManager(const TypeManager* p_parent) + : m_parent(p_parent) +{ +} + + +TypeManager::~TypeManager() +{ +} + + +NamedTypeManager::NamedTypeManager() : TypeManager(NULL) +{ + // Tuples of interest. + boost::array tuplesOfInterestArray = {{ + StructType::MemberInfo("WordStart", TypeImpl::GetUInt32Instance(true), "iWordStart", 0, 4), + StructType::MemberInfo("WordEnd", TypeImpl::GetUInt32Instance(true), "iWordEnd", 4, 4), + StructType::MemberInfo("Weight", TypeImpl::GetUInt32Instance(true), "iWeight", 8, 4) + }}; + + std::vector tuplesOfInterest(tuplesOfInterestArray.begin(), + tuplesOfInterestArray.end()); + + NamedTypeManager::GetStructType("TupleOfInterest","FreeForm2::TupleOfInterest", tuplesOfInterest, true); + + // AllDoublesDecodeIndexes + boost::array allDoublesDecodeIndexesArray = {{ + StructType::MemberInfo("FirstIndex", TypeImpl::GetUInt32Instance(true), "m_firstIndex", 0, 4), + StructType::MemberInfo("SecondIndex", TypeImpl::GetUInt32Instance(true), "m_secondIndex", 4, 4) + }}; + + std::vector allDoublesDecodeIndexes(allDoublesDecodeIndexesArray.begin(), + allDoublesDecodeIndexesArray.end()); + + NamedTypeManager::GetStructType("AllDoublesDecodeIndexes","FreeForm2::RuntimeLibrary::AllDoublesDecodeIndexes", allDoublesDecodeIndexes, true); + + // AllTriplesDecodeIndexes + boost::array allTriplesDecodeIndexesArray = {{ + StructType::MemberInfo("FirstIndex", TypeImpl::GetUInt32Instance(true), "m_firstIndex", 0, 4), + StructType::MemberInfo("SecondIndex", TypeImpl::GetUInt32Instance(true), "m_secondIndex", 4, 4), + StructType::MemberInfo("ThirdIndex", TypeImpl::GetUInt32Instance(true), "m_thirdIndex", 8, 4) + }}; + + std::vector allTriplesDecodeIndexes(allTriplesDecodeIndexesArray.begin(), + allTriplesDecodeIndexesArray.end()); + + NamedTypeManager::GetStructType("AllTriplesDecodeIndexes","FreeForm2::RuntimeLibrary::AllTriplesDecodeIndexes", allTriplesDecodeIndexes, true); + + // NumberOfTuples Object. + std::vector numberOfTuplesCommonMembers; + + const FunctionType& numberOfTuplesInitialize + = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), NULL, 0); + numberOfTuplesCommonMembers.push_back(ObjectType::ObjectMember("Initialize", numberOfTuplesInitialize)); + + const FunctionType& numberOfTuplesReset + = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), NULL, 0); + numberOfTuplesCommonMembers.push_back(ObjectType::ObjectMember("Reset", numberOfTuplesReset)); + + const FunctionType& numberOfTuplesInflateMatrix + = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), NULL, 0); + numberOfTuplesCommonMembers.push_back(ObjectType::ObjectMember("InflateMatrix", numberOfTuplesInflateMatrix)); + + const TypeImpl* numberOfTuplesAddWordArray[] = { &TypeImpl::GetWordInstance(true) }; + const FunctionType& numberOfTuplesAddWord + = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), + numberOfTuplesAddWordArray, + countof(numberOfTuplesAddWordArray)); + numberOfTuplesCommonMembers.push_back(ObjectType::ObjectMember("AddWord", numberOfTuplesAddWord)); + + const TypeImpl* numberOfTuplesValueArray[] = { &TypeImpl::GetUInt32Instance(true), + &TypeImpl::GetUInt32Instance(true) }; + const FunctionType& numberOfTuplesValue + = NamedTypeManager::GetFunctionType(TypeImpl::GetInt32Instance(true), + numberOfTuplesValueArray, + countof(numberOfTuplesValueArray)); + const FunctionType& offsetOfTuplesValue + = NamedTypeManager::GetFunctionType(TypeImpl::GetUInt32Instance(true), + numberOfTuplesValueArray, + countof(numberOfTuplesValueArray)); + numberOfTuplesCommonMembers.push_back(ObjectType::ObjectMember("numberOfTuples", numberOfTuplesValue)); + numberOfTuplesCommonMembers.push_back(ObjectType::ObjectMember("firstOccurrenceOffsetOfTuples", offsetOfTuplesValue)); + numberOfTuplesCommonMembers.push_back(ObjectType::ObjectMember("lastOccurrenceOffsetOfTuples", offsetOfTuplesValue)); + numberOfTuplesCommonMembers.push_back(ObjectType::ObjectMember("incrementValue", + TypeImpl::GetInt32Instance(false), + "m_iIncrementingValue")); + + NamedTypeManager::GetObjectType("NumberOfTuplesCommon", + "CNumberOfTuples", + numberOfTuplesCommonMembers, + false); + + NamedTypeManager::GetObjectType("NumberOfTuplesCommonNoDuplicate", + "CNumberOfTuples", + numberOfTuplesCommonMembers, + false); + + // NumberOfTuplesInTriples Object. + std::vector numberOfTuplesInTriplesCommonMembers; + + const TypeImpl* numberOfTuplesInTriplesInitializeArray[] = { &TypeImpl::GetUInt32Instance(true), + &TypeImpl::GetUInt32Instance(true), + &TypeImpl::GetUInt32Instance(true) }; + const FunctionType& numberOfTuplesInTriplesInitialize + = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), + numberOfTuplesInTriplesInitializeArray, + countof(numberOfTuplesInTriplesInitializeArray)); + numberOfTuplesInTriplesCommonMembers.push_back(ObjectType::ObjectMember("Initialize", numberOfTuplesInTriplesInitialize)); + + const TypeImpl* numberOfTuplesInTriplesStartPhraseArray[] = { &TypeImpl::GetUInt32Instance(true) }; + const FunctionType& numberOfTuplesInTriplesStartPhrase + = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), + numberOfTuplesInTriplesStartPhraseArray, + countof(numberOfTuplesInTriplesStartPhraseArray)); + numberOfTuplesInTriplesCommonMembers.push_back(ObjectType::ObjectMember("StartPhrase", numberOfTuplesInTriplesStartPhrase)); + + const TypeImpl* numberOfTuplesInTriplesAddWordArray[] = { &TypeImpl::GetWordInstance(true) }; + const FunctionType& numberOfTuplesInTriplesAddWord + = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), + numberOfTuplesInTriplesAddWordArray, + countof(numberOfTuplesInTriplesAddWordArray)); + numberOfTuplesInTriplesCommonMembers.push_back(ObjectType::ObjectMember("AddWord", numberOfTuplesInTriplesAddWord)); + + const FunctionType& numberOfTuplesInTriplesEndPage + = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), NULL, 0); + numberOfTuplesInTriplesCommonMembers.push_back(ObjectType::ObjectMember("EndPage", numberOfTuplesInTriplesEndPage)); + + const TypeImpl* numberOfTuplesInTriplesValueArray[] = { &TypeImpl::GetUInt32Instance(true) }; + const FunctionType& numberOfTuplesInTriplesValue + = NamedTypeManager::GetFunctionType(TypeImpl::GetUInt32Instance(true), + numberOfTuplesInTriplesValueArray, + countof(numberOfTuplesInTriplesValueArray)); + numberOfTuplesInTriplesCommonMembers.push_back(ObjectType::ObjectMember("numberOfTuples", numberOfTuplesInTriplesValue)); + numberOfTuplesInTriplesCommonMembers.push_back(ObjectType::ObjectMember("numberOfTuplesInOrder", numberOfTuplesInTriplesValue)); + + NamedTypeManager::GetObjectType("NumberOfTuplesInTriplesCommon", + "CNumberOfTuplesInTriples", + numberOfTuplesInTriplesCommonMembers, + false); + + // WeightingCalculator Object + std::vector weightingCalculatorMembers; + + const FunctionType& weightingCalculatorReset + = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), NULL, 0); + weightingCalculatorMembers.push_back(ObjectType::ObjectMember("Reset", weightingCalculatorReset)); + + { + const TypeImpl* weightingCalculatorAddWordArray[] = { &TypeImpl::GetWordInstance(true) }; + const FunctionType& weightingCalculatorAddWord + = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), + weightingCalculatorAddWordArray, + countof(weightingCalculatorAddWordArray)); + weightingCalculatorMembers.push_back(ObjectType::ObjectMember("AddWord", weightingCalculatorAddWord)); + } + + { + const TypeImpl* weightingCalculatorApplyWeightingArray[] = { &TypeImpl::GetIntInstance(true) }; + const FunctionType& weightingCalculatorApplyWeighting + = NamedTypeManager::GetFunctionType(TypeImpl::GetIntInstance(true), + weightingCalculatorApplyWeightingArray, + countof(weightingCalculatorApplyWeightingArray)); + weightingCalculatorMembers.push_back(ObjectType::ObjectMember("ApplyWeightingRound", + weightingCalculatorApplyWeighting, + "ApplyWeighting")); + } + + { + const TypeImpl* weightingCalculatorApplyWeightingArray[] = { &TypeImpl::GetFloatInstance(true) }; + const FunctionType& weightingCalculatorApplyWeighting + = NamedTypeManager::GetFunctionType(TypeImpl::GetFloatInstance(true), + weightingCalculatorApplyWeightingArray, + countof(weightingCalculatorApplyWeightingArray)); + weightingCalculatorMembers.push_back( + ObjectType::ObjectMember("ApplyWeighting", weightingCalculatorApplyWeighting)); + } + + NamedTypeManager::GetObjectType("AlterationAndTermWeightingCalculator", + "BarramundiWeightingCalculator", + weightingCalculatorMembers, + false); + + NamedTypeManager::GetObjectType("AlterationWeightingCalculator", + "BarramundiWeightingCalculator", + weightingCalculatorMembers, + false); + + // TrueNearDoublesQueue Object + std::vector trueNearDoubleQueueMembers; + + const FunctionType& trueNearDoublesQueueReset = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), NULL, 0); + trueNearDoubleQueueMembers.push_back(ObjectType::ObjectMember("Reset", trueNearDoublesQueueReset)); + + { + const TypeImpl* trueNearDoublesQueueReceiveWordArray[] = { &TypeImpl::GetWordInstance(true), + &TypeImpl::GetUInt32Instance(false), + &TypeImpl::GetUInt32Instance(false) }; + const FunctionType& trueNearDoublesQueueReceiveWord = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), + trueNearDoublesQueueReceiveWordArray, + countof(trueNearDoublesQueueReceiveWordArray)); + trueNearDoubleQueueMembers.push_back(ObjectType::ObjectMember("ReceiveWord", trueNearDoublesQueueReceiveWord)); + } + + NamedTypeManager::GetObjectType("TrueNearDoubleQueue", "TrueNearDoubleQueue", trueNearDoubleQueueMembers, false); + + // BoundedQueue Object + std::vector boundedQueueMembers; + + { + const FunctionType& boundedQueueClear = NamedTypeManager::GetFunctionType(TypeImpl::GetVoidInstance(), NULL, 0); + boundedQueueMembers.push_back(ObjectType::ObjectMember("Clear", boundedQueueClear)); + } + + { + const FunctionType& boundedQueueEmpty = NamedTypeManager::GetFunctionType(TypeImpl::GetBoolInstance(true), NULL, 0); + boundedQueueMembers.push_back(ObjectType::ObjectMember("Empty", boundedQueueEmpty)); + } + + { + const FunctionType& boundedQueueFull = NamedTypeManager::GetFunctionType(TypeImpl::GetBoolInstance(true), NULL, 0); + boundedQueueMembers.push_back(ObjectType::ObjectMember("Full", boundedQueueFull)); + } + + { + const TypeImpl* boundedQueueGetParamArray[] = { &TypeImpl::GetUInt32Instance(true) }; + const FunctionType& boundedQueueGet = NamedTypeManager::GetFunctionType(TypeImpl::GetIntInstance(true), boundedQueueGetParamArray, 1); + boundedQueueMembers.push_back(ObjectType::ObjectMember("Get", boundedQueueGet)); + } + + { + const FunctionType& boundedQueuePop = NamedTypeManager::GetFunctionType(TypeImpl::GetIntInstance(true), NULL, 0); + boundedQueueMembers.push_back(ObjectType::ObjectMember("Pop", boundedQueuePop)); + } + + { + const FunctionType& boundedQueuePush = NamedTypeManager::GetFunctionType(TypeImpl::GetIntInstance(true), NULL, 0); + boundedQueueMembers.push_back(ObjectType::ObjectMember("Push", boundedQueuePush)); + } + + { + const FunctionType& boundedQueueSize = NamedTypeManager::GetFunctionType(TypeImpl::GetUInt32Instance(true), NULL, 0); + boundedQueueMembers.push_back(ObjectType::ObjectMember("Size", boundedQueueSize)); + } + + // Visage bounded queues can hold up to 41 indices. + NamedTypeManager::GetObjectType("BoundedQueue", "FreeForm2::RuntimeLibrary::BoundedQueue<41>", boundedQueueMembers, false); +} + + +NamedTypeManager::NamedTypeManager(const TypeManager& p_parent) : TypeManager(&p_parent) +{ +} + + +const TypeImpl* +NamedTypeManager::GetTypeInfo(const std::string& p_name) const +{ + std::map>::const_iterator info + = m_typeMap.find(p_name); + + if (info != m_typeMap.end()) + { + return info->second.get(); + } + else if (GetParent() != NULL) + { + return GetParent()->GetTypeInfo(p_name); + } + else + { + return NULL; + } +} + + +boost::shared_ptr +TypeManager::CreateArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + const unsigned int p_elementCounts[], + unsigned int p_maxElements) +{ + const size_t structSize = sizeof(ArrayType) + sizeof(unsigned int) + * (p_dimensions > 0 ? p_dimensions - 1 : 0); + + char* mem = NULL; + try + { + mem = new char[structSize]; + return boost::shared_ptr( + new (mem) ArrayType(p_child, p_isConst, p_dimensions, p_elementCounts, p_maxElements, *this), + ByteArrayDeleter); + } + catch (...) + { + delete[] mem; + throw; + } +} + + +boost::shared_ptr +TypeManager::CreateArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + unsigned int p_maxElements) +{ + return boost::shared_ptr( + new ArrayType(p_child, p_isConst, p_dimensions, p_maxElements, boost::ref(*this))); +} + + +boost::shared_ptr + TypeManager::CreateStructType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) +{ + return boost::shared_ptr( + new StructType(p_name, p_externName, p_members, p_isConst, boost::ref(*this))); +} + + +boost::shared_ptr +TypeManager::CreateObjectType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) +{ + return boost::shared_ptr( + new ObjectType(p_name, p_externName, p_members, p_isConst, boost::ref(*this))); +} + + +boost::shared_ptr +TypeManager::CreateStateMachineType(const std::string& p_name, + const CompoundType::Member* p_members, + size_t p_numMembers, + boost::weak_ptr p_expr) +{ + const size_t memSize = sizeof(StateMachineType) + + sizeof(CompoundType::Member) * (std::max(p_numMembers, (size_t) 1ULL) - 1); + char* mem = NULL; + + try + { + mem = new char[memSize]; + return boost::shared_ptr(new (mem) StateMachineType(*this, + p_name, + p_members, + p_numMembers, + p_expr), + ByteArrayDeleter); + } + catch (...) + { + delete[] mem; + throw; + } +} + + +boost::shared_ptr +TypeManager::CreateFunctionType(const TypeImpl& p_returnType, + const TypeImpl* const* p_params, + size_t p_numParams) +{ + const size_t memSize = sizeof(FunctionType) + + sizeof(TypeImpl) * (std::max(p_numParams, (size_t) 1ULL) - 1); + char* mem = NULL; + + try + { + mem = new char[memSize]; + return boost::shared_ptr(new (mem) FunctionType(*this, + p_returnType, + p_params, + p_numParams), + ByteArrayDeleter); + } + catch (...) + { + delete[] mem; + throw; + } +} + + +const TypeImpl& +TypeManager::GetChildType(const TypeImpl& p_type) +{ + const TypeImpl* childType = &p_type; + switch (childType->Primitive()) + { + case Type::Struct: + { + childType = &GetStructType(static_cast(*childType)); + break; + } + + case Type::Array: + { + childType = &GetArrayType(static_cast(*childType)); + break; + } + + case Type::Object: + { + childType = &GetObjectType(static_cast(*childType)); + break; + } + + case Type::StateMachine: + { + childType = &GetStateMachineType(static_cast(*childType)); + break; + } + + default: + { + FF2_ASSERT(childType == &TypeImpl::GetCommonType(childType->Primitive(), childType->IsConst())); + break; + } + } + return *childType; +} + + +boost::shared_ptr +TypeManager::CopyArrayType(const ArrayType& p_arrayType) +{ + const TypeImpl& childType = GetChildType(p_arrayType.GetChildType()); + + if (p_arrayType.IsFixedSize()) + { + return CreateArrayType(childType, + p_arrayType.IsConst(), + p_arrayType.GetDimensionCount(), + p_arrayType.GetDimensions(), + p_arrayType.GetMaxElements()); + } + else + { + return CreateArrayType(childType, + p_arrayType.IsConst(), + p_arrayType.GetDimensionCount(), + p_arrayType.GetMaxElements()); + } +} + + +boost::shared_ptr +TypeManager::CopyStructType(const StructType& p_structType) +{ + std::vector members(p_structType.GetMembers()); + + BOOST_FOREACH(StructType::MemberInfo& info, members) + { + FF2_ASSERT(info.m_type != NULL); + info.m_type = &GetChildType(*info.m_type); + } + + return CreateStructType(p_structType.GetName(), + p_structType.GetExternName(), + members, + p_structType.IsConst()); +} + + +boost::shared_ptr +TypeManager::CopyObjectType(const ObjectType& p_objectType) +{ + std::vector members; + for (std::map::const_iterator memberIterator = p_objectType.m_members.begin(); + memberIterator != p_objectType.m_members.end(); + ++memberIterator) + { + members.push_back(memberIterator->second); + } + + return CreateObjectType(p_objectType.GetName(), + p_objectType.GetExternName(), + members, + p_objectType.IsConst()); +} + + +boost::shared_ptr +TypeManager::CopyStateMachineType(const StateMachineType& p_type) +{ + std::vector members(p_type.BeginMembers(), p_type.EndMembers()); + + BOOST_FOREACH(CompoundType::Member& member, members) + { + member.m_type = &GetChildType(*member.m_type); + } + + return CreateStateMachineType(p_type.GetName(), + &members[0], + members.size(), + p_type.GetDefinition()); +} + + +boost::shared_ptr +TypeManager::CopyFunctionType(const FunctionType& p_type) +{ + std::vector params; + + for (UInt32 i = 0; i < p_type.GetParameterCount(); i++) + { + const TypeImpl& param = *p_type.BeginParameters()[i]; + params.push_back(&GetChildType(param)); + } + + return CreateFunctionType(GetChildType(p_type.GetReturnType()), + params.size() > 0 ? ¶ms[0] : nullptr, + params.size()); +} + + +const TypeManager* +TypeManager::GetParent() const +{ + return m_parent; +} + + +const ArrayType& +NamedTypeManager::GetArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + unsigned int p_maxElements) +{ + return GetArrayType(p_child, p_isConst, p_dimensions, NULL, p_maxElements); +} + + +const ArrayType& +NamedTypeManager::GetArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + const unsigned int p_elementCounts[], + unsigned int p_maxElements) +{ + std::string signature + = ArrayType::GetName(p_child, p_isConst, p_dimensions, p_elementCounts, p_maxElements); + + const TypeImpl* type = GetTypeInfo(signature); + if (type != NULL) + { + FF2_ASSERT(type->IsConst() == p_isConst); + FF2_ASSERT(type->Primitive() == Type::Array); + const ArrayType& arrayType = static_cast(*type); + return arrayType; + } + else + { + if (p_elementCounts != NULL) + { + return RegisterType( + CreateArrayType(p_child, p_isConst, p_dimensions, p_elementCounts, p_maxElements)); + } + else + { + return RegisterType(CreateArrayType(p_child, p_isConst, p_dimensions, p_maxElements)); + } + } +} + + +const ArrayType& +NamedTypeManager::GetArrayType(const ArrayType& p_type) +{ + const TypeImpl* type = GetTypeInfo(p_type.GetName()); + if (type != NULL) + { + FF2_ASSERT(type->Primitive() == Type::Array); + return static_cast(*type); + } + else + { + return RegisterType(CopyArrayType(p_type)); + } +} + + +const StructType& +NamedTypeManager::GetStructType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) +{ + FF2_ASSERT(p_name.find(' ') == std::string::npos); + + std::string name; + name.reserve(8 + p_name.size()); + if (!p_isConst) + { + name.assign("mutable "); + } + name.append(p_name); + + const TypeImpl* type = GetTypeInfo(name); + if (type != NULL) + { + FF2_ASSERT(type->IsConst() == p_isConst); + FF2_ASSERT(type->Primitive() == Type::Struct); + const StructType& structType = static_cast(*type); + FF2_ASSERT(structType.GetExternName() == p_externName); + return structType; + } + else + { + return RegisterType(CreateStructType(p_name, p_externName, p_members, p_isConst), name); + } +} + + +const StructType& +NamedTypeManager::GetStructType(const StructType& p_type) +{ + FF2_ASSERT(p_type.GetName().find(' ') == std::string::npos); + + std::string name; + name.reserve(8 + p_type.GetName().size()); + if (!p_type.IsConst()) + { + name.assign("mutable "); + } + name.append(p_type.GetName()); + + const TypeImpl* type = GetTypeInfo(name); + if (type != NULL) + { + FF2_ASSERT(type->Primitive() == Type::Struct); + return static_cast(*type); + } + else + { + return RegisterType(CopyStructType(p_type), name); + } +} + + +const ObjectType& +NamedTypeManager::GetObjectType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) +{ + FF2_ASSERT(p_name.find(' ') == std::string::npos); + + const TypeImpl* type = GetTypeInfo(p_name); + if (type != NULL) + { + FF2_ASSERT(type->IsConst() == p_isConst); + FF2_ASSERT(type->Primitive() == Type::Object); + const ObjectType& structType = static_cast(*type); + FF2_ASSERT(structType.GetExternName() == p_externName); + return structType; + } + else + { + return RegisterType(CreateObjectType(p_name, p_externName, p_members, p_isConst), p_name); + } +} + + +const ObjectType& +NamedTypeManager::GetObjectType(const ObjectType& p_type) +{ + FF2_ASSERT(p_type.GetName().find(' ') == std::string::npos); + + std::string name = p_type.GetName(); + const TypeImpl* type = GetTypeInfo(name); + if (type != NULL) + { + FF2_ASSERT(type->Primitive() == Type::Object); + return static_cast(*type); + } + else + { + return RegisterType(CopyObjectType(p_type), name); + } +} + + +const StateMachineType& +NamedTypeManager::GetStateMachineType(const std::string& p_name, + const CompoundType::Member* p_members, + size_t p_numMembers, + boost::weak_ptr p_expr) +{ + const TypeImpl* type = GetTypeInfo(p_name); + if (type != NULL) + { + FF2_ASSERT(type->Primitive() == Type::StateMachine); + const StateMachineType& machineType = static_cast(*type); + FF2_ASSERT(machineType.GetName() == p_name); + FF2_ASSERT(machineType.GetMemberCount() == p_numMembers); + const CompoundType::Member* members = machineType.BeginMembers(); + for (size_t i = 0; i < p_numMembers; i++) + { + FF2_ASSERT(members[i].m_name == p_members[i].m_name + && *members[i].m_type == *p_members[i].m_type); + } + return machineType; + } + else + { + return RegisterType(CreateStateMachineType(p_name, p_members, p_numMembers, p_expr)); + } +} + + +const StateMachineType& +NamedTypeManager::GetStateMachineType(const StateMachineType& p_type) +{ + const TypeImpl* type = GetTypeInfo(p_type.GetName()); + if (type != NULL) + { + FF2_ASSERT(*type == p_type); + const StateMachineType& machineType = static_cast(*type); + return machineType; + } + else + { + return RegisterType(CopyStateMachineType(p_type)); + } +} + + +const FunctionType& +NamedTypeManager::GetFunctionType(const TypeImpl& p_returnType, + const TypeImpl* const* p_params, + size_t p_numParams) +{ + std::string name + = FunctionType::GetName(p_returnType, p_params, p_numParams); + const TypeImpl* type = GetTypeInfo(name); + if (type != NULL) + { + FF2_ASSERT(type->Primitive() == Type::Function); + const FunctionType& functionType = static_cast(*type); + FF2_ASSERT(functionType.GetReturnType() == p_returnType); + FF2_ASSERT(functionType.GetParameterCount() == p_numParams); + + for (size_t i = 0; i < p_numParams; i++) + { + FF2_ASSERT(*functionType.BeginParameters()[i] == *p_params[i]); + } + return functionType; + } + else + { + return RegisterType(CreateFunctionType(p_returnType, p_params, p_numParams)); + } +} + + +const FunctionType& +NamedTypeManager::GetFunctionType(const FunctionType& p_type) +{ + const TypeImpl* type = GetTypeInfo(p_type.GetName()); + if (type != NULL) + { + FF2_ASSERT(*type == p_type); + const FunctionType& functionType = static_cast(*type); + return functionType; + } + else + { + return RegisterType(CopyFunctionType(p_type)); + } +} + + +const FreeForm2::TypeManager& +FreeForm2::TypeManager::GetGlobalTypeManager() +{ + static const NamedTypeManager s_instance; + + return s_instance; +} + + +std::auto_ptr +FreeForm2::TypeManager::CreateTypeManager() +{ + return std::auto_ptr(new NamedTypeManager(TypeManager::GetGlobalTypeManager())); +} + + +std::auto_ptr +FreeForm2::TypeManager::CreateTypeManager(const TypeManager& p_parent) +{ + return std::auto_ptr(new NamedTypeManager(p_parent)); +} + + +std::auto_ptr +FreeForm2::TypeManager::CreateTypeManager(const ExternalDataManager& p_parent) +{ + return std::auto_ptr(new NamedTypeManager(p_parent.m_typeFactory->GetTypeManager())); +} + + +AnonymousTypeManager::AnonymousTypeManager() + : TypeManager(NULL) +{ +} + + +const TypeImpl* +AnonymousTypeManager::GetTypeInfo(const std::string& p_name) const +{ + return NULL; +} + + +const ArrayType& +AnonymousTypeManager::GetArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + const unsigned int p_elementCounts[], + unsigned int p_maxElements) +{ + boost::shared_ptr type( + CreateArrayType(p_child, p_isConst, p_dimensions, p_elementCounts, p_maxElements)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} + + +const ArrayType& +AnonymousTypeManager::GetArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + unsigned int p_maxElements) +{ + boost::shared_ptr type( + CreateArrayType(p_child, p_isConst, p_dimensions, p_maxElements)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} + + +const ArrayType& +AnonymousTypeManager::GetArrayType(const ArrayType& p_arrayType) +{ + boost::shared_ptr type(CopyArrayType(p_arrayType)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} + + +const StructType& +AnonymousTypeManager::GetStructType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) +{ + boost::shared_ptr type( + CreateStructType(p_name, p_externName, p_members, p_isConst)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} + + +const StructType& +AnonymousTypeManager::GetStructType(const StructType& p_type) +{ + boost::shared_ptr type(CopyStructType(p_type)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} + + +const ObjectType& +AnonymousTypeManager::GetObjectType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) +{ + boost::shared_ptr type( + CreateObjectType(p_name, p_externName, p_members, p_isConst)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} + + +const ObjectType& +AnonymousTypeManager::GetObjectType(const ObjectType& p_type) +{ + boost::shared_ptr type(CopyObjectType(p_type)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} + + +const StateMachineType& +AnonymousTypeManager::GetStateMachineType(const std::string& p_name, + const CompoundType::Member* p_members, + size_t p_numMembers, + boost::weak_ptr p_expr) +{ + boost::shared_ptr type( + CreateStateMachineType(p_name, p_members, p_numMembers, p_expr)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} + + +const StateMachineType& +AnonymousTypeManager::GetStateMachineType(const StateMachineType& p_type) +{ + boost::shared_ptr type(CopyStateMachineType(p_type)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} + + +const FunctionType& +AnonymousTypeManager::GetFunctionType(const TypeImpl& p_returnType, + const TypeImpl* const* p_parameters, + size_t p_numParameters) +{ + boost::shared_ptr type( + CreateFunctionType(p_returnType, p_parameters, p_numParameters)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} + + +const FunctionType& +AnonymousTypeManager::GetFunctionType(const FunctionType& p_type) +{ + boost::shared_ptr type(CopyFunctionType(p_type)); + m_types.push_back(boost::static_pointer_cast(type)); + return *type; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeManager.h b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeManager.h new file mode 100644 index 000000000000..6888d942b4eb --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Shared/TypeManager.h @@ -0,0 +1,315 @@ +#pragma once + +#ifndef FREEFORM2_TYPEMANAGER_H +#define FREEFORM2_TYPEMANAGER_H + +#include +#include +#include +#include "CompoundType.h" +#include "FunctionType.h" +#include +#include "ObjectType.h" +#include +#include "StructType.h" + +namespace FreeForm2 +{ + class ArrayType; + class ExternalDataManager; + class StateMachineExpression; + class StateMachineType; + class TypeImpl; + + // The TypeManager class acts as both an owner and a factory for TypeImpl + // objects. Type managers are implemented as a list, where each type + // manager has a parent (with the exception of the global type manager). + // If a type manager cannot find information about a type, it should query + // its parent type manager. + class TypeManager : boost::noncopyable + { + public: + // Create a TypeManager with an optional parent. + explicit TypeManager(const TypeManager* p_parent); + + // Unimplemented destructor. + virtual ~TypeManager(); + + // Gets the static instance of the TypeManager. This function is not thread-safe. + static const TypeManager& GetGlobalTypeManager(); + + // Create a type manager with the default implementation. + static std::auto_ptr CreateTypeManager(); + + // Create a type manager with the given parent. + static std::auto_ptr CreateTypeManager(const TypeManager& p_parent); + + // Create a type manager with the given parent. + static std::auto_ptr CreateTypeManager(const ExternalDataManager& p_parent); + + // Gets the type information for the provided name. Returns NULL if the + // type is not found. This method should check this type manager as + // well as its parent type manager. + virtual const TypeImpl* GetTypeInfo(const std::string& p_name) const = 0; + + // Get a variable sized array type owned by this TypeManager. + virtual + const ArrayType& + GetArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + unsigned int p_maxElements) = 0; + + // Get a fixed-sized array type owned by this TypeManager. + virtual + const ArrayType& + GetArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + const unsigned int p_elementCounts[], + unsigned int p_maxElements) = 0; + + // Returns an array type owned by this TypeManager which has the same + // properties as another ArrayType. + virtual + const ArrayType& + GetArrayType(const ArrayType& p_type) = 0; + + // Get a struct type owned by this TypeManager. The TypeManager is not + // required to allow multiple non-unique names exist in the context of + // its owned types. + virtual + const StructType& + GetStructType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) = 0; + + // Returns a struct type owned by this TypeManager which has the same + // properties as another StructType. + virtual + const StructType& + GetStructType(const StructType& p_type) = 0; + + // Get an object type owned by this TypeManager. The TypeManager is not + // required to allow multiple non-unique names exist in the context of + // its owned types. + virtual + const ObjectType& + GetObjectType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) = 0; + + // Returns an object type owned by this TypeManager which has the same + // properties as another StructType. + virtual + const ObjectType& + GetObjectType(const ObjectType& p_type) = 0; + + // Returns a state machine type owned by this TypeManager. The type + // manager is not required to allow multiple non-unique names to exist + // in the context of its owned types. + virtual + const StateMachineType& + GetStateMachineType(const std::string& p_name, + const CompoundType::Member* p_members, + size_t p_numMembers, + boost::weak_ptr p_expr) = 0; + + // Returns a state machine type owned by this TypeManager which has the + // same properties as another state machine type. + virtual + const StateMachineType& + GetStateMachineType(const StateMachineType& p_type) = 0; + + // Get a function type owned by this TypeManager. The TypeManager will just store one + // function type per signature. + virtual + const FunctionType& + GetFunctionType(const TypeImpl& p_returnType, + const TypeImpl* const* p_params, + size_t p_numParams) = 0; + + // Returns a function type owned by this TypeManager which has the same + // properties as another FunctionType. + virtual + const FunctionType& + GetFunctionType(const FunctionType& p_type) = 0; + + // Get the parent of this type manager. This function may return NULL. + const TypeManager* GetParent() const; + + protected: + // Create an array type. + boost::shared_ptr + CreateArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + unsigned int p_maxElements); + + // Create a fixed-sized array type. + boost::shared_ptr + CreateArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + const unsigned int p_elementCounts[], + unsigned int p_maxElements); + + // Create a deep copy of an array type. The type should not contain + // cyclical dependencies. + boost::shared_ptr + CopyArrayType(const ArrayType& p_arrayType); + + // Create a struct type. + boost::shared_ptr + CreateStructType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst); + + // Create a deep copy of a state machine type. + boost::shared_ptr + CopyStructType(const StructType& p_type); + + // Create an object type. + boost::shared_ptr + CreateObjectType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst); + + // Create a deep copy of an object type. + boost::shared_ptr + CopyObjectType(const ObjectType& p_type); + + // Create a new state machine type. + boost::shared_ptr + CreateStateMachineType(const std::string& p_name, + const CompoundType::Member* p_members, + size_t p_numMembers, + boost::weak_ptr p_expr); + + // Create a deep copy of a state machine type. + boost::shared_ptr + CopyStateMachineType(const StateMachineType& p_type); + + // Create a function type. + boost::shared_ptr + CreateFunctionType(const TypeImpl& p_returnType, + const TypeImpl* const* p_params, + size_t p_numParams); + + // Create a deep copy of a state machine type. + boost::shared_ptr + CopyFunctionType(const FunctionType& p_type); + + private: + // Get a copy of the specified type that is owned by this TypeManager. + const TypeImpl& GetChildType(const TypeImpl& p_type); + + // The parent of this type manager. + const TypeManager* m_parent; + }; + + // This class acts as a lightweight TypeManager to keep ownership of + // TypeImpl objects. This class does not provide any name lookup of types. + class AnonymousTypeManager : public TypeManager + { + public: + // Construct an anonymous type manager with no parent. + AnonymousTypeManager(); + + // The anonymous type manager does not save any + virtual const TypeImpl* GetTypeInfo(const std::string& p_name) const override; + + // Get a variable sized array type owned by this TypeManager. + virtual + const ArrayType& + GetArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + unsigned int p_maxElements) override; + + // Get a fixed-sized array type owned by this TypeManager. + virtual + const ArrayType& + GetArrayType(const TypeImpl& p_child, + bool p_isConst, + unsigned int p_dimensions, + const unsigned int p_elementCounts[], + unsigned int p_maxElements) override; + + // Returns an array type owned by this TypeManager which has the same + // properties as another ArrayType. + virtual + const ArrayType& + GetArrayType(const ArrayType& p_type) override; + + // Get a struct type owned by this TypeManager. The TypeManager is not + // required to allow multiple non-unique names exist in the context of + // its owned types. + virtual + const StructType& + GetStructType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) override; + + // Returns a struct type owned by this TypeManager which has the same + // properties as another StructType. + virtual + const StructType& + GetStructType(const StructType& p_type) override; + + // Get an object type owned by this TypeManager. The TypeManager is not + // required to allow multiple non-unique names exist in the context of + // its owned types. + virtual + const ObjectType& + GetObjectType(const std::string& p_name, + const std::string& p_externName, + const std::vector& p_members, + bool p_isConst) override; + + // Returns an object type owned by this TypeManager which has the same + // properties as another ObjectType. + virtual + const ObjectType& + GetObjectType(const ObjectType& p_type) override; + + // Returns a state machine type owned by this TypeManager. The type + // manager is not required to allow multiple non-unique names to exist + // in the context of its owned types. + virtual + const StateMachineType& + GetStateMachineType(const std::string& p_name, + const CompoundType::Member* p_members, + size_t p_numMembers, + boost::weak_ptr p_expr) override; + + // Returns a state machine type owned by this TypeManager which has the + // same properties as another state machine type. + virtual + const StateMachineType& + GetStateMachineType(const StateMachineType& p_type) override; + + // Get a function type owned by this TypeManager. + virtual + const FunctionType& + GetFunctionType(const TypeImpl& p_returnType, + const TypeImpl* const* p_params, + size_t p_numParams) override; + + // Get a function type owned by this TypeManager. + virtual + const FunctionType& + GetFunctionType(const FunctionType& p_type) override; + + private: + // A vector containing all types created by this manager. + std::vector> m_types; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/AllocationVisitor.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/AllocationVisitor.cpp new file mode 100644 index 000000000000..3225b608982b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/AllocationVisitor.cpp @@ -0,0 +1,165 @@ +#include "AllocationVisitor.h" + +#include "Allocation.h" +#include "ArrayLiteralExpression.h" +#include "Declaration.h" +#include "Extern.h" +#include "FeatureSpec.h" +#include "FreeForm2Assert.h" +#include "LiteralExpression.h" +#include "StateMachine.h" + +FreeForm2::AllocationVisitor::AllocationVisitor(const Expression& p_exp) +{ + p_exp.Accept(*this); +} + + +const FreeForm2::AllocationVisitor::AllocationVector& +FreeForm2::AllocationVisitor::GetAllocations() const +{ + return m_allocations; +} + + +void +FreeForm2::AllocationVisitor::Visit(const ArrayLiteralExpression& p_expr) +{ + if (m_allocationIds.find(p_expr.GetId()) == m_allocationIds.end()) + { + m_allocations.push_back(boost::shared_ptr( + new Allocation(Allocation::ArrayLiteral, + p_expr.GetId(), + p_expr.GetType()))); + m_allocationIds.insert(p_expr.GetId()); + } +} + + +void +FreeForm2::AllocationVisitor::Visit(const LiteralStreamExpression& p_expr) +{ + if (m_allocationIds.find(p_expr.GetId()) == m_allocationIds.end()) + { + m_allocations.push_back(boost::shared_ptr( + new Allocation(Allocation::LiteralStream, + p_expr.GetId(), + p_expr.GetType(), + p_expr.GetNumChildren()))); + m_allocationIds.insert(p_expr.GetId()); + } +} + + +void +FreeForm2::AllocationVisitor::Visit(const LiteralWordExpression& p_expr) +{ + if (m_allocationIds.find(p_expr.GetId()) == m_allocationIds.end()) + { + m_allocations.push_back(boost::shared_ptr( + new Allocation(Allocation::LiteralWord, + p_expr.GetId(), + p_expr.GetType()))); + m_allocationIds.insert(p_expr.GetId()); + } +} + + +void +FreeForm2::AllocationVisitor::Visit(const DeclarationExpression& p_expr) +{ + if (m_allocationIds.find(p_expr.GetId()) == m_allocationIds.end()) + { + m_allocations.push_back(boost::shared_ptr( + new Allocation(Allocation::Declaration, + p_expr.GetId(), + p_expr.GetDeclaredType()))); + m_allocationIds.insert(p_expr.GetId()); + } +} + + +void +FreeForm2::AllocationVisitor::Visit(const ExternExpression& p_expr) +{ + if (p_expr.GetType().Primitive() == Type::Array) + { + if (m_allocationIds.find(p_expr.GetId()) == m_allocationIds.end()) + { + m_allocations.push_back(boost::shared_ptr( + new Allocation(Allocation::ExternArray, + p_expr.GetId(), + p_expr.GetType()))); + m_allocationIds.insert(p_expr.GetId()); + } + } +} + + +void +FreeForm2::AllocationVisitor::Visit(const ImportFeatureExpression& p_expr) +{ + if (p_expr.GetType().Primitive() == Type::Array) + { + m_allocations.push_back(boost::shared_ptr( + new Allocation(Allocation::FeatureArray, + p_expr.GetId(), + p_expr.GetType()))); + } + else + { + m_allocations.push_back(boost::shared_ptr( + new Allocation(Allocation::Declaration, + p_expr.GetId(), + p_expr.GetType()))); + } + m_allocationIds.insert(p_expr.GetId()); +} + + +void +FreeForm2::AllocationVisitor::Visit(const StateExpression& p_expr) +{ + // Don't actually add any allocations for state expressions, but visit + // its actions and leaving actions for transitions, in case they allocate + // something. + for (auto& action : p_expr.m_actions) + { + action.m_action->Accept(*this); + } + + for (auto& transition : p_expr.m_transitions) + { + if (transition.m_leavingAction) + { + transition.m_leavingAction->Accept(*this); + } + } +} + + +void +FreeForm2::AllocationVisitor::Visit(const ExecuteMachineExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetMachine().GetType().Primitive() == Type::StateMachine); + + const StateMachineType& machineType = dynamic_cast(p_expr.GetMachine().GetType()); + + machineType.GetDefinition()->Accept(*this); +} + + +bool +FreeForm2::AllocationVisitor::AlternativeVisit(const FunctionExpression& p_expr) +{ + // Allocations should not cross function boundaries. + return true; +} + + +bool +FreeForm2::AllocationVisitor::AlternativeVisit(const ExecuteStreamRewritingStateMachineGroupExpression& p_expr) +{ + // Allocations should not cross stream rewriting machine group boundaries. + return true; +} \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/AllocationVisitor.h b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/AllocationVisitor.h new file mode 100644 index 000000000000..d4d4b6e3432d --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/AllocationVisitor.h @@ -0,0 +1,73 @@ +#pragma once + +#ifndef FREEFORM2_ARRAY_ALLOCATION_VISITOR_H +#define FREEFORM2_ARRAY_ALLOCATION_VISITOR_H + +#include +#include "Expression.h" +#include "NoOpVisitor.h" +#include +#include + +namespace FreeForm2 +{ + class AllocationVisitor : public NoOpVisitor + { + public: + typedef std::vector> AllocationVector; + + // Create a new instance. + AllocationVisitor(const Expression& p_exp); + + // Return the list of allocations for this expression tree. + const AllocationVector& GetAllocations() const; + + // Visits the ArrayLiteralExpression and adds an Array Allocation + // to the list. + virtual void Visit(const ArrayLiteralExpression& p_expr) override; + + // Visits the LiteralStreamExpression and adds an LiteralStreamAllocationExpression + // to the list. + virtual void Visit(const LiteralStreamExpression& p_expr) override; + + // Visits the LiteralWordExpression and adds an LiteralWord Allocation + // to the list. + virtual void Visit(const LiteralWordExpression& p_expr) override; + + // Visits the DeclarationExpression and adds a + // Declaration Allocation to the list. + virtual void Visit(const DeclarationExpression& p_expr) override; + + // Visits the ExternExpression and adds a Declaration Allocation + // and an ArrayLiteralAllocationExpression if applicable. + virtual void Visit(const ExternExpression& p_expr) override; + + // Visits the ImportFeatureExpression and adds either an array literal + // or declaration allocation depending on the type. + virtual void Visit(const ImportFeatureExpression& p_expr) override; + + // Visits the ExecuteMachineExpression and StateExpression so the allocations + // done within the actions or transitions of a state machine can also be + // considered for allocations. + virtual void Visit(const StateExpression& p_expr) override; + virtual void Visit(const ExecuteMachineExpression& p_expr) override; + + // Skips the visitation of FunctionExpressions. + virtual bool AlternativeVisit(const FunctionExpression& p_expr) override; + + // Skip the visitation of ExecuteStreamRewritingStateMachineGroupExpressions. + virtual bool AlternativeVisit(const ExecuteStreamRewritingStateMachineGroupExpression& p_expr) override; + + private: + // Holds all the array allocation expressions. + AllocationVector m_allocations; + + // A set of allocation IDs created to avoid duplicated allocations. + // Duplcates are possible in, for example, array-based range reduce + // expressions, in which the array dereference and array length in the + // loop reference the same ArrayLiteralExpression. + std::set m_allocationIds; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/CMakeLists.txt b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/CMakeLists.txt new file mode 100644 index 000000000000..f3beeb95ae59 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/CMakeLists.txt @@ -0,0 +1,34 @@ +cmake_minimum_required(VERSION 3.15) + +set(PROJECT_NAME DRFreeFormTransformLibrary) + + +Project(${PROJECT_NAME}) + +SET(CMAKE_CXX_FLAGS "-std=c++11 ${CMAKE_CXX_FLAGS} -fpermissive") + + + +add_library(${PROJECT_NAME} STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/AllocationVisitor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/CopyingVisitor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/FunctionInlineVisitor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ObjectResolutionVisitor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/OperandPromotionVisitor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ProcessFeaturesUsed.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/TypeCheckingVisitor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/UniformExpressionVisitor.cpp +) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../../inc + ${CMAKE_CURRENT_SOURCE_DIR}/../../../NeuralTree.Library/inc + ${CMAKE_CURRENT_SOURCE_DIR}/../Shared + ${CMAKE_CURRENT_SOURCE_DIR}/../Expression + ) + +install(TARGETS ${PROJECT_NAME} + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + ) \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/CopyingVisitor.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/CopyingVisitor.cpp new file mode 100644 index 000000000000..091be8c44c07 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/CopyingVisitor.cpp @@ -0,0 +1,1581 @@ +#include "CopyingVisitor.h" + +#include "SimpleExpressionOwner.h" + +#include +#include +#include +#include +#include +#include "ArrayDereferenceExpression.h" +#include "ArrayLength.h" +#include "ArrayLiteralExpression.h" +#include "BlockExpression.h" +#include "Conditional.h" +#include "ConvertExpression.h" +#include "DebugExpression.h" +#include "Declaration.h" +#include "Expression.h" +#include "Extern.h" +#include "FeatureSpec.h" +#include "FreeForm2Assert.h" +#include "Function.h" +#include "LetExpression.h" +#include "LiteralExpression.h" +#include +#include "Match.h" +#include "MemberAccessExpression.h" +#include "Mutation.h" +#include "ObjectType.h" +#include "OperatorExpression.h" +#include "PhiNode.h" +#include "Publish.h" +#include "RangeReduceExpression.h" +#include "RandExpression.h" +#include "RefExpression.h" +#include "SelectNth.h" +#include "StateMachine.h" +#include "StateMachineType.h" +#include "StreamData.h" +#include +#include "TypeManager.h" +#include "TypeUtil.h" + + +FreeForm2::CopyingVisitor::CopyingVisitor() + : m_owner(boost::make_shared()), + m_typeManager(TypeManager::CreateTypeManager().release()) +{ +} + + +FreeForm2::CopyingVisitor::CopyingVisitor(const boost::shared_ptr& p_owner, + const boost::shared_ptr& p_typeManager) + : m_owner(p_owner), m_typeManager(p_typeManager) +{ +} + + +boost::shared_ptr +FreeForm2::CopyingVisitor::GetExpressionOwner() const +{ + return m_owner; +} + + +boost::shared_ptr +FreeForm2::CopyingVisitor::GetTypeManager() const +{ + return m_typeManager; +} + + +const FreeForm2::Expression* +FreeForm2::CopyingVisitor::GetSyntaxTree() const +{ + FF2_ASSERT(m_stack.size() == 1); + return m_stack.back(); +} + + +std::vector& +FreeForm2::CopyingVisitor::GetStack() +{ + return m_stack; +} + + +void +FreeForm2::CopyingVisitor::AddExpression( + const boost::shared_ptr& p_expr) +{ + m_owner->AddExpression(p_expr); + m_stack.push_back(p_expr.get()); +} + + +void +FreeForm2::CopyingVisitor::AddExpressionToOwner( + const boost::shared_ptr& p_expr) +{ + m_owner->AddExpression(p_expr); +} + + +void +FreeForm2::CopyingVisitor::Visit(const SelectNthExpression& p_expr) +{ + std::vector children(p_expr.GetNumChildren()); + + children[0] = m_stack.back(); + m_stack.pop_back(); + + for (unsigned int i = 0; i < p_expr.GetNumChildren() - 1; i++) + { + // Children are pushed on the stack in the reverse order from what + // SelectNthExpression::Alloc expects. + children[p_expr.GetNumChildren() - i - 1] = m_stack.back(); + m_stack.pop_back(); + } + + AddExpression(SelectNthExpression::Alloc(p_expr.GetAnnotations(), children)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const SelectRangeExpression& p_expr) +{ + const Expression& arrayExp = *m_stack.back(); + m_stack.pop_back(); + + const Expression& count = *m_stack.back(); + m_stack.pop_back(); + + const Expression& start = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + start, + count, + arrayExp, + *m_typeManager)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ConditionalExpression& p_expr) +{ + const Expression& conditionExpression = *m_stack.back(); + m_stack.pop_back(); + const Expression& thenExpression = *m_stack.back(); + m_stack.pop_back(); + const Expression& elseExpression = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + conditionExpression, + thenExpression, + elseExpression)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ArrayLiteralExpression& p_expr) +{ + std::vector children; + children.reserve(p_expr.GetNumChildren()); + + for (size_t i = 0; i < p_expr.GetNumChildren(); i++) + { + children.push_back(m_stack.back()); + m_stack.pop_back(); + } + + FF2_ASSERT(p_expr.GetType().Primitive() == Type::Array); + const ArrayType& exprType = static_cast(CopyType(p_expr.GetType())); + + AddExpression(ArrayLiteralExpression::Alloc(p_expr.GetAnnotations(), + exprType, + children, + p_expr.GetId())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const LetExpression& p_expr) +{ + std::vector children(p_expr.GetNumChildren() - 1); + const Expression& value = *m_stack.back(); + m_stack.pop_back(); + + for (size_t i = 0; i < p_expr.GetNumChildren() - 1; i++) + { + const size_t index = p_expr.GetNumChildren() - i - 2; + children[index] = std::make_pair(p_expr.GetBound()[index].first, m_stack.back()); + m_stack.pop_back(); + } + + AddExpression(LetExpression::Alloc(p_expr.GetAnnotations(), children, &value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const BlockExpression& p_expr) +{ + std::vector children(p_expr.GetNumChildren()); + + for (size_t i = 0; i < p_expr.GetNumChildren(); i++) + { + const size_t index = p_expr.GetNumChildren() - i - 1; + children[index] = m_stack.back(); + m_stack.pop_back(); + } + + AddExpression(BlockExpression::Alloc(p_expr.GetAnnotations(), + &children[0], + static_cast(p_expr.GetNumChildren()), + p_expr.GetNumBound())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const BinaryOperatorExpression& p_expr) +{ + const size_t numChildren = p_expr.GetNumChildren(); + std::vector children(numChildren); + + for (size_t i = 0; i < numChildren; i++) + { + children[numChildren - i - 1] = m_stack.back(); + m_stack.pop_back(); + } + + AddExpression(BinaryOperatorExpression::Alloc(p_expr.GetAnnotations(), + children, + p_expr.GetOperator(), + *m_typeManager)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const UnaryOperatorExpression& p_expr) +{ + const Expression& child = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), child, p_expr.m_op)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const RangeReduceExpression& p_expr) +{ + const Expression& reduce = *m_stack.back(); + m_stack.pop_back(); + const Expression& low = *m_stack.back(); + m_stack.pop_back(); + const Expression& high = *m_stack.back(); + m_stack.pop_back(); + const Expression& initial = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + low, + high, + initial, + reduce, + p_expr.GetStepId(), + p_expr.GetReduceId())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ForEachLoopExpression& p_expr) +{ + const Expression& body = *m_stack.back(); + m_stack.pop_back(); + const Expression& next = *m_stack.back(); + m_stack.pop_back(); + const Expression& end = *m_stack.back(); + m_stack.pop_back(); + const Expression& begin = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + std::make_pair(&begin, &end), + next, + body, + p_expr.GetIteratorId(), + p_expr.GetVersion(), + p_expr.GetHint(), + boost::ref(*m_typeManager))); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ComplexRangeLoopExpression& p_expr) +{ + const Expression& loopCondition = *m_stack.back(); + m_stack.pop_back(); + const Expression& body = *m_stack.back(); + m_stack.pop_back(); + const Expression& step = *m_stack.back(); + m_stack.pop_back(); + const Expression& high = *m_stack.back(); + m_stack.pop_back(); + const Expression& low = *m_stack.back(); + m_stack.pop_back(); + const Expression& precondition = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + std::make_pair(&low, &high), + step, + body, + precondition, + loopCondition, + CopyType(p_expr.GetStepType()), + p_expr.GetStepId(), + p_expr.GetVersion())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const MutationExpression& p_expr) +{ + const Expression* right = m_stack.back(); + m_stack.pop_back(); + const Expression* left = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *left, *right)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const MatchExpression& p_expr) +{ + const Expression* action = m_stack.back(); + m_stack.pop_back(); + const MatchSubExpression* pattern + = boost::polymorphic_downcast(m_stack.back()); + m_stack.pop_back(); + const Expression* value = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *value, *pattern, *action, p_expr.IsOverlapping())); +} + + +// bool +// FreeForm2::CopyingVisitor::AlternativeVisit(const MatchWordExpression& p_expr) +// { +// FSM::WordConstraint constraint = CopyWordConstraint(p_expr.GetConstraint()); + +// const size_t numEffects = p_expr.GetNumEffects(); +// std::vector children(numEffects); + +// for (size_t i = 0; i < numEffects; i++) +// { +// p_expr.BeginEffects()[i]->Accept(*this); +// children[i] +// = boost::polymorphic_downcast(m_stack.back()); +// m_stack.pop_back(); +// } + +// AddExpression(MatchWordExpression::Alloc(p_expr.GetAnnotations(), +// constraint, +// children.size(), +// children.empty() ? NULL : &children[0])); +// return true; +// } + + +// void +// FreeForm2::CopyingVisitor::Visit(const MatchWordExpression& p_expr) +// { +// // Should be handled by AlternativeVisit, above. +// FF2_ASSERT(false); +// } + + +// void +// FreeForm2::CopyingVisitor::Visit(const MatchLiteralExpression& p_expr) +// { +// FF2_ASSERT(p_expr.GetNumChildren() == 1); +// const Expression* child = m_stack.back(); +// m_stack.pop_back(); + +// AddExpression(boost::make_shared(p_expr.GetAnnotations(), *child, p_expr.m_int)); +// } + + +// void +// FreeForm2::CopyingVisitor::Visit(const MatchCurrentWordExpression& p_expr) +// { +// FF2_ASSERT(p_expr.GetNumChildren() == 0); +// AddExpression(boost::make_shared(p_expr.GetAnnotations(), p_expr.m_offset, p_expr.m_matchType)); +// } + + +void +FreeForm2::CopyingVisitor::Visit(const MatchOperatorExpression& p_expr) +{ + const size_t numChildren = p_expr.GetNumChildren(); + std::vector children(numChildren); + FF2_ASSERT(numChildren > 0); + + for (size_t i = 0; i < numChildren; i++) + { + children[numChildren - i - 1] + = boost::polymorphic_downcast(m_stack.back()); + m_stack.pop_back(); + } + + AddExpression(MatchOperatorExpression::Alloc(p_expr.GetAnnotations(), + &children[0], + children.size(), + p_expr.GetOperator())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const MatchGuardExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetNumChildren() == 1); + const Expression* guard = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *guard)); +} + + + +void +FreeForm2::CopyingVisitor::Visit(const MatchBindExpression& p_expr) +{ + const MatchSubExpression* init + = boost::polymorphic_downcast(m_stack.back()); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *init, p_expr.m_id)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const MemberAccessExpression& p_expr) +{ + const Expression& container = *m_stack.back(); + m_stack.pop_back(); + + // Look up the CompoundType::Member struct in the next expression type. + // Becase types are being copied, the member from p_expr.GetType() will not + // be the same as the one from the expression at the top of the stack. + FF2_ASSERT(CompoundType::IsCompoundType(container.GetType())); + const CompoundType& type = static_cast(container.GetType()); + const CompoundType::Member* member = type.FindMember(p_expr.GetMemberInfo().m_name); + FF2_ASSERT(member != NULL); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), container, *member, p_expr.GetVersion())); +} + + +//void +//FreeForm2::CopyingVisitor::Visit(const NeuralInputResultExpression& p_expr) +//{ +// FF2_ASSERT(p_expr.GetNumChildren() == 1); +// const Expression* child = m_stack.back(); +// m_stack.pop_back(); +// +// AddExpression(boost::make_shared(p_expr.GetAnnotations(), p_expr.m_index, *child)); +//} + + +//void +//FreeForm2::CopyingVisitor::Visit(const ObjectMethodExpression& p_expr) +//{ +// FF2_ASSERT(p_expr.GetNumChildren() == 1 + p_expr.m_numParameters); +// +// const Expression* child = m_stack.back(); +// m_stack.pop_back(); +// +// if (child->GetType().Primitive() == Type::Object) +// { +// std::vector parameters(p_expr.m_numParameters); +// for (size_t i = 0; i < p_expr.m_numParameters; i++) +// { +// parameters[p_expr.m_numParameters - i - 1] = m_stack.back(); +// m_stack.pop_back(); +// } +// +// // Look up the CompoundType::Member in the next expression type. +// // Because types are being copied, the member from p_expr.GetType() will not +// // be the same as the one from the expression at the top of the stack. +// FF2_ASSERT(p_expr.GetType() != child->GetType()); +// FF2_ASSERT(CompoundType::IsCompoundType(child->GetType())); +// const CompoundType& type = static_cast(child->GetType()); +// const CompoundType::Member* member = type.FindMember(p_expr.m_member->m_name); +// FF2_ASSERT(member != NULL); +// AddExpression(ObjectMethodExpression::Alloc(p_expr.GetAnnotations(), *child, *member, parameters)); +// } +// else +// { +// AddExpression(ObjectMethodExpression::Alloc(p_expr.GetAnnotations(), *child, p_expr.m_method)); +// } +//} + + +void +FreeForm2::CopyingVisitor::Visit(const LiteralIntExpression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), p_expr.GetConstantValue().m_int)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const LiteralUInt64Expression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), p_expr.GetConstantValue().m_uint64)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const LiteralInt32Expression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), p_expr.GetConstantValue().m_int32)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const LiteralUInt32Expression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), p_expr.GetConstantValue().m_uint32)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ArrayLengthExpression& p_expr) +{ + const Expression* arrayLiteral = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *arrayLiteral)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ArrayDereferenceExpression& p_expr) +{ + const Expression* index = m_stack.back(); + m_stack.pop_back(); + const Expression* arrayExpression = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + *arrayExpression, + *index, + p_expr.GetVersion())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ConvertToFloatExpression& p_expr) +{ + const Expression* value = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ConvertToIntExpression& p_expr) +{ + const Expression* value = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ConvertToUInt64Expression& p_expr) +{ + const Expression* value = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ConvertToInt32Expression& p_expr) +{ + const Expression* value = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ConvertToUInt32Expression& p_expr) +{ + const Expression* value = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ConvertToBoolExpression& p_expr) +{ + const Expression* value = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ConvertToImperativeExpression& p_expr) +{ + const Expression* value = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), *value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const DeclarationExpression& p_expr) +{ + const Expression& init = *m_stack.back(); + m_stack.pop_back(); + + boost::shared_ptr expr( + boost::make_shared(p_expr.GetAnnotations(), + CopyType(p_expr.GetDeclaredType()), + init, + p_expr.HasVoidValue(), + p_expr.GetId(), + p_expr.GetVersion())); + AddExpression(expr); +} + + +void +FreeForm2::CopyingVisitor::Visit(const DirectPublishExpression& p_expr) +{ + const Expression& value = *m_stack.back(); + m_stack.pop_back(); + + const size_t numIndices = p_expr.GetNumIndices(); + std::vector indices(numIndices); + + for (size_t i = 0; i < numIndices; i++) + { + indices[numIndices - i - 1] = m_stack.back(); + m_stack.pop_back(); + } + + AddExpression(DirectPublishExpression::Alloc(p_expr.GetAnnotations(), + p_expr.GetFeatureName(), + &indices[0], + static_cast(indices.size()), + value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ExternExpression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + p_expr.GetData(), + CopyType(p_expr.GetType()), + p_expr.GetId(), + boost::ref(*m_typeManager))); +} + + +void +FreeForm2::CopyingVisitor::Visit(const FunctionExpression& p_expr) +{ + const size_t numParameters = p_expr.GetNumParameters(); + std::vector parameters(numParameters); + + const Expression* body = m_stack.back(); + m_stack.pop_back(); + + for (size_t i = 0; i < numParameters; i++) + { + parameters[numParameters - i - 1] + = p_expr.GetParameters()[numParameters - i - 1]; + parameters[numParameters - i - 1].m_parameter + = static_cast(m_stack.back()); + m_stack.pop_back(); + } + + boost::shared_ptr expr(new FunctionExpression(p_expr.GetAnnotations(), + static_cast(CopyType(p_expr.GetFunctionType())), + p_expr.GetName(), + parameters, + *body)); + AddExpression(expr); +} + + +bool +FreeForm2::CopyingVisitor::AlternativeVisit(const FunctionCallExpression& p_expr) +{ + const size_t numParameters = p_expr.GetNumParameters(); + std::vector parameters(numParameters); + + for (size_t i = 0; i < numParameters; i++) + { + p_expr.GetParameters()[i]->Accept(*this); + parameters[i] = m_stack.back(); + m_stack.pop_back(); + } + + const Expression* function; + const FunctionExpression* functionExpression = dynamic_cast(&p_expr.GetFunction()); + + if (functionExpression != nullptr) + { + if (m_functionTranslation.find(functionExpression) == m_functionTranslation.end()) + { + p_expr.GetFunction().Accept(*this); + m_functionTranslation.insert(std::make_pair(functionExpression, static_cast(m_stack.back()))); + m_stack.pop_back(); + } + + function = m_functionTranslation[functionExpression]; + } + else + { + p_expr.GetFunction().Accept(*this); + function = m_stack.back(); + m_stack.pop_back(); + } + + AddExpression(FunctionCallExpression::Alloc(p_expr.GetAnnotations(), + *function, + parameters)); + + return true; +} + + +void +FreeForm2::CopyingVisitor::Visit(const FunctionCallExpression& p_expr) +{ + // Handled in AlternativeVisit. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::CopyingVisitor::Visit(const LiteralFloatExpression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + p_expr.GetConstantValue().m_float)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const LiteralBoolExpression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + p_expr.GetConstantValue().m_bool)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const LiteralVoidExpression& p_expr) +{ + m_stack.push_back(&LiteralVoidExpression::GetInstance()); +} + + +void +FreeForm2::CopyingVisitor::Visit(const LiteralStreamExpression& p_expr) +{ + const size_t numChildren = p_expr.GetNumChildren(); + std::vector children(numChildren); + + for (size_t i = 0; i < numChildren; i++) + { + children[numChildren - i - 1] = m_stack.back(); + m_stack.pop_back(); + } + + AddExpression(LiteralStreamExpression::Alloc(p_expr.GetAnnotations(), + &children[0], + numChildren, + p_expr.GetId())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const LiteralWordExpression& p_expr) +{ + const Expression* word = NULL; + const Expression* offset = NULL; + const Expression* attribute = NULL; + const Expression* length = NULL; + const Expression* candidate = NULL; + + if (p_expr.m_candidate != NULL) + { + candidate = m_stack.back(); + m_stack.pop_back(); + } + + if (p_expr.m_length != NULL) + { + length = m_stack.back(); + m_stack.pop_back(); + } + + if (p_expr.m_attribute != NULL) + { + attribute = m_stack.back(); + m_stack.pop_back(); + } + + offset = m_stack.back(); + m_stack.pop_back(); + word = m_stack.back(); + m_stack.pop_back(); + + if (p_expr.m_isHeader) + { + FF2_ASSERT(attribute == NULL && length == NULL && candidate == NULL); + AddExpression( + boost::make_shared(p_expr.GetAnnotations(), + *word, + *offset, + p_expr.GetId())); + } + else + { + AddExpression( + boost::make_shared(p_expr.GetAnnotations(), + *word, + *offset, + attribute, + length, + candidate, + p_expr.GetId())); + } +} + + +void +FreeForm2::CopyingVisitor::Visit(const LiteralInstanceHeaderExpression& p_expr) +{ + const Expression* instanceLength = m_stack.back(); + m_stack.pop_back(); + const Expression* rank = m_stack.back(); + m_stack.pop_back(); + const Expression* instanceCount = m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + *instanceCount, + *rank, + *instanceLength)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const FeatureRefExpression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + p_expr.m_index)); +} + + +//bool +//FreeForm2::CopyingVisitor::AlternativeVisit(const FSMExpression& p_expr) +//{ +// size_t stackSize = m_stack.size(); +// // CopyFSM copy(*this, p_expr); +// // p_expr.Accept(copy); +// // m_stack.push_back(©.GetCopy(*m_owner)); +// FF2_ASSERT(m_stack.size() == stackSize + 1); +// return true; +//} +// +// +//void +//FreeForm2::CopyingVisitor::Visit(const FSMExpression& p_expr) +//{ +// // We handle FSMExpressions via AlternativeVisit. +// Unreachable(__FILE__, __LINE__); +//} + + +void +FreeForm2::CopyingVisitor::Visit(const FeatureSpecExpression& p_expr) +{ + const Expression& body = *m_stack.back(); + m_stack.pop_back(); + + boost::shared_ptr featureMapCopy = + boost::make_shared(); + + BOOST_FOREACH (const FeatureSpecExpression::PublishFeatureMap::value_type& featureNameToType, *p_expr.GetPublishFeatureMap()) + { + featureMapCopy->insert(FeatureSpecExpression::PublishFeatureMap::value_type(featureNameToType.first, + CopyType(featureNameToType.second))); + } + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + featureMapCopy, + body, + p_expr.GetFeatureSpecType(), + p_expr.GetType().Primitive() != Type::Void)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const FeatureGroupSpecExpression& p_expr) +{ + std::vector featureSpecs; + + for (int i = 0; i < p_expr.GetFeatureSpecs().size(); ++i) + { + featureSpecs.insert(featureSpecs.begin(), boost::polymorphic_downcast(m_stack.back())); + m_stack.pop_back(); + } + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + p_expr.GetName(), + featureSpecs, + p_expr.IsExtendedExperimental(), + p_expr.IsSmallExperimental(), + p_expr.IsBlockLevelFeature(), + p_expr.IsBodyBlockFeature(), + p_expr.IsForwardIndexFeature(), + p_expr.GetMetaStreamName())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const PhiNodeExpression& p_expr) +{ + AddExpression(PhiNodeExpression::Alloc(p_expr.GetAnnotations(), + p_expr.GetVersion(), + p_expr.GetIncomingVersionsCount(), + p_expr.GetIncomingVersions())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const PublishExpression& p_expr) +{ + const Expression& value = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + p_expr.GetFeatureName(), value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ReturnExpression& p_expr) +{ + const Expression& value = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + value)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const StreamDataExpression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + p_expr.m_requestsLength)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const UpdateStreamDataExpression& p_expr) +{ + m_stack.push_back(&UpdateStreamDataExpression::GetInstance()); +} + + +void +FreeForm2::CopyingVisitor::Visit(const VariableRefExpression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + p_expr.GetId(), + p_expr.GetVersion(), + CopyType(p_expr.GetType()))); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ImportFeatureExpression& p_expr) +{ + if (p_expr.GetType().Primitive() == Type::Array) + { + const ArrayType& type = static_cast(p_expr.GetType()); + const std::vector dimensions(type.GetDimensions(), + type.GetDimensions() + type.GetDimensionCount()); + AddExpression( + boost::make_shared(p_expr.GetAnnotations(), + p_expr.GetFeatureName(), + dimensions, + p_expr.GetId(), + boost::ref(*m_typeManager))); + } + else + { + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + p_expr.GetFeatureName(), + p_expr.GetId())); + } +} + + +void +FreeForm2::CopyingVisitor::Visit(const StateExpression& p_expr) +{ + boost::shared_ptr expr(new StateExpression(p_expr.GetAnnotations())); + m_owner->AddExpression(expr); + + expr->m_id = p_expr.m_id; + + // Copy state actions. + { + std::list::const_iterator iter = p_expr.m_actions.begin(); + for (; iter != p_expr.m_actions.end(); ++iter) + { + iter->m_action->Accept(*this); + StateExpression::Action a = { iter->m_matchType, m_stack.back() }; + m_stack.pop_back(); + expr->m_actions.push_back(a); + } + } + + // Copy state transitions. + { + std::list::const_iterator iter = p_expr.m_transitions.begin(); + for (; iter != p_expr.m_transitions.end(); ++iter) + { + StateExpression::Transition t = { iter->m_matchType, NULL, iter->m_destinationId, NULL }; + if (iter->m_condition) + { + iter->m_condition->Accept(*this); + t.m_condition = m_stack.back(); + m_stack.pop_back(); + } + + if (iter->m_leavingAction) + { + iter->m_leavingAction->Accept(*this); + t.m_leavingAction = m_stack.back(); + m_stack.pop_back(); + } + + expr->m_transitions.push_back(t); + } + } + + m_stack.push_back(expr.get()); +} + + +bool +FreeForm2::CopyingVisitor::AlternativeVisit(const StateMachineExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType().Primitive() == Type::StateMachine); + const TypeImpl& copiedType = CopyType(p_expr.GetType()); + const StateMachineType& type = static_cast(copiedType); + + // The state machine has already been copied; do not re-copy. + if (type.HasDefinition()) + { + boost::shared_ptr expr = type.GetDefinition(); + FF2_ASSERT(expr.get() != nullptr); + + m_stack.push_back(expr.get()); + } + else + { + FF2_ASSERT(type.IsSameAs(type, false)); + p_expr.GetInitializer().Accept(*this); + const TypeInitializerExpression& init + = *boost::polymorphic_downcast(m_stack.back()); + m_stack.pop_back(); + + const size_t numStates = p_expr.GetNumChildren() - 1; + std::vector states; + states.reserve(numStates); + for (size_t i = 0; i < numStates; i++) + { + p_expr.GetChildren()[i]->Accept(*this); + states.push_back(boost::polymorphic_downcast(m_stack.back())); + m_stack.pop_back(); + } + + boost::shared_ptr expr( + StateMachineExpression::Alloc(p_expr.GetAnnotations(), + type, + init, + states.size() > 0 ? &states[0] : NULL, + states.size(), + p_expr.GetStartStateId())); + AddExpression(expr); + FF2_ASSERT(type.HasDefinition()); + } + return true; +} + + +void +FreeForm2::CopyingVisitor::Visit(const StateMachineExpression& p_expr) +{ + // Handled by AlternativeVisit. + FF2_UNREACHABLE(); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ExecuteStreamRewritingStateMachineGroupExpression& p_expr) +{ + std::vector machineInstances(p_expr.GetNumMachineInstances()); + + for (size_t i = 0; i < p_expr.GetNumMachineInstances(); ++i) + { + size_t index = p_expr.GetNumMachineInstances() - i - 1; + machineInstances[index].m_machineExpression = boost::polymorphic_downcast(m_stack.back()); + m_stack.pop_back(); + machineInstances[index].m_machineDeclaration = boost::polymorphic_downcast(m_stack.back()); + m_stack.pop_back(); + } + + const Expression* duplicateTermInformation = nullptr; + + if (p_expr.GetDuplicateTermInformation() != nullptr) + { + p_expr.GetDuplicateTermInformation()->Accept(*this); + duplicateTermInformation = m_stack.back(); + m_stack.pop_back(); + } + + const Expression* numQueryPaths = nullptr; + + if (p_expr.GetNumQueryPaths() != nullptr) + { + p_expr.GetNumQueryPaths()->Accept(*this); + numQueryPaths = m_stack.back(); + m_stack.pop_back(); + } + + const Expression* queryPathCandidates = nullptr; + + if (p_expr.GetQueryPathCandidates() != nullptr) + { + p_expr.GetQueryPathCandidates()->Accept(*this); + queryPathCandidates = m_stack.back(); + m_stack.pop_back(); + } + + const Expression* queryLength = nullptr; + + if (p_expr.GetQueryLength() != nullptr) + { + p_expr.GetQueryLength()->Accept(*this); + queryLength = m_stack.back(); + m_stack.pop_back(); + } + + const Expression* tupleOfInterestCount = nullptr; + + if (p_expr.GetTupleOfInterestCount() != nullptr) + { + p_expr.GetTupleOfInterestCount()->Accept(*this); + tupleOfInterestCount = m_stack.back(); + m_stack.pop_back(); + } + + const Expression* tuplesOfInterest = nullptr; + + if (p_expr.GetTuplesOfInterest() != nullptr) + { + p_expr.GetTuplesOfInterest()->Accept(*this); + tuplesOfInterest = m_stack.back(); + m_stack.pop_back(); + } + + AddExpression(ExecuteStreamRewritingStateMachineGroupExpression::Alloc(p_expr.GetAnnotations(), + &machineInstances[0], + static_cast(machineInstances.size()), + p_expr.GetNumBound(), + p_expr.GetMachineIndexID(), + p_expr.GetMachineArraySize(), + p_expr.GetStreamRewritingType(), + duplicateTermInformation, + numQueryPaths, + queryPathCandidates, + queryLength, + tupleOfInterestCount, + tuplesOfInterest, + p_expr.IsNearChunk(), + p_expr.GetMinChunkNumber())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ExecuteMachineExpression& p_expr) +{ + std::vector> yieldActions(p_expr.GetNumYieldActions()); + + for (size_t i = 0; i < p_expr.GetNumYieldActions(); ++i) + { + size_t index = p_expr.GetNumYieldActions() - i - 1; + yieldActions[index].first = p_expr.GetYieldActions()[index].first; + yieldActions[index].second = m_stack.back(); + m_stack.pop_back(); + } + + const Expression& machine = *m_stack.back(); + m_stack.pop_back(); + + const Expression& stream = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(ExecuteMachineExpression::Alloc(p_expr.GetAnnotations(), + stream, + machine, + yieldActions.size() > 0 ? &yieldActions[0] : NULL, + yieldActions.size())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ExecuteMachineGroupExpression& p_expr) +{ + std::vector machineInstances(p_expr.GetNumMachineInstances()); + + for (size_t i = 0; i < p_expr.GetNumMachineInstances(); ++i) + { + size_t index = p_expr.GetNumMachineInstances() - i - 1; + machineInstances[index].m_machineExpression = boost::polymorphic_downcast(m_stack.back()); + m_stack.pop_back(); + machineInstances[index].m_machineDeclaration = m_stack.back(); + m_stack.pop_back(); + } + + AddExpression(ExecuteMachineGroupExpression::Alloc(p_expr.GetAnnotations(), + &machineInstances[0], + static_cast(machineInstances.size()), + p_expr.GetNumBound())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const YieldExpression& p_expr) +{ + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + p_expr.GetMachineName(), + p_expr.GetName())); +} + + +void +FreeForm2::CopyingVisitor::Visit(const RandFloatExpression& p_expr) +{ + m_stack.push_back(&RandFloatExpression::GetInstance()); +} + + +void +FreeForm2::CopyingVisitor::Visit(const RandIntExpression& p_expr) +{ + const Expression& upperBound = *m_stack.back(); + m_stack.pop_back(); + const Expression& lowerBound = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + lowerBound, + upperBound)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const ThisExpression& p_expr) +{ + FF2_ASSERT(CompoundType::IsCompoundType(p_expr.GetType()) + || p_expr.GetType().Primitive() == Type::Unknown); + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + CopyType(p_expr.GetType()))); +} + + +void +FreeForm2::CopyingVisitor::Visit(const UnresolvedAccessExpression& p_expr) +{ + const Expression& object = *m_stack.back(); + m_stack.pop_back(); + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + object, + p_expr.GetMemberName(), + CopyType(p_expr.GetType()))); +} + + +bool +FreeForm2::CopyingVisitor::AlternativeVisit(const TypeInitializerExpression& p_expr) +{ + FF2_ASSERT(CompoundType::IsCompoundType(p_expr.GetType())); + const CompoundType& type = static_cast(CopyType(p_expr.GetType())); + + if (type.Primitive() == Type::StateMachine) + { + const StateMachineType& stateMachine = static_cast(type); + if (stateMachine.HasDefinition()) + { + const StateMachineExpression& expr = *stateMachine.GetDefinition(); + m_stack.push_back(&expr.GetInitializer()); + return true; + } + } + + std::vector inits(p_expr.BeginInitializers(), + p_expr.EndInitializers()); + + BOOST_FOREACH(TypeInitializerExpression::Initializer& init, inits) + { + const CompoundType::Member* member = type.FindMember(init.m_member->m_name); + FF2_ASSERT(member != NULL); + init.m_member = member; + + init.m_initializer->Accept(*this); + init.m_initializer = m_stack.back(); + m_stack.pop_back(); + } + + AddExpression(TypeInitializerExpression::Alloc(p_expr.GetAnnotations(), + type, + inits.size() > 0 ? &inits[0] : NULL, + inits.size())); + return true; +} + + +void +FreeForm2::CopyingVisitor::Visit(const TypeInitializerExpression& p_expr) +{ + // Handled by AlternativeVisit. + Unreachable(__FILE__, __LINE__); +} + + +void +FreeForm2::CopyingVisitor::Visit(const AggregateContextExpression& p_expr) +{ + const Expression& body = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), body)); +} + + +void +FreeForm2::CopyingVisitor::Visit(const DebugExpression& p_expr) +{ + const Expression& child = *m_stack.back(); + m_stack.pop_back(); + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), child, p_expr.GetChildText())); +} + + +void +FreeForm2::CopyingVisitor::VisitReference(const ArrayDereferenceExpression& p_expr) +{ + Visit(p_expr); +} + + +void +FreeForm2::CopyingVisitor::VisitReference(const VariableRefExpression& p_expr) +{ + Visit(p_expr); +} + + +void +FreeForm2::CopyingVisitor::VisitReference(const MemberAccessExpression& p_expr) +{ + Visit(p_expr); +} + + +void +FreeForm2::CopyingVisitor::VisitReference(const ThisExpression& p_expr) +{ + Visit(p_expr); +} + + +void +FreeForm2::CopyingVisitor::VisitReference(const UnresolvedAccessExpression& p_expr) +{ + Visit(p_expr); +} + + +size_t +FreeForm2::CopyingVisitor::StackSize() const +{ + return m_stack.size(); +} + + +size_t +FreeForm2::CopyingVisitor::StackIncrement() const +{ + return 1; +} + + +const FreeForm2::TypeImpl& +FreeForm2::CopyingVisitor::CopyType(const TypeImpl& p_type) +{ + switch (p_type.Primitive()) + { + case Type::Float: + { + FF2_ASSERT(&p_type == &TypeImpl::GetFloatInstance(p_type.IsConst())); + return p_type; + } + + case Type::Int: + { + FF2_ASSERT(&p_type == &TypeImpl::GetIntInstance(p_type.IsConst())); + return p_type; + } + + case Type::UInt64: + { + FF2_ASSERT(&p_type == &TypeImpl::GetUInt64Instance(p_type.IsConst())); + return p_type; + } + + case Type::Int32: + { + FF2_ASSERT(&p_type == &TypeImpl::GetInt32Instance(p_type.IsConst())); + return p_type; + } + + case Type::UInt32: + { + FF2_ASSERT(&p_type == &TypeImpl::GetUInt32Instance(p_type.IsConst())); + return p_type; + } + + case Type::Bool: + { + FF2_ASSERT(&p_type == &TypeImpl::GetBoolInstance(p_type.IsConst())); + return p_type; + } + + case Type::Array: + { + const ArrayType& type = static_cast(p_type); + return m_typeManager->GetArrayType(type); + } + + case Type::Struct: + { + const StructType& type = static_cast(p_type); + return m_typeManager->GetStructType(type); + } + + case Type::Object: + { + const ObjectType& type = static_cast(p_type); + return m_typeManager->GetObjectType(type); + } + + case Type::Function: + { + const FunctionType& type = static_cast(p_type); + return m_typeManager->GetFunctionType(type); + } + + case Type::StateMachine: + { + // Check if the type has already been copied. + const StateMachineType& type = static_cast(p_type); + const TypeImpl* copied = m_typeManager->GetTypeInfo(p_type.GetName()); + if (copied != NULL) + { + FF2_ASSERT(copied->Primitive() == Type::StateMachine); + return *copied; + } + else + { + // Copy the type without the implementing expression. + std::vector members(type.BeginMembers(), type.EndMembers()); + BOOST_FOREACH (CompoundType::Member& member, members) + { + // Prevent self-reference. + FF2_ASSERT(!member.m_type->IsSameAs(type, true)); + + member.m_type = &CopyType(*member.m_type); + } + const StateMachineType& copiedType + = m_typeManager->GetStateMachineType(type.GetName(), + members.size() > 0 ? &members[0] : NULL, + members.size(), + boost::weak_ptr()); + + // Copy the StateMachineExpression, which should set the definition. + type.GetDefinition()->Accept(*this); + FF2_ASSERT(&m_stack.back()->GetType() == &copiedType); + FF2_ASSERT(copiedType.GetDefinition().get() != NULL); + m_stack.pop_back(); + return copiedType; + } + } + + case Type::Void: + { + FF2_ASSERT(&p_type == &TypeImpl::GetVoidInstance()); + return p_type; + } + + case Type::Stream: + { + FF2_ASSERT(&p_type == &TypeImpl::GetStreamInstance(p_type.IsConst())); + return p_type; + } + + case Type::Word: + { + FF2_ASSERT(&p_type == &TypeImpl::GetWordInstance(p_type.IsConst())); + return p_type; + } + + case Type::InstanceHeader: + { + FF2_ASSERT(&p_type == &TypeImpl::GetInstanceHeaderInstance(p_type.IsConst())); + return p_type; + } + + case Type::BodyBlockHeader: + { + FF2_ASSERT(&p_type == &TypeImpl::GetBodyBlockHeaderInstance(p_type.IsConst())); + return p_type; + } + + case Type::Unknown: + { + FF2_ASSERT(&p_type == &TypeImpl::GetUnknownType()); + return p_type; + } + + case Type::Invalid: + { + FF2_ASSERT(&p_type == &TypeImpl::GetInvalidType()); + return p_type; + } + + default: + { + Unreachable(__FILE__, __LINE__); + } + } +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/CopyingVisitor.h b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/CopyingVisitor.h new file mode 100644 index 000000000000..d3841c18034d --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/CopyingVisitor.h @@ -0,0 +1,159 @@ +#pragma once + +#ifndef FREEFORM2_COPYING_VISITOR_H +#define FREEFORM2_COPYING_VISITOR_H + +#include +#include +#include "Expression.h" +#include +#include "Visitor.h" +// #include "WordConstraint.h" +#include + +namespace FreeForm2 +{ + class AllocationExpression; + class SimpleExpressionOwner; + class ExpressionOwner; + + class CopyingVisitor : public Visitor + { + public: + // Create a new instance. + CopyingVisitor(); + CopyingVisitor(const boost::shared_ptr& p_owner, + const boost::shared_ptr& p_typeManager); + + // Get the ExpressionOwner used by the class to own the + // new Expression nodes. + boost::shared_ptr GetExpressionOwner() const; + + // Get the TypeManager used by the class to own new Types. + boost::shared_ptr GetTypeManager() const; + + // Add an expression to the expression owner. + void AddExpressionToOwner(const boost::shared_ptr& p_expr); + + // Add an expression to the owner and the back of the stack. + void AddExpression(const boost::shared_ptr& p_expr); + + // Return the new syntax tree created by this class. + const Expression* GetSyntaxTree() const; + + // Get the stack used by this visitor to create copies of the expression tree. + std::vector& GetStack(); + + // Method inherited from Visitor. + virtual void Visit(const SelectNthExpression& p_expr) override; + virtual void Visit(const SelectRangeExpression& p_expr) override; + virtual void Visit(const ConditionalExpression& p_expr) override; + virtual void Visit(const ArrayLiteralExpression& p_expr) override; + virtual void Visit(const LetExpression& p_expr) override; + virtual void Visit(const BlockExpression& p_expr) override; + virtual void Visit(const BinaryOperatorExpression& p_expr) override; + virtual void Visit(const RangeReduceExpression& p_expr) override; + virtual void Visit(const ForEachLoopExpression& p_expr) override; + virtual void Visit(const ComplexRangeLoopExpression& p_expr) override; + virtual void Visit(const MutationExpression& p_expr) override; + virtual void Visit(const MatchExpression& p_expr) override; + // virtual bool AlternativeVisit(const MatchWordExpression& p_expr) override; + // virtual void Visit(const MatchWordExpression& p_expr) override; + // virtual void Visit(const MatchLiteralExpression& p_expr) override; + // virtual void Visit(const MatchCurrentWordExpression& p_expr) override; + virtual void Visit(const MatchOperatorExpression& p_expr) override; + virtual void Visit(const MatchGuardExpression& p_expr) override; + virtual void Visit(const MatchBindExpression& p_expr) override; + virtual void Visit(const MemberAccessExpression& p_expr) override; + // virtual void Visit(const NeuralInputResultExpression& p_expr) override; + // virtual void Visit(const ObjectMethodExpression& p_expr) override; + virtual void Visit(const ArrayLengthExpression& p_expr) override; + virtual void Visit(const ArrayDereferenceExpression& p_expr) override; + virtual void Visit(const ConvertToFloatExpression& p_expr) override; + virtual void Visit(const ConvertToIntExpression& p_expr) override; + virtual void Visit(const ConvertToUInt64Expression& p_expr) override; + virtual void Visit(const ConvertToInt32Expression& p_expr) override; + virtual void Visit(const ConvertToUInt32Expression& p_expr) override; + virtual void Visit(const ConvertToBoolExpression& p_expr) override; + virtual void Visit(const ConvertToImperativeExpression& p_expr) override; + virtual void Visit(const DeclarationExpression& p_expr) override; + virtual void Visit(const DirectPublishExpression& p_expr) override; + virtual void Visit(const ExternExpression& p_expr) override; + virtual void Visit(const FunctionExpression& p_expr) override; + virtual bool AlternativeVisit(const FunctionCallExpression& p_expr) override; + virtual void Visit(const FunctionCallExpression& p_expr) override; + virtual void Visit(const LiteralIntExpression& p_expr) override; + virtual void Visit(const LiteralUInt64Expression& p_expr) override; + virtual void Visit(const LiteralInt32Expression& p_expr) override; + virtual void Visit(const LiteralUInt32Expression& p_expr) override; + virtual void Visit(const LiteralFloatExpression& p_expr) override; + virtual void Visit(const LiteralBoolExpression& p_expr) override; + virtual void Visit(const LiteralVoidExpression& p_expr) override; + virtual void Visit(const LiteralStreamExpression& p_expr) override; + virtual void Visit(const LiteralWordExpression& p_expr) override; + virtual void Visit(const LiteralInstanceHeaderExpression& p_expr) override; + virtual void Visit(const FeatureRefExpression& p_expr) override; + // virtual bool AlternativeVisit(const FSMExpression& p_expr) override; + // virtual void Visit(const FSMExpression& p_expr) override; + virtual void Visit(const UnaryOperatorExpression& p_expr) override; + virtual void Visit(const FeatureSpecExpression& p_expr) override; + virtual void Visit(const FeatureGroupSpecExpression& p_expr) override; + virtual void Visit(const PhiNodeExpression& p_expr) override; + virtual void Visit(const PublishExpression& p_expr) override; + virtual void Visit(const ReturnExpression& p_expr) override; + virtual void Visit(const StreamDataExpression& p_expr) override; + virtual void Visit(const UpdateStreamDataExpression& p_expr) override; + virtual void Visit(const VariableRefExpression& p_expr) override; + virtual void Visit(const ImportFeatureExpression& p_expr) override; + virtual void Visit(const StateExpression& p_expr) override; + virtual bool AlternativeVisit(const StateMachineExpression& p_expr) override; + virtual void Visit(const StateMachineExpression& p_expr) override; + virtual void Visit(const ExecuteStreamRewritingStateMachineGroupExpression& p_expr) override; + virtual void Visit(const ExecuteMachineExpression& p_expr) override; + virtual void Visit(const ExecuteMachineGroupExpression& p_expr) override; + virtual void Visit(const YieldExpression& p_expr) override; + virtual void Visit(const RandFloatExpression& p_expr) override; + virtual void Visit(const RandIntExpression& p_expr) override; + virtual void Visit(const ThisExpression& p_expr) override; + virtual void Visit(const UnresolvedAccessExpression& p_expr) override; + virtual bool AlternativeVisit(const TypeInitializerExpression& p_expr) override; + virtual void Visit(const TypeInitializerExpression& p_expr) override; + // virtual void Visit(const DocumentContextExpression& p_expr) override; + // virtual void Visit(const DocumentAggregationCacheExpression& p_expr) override; + virtual void Visit(const AggregateContextExpression& p_expr) override; + // virtual void Visit(const DocumentListExpression& p_expr) override; + virtual void Visit(const DebugExpression& p_expr) override; + + virtual void VisitReference(const ArrayDereferenceExpression& p_expr) override; + virtual void VisitReference(const VariableRefExpression& p_expr) override; + virtual void VisitReference(const MemberAccessExpression& p_expr) override; + virtual void VisitReference(const ThisExpression& p_expr) override; + virtual void VisitReference(const UnresolvedAccessExpression& p_expr) override; + + virtual size_t StackSize() const override; + virtual size_t StackIncrement() const override; + + // Copy a word constraint, including all Expressions within the + // constraint. + // FSM::WordConstraint CopyWordConstraint(const FSM::WordConstraint& p_constraint); + + protected: + + // Copy a type so that it is owned by this copy's TypeManager. + const TypeImpl& CopyType(const TypeImpl& p_type); + + // Expression owner for the new AST. + boost::shared_ptr m_owner; + + // The type manager used with both the new and old AST. + boost::shared_ptr m_typeManager; + + // Temporary expression stack. + std::vector m_stack; + + // A map of pointers to FunctionExpressions in the old tree and the new one. + std::map m_functionTranslation; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/FunctionInlineVisitor.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/FunctionInlineVisitor.cpp new file mode 100644 index 000000000000..46f786f7532e --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/FunctionInlineVisitor.cpp @@ -0,0 +1,133 @@ +#include "FunctionInlineVisitor.h" + +#include +#include +#include "LetExpression.h" +#include "Function.h" +#include "FunctionType.h" +#include "FreeForm2Assert.h" +#include "RefExpression.h" +#include "TypeImpl.h" +#include "TypeUtil.h" +#include + +FreeForm2::FunctionInlineVisitor::FunctionInlineVisitor(const boost::shared_ptr& p_owner, + const boost::shared_ptr& p_typeManager, + VariableID p_variableId) + : CopyingVisitor(p_owner, p_typeManager), + m_variableId(p_variableId) +{ +} + + +FreeForm2::VariableID +FreeForm2::FunctionInlineVisitor::GetVariableId() +{ + return m_variableId; +} + + +bool +FreeForm2::FunctionInlineVisitor::AlternativeVisit(const FunctionCallExpression& p_expr) +{ + const FunctionType& funType = p_expr.GetFunctionType(); + FF2_ASSERT(funType.GetParameterCount() == p_expr.GetNumParameters()); + + std::vector letValues; + letValues.reserve(funType.GetParameterCount()); + + std::vector parameterExpressions; + parameterExpressions.reserve(funType.GetParameterCount()); + + const FunctionExpression& functionExpression = dynamic_cast(p_expr.GetFunction()); + const std::vector& functionParams = functionExpression.GetParameters(); + FF2_ASSERT(functionParams.size() == funType.GetParameterCount()); + + for (size_t i = 0; i < funType.GetParameterCount(); i++) + { + // Visit the parameter to make sure that its type has been determined. + // (It could be a FunctionCallExpression for instance) + p_expr.GetParameters()[i]->Accept(*this); + const Expression* parameter = m_stack.back(); + m_stack.pop_back(); + + const TypeImpl* formalType = funType.BeginParameters()[i]; + FF2_ASSERT(formalType != nullptr && parameter != nullptr); + const TypeImpl& paramType = parameter->GetType(); + FF2_ASSERT(paramType.Primitive() != Type::Unknown); + + // Try to assign the type of the parameter to the type of the Function input. + FF2_ASSERT(TypeUtil::IsAssignable(*formalType, paramType)); + + if (formalType->Primitive() == Type::Unknown) + { + formalType = ¶mType; + } + else if (*formalType != paramType) + { + FF2_ASSERT(TypeUtil::IsConvertible(paramType, *formalType)); + auto expr = TypeUtil::Convert(*parameter, formalType->Primitive()); + AddExpressionToOwner(expr); + parameter = expr.get(); + } + m_parameterTypeTranslation.insert(std::make_pair(functionParams[i].m_parameter->GetId(), ¶mType)); + parameterExpressions.push_back(parameter); + } + + // Determine new variable ids. + FF2_ASSERT(m_newVariableIdMapping.empty()); + for (size_t i = 0; i < funType.GetParameterCount(); i++) + { + VariableID newVariableID = m_variableId; + ++m_variableId.m_value; + m_newVariableIdMapping.insert(std::make_pair(functionParams[i].m_parameter->GetId(), newVariableID)); + letValues.push_back(std::make_pair(newVariableID, parameterExpressions[i])); + } + + // Then visit the Function Body. This should replace unknown types with known types. + // And old variable ids with the new ids. + functionExpression.GetBody().Accept(*this); + const Expression* newFunctionBody = m_stack.back(); + m_stack.pop_back(); + + // Clear out the variable id mapping. + m_newVariableIdMapping.clear(); + + // Add a Let expression containing the parameter values in the Invoke statement + // and the function's body. + AddExpression(LetExpression::Alloc(p_expr.GetAnnotations(), + letValues, + newFunctionBody)); + + + return true; +} + + +void +FreeForm2::FunctionInlineVisitor::Visit(const VariableRefExpression& p_expr) +{ + const auto find = m_parameterTypeTranslation.find(p_expr.GetId()); + if (find != m_parameterTypeTranslation.end()) + { + const auto newVariableIdFind = m_newVariableIdMapping.find(p_expr.GetId()); + FF2_ASSERT(newVariableIdFind != m_newVariableIdMapping.end()); + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + newVariableIdFind->second, + p_expr.GetVersion(), + *find->second)); + } + else + { + CopyingVisitor::Visit(p_expr); + } +} + + +void +FreeForm2::FunctionInlineVisitor::Visit(const ReturnExpression& p_expr) +{ + // The FunctionInlineVisitor should never process a ReturnExpression. + FF2_UNREACHABLE(); +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/FunctionInlineVisitor.h b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/FunctionInlineVisitor.h new file mode 100644 index 000000000000..7bf723386ea8 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/FunctionInlineVisitor.h @@ -0,0 +1,38 @@ +#pragma once + +#include "CopyingVisitor.h" +#include + +namespace FreeForm2 +{ + // This class inlines function calls, inserting the function body directly + // into the resultant expression tree. This class only works with + // S-expression trees (or trees which have an implicit function return). + class FunctionInlineVisitor : public CopyingVisitor + { + public: + FunctionInlineVisitor(const boost::shared_ptr& p_owner, + const boost::shared_ptr& p_typeManager, + VariableID p_variableId); + + virtual bool AlternativeVisit(const FunctionCallExpression& p_expr) override; + virtual void Visit(const ReturnExpression& p_expr) override; + virtual void Visit(const VariableRefExpression& p_expr) override; + + // Returns the variable id counter. + VariableID GetVariableId(); + + private: + // A map containing type translations for function parameters. + std::map m_parameterTypeTranslation; + + // A map containing mappings from old variable ids to new variable ids. + // This is necessary for all variables within a lambda. + std::map m_newVariableIdMapping; + + // Counter to keep track of the next variable ID. This allows us + // to ensure that the variables assigned for each function call have + // unique ids. + VariableID m_variableId; + }; +} \ No newline at end of file diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/NoOpVisitor.h b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/NoOpVisitor.h new file mode 100644 index 000000000000..8d7b01ecd921 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/NoOpVisitor.h @@ -0,0 +1,90 @@ +#pragma once + +#ifndef FREEFORM2_NOOP_VISITOR_H +#define FREEFORM2_NOOP_VISITOR_H + +#include "Visitor.h" + +namespace FreeForm2 +{ + // Visitor that implements methods for every expression class that do nothing. + class NoOpVisitor : public Visitor + { + public: + // Method inherited from Visitor. + virtual void Allocate(const Allocation&) {} + + virtual void Visit(const SelectNthExpression&) override {} + virtual void Visit(const SelectRangeExpression&) override {} + virtual void Visit(const ConditionalExpression&) override {} + virtual void Visit(const ArrayLiteralExpression&) override {} + virtual void Visit(const LetExpression&) override {} + virtual void Visit(const BlockExpression&) override {} + virtual void Visit(const BinaryOperatorExpression&) override {} + virtual void Visit(const RangeReduceExpression&) override {} + virtual void Visit(const ForEachLoopExpression&) override {} + virtual void Visit(const ComplexRangeLoopExpression&) override {} + virtual void Visit(const MutationExpression&) override {} + virtual void Visit(const MatchExpression&) override {} + virtual void Visit(const MatchOperatorExpression&) override {} + virtual void Visit(const MatchGuardExpression&) override {} + virtual void Visit(const MatchBindExpression&) override {} + virtual void Visit(const MemberAccessExpression&) override {} + virtual void Visit(const PhiNodeExpression&) override {} + virtual void Visit(const PublishExpression&) override {} + virtual void Visit(const ReturnExpression&) override {} + virtual void Visit(const ArrayLengthExpression&) override {} + virtual void Visit(const ArrayDereferenceExpression&) override {} + virtual void Visit(const ConvertToFloatExpression&) override {} + virtual void Visit(const ConvertToIntExpression&) override {} + virtual void Visit(const ConvertToUInt64Expression&) override {} + virtual void Visit(const ConvertToInt32Expression&) override {} + virtual void Visit(const ConvertToUInt32Expression&) override {} + virtual void Visit(const ConvertToBoolExpression&) override {} + virtual void Visit(const ConvertToImperativeExpression&) override {} + virtual void Visit(const DeclarationExpression&) override {} + virtual void Visit(const DirectPublishExpression&) override {} + virtual void Visit(const ExternExpression&) override {} + virtual void Visit(const FunctionExpression&) override {} + virtual void Visit(const FunctionCallExpression&) override {} + virtual void Visit(const LiteralIntExpression&) override {} + virtual void Visit(const LiteralUInt64Expression&) override {} + virtual void Visit(const LiteralInt32Expression&) override {} + virtual void Visit(const LiteralUInt32Expression&) override {} + virtual void Visit(const LiteralFloatExpression&) override {} + virtual void Visit(const LiteralBoolExpression&) override {} + virtual void Visit(const LiteralVoidExpression&) override {} + virtual void Visit(const LiteralStreamExpression&) override {} + virtual void Visit(const LiteralWordExpression&) override {} + virtual void Visit(const LiteralInstanceHeaderExpression&) override {} + virtual void Visit(const FeatureRefExpression&) override {} + virtual void Visit(const UnaryOperatorExpression&) override {} + virtual void Visit(const FeatureSpecExpression&) override {} + virtual void Visit(const FeatureGroupSpecExpression&) override {} + virtual void Visit(const StreamDataExpression&) override {} + virtual void Visit(const UpdateStreamDataExpression&) override {} + virtual void Visit(const VariableRefExpression&) override {} + virtual void Visit(const ImportFeatureExpression&) override {} + virtual void Visit(const StateExpression&) override {} + virtual void Visit(const StateMachineExpression&) override {} + virtual void Visit(const ExecuteStreamRewritingStateMachineGroupExpression&) override {} + virtual void Visit(const ExecuteMachineExpression&) override {} + virtual void Visit(const ExecuteMachineGroupExpression&) override {} + virtual void Visit(const YieldExpression&) override {} + virtual void Visit(const RandFloatExpression&) override {} + virtual void Visit(const RandIntExpression&) override {} + virtual void Visit(const ThisExpression&) override {} + virtual void Visit(const UnresolvedAccessExpression&) override {} + virtual void Visit(const TypeInitializerExpression&) override {}; + virtual void Visit(const AggregateContextExpression&) override {}; + virtual void Visit(const DebugExpression&) override {} + + virtual void VisitReference(const ArrayDereferenceExpression&) override {} + virtual void VisitReference(const VariableRefExpression&) override {} + virtual void VisitReference(const MemberAccessExpression&) override {} + virtual void VisitReference(const ThisExpression&) override {} + virtual void VisitReference(const UnresolvedAccessExpression&) override {} + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ObjectResolutionVisitor.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ObjectResolutionVisitor.cpp new file mode 100644 index 000000000000..4ac0b9f5cf68 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ObjectResolutionVisitor.cpp @@ -0,0 +1,128 @@ +#include "ObjectResolutionVisitor.h" + +#include +#include +#include +#include "CompoundType.h" +#include "FreeForm2Assert.h" +#include "MemberAccessExpression.h" +#include "Mutation.h" +#include "RefExpression.h" +#include +#include "StateMachine.h" +#include "TypeManager.h" + +FreeForm2::ObjectResolutionVisitor::ObjectResolutionVisitor() +{ +} + + +bool +FreeForm2::ObjectResolutionVisitor::AlternativeVisit(const StateMachineExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType().Primitive() == Type::StateMachine); + const TypeImpl& copiedType = CopyType(p_expr.GetType()); + FF2_ASSERT(copiedType.IsSameAs(p_expr.GetType(), false)); + const StateMachineType& machineType = static_cast(copiedType); + + m_thisTypeStack.push(&machineType); + const bool result = CopyingVisitor::AlternativeVisit(p_expr); + FF2_ASSERT(m_thisTypeStack.top() == &machineType); + m_thisTypeStack.pop(); + + // Assert that we are correct in not calling the CopyingVisitor::Visit + // method for this expression. + FF2_ASSERT(result); + return true; +} + + +void +FreeForm2::ObjectResolutionVisitor::Visit(const ThisExpression& p_expr) +{ + if (m_thisTypeStack.empty()) + { + std::ostringstream err; + err << "Invalid this reference: not in object scope"; + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + const CompoundType& currentThisType = *m_thisTypeStack.top(); + if (p_expr.GetType().Primitive() != Type::Unknown + && !p_expr.GetType().IsSameAs(currentThisType, false)) + { + std::ostringstream err; + err << "Object types not compatible. Expected type: " + << currentThisType << "; found type: " + << p_expr.GetType(); + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + currentThisType)); +} + + +void +FreeForm2::ObjectResolutionVisitor::Visit(const UnresolvedAccessExpression& p_expr) +{ + const Expression& object = *m_stack.back(); + m_stack.pop_back(); + + if (object.GetType().Primitive() != Type::StateMachine) + { + std::ostringstream err; + err << "Unable to resolve member" << p_expr.GetMemberName() + << " on type " << object.GetType(); + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + const StateMachineType& type = static_cast(object.GetType()); + + const std::string memberName + = StateMachineExpression::GetAugmentedMemberName(type.GetName(), p_expr.GetMemberName()); + + const CompoundType::Member* member = type.FindMember(memberName); + if (member == NULL) + { + std::ostringstream err; + err << "Unable to resolve member " << p_expr.GetMemberName() + << " on type " << object.GetType(); + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + if (!member->m_type->IsSameAs(p_expr.GetType(), false) + && p_expr.GetType().Primitive() != Type::Unknown) + { + std::ostringstream err; + err << "expected member " << p_expr.GetMemberName() + << " to be type " << p_expr.GetType() + << " but encountered type " << *member->m_type; + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + AddExpression(boost::make_shared(p_expr.GetAnnotations(), + object, + *member, + 0)); +} + + +bool +FreeForm2::ObjectResolutionVisitor::AlternativeVisit(const TypeInitializerExpression& p_expr) +{ + FF2_ASSERT(p_expr.GetType().Primitive() == Type::StateMachine); + const TypeImpl& copiedType = CopyType(p_expr.GetType()); + FF2_ASSERT(copiedType.IsSameAs(p_expr.GetType(), false)); + const StateMachineType& machineType = static_cast(copiedType); + + m_thisTypeStack.push(&machineType); + const bool result = CopyingVisitor::AlternativeVisit(p_expr); + FF2_ASSERT(m_thisTypeStack.top() == &machineType); + m_thisTypeStack.pop(); + + // Assert that we are correct in not calling the CopyingVisitor::Visit + // method for this expression. + FF2_ASSERT(result); + return true; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ObjectResolutionVisitor.h b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ObjectResolutionVisitor.h new file mode 100644 index 000000000000..236fbe292a19 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ObjectResolutionVisitor.h @@ -0,0 +1,36 @@ +#pragma once +#ifndef FREEFORM2_OBJECT_RESOLUTION_VISITOR_H +#define FREEFORM2_OBJECT_RESOLUTION_VISITOR_H + +#include "CopyingVisitor.h" +#include + +namespace FreeForm2 +{ + class CompoundType; + + // The ObjectResolutionVisitor is responsible for two-pass object type + // annotating. For example with StateMachineExpressions, when parsed all + // ThisExpressions that are children of the StateMachineExpression are of + // type Unknown, and all unknown variables turn into + // UnresolveAccessExpressions. Since the type can only be created after + // parsing is complete, this information must be added in a second pass. + class ObjectResolutionVisitor : public CopyingVisitor + { + public: + // Create an ObjectResolutionVisitor using a new ExpressionOwner + // and TypeManager. + ObjectResolutionVisitor(); + + virtual bool AlternativeVisit(const StateMachineExpression& p_expr); + virtual void Visit(const ThisExpression& p_expr); + virtual void Visit(const UnresolvedAccessExpression& p_expr); + virtual bool AlternativeVisit(const TypeInitializerExpression& p_expr); + + private: + // The current type with which a ThisExpression should be annotated. + std::stack m_thisTypeStack; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/OperandPromotionVisitor.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/OperandPromotionVisitor.cpp new file mode 100644 index 000000000000..2750c69c897a --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/OperandPromotionVisitor.cpp @@ -0,0 +1,228 @@ +#include "OperandPromotionVisitor.h" + +#include "SimpleExpressionOwner.h" +#include "FreeForm2Assert.h" + +#include "Conditional.h" +#include "ConvertExpression.h" +#include "Expression.h" +#include "Function.h" +#include "FunctionType.h" +#include "OperatorExpression.h" +#include "SimpleExpressionOwner.h" +#include "TypeUtil.h" + +#include +#include +#include +#include +#include + +namespace +{ + // Unify the types of loop counter variables starting at an iterator + // position. The iterator must be a writable iterator. + template + void UnifyLoopCounters(const Iter p_iter, + unsigned int p_numCounters, + FreeForm2::TypeManager& p_typeManager, + FreeForm2::SimpleExpressionOwner& p_owner) + { + using namespace FreeForm2; + BOOST_CONCEPT_ASSERT((boost_concepts::WritableIterator)); + const TypeImpl* unifiedType = &TypeImpl::GetUnknownType(); + + Iter iter = p_iter; + for (size_t i = 0; i < p_numCounters; ++i, ++iter) + { + unifiedType = &TypeUtil::Unify(*unifiedType, (*iter)->GetType(), p_typeManager, false, true); + } + + if (!unifiedType->IsValid()) + { + std::ostringstream err; + err << "Loop bounds must be of a unifiable type."; + throw ParseError(err.str(), (*p_iter)->GetSourceLocation()); + } + + iter = p_iter; + for (size_t i = 0; i < p_numCounters; ++i, ++iter) + { + const Expression& expression = **iter; + + if (!expression.GetType().IsSameAs(*unifiedType, true)) + { + const TypeImpl& stackType = expression.GetType(); + if (TypeUtil::IsConvertible(stackType, *unifiedType)) + { + boost::shared_ptr convert( + TypeUtil::Convert(expression, unifiedType->Primitive())); + + p_owner.AddExpression(convert); + *iter = convert.get(); + } + else + { + std::ostringstream err; + err << "Expected a type convertible to " << *unifiedType + << "got type: " << expression.GetType(); + throw ParseError(err.str(), expression.GetSourceLocation()); + } + } + } + } +} + + +bool +FreeForm2::OperandPromotionVisitor::AlternativeVisit(const BinaryOperatorExpression& p_expr) +{ + // Handle via Visit: overridden to ensure Visit is called. + return false; +} + + +void +FreeForm2::OperandPromotionVisitor::Visit(const BinaryOperatorExpression& p_expr) +{ + std::vector::reverse_iterator iter = GetStack().rbegin(); + + const TypeImpl& parameterType = p_expr.GetChildType().AsConstType(); + + for (size_t i = 0; i < p_expr.GetNumChildren(); ++i, ++iter) + { + FF2_ASSERT(iter != GetStack().rend()); + const Expression& expression = **iter; + if (!expression.GetType().IsSameAs(parameterType, true)) + { + const TypeImpl& stackType = expression.GetType(); + if (TypeUtil::IsConvertible(stackType, parameterType)) + { + boost::shared_ptr convert( + TypeUtil::Convert(expression, parameterType.Primitive())); + + AddExpressionToOwner(convert); + *iter = convert.get(); + } + else + { + std::ostringstream err; + err << "Expected a type convertible to " << parameterType + << "got type: " << expression.GetType(); + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + } + } + + CopyingVisitor::Visit(p_expr); +} + + +void +FreeForm2::OperandPromotionVisitor::Visit(const FunctionCallExpression& p_expr) +{ + auto iter = GetStack().rbegin(); + + const FunctionType& type = p_expr.GetFunctionType(); + FunctionType::ParameterIterator parameterTypes = type.BeginParameters(); + + for (size_t i = 0; i < p_expr.GetNumParameters(); ++i, ++iter) + { + FF2_ASSERT(iter != GetStack().rend()); + const Expression& expression = **iter; + const TypeImpl& parameterType = *parameterTypes[p_expr.GetNumParameters() - i - 1]; + + if (!expression.GetType().IsSameAs(parameterType, true)) + { + const TypeImpl& stackType = expression.GetType(); + if (TypeUtil::IsConvertible(stackType, parameterType)) + { + boost::shared_ptr convert( + TypeUtil::Convert(expression, parameterType.Primitive())); + + AddExpressionToOwner(convert); + *iter = convert.get(); + } + else + { + std::ostringstream err; + err << "Expected a type convertible to " << parameterType + << "got type: " << expression.GetType(); + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + } + } + + CopyingVisitor::Visit(p_expr); +} + + +void +FreeForm2::OperandPromotionVisitor::Visit(const ConditionalExpression& p_expr) +{ + FF2_ASSERT(GetStack().size() >= 3); + + auto iter = GetStack().rbegin() + 1; // Skip condition expression. + const TypeImpl& unifiedType = p_expr.GetType(); + + for (size_t i = 0; i < 2; ++i, ++iter) + { + FF2_ASSERT(iter != GetStack().rend()); + const Expression& expression = **iter; + + if (!expression.GetType().IsSameAs(unifiedType, true)) + { + const TypeImpl& stackType = expression.GetType(); + if (TypeUtil::IsConvertible(stackType, unifiedType)) + { + boost::shared_ptr convert( + TypeUtil::Convert(expression, unifiedType.Primitive())); + + AddExpressionToOwner(convert); + *iter = convert.get(); + } + else + { + std::ostringstream err; + err << "Expected a type convertible to " << unifiedType + << "got type: " << expression.GetType(); + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + } + } + + CopyingVisitor::Visit(p_expr); +} + + +void +FreeForm2::OperandPromotionVisitor::Visit(const RangeReduceExpression& p_expr) +{ + FF2_ASSERT(GetStack().size() >= 4); + + UnifyLoopCounters(GetStack().rbegin() + 1, 2, *m_typeManager, *m_owner); + + CopyingVisitor::Visit(p_expr); +} + + +void +FreeForm2::OperandPromotionVisitor::Visit(const ForEachLoopExpression& p_expr) +{ + FF2_ASSERT(GetStack().size() >= 4); + + UnifyLoopCounters(GetStack().rbegin() + 1, 3, *m_typeManager, *m_owner); + + CopyingVisitor::Visit(p_expr); +} + + +void +FreeForm2::OperandPromotionVisitor::Visit(const ComplexRangeLoopExpression& p_expr) +{ + FF2_ASSERT(GetStack().size() >= 6); + + UnifyLoopCounters(GetStack().rbegin() + 2, 3, *m_typeManager, *m_owner); + + CopyingVisitor::Visit(p_expr); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/OperandPromotionVisitor.h b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/OperandPromotionVisitor.h new file mode 100644 index 000000000000..bd17ebd94e15 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/OperandPromotionVisitor.h @@ -0,0 +1,40 @@ +#pragma once + +#ifndef FREEFORM2_OPERAND_PROMOTION_VISITOR_H +#define FREEFORM2_OPERAND_PROMOTION_VISITOR_H + +#include "CopyingVisitor.h" + +#include +#include + +namespace FreeForm2 +{ + class SimpleExpressionOwner; + class ExpressionOwner; + class Expression; + + class OperandPromotionVisitor : public CopyingVisitor + { + public: + // Allow promotion in binary operator. + virtual bool AlternativeVisit(const BinaryOperatorExpression& p_expr); + virtual void Visit(const BinaryOperatorExpression& p_expr); + + // Allow promotion in the function call expression. + virtual void Visit(const FunctionCallExpression& p_expr); + + // Allow promotion in conditional. + virtual void Visit(const ConditionalExpression& p_expr); + + // Allow promotion in loop structures. + virtual void Visit(const RangeReduceExpression& p_expr); + virtual void Visit(const ForEachLoopExpression& p_expr); + virtual void Visit(const ComplexRangeLoopExpression& p_expr); + + // Promote void type object method expressions to statements. + // virtual void Visit(const ObjectMethodExpression& p_expr); + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ProcessFeaturesUsed.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ProcessFeaturesUsed.cpp new file mode 100644 index 000000000000..6600e8c5d3f1 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ProcessFeaturesUsed.cpp @@ -0,0 +1,17 @@ +#include "ProcessFeaturesUsed.h" + +#include "RefExpression.h" +#include + +FreeForm2::ProcessFeaturesUsedVisitor::ProcessFeaturesUsedVisitor( + DynamicRank::INeuralNetFeatures& p_features) + : m_features(p_features) +{ +} + + +void +FreeForm2::ProcessFeaturesUsedVisitor::Visit(const FeatureRefExpression& p_expr) +{ + m_features.ProcessFeature(p_expr.m_index); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ProcessFeaturesUsed.h b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ProcessFeaturesUsed.h new file mode 100644 index 000000000000..01f0e3531245 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/ProcessFeaturesUsed.h @@ -0,0 +1,30 @@ +#pragma once + +#ifndef FREEFORM2_PROCESS_FEATURES_USED_H +#define FREEFORM2_PROCESS_FEATURES_USED_H + +#include "NoOpVisitor.h" + +namespace DynamicRank +{ + class IFeatureMap; + class INeuralNetFeatures; +}; + +namespace FreeForm2 +{ + // Class to collect the set of features used by a program. + class ProcessFeaturesUsedVisitor : public NoOpVisitor + { + public: + ProcessFeaturesUsedVisitor(DynamicRank::INeuralNetFeatures& p_features); + + // Methods inherited from Visitor. + virtual void Visit(const FeatureRefExpression& p_expr); + + private: + DynamicRank::INeuralNetFeatures& m_features; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/TypeCheckingVisitor.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/TypeCheckingVisitor.cpp new file mode 100644 index 000000000000..2f12b295fae4 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/TypeCheckingVisitor.cpp @@ -0,0 +1,599 @@ +#include "TypeCheckingVisitor.h" + +#include "ArrayType.h" +#include "BlockExpression.h" +#include +#include "Conditional.h" +#include "ConvertExpression.h" +#include "DebugExpression.h" +#include "Declaration.h" +#include "Expression.h" +#include "FeatureSpec.h" +#include "FreeForm2Assert.h" +#include "Function.h" +#include "LetExpression.h" +#include "LiteralExpression.h" +#include "Match.h" +#include "Mutation.h" +#include "NoOpVisitor.h" +#include "FunctionType.h" +#include "Publish.h" +#include "RangeReduceExpression.h" +#include "RefExpression.h" +#include +#include "StateMachine.h" +#include "TypeImpl.h" +#include "TypeUtil.h" + +FreeForm2::TypeCheckingVisitor::TypeCheckingVisitor() + : m_lastExpressionReturns(true), + m_hasSideEffects(false) +{ +} + + +void +FreeForm2::TypeCheckingVisitor::AssertSideEffects(const FreeForm2::SourceLocation& p_sourceLocation) const +{ + if (!m_hasSideEffects) + { + std::ostringstream err; + err << "Statement does not have side effects."; + throw ParseError(err.str(), p_sourceLocation); + } +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const BlockExpression& p_expr) +{ + unsigned int numChildren = static_cast(p_expr.GetNumChildren()); + + if (m_lastExpressionReturns) + { + numChildren--; + } + + for (unsigned int i = 0; i < numChildren; i++) + { + const FreeForm2::Expression& child = p_expr.GetChild(i); + + if (m_functions.size() > 0 && m_functions.top().m_allPathsReturn) + { + std::ostringstream err; + err << "Unreachable code detected."; + throw ParseError(err.str(), child.GetSourceLocation()); + } + + m_hasSideEffects = false; + child.Accept(*this); + AssertSideEffects(child.GetSourceLocation()); + } + + if (m_lastExpressionReturns) + { + if (m_functions.size() > 0 && m_functions.top().m_allPathsReturn) + { + std::ostringstream err; + err << "Unreachable code detected."; + throw ParseError(err.str(), p_expr.GetChild(numChildren).GetSourceLocation()); + } + + p_expr.GetChild(numChildren).Accept(*this); + } + + UniformExpressionVisitor::Visit(p_expr); + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const ConditionalExpression& p_expr) +{ + const bool lastSiblingReturns = m_lastExpressionReturns; + const bool conditionalReturnsValue = p_expr.GetType().Primitive() != Type::Void; + + // We don't care if the condition has side effects, only the then and + // else clauses. + m_lastExpressionReturns = conditionalReturnsValue; + m_hasSideEffects = false; + p_expr.GetThen().Accept(*this); + if (!(m_hasSideEffects || m_lastExpressionReturns)) + { + AssertSideEffects(p_expr.GetThen().GetSourceLocation()); + } + + const bool ifExprReturns = m_functions.size() > 0 && m_functions.top().m_allPathsReturn; + + if (ifExprReturns) + { + m_functions.top().m_allPathsReturn = false; + } + + if (&p_expr.GetElse() != &FreeForm2::LiteralVoidExpression::GetInstance()) + { + m_lastExpressionReturns = conditionalReturnsValue; + m_hasSideEffects = false; + p_expr.GetElse().Accept(*this); + if (!(m_hasSideEffects || m_lastExpressionReturns)) + { + AssertSideEffects(p_expr.GetElse().GetSourceLocation()); + } + + if (m_functions.size() > 0) + { + m_functions.top().m_allPathsReturn &= ifExprReturns; + } + } + + UniformExpressionVisitor::Visit(p_expr); + + m_lastExpressionReturns = lastSiblingReturns; + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const MatchExpression& p_expr) +{ + const bool lastSiblingReturns = m_lastExpressionReturns; + + p_expr.GetPattern().Accept(*this); + p_expr.GetValue().Accept(*this); + + // Ignore the value and pattern of the match; they are not expected to + // have side effects. + m_lastExpressionReturns = false; + m_hasSideEffects = false; + p_expr.GetAction().Accept(*this); + AssertSideEffects(p_expr.GetAction().GetSourceLocation()); + + UniformExpressionVisitor::Visit(p_expr); + + m_lastExpressionReturns = lastSiblingReturns; + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const ConvertToImperativeExpression& p_expr) +{ + m_hasSideEffects = false; + p_expr.GetChild().Accept(*this); + AssertSideEffects(p_expr.GetChild().GetSourceLocation()); + UniformExpressionVisitor::Visit(p_expr); + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const RangeReduceExpression& p_expr) +{ + FF2_ASSERT(m_variableTypes.find(p_expr.GetReduceId()) == m_variableTypes.end() + && m_variableTypes.find(p_expr.GetStepId()) == m_variableTypes.end()); + const auto reduceKey = p_expr.GetReduceId(); + m_variableTypes.insert( + std::make_pair(reduceKey, &p_expr.GetReduceExpression().GetType())); + const auto stepKey = p_expr.GetStepId(); + m_variableTypes.insert( + std::make_pair(stepKey, &p_expr.GetLow().GetType())); + + const bool lastSiblingReturns = m_lastExpressionReturns; + + // Range-reduce expression only returns something if the reduction + // variable is non-void. + m_lastExpressionReturns = p_expr.GetType().Primitive() != Type::Void; + + p_expr.GetReduceExpression().Accept(*this); + m_hasSideEffects |= m_lastExpressionReturns; + AssertSideEffects(p_expr.GetReduceExpression().GetSourceLocation()); + + m_lastExpressionReturns = lastSiblingReturns; + + m_variableTypes.erase(reduceKey); + m_variableTypes.erase(stepKey); + + UniformExpressionVisitor::Visit(p_expr); + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const ForEachLoopExpression& p_expr) +{ + FF2_ASSERT(m_variableTypes.find(p_expr.GetIteratorId()) == m_variableTypes.end()); + const auto insertKey = p_expr.GetIteratorId(); + m_variableTypes.insert( + std::make_pair(insertKey, &p_expr.GetIteratorType())); + + const bool lastSiblingReturns = m_lastExpressionReturns; + m_lastExpressionReturns = false; + p_expr.GetBody().Accept(*this); + AssertSideEffects(p_expr.GetBody().GetSourceLocation()); + m_lastExpressionReturns = lastSiblingReturns; + + m_variableTypes.erase(insertKey); + + UniformExpressionVisitor::Visit(p_expr); + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const ComplexRangeLoopExpression& p_expr) +{ + FF2_ASSERT(m_variableTypes.find(p_expr.GetStepId()) == m_variableTypes.end()); + const auto insertKey = p_expr.GetStepId(); + m_variableTypes.insert(std::make_pair(insertKey, &p_expr.GetStepType())); + + const bool lastSiblingReturns = m_lastExpressionReturns; + m_lastExpressionReturns = false; + p_expr.GetBody().Accept(*this); + AssertSideEffects(p_expr.GetBody().GetSourceLocation()); + m_lastExpressionReturns = lastSiblingReturns; + + m_variableTypes.erase(insertKey); + + UniformExpressionVisitor::Visit(p_expr); + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const AggregateContextExpression& p_expr) +{ + const bool lastSiblingReturns = m_lastExpressionReturns; + m_lastExpressionReturns = false; + p_expr.GetBody().Accept(*this); + AssertSideEffects(p_expr.GetBody().GetSourceLocation()); + m_lastExpressionReturns = lastSiblingReturns; + + UniformExpressionVisitor::Visit(p_expr); + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const FunctionExpression& p_expr) +{ + // Save the state of the name->type mapping to restore it after the function body + // is evaluated. + const auto savedVariableTypesMap = m_variableTypes; + + // Bind all parameters for type checking purposes. + BOOST_FOREACH (auto& param, p_expr.GetParameters()) + { + FF2_ASSERT(m_variableTypes.find(param.m_parameter->GetId()) == m_variableTypes.end()); + m_variableTypes.insert(std::make_pair(param.m_parameter->GetId(), ¶m.m_parameter->GetType())); + } + + FunctionState functionState; + functionState.m_returnType = &p_expr.GetFunctionType().GetReturnType(); + functionState.m_allPathsReturn = false; + + m_functions.push(functionState); + + p_expr.GetBody().Accept(*this); + + UniformExpressionVisitor::Visit(p_expr); + + if (!m_functions.top().m_allPathsReturn) + { + std::ostringstream err; + err << "Not all code paths return a value."; + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + m_functions.pop(); + + // Restore the name->type mapping. + m_variableTypes = savedVariableTypesMap; + + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const FunctionCallExpression& p_expr) +{ + const FunctionType& functionType = p_expr.GetFunctionType(); + + FF2_ASSERT(p_expr.GetNumParameters() == functionType.GetParameterCount()); + + for (unsigned int i = 0; i < functionType.GetParameterCount(); ++i) + { + p_expr.GetParameters()[i]->Accept(*this); + + // If a parameter in a function is marked as mutable, the parameter passed + // into a function call must be a mutable l-value. + if (!functionType.BeginParameters()[i]->IsConst()) + { + if (p_expr.GetParameters()[i]->GetType().IsConst()) + { + throw ParseError("Parameter must be a mutable l-value.", + p_expr.GetParameters()[i]->GetSourceLocation()); + } + + // Ref-parameters must be variable ref expressions for the moment. + const VariableRefExpression* refExpression + = dynamic_cast(p_expr.GetParameters()[i]); + + if (refExpression == nullptr) + { + throw ParseError("Parameter must be a non-array variable name.", + p_expr.GetParameters()[i]->GetSourceLocation()); + } + } + } + + p_expr.GetFunction().Accept(*this); + + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const FeatureSpecExpression& p_expr) +{ + m_publishFeatureMap = p_expr.GetPublishFeatureMap().get(); + + m_lastExpressionReturns = p_expr.GetType().Primitive() != Type::Void; + p_expr.GetBody().Accept(*this); + + UniformExpressionVisitor::Visit(p_expr); + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const StateMachineExpression& p_expr) +{ + m_hasSideEffects = true; + UniformExpressionVisitor::Visit(p_expr); + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const LetExpression& p_expr) +{ + for (unsigned int i = 0; i < p_expr.GetNumChildren() - 1; i++) + { + FF2_ASSERT(m_variableTypes.find(p_expr.GetBound()[i].first) == m_variableTypes.end()); + m_variableTypes.insert( + std::make_pair(p_expr.GetBound()[i].first, &p_expr.GetBound()[i].second->GetType())); + } + + const bool lastSiblingReturns = m_lastExpressionReturns; + m_lastExpressionReturns = true; + p_expr.GetValue().Accept(*this); + m_lastExpressionReturns = lastSiblingReturns; + + for (unsigned int i = 0; i < p_expr.GetNumChildren() - 1; i++) + { + auto find = m_variableTypes.find(p_expr.GetBound()[i].first); + FF2_ASSERT(find != m_variableTypes.end()); + m_variableTypes.erase(find); + } + + UniformExpressionVisitor::Visit(p_expr); + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const ExecuteStreamRewritingStateMachineGroupExpression& p_expr) +{ + const bool lastSiblingReturns = m_lastExpressionReturns; + m_lastExpressionReturns = false; + + // Process all machine declarations before processing machine execution + // expressions, otherwise machine execution will reference undeclared + // variables. + for (unsigned int i = 0; i < p_expr.GetNumMachineInstances(); i++) + { + m_hasSideEffects = false; + p_expr.GetMachineInstances()[i].m_machineDeclaration->Accept(*this); + AssertSideEffects(p_expr.GetMachineInstances()[i].m_machineDeclaration->GetSourceLocation()); + } + + FF2_ASSERT(m_variableTypes.find(p_expr.GetMachineIndexID()) == m_variableTypes.end()); + const auto insertKey = p_expr.GetMachineIndexID(); + m_variableTypes.insert(std::make_pair(insertKey, &FreeForm2::TypeImpl::GetUInt32Instance(true))); + + for (unsigned int i = 0; i < p_expr.GetNumMachineInstances(); i++) + { + p_expr.GetMachineInstances()[i].m_machineExpression->Accept(*this); + } + + m_variableTypes.erase(insertKey); + + m_lastExpressionReturns = lastSiblingReturns; + + return true; +} + + +bool +FreeForm2::TypeCheckingVisitor::AlternativeVisit(const ExecuteMachineGroupExpression& p_expr) +{ + const bool lastSiblingReturns = m_lastExpressionReturns; + m_lastExpressionReturns = false; + + // Process all machine declarations before processing machine execution + // expressions, otherwise machine execution will reference undeclared + // variables. + for (unsigned int i = 0; i < p_expr.GetNumMachineInstances(); i++) + { + m_hasSideEffects = false; + p_expr.GetMachineInstances()[i].m_machineDeclaration->Accept(*this); + AssertSideEffects(p_expr.GetMachineInstances()[i].m_machineDeclaration->GetSourceLocation()); + } + + for (unsigned int i = 0; i < p_expr.GetNumMachineInstances(); i++) + { + p_expr.GetMachineInstances()[i].m_machineExpression->Accept(*this); + } + + m_lastExpressionReturns = lastSiblingReturns; + return true; +} + + +void +FreeForm2::TypeCheckingVisitor::Visit(const MutationExpression& p_expr) +{ + m_hasSideEffects = true; + UniformExpressionVisitor::Visit(p_expr); +} + + +void +FreeForm2::TypeCheckingVisitor::Visit(const DeclarationExpression& p_expr) +{ + FF2_ASSERT(m_variableTypes.find(p_expr.GetId()) == m_variableTypes.end()); + m_variableTypes.insert(std::make_pair(p_expr.GetId(), &p_expr.GetDeclaredType())); + + m_hasSideEffects = true; + UniformExpressionVisitor::Visit(p_expr); +} + + +void +FreeForm2::TypeCheckingVisitor::Visit(const DirectPublishExpression& p_expr) +{ + m_hasSideEffects = true; + + FF2_ASSERT(m_publishFeatureMap != NULL); + FeatureSpecExpression::PublishFeatureMap::const_iterator featureNameToType = + m_publishFeatureMap->find(FeatureSpecExpression::FeatureName(p_expr.GetFeatureName())); + + // The lemon file should have already checked that the feature names + // being published are valid. + FF2_ASSERT(featureNameToType != m_publishFeatureMap->end()); + FF2_ASSERT(featureNameToType->second.Primitive() == Type::Array); + + const ArrayType& type = static_cast(featureNameToType->second); + + if (!TypeUtil::IsAssignable(type.GetChildType(), p_expr.GetValue().GetType())) + { + std::ostringstream err; + err << "Invalid publish type: " << p_expr.GetValue().GetType() + << "; expected type: " << type.GetChildType(); + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + UniformExpressionVisitor::Visit(p_expr); +} + + +void +FreeForm2::TypeCheckingVisitor::Visit(const PublishExpression& p_expr) +{ + m_hasSideEffects = true; + + FF2_ASSERT(m_publishFeatureMap != NULL); + FeatureSpecExpression::PublishFeatureMap::const_iterator featureNameToType = + m_publishFeatureMap->find(FeatureSpecExpression::FeatureName(p_expr.GetFeatureName())); + + // The lemon file should have already checked that the feature names + // being published are valid. + FF2_ASSERT(featureNameToType != m_publishFeatureMap->end()); + + if (!TypeUtil::IsAssignable(featureNameToType->second, p_expr.GetValue().GetType())) + { + std::ostringstream err; + err << "Invalid publish type: " << p_expr.GetValue().GetType() + << "; expected type: " << featureNameToType->second; + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + UniformExpressionVisitor::Visit(p_expr); +} + + +void +FreeForm2::TypeCheckingVisitor::Visit(const ReturnExpression& p_expr) +{ + if (m_functions.size() == 0) + { + std::ostringstream err; + err << "Return statements can only be used in functions."; + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + if (!TypeUtil::IsAssignable(*m_functions.top().m_returnType, p_expr.GetValue().GetType())) + { + std::ostringstream err; + err << "Invalid return type: " << p_expr.GetValue().GetType() + << "; expected type: " << *m_functions.top().m_returnType; + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + m_hasSideEffects = true; + m_functions.top().m_allPathsReturn = true; + UniformExpressionVisitor::Visit(p_expr); +} + + +void +FreeForm2::TypeCheckingVisitor::Visit(const ImportFeatureExpression& p_expr) +{ + FF2_ASSERT(m_variableTypes.find(p_expr.GetId()) == m_variableTypes.end()); + m_variableTypes.insert(std::make_pair(p_expr.GetId(), &p_expr.GetType())); + + m_hasSideEffects = true; + UniformExpressionVisitor::Visit(p_expr); +} + + +void +FreeForm2::TypeCheckingVisitor::Visit(const VariableRefExpression& p_expr) +{ + auto find = m_variableTypes.find(p_expr.GetId()); + + if (find == m_variableTypes.end()) + { + std::ostringstream err; + err << "Variable referenced before its declaration (ID " + << p_expr.GetId() << ")"; + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + FF2_ASSERT(find->second != nullptr); + if (!find->second->IsSameAs(p_expr.GetType(), true)) + { + std::ostringstream err; + err << "Variable declaration and reference have different types. " + << "Got " << *find->second << " and " << p_expr.GetType() + << " for the declaration and reference, respectively " + << "(ID " << p_expr.GetId() << ")"; + throw ParseError(err.str(), p_expr.GetSourceLocation()); + } + + UniformExpressionVisitor::Visit(p_expr); +} + + +void +FreeForm2::TypeCheckingVisitor::Visit(const DebugExpression& p_expr) +{ + m_hasSideEffects = true; + UniformExpressionVisitor::Visit(p_expr); +} + + +void +FreeForm2::TypeCheckingVisitor::VisitReference(const VariableRefExpression& p_expr) +{ + Visit(p_expr); +} + + +void +FreeForm2::TypeCheckingVisitor::Visit(const Expression& p_expr) +{ + const TypeImpl& type = p_expr.GetType(); + FF2_ASSERT(type.Primitive() != Type::Unknown); + FF2_ASSERT(type.IsValid()); +} diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/TypeCheckingVisitor.h b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/TypeCheckingVisitor.h new file mode 100644 index 000000000000..8150b016539b --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/TypeCheckingVisitor.h @@ -0,0 +1,92 @@ +#pragma once + +#ifndef FREEFORM2_TYPE_CHECKING_VISITOR_H +#define FREEFORM2_TYPE_CHECKING_VISITOR_H + +#include "UniformExpressionVisitor.h" +#include "TypeImpl.h" +#include "FeatureSpec.h" +#include + +namespace FreeForm2 +{ + // Visitor that visits all expressions to ensure that + // type-checking occurs. + class TypeCheckingVisitor : public UniformExpressionVisitor + { + public: + // Construct a type checking visitor. + TypeCheckingVisitor(); + + // These expressions require special handling to ensure that their + // children have side effects. Blocks, conditionals, matches, and + // conversions-to-imperative have side effects iff all of their + // children have side effects. + virtual bool AlternativeVisit(const BlockExpression& p_expr) override; + virtual bool AlternativeVisit(const ConditionalExpression& p_expr) override; + virtual bool AlternativeVisit(const MatchExpression& p_expr) override; + virtual bool AlternativeVisit(const ConvertToImperativeExpression& p_expr) override; + virtual bool AlternativeVisit(const RangeReduceExpression& p_expr) override; + virtual bool AlternativeVisit(const ForEachLoopExpression& p_expr) override; + virtual bool AlternativeVisit(const ComplexRangeLoopExpression& p_expr) override; + virtual bool AlternativeVisit(const AggregateContextExpression& p_expr) override; + virtual bool AlternativeVisit(const FunctionExpression& p_expr) override; + virtual bool AlternativeVisit(const FunctionCallExpression& p_expr) override; + virtual bool AlternativeVisit(const FeatureSpecExpression& p_expr) override; + virtual bool AlternativeVisit(const StateMachineExpression& p_expr) override; + virtual bool AlternativeVisit(const LetExpression& p_expr) override; + virtual bool AlternativeVisit(const ExecuteStreamRewritingStateMachineGroupExpression& p_expr) override; + virtual bool AlternativeVisit(const ExecuteMachineGroupExpression& p_expr) override; + + // Mutations and declarations always have side effects. + virtual void Visit(const MutationExpression& p_expr) override; + virtual void Visit(const DeclarationExpression& p_expr) override; + virtual void Visit(const DirectPublishExpression& p_expr) override; + virtual void Visit(const PublishExpression& p_expr) override; + virtual void Visit(const ReturnExpression& p_expr) override; + virtual void Visit(const ImportFeatureExpression& p_expr) override; + // virtual void Visit(const ObjectMethodExpression& p_expr) override; + virtual void Visit(const VariableRefExpression& p_expr) override; + virtual void Visit(const DebugExpression& p_expr) override; + virtual void VisitReference(const VariableRefExpression& p_expr) override; + + // Check types for all statements. + virtual void Visit(const Expression& p_expr) override; + + private: + // This flag indicates whether the last child expression of the + // current block acts as the return value of the block. + bool m_lastExpressionReturns; + + // This flag indicates whether the last visited expression tree has + // side effects. + bool m_hasSideEffects; + + // Store the map of published features to their return types so that + // publish expressions can be type checked against it. + const FeatureSpecExpression::PublishFeatureMap* m_publishFeatureMap; + + // Assert that the last visited expression tree has side effects. + void AssertSideEffects(const SourceLocation& p_sourceLocation) const; + + // This map is used for verifying variable declaration/reference type + // matching. + std::map m_variableTypes; + + // This structure holds information about the function that is currently + // being analyzed. + struct FunctionState + { + const TypeImpl* m_returnType; + bool m_allPathsReturn; + }; + + // A stack that holds the checking state for functions. + // Since FunctionCallExpression has a pointer to the FunctionExpression, + // function declarations can be nested in the expression tree, although + // this is forbidden by the grammar. + std::stack m_functions; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/UniformExpressionVisitor.cpp b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/UniformExpressionVisitor.cpp new file mode 100644 index 000000000000..46c8c281b8c0 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/UniformExpressionVisitor.cpp @@ -0,0 +1,581 @@ +#include "UniformExpressionVisitor.h" + +#include "Expression.h" +#include "Allocation.h" +#include "ArrayDereferenceExpression.h" +#include "ArrayLength.h" +#include "ArrayLiteralExpression.h" +#include "BlockExpression.h" +#include "Conditional.h" +#include "ConvertExpression.h" +#include "DebugExpression.h" +#include "Declaration.h" +#include "Extern.h" +#include "FeatureSpec.h" +#include "Function.h" +#include "LetExpression.h" +#include "LiteralExpression.h" +#include "Mutation.h" +#include "Match.h" +#include "MemberAccessExpression.h" +#include "OperatorExpression.h" +#include "PhiNode.h" +#include "Publish.h" +#include "RangeReduceExpression.h" +#include "RandExpression.h" +#include "RefExpression.h" +#include "SelectNth.h" +#include "StateMachine.h" +#include "StreamData.h" + + +void +FreeForm2::UniformExpressionVisitor::Visit(const SelectNthExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const SelectRangeExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ConditionalExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ArrayLiteralExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LetExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const BlockExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const BinaryOperatorExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const RangeReduceExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ForEachLoopExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ComplexRangeLoopExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const MutationExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const MatchExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const MatchOperatorExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const MatchGuardExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const MatchBindExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const MemberAccessExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ArrayLengthExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ArrayDereferenceExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ConvertToFloatExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ConvertToIntExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + +void +FreeForm2::UniformExpressionVisitor::Visit(const ConvertToUInt64Expression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ConvertToInt32Expression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ConvertToUInt32Expression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ConvertToBoolExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ConvertToImperativeExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const DeclarationExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const DirectPublishExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ExternExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const FunctionExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const FunctionCallExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LiteralIntExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LiteralUInt64Expression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LiteralInt32Expression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LiteralUInt32Expression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LiteralFloatExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LiteralBoolExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LiteralVoidExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LiteralStreamExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LiteralWordExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const LiteralInstanceHeaderExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const FeatureRefExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const UnaryOperatorExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const FeatureSpecExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const FeatureGroupSpecExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const PhiNodeExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const PublishExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ReturnExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const StreamDataExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const UpdateStreamDataExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const VariableRefExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ImportFeatureExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const StateExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const StateMachineExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ExecuteStreamRewritingStateMachineGroupExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ExecuteMachineExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ExecuteMachineGroupExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const YieldExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const RandFloatExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const RandIntExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const ThisExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const UnresolvedAccessExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const TypeInitializerExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const AggregateContextExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::VisitReference(const ArrayDereferenceExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::VisitReference(const VariableRefExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::VisitReference(const MemberAccessExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::VisitReference(const ThisExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::VisitReference(const UnresolvedAccessExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + + +void +FreeForm2::UniformExpressionVisitor::Visit(const DebugExpression& p_expr) +{ + const Expression& expr = p_expr; + Visit(expr); +} + diff --git a/src/transform/DynamicRank.FreeForm.Library/libs/Transform/UniformExpressionVisitor.h b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/UniformExpressionVisitor.h new file mode 100644 index 000000000000..b4c9c39f0cdf --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/libs/Transform/UniformExpressionVisitor.h @@ -0,0 +1,90 @@ +#pragma once + +#ifndef FREEFORM2_UNIFORM_EXPRESSION_VISITOR_H +#define FREEFORM2_UNIFORM_EXPRESSION_VISITOR_H + +#include "Visitor.h" + +namespace FreeForm2 +{ + // Visitor that implements methods for every expression class that does nothing. + class UniformExpressionVisitor : public Visitor + { + public: + virtual void Visit(const Expression& p_expr) = 0; + + // Methods inherited from Visitor. + virtual void Visit(const SelectNthExpression& p_expr) override; + virtual void Visit(const SelectRangeExpression& p_expr) override; + virtual void Visit(const ConditionalExpression& p_expr) override; + virtual void Visit(const ArrayLiteralExpression& p_expr) override; + virtual void Visit(const LetExpression& p_expr) override; + virtual void Visit(const BlockExpression& p_expr) override; + virtual void Visit(const BinaryOperatorExpression& p_expr) override; + virtual void Visit(const RangeReduceExpression& p_expr) override; + virtual void Visit(const ForEachLoopExpression& p_expr) override; + virtual void Visit(const ComplexRangeLoopExpression& p_expr) override; + virtual void Visit(const MutationExpression& p_expr) override; + virtual void Visit(const MatchExpression& p_expr) override; + virtual void Visit(const MatchOperatorExpression& p_expr) override; + virtual void Visit(const MatchGuardExpression& p_expr) override; + virtual void Visit(const MatchBindExpression& p_expr) override; + virtual void Visit(const MemberAccessExpression& p_expr) override; + virtual void Visit(const ArrayLengthExpression& p_expr) override; + virtual void Visit(const ArrayDereferenceExpression& p_expr) override; + virtual void Visit(const ConvertToFloatExpression& p_expr) override; + virtual void Visit(const ConvertToIntExpression& p_expr) override; + virtual void Visit(const ConvertToUInt64Expression& p_expr) override; + virtual void Visit(const ConvertToInt32Expression& p_expr) override; + virtual void Visit(const ConvertToUInt32Expression& p_expr) override; + virtual void Visit(const ConvertToBoolExpression& p_expr) override; + virtual void Visit(const ConvertToImperativeExpression& p_expr) override; + virtual void Visit(const DeclarationExpression& p_expr) override; + virtual void Visit(const DirectPublishExpression& p_expr) override; + virtual void Visit(const ExternExpression& p_expr) override; + virtual void Visit(const FunctionExpression& p_expr) override; + virtual void Visit(const FunctionCallExpression& p_expr) override; + virtual void Visit(const LiteralIntExpression& p_expr) override; + virtual void Visit(const LiteralUInt64Expression& p_expr) override; + virtual void Visit(const LiteralInt32Expression& p_expr) override; + virtual void Visit(const LiteralUInt32Expression& p_expr) override; + virtual void Visit(const LiteralFloatExpression& p_expr) override; + virtual void Visit(const LiteralBoolExpression& p_expr) override; + virtual void Visit(const LiteralVoidExpression& p_expr) override; + virtual void Visit(const LiteralStreamExpression& p_expr) override; + virtual void Visit(const LiteralWordExpression& p_expr) override; + virtual void Visit(const LiteralInstanceHeaderExpression& p_expr) override; + virtual void Visit(const FeatureRefExpression& p_expr) override; + virtual void Visit(const UnaryOperatorExpression& p_expr) override; + virtual void Visit(const FeatureSpecExpression& p_expr) override; + virtual void Visit(const FeatureGroupSpecExpression& p_expr) override; + virtual void Visit(const PhiNodeExpression& p_expr) override; + virtual void Visit(const PublishExpression& p_expr) override; + virtual void Visit(const ReturnExpression& p_expr) override; + virtual void Visit(const StreamDataExpression& p_expr) override; + virtual void Visit(const UpdateStreamDataExpression& p_expr) override; + virtual void Visit(const VariableRefExpression& p_expr) override; + virtual void Visit(const ImportFeatureExpression& p_expr) override; + virtual void Visit(const StateExpression& p_expr) override; + virtual void Visit(const StateMachineExpression& p_expr) override; + virtual void Visit(const ExecuteStreamRewritingStateMachineGroupExpression& p_expr) override; + virtual void Visit(const ExecuteMachineExpression& p_expr) override; + virtual void Visit(const ExecuteMachineGroupExpression& p_expr) override; + virtual void Visit(const YieldExpression& p_expr) override; + virtual void Visit(const RandFloatExpression& p_expr) override; + virtual void Visit(const RandIntExpression& p_expr) override; + virtual void Visit(const ThisExpression& p_expr) override; + virtual void Visit(const UnresolvedAccessExpression& p_expr) override; + virtual void Visit(const TypeInitializerExpression& p_expr) override; + virtual void Visit(const AggregateContextExpression& p_expr) override; + virtual void Visit(const DebugExpression& p_expr) override; + + virtual void VisitReference(const ArrayDereferenceExpression& p_expr) override; + virtual void VisitReference(const VariableRefExpression& p_expr) override; + virtual void VisitReference(const MemberAccessExpression& p_expr) override; + virtual void VisitReference(const ThisExpression&) override; + virtual void VisitReference(const UnresolvedAccessExpression&) override; + }; +} + +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/test/CMakeLists.txt b/src/transform/DynamicRank.FreeForm.Library/test/CMakeLists.txt new file mode 100644 index 000000000000..1e43c48ca0bb --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/test/CMakeLists.txt @@ -0,0 +1,51 @@ +cmake_minimum_required(VERSION 3.15) + +set(PROJECT_NAME FreeForm2Test) + +project(${PROJECT_NAME}) + +link_directories( + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/Shared + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/Expression + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/Backend/llvm + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/Transform + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/Parse/SExpression/lib + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/External + ${CMAKE_CURRENT_SOURCE_DIR}/../../NeuralTree.Library/src + ) + +add_executable(${PROJECT_NAME} + ${CMAKE_CURRENT_SOURCE_DIR}/FreeFormLibTest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/SimpleFeatureMap.cpp + ) + +target_include_directories(${PROJECT_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../inc + ${CMAKE_CURRENT_SOURCE_DIR}/../../NeuralTree.Library/inc + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/Shared + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/Expression + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/Backend/llvm + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/Transform + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/Parse/SExpression/inc + ${CMAKE_CURRENT_SOURCE_DIR}/../libs/External + ) + +target_link_libraries(${PROJECT_NAME} + -Wl,--no-as-needed + -Wl,--start-group + -lpthread + -lz + -ldl + -ltinfo + DRNeuralTreeLibrary + DRFreeFormLlvmBackendLibrary + DRFreeFormSharedLibrary + DRFreeFormExpressionLibrary + DRFreeFormTransformLibrary + DRFreeFormSExpressionLibrary + DRFreeFormLibrary + ${BOOST_LIB} + ${LLVM_LIB} + -Wl,--end-group + ) diff --git a/src/transform/DynamicRank.FreeForm.Library/test/FreeFormLibTest.cpp b/src/transform/DynamicRank.FreeForm.Library/test/FreeFormLibTest.cpp new file mode 100644 index 000000000000..f08ecf65df77 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/test/FreeFormLibTest.cpp @@ -0,0 +1,15 @@ +#include +#include "FreeFormLibTestSet.h" +#include + +using namespace LightGBM; + +int main() +{ + Log::Info("-------- FreeForm Library Test Starts --------"); + FreeFormLibTestSet::TestParser("(if (== Foo Bar) 1 0)"); + FreeFormLibTestSet::TestParser("(* (ln1 NumberOfCompleteMatches_IETBSatModel-IM-Prod) OriginalQueryMaxNumberOfPerfectMatches_BingClicks-Prod)"); + FreeFormLibTestSet::TestNeuralInputLoadSave(); + Log::Info("-------- FreeForm Library Test Finished --------"); + return 0; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/test/FreeFormLibTestSet.h b/src/transform/DynamicRank.FreeForm.Library/test/FreeFormLibTestSet.h new file mode 100644 index 000000000000..f110ec1c7050 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/test/FreeFormLibTestSet.h @@ -0,0 +1,142 @@ +#ifndef FREEFORMLIBTESTSET_H +#define FREEFORMLIBTESTSET_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "NeuralInputFreeForm2.h" +#include "FreeForm2CompilerFactory.h" +#include "FreeForm2Compiler.h" +#include "SimpleFeatureMap.h" +#include + +using namespace LightGBM; + +class FreeFormLibTestSet +{ + static void AssertIsNotNull(const void * ptr){ + if(ptr == nullptr) + { + Error("AssertIsNotNull Failed!"); + } + Log::Info("[PASS] AssertIsNotNull\n"); + } + + static void AssertAreEqual(double truth, double value, double error){ + double delta = truth - value; + delta = delta >= 0? delta: -delta; + if(delta > error) + { + Error("AssertAreEqual(double) Failed!"); + } + Log::Info("[PASS] AssertAreEqual(double)\n"); + } + + static void AssertAreEqual(int truth, int value){ + if(truth != value) + { + Error("AssertAreEqual(int) Failed!"); + } + Log::Info("[PASS] AssertAreEqual(int)\n"); + } + + static void AssertSzAreEqual(const char * cptr0, const char * cptr1){ + unsigned int offset = 0; + while(*(cptr0 + offset) && *(cptr1 + offset)) + ++offset; + if(*(cptr0 + offset) != *(cptr1 + offset)) + { + Error("AssertSzAreEqual Failed!"); + } + Log::Info("[PASS] AssertSzAreEqual\n"); + } + + static void Error(const char * s){ + Log::Fatal("%s", s); + throw s; + } + // Test base: + static void + TestInput( + DynamicRank::Config& p_config, + double p_result, + const char* p_serial, + const char* p_section) + { + SimpleFeatureMap map; + FreeForm2::CompiledNeuralInputLoader loader("FreeForm2"); + std::auto_ptr input(loader(p_config, p_section, map)); + AssertIsNotNull(input.get()); + std::unique_ptr compiler(FreeForm2::CompilerFactory::CreateExecutableCompiler( + FreeForm2::Compiler::c_defaultOptimizationLevel, + FreeForm2::CompilerFactory::SingleDocumentEvaluation)); + loader.Compile(*compiler); + AssertAreEqual(p_result, input->Evaluate(NULL), 0.0001); + + // Unfortunately, we have to write to disk to test this input, + // because i'm not sure how to get a FILE* backed by memory. This + // would be much easier if inputs wrote to iostreams. + const char* filename = "TestNeuralInputLoadSave.tmp"; + FILE* f = fopen(filename, "w"); + AssertIsNotNull(f); + input->Save(f, 0, map); + fclose(f); + + std::ifstream fs(filename); + std::stringstream buffer; + buffer << fs.rdbuf(); + std::string content = buffer.str(); + AssertSzAreEqual(p_serial, content.c_str()); + Log::Info("[PASS] TestInput\n"); + } + +public: + // Test functions: + static void TestParser(const char * freeform) + { + boost::shared_ptr featureMap(new SimpleFeatureMap()); + boost::shared_ptr input = boost::shared_ptr( + new FreeForm2::NeuralInputFreeForm2(std::string(freeform), "freeform2", *featureMap)); + AssertIsNotNull(input.get()); + + + std::unique_ptr comp(FreeForm2::CompilerFactory::CreateExecutableCompiler(2)); + input->Compile(comp.get()); + } + + static void PRINT_MEM(void* ptr, size_t len) + { + printf("\nMEM %02X - %02X:\n", ptr, ptr + len - 1); + for(unsigned int i = 0; i < len; ++i) + { + printf("%02X ", *(unsigned char *)(ptr + i)); + } + printf("\n"); + } + + static void TestNeuralInputLoadSave() + { + std::string path(""); + std::map> config_map; + // PRINT_MEM((void *) &config_map, sizeof(config_map)); + const char* section = "Input:0"; + const char* transform = "FreeForm2"; + config_map[section]["Transform"] = transform; + config_map[section]["Line1"] = "(+ 1 2"; + config_map[section]["Line2"] = "1 2"; + config_map[section]["Line3"] = ")"; + // PRINT_MEM((void *) &config_map, sizeof(config_map)); + void* config_ptr = malloc(sizeof(DynamicRank::Config)); + memcpy(config_ptr, &path, sizeof(path)); + memcpy(config_ptr + sizeof(path), &config_map, sizeof(config_map)); + DynamicRank::Config config = *(DynamicRank::Config*) config_ptr; + TestInput(config, 6.0, "\n[Input:0]\nTransform=FreeForm2\nLine1=(+ 1 2\nLine2=1 2\nLine3=)\n", section); + } +}; +#endif diff --git a/src/transform/DynamicRank.FreeForm.Library/test/SimpleFeatureMap.cpp b/src/transform/DynamicRank.FreeForm.Library/test/SimpleFeatureMap.cpp new file mode 100644 index 000000000000..71dae5e3ba61 --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/test/SimpleFeatureMap.cpp @@ -0,0 +1,85 @@ +#include "SimpleFeatureMap.h" +#include + + + +SimpleFeatureMap::SimpleFeatureMap() + : m_numberOfFeatures(0) +{ +} + +SimpleFeatureMap::~SimpleFeatureMap() +{ +} + + + +bool SimpleFeatureMap::GetExistingFeatureIndex( + const char* featureName, UInt32& featureIndex) const +{ + std::string featureNameStr(featureName); + std::map::const_iterator it = m_featureMap.find(featureNameStr); + + if (it == m_featureMap.end()) + { + return false; + } + featureIndex = it->second; + return true; +} + + +bool SimpleFeatureMap::ObtainFeatureIndex( + const char* featureName, UInt32& featureIndex) +{ + const std::string featureNameStr(featureName); + std::map::iterator it = m_featureMap.find(featureNameStr); + if (it != m_featureMap.end()) + { + featureIndex = it->second; + } + else + { + UInt32 iFeature = m_numberOfFeatures; + std::pair pair(featureNameStr, iFeature); + m_featureMap.insert(pair); + m_reverseFeatureMap.push_back(std::string(featureName)); + m_numberOfFeatures++; + featureIndex = iFeature; + } + return true; +} + + +bool SimpleFeatureMap::ObtainFeatureIndex( + const SIZED_STRING& featureName, UInt32& featureIndex) +{ + std::string localFeatureName((const char*)featureName.pbData, featureName.cbData); + return ObtainFeatureIndex(localFeatureName.c_str(), featureIndex); +} + +bool SimpleFeatureMap::GetFeatureName( + UInt32 featureIndex, char *featureName, UInt32 maxNameLength) const +{ + if (featureIndex < m_numberOfFeatures) + { + _snprintf_s( + featureName, + maxNameLength, + _TRUNCATE, + "%s", + m_reverseFeatureMap[featureIndex].c_str()); + return true; + } + return false; +} + +const std::string& SimpleFeatureMap::GetFeatureName(UInt32 featureIndex) const +{ + return m_reverseFeatureMap[featureIndex]; +} + +UInt32 SimpleFeatureMap::GetNumberOfFeatures() const +{ + return m_numberOfFeatures; +} diff --git a/src/transform/DynamicRank.FreeForm.Library/test/SimpleFeatureMap.h b/src/transform/DynamicRank.FreeForm.Library/test/SimpleFeatureMap.h new file mode 100644 index 000000000000..4f47e86bb6bd --- /dev/null +++ b/src/transform/DynamicRank.FreeForm.Library/test/SimpleFeatureMap.h @@ -0,0 +1,39 @@ +#pragma once + +#include "IFeatureMap.h" +#include "basic_types.h" + +#include +#include +#include + +#define _snprintf_s(a,b,c,...) snprintf(a,b,__VA_ARGS__) + + +class SimpleFeatureMap : public DynamicRank::IFeatureMap +{ +private: + std::map m_featureMap; + std::vector m_reverseFeatureMap; + UInt32 m_numberOfFeatures; + +public: + SimpleFeatureMap(); + ~SimpleFeatureMap(); + + // Get the index of a feature, or return false if there is no such feature. + bool GetExistingFeatureIndex(const char* featureName, UInt32& featureIndex) const; + + // Convert the FeatureName to the FeatureIndex. + bool ObtainFeatureIndex(const char* featureName, UInt32& featureIndex); + + // Convert the FeatureName to the FeatureIndex. + bool ObtainFeatureIndex(const SIZED_STRING& featureName, UInt32& featureIndex); + + // Convert the FeatureIndex to the FeatureName. + bool GetFeatureName(UInt32 featureIndex, char* featureName, UInt32 maxNameLength) const; + const std::string& GetFeatureName(UInt32 featureIndex) const; + + // The number of features within the map. + UInt32 GetNumberOfFeatures() const; +}; diff --git a/src/transform/NeuralTree.Library/CMakeLists.txt b/src/transform/NeuralTree.Library/CMakeLists.txt new file mode 100644 index 000000000000..e0f9f7c23809 --- /dev/null +++ b/src/transform/NeuralTree.Library/CMakeLists.txt @@ -0,0 +1,7 @@ +cmake_minimum_required(VERSION 3.15.0 FATAL_ERROR) + +project(NeuralTree.Library CXX) + +set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -std=c++11") + +add_subdirectory(src) diff --git a/src/transform/NeuralTree.Library/inc/CsHash.h b/src/transform/NeuralTree.Library/inc/CsHash.h new file mode 100644 index 000000000000..3d45dae56804 --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/CsHash.h @@ -0,0 +1,157 @@ +#pragma once +#include "basic_types.h" +#include "MigratedApi.h" + +// This code is copyed from apsSDK/CsHash.h +// lookup3, by Bob Jenkins, public domain +class CsHash32 +{ +private: + CsHash32 (); // no NEW allowed + CsHash32 ( const CsHash32& ); // no copy allowed + CsHash32& operator= ( const CsHash32& ); // no assignment allowed + +public: + + // + // Mix(): cause every bit of a,b,c to affect 32 bits of a,b,c + // both forwards and in reverse. Same for pairs of bits in a,b,c. + // This can be used along with Final() to hash a fixed number of + // 4-byte integers, for example see the implementation of Guid(). + // + inline static void Mix( UInt32& a, UInt32& b, UInt32& c) + { + a -= c; a ^= rotl(c, 4); c += b; + b -= a; b ^= rotl(a, 6); a += c; + c -= b; c ^= rotl(b, 8); b += a; + a -= c; a ^= rotl(c,16); c += b; + b -= a; b ^= rotl(a,19); a += c; + c -= b; c ^= rotl(b, 4); b += a; + } + + // + // Final: cause every bit of a,b,c to affect every bit of c, only forward. + // Same for pairs of bits in a,b,c. It also causes b to be an OK hash. + // This is a good way to hash 1 or 2 or 3 integers: + // a = k1; b = k2; c = 0; + // CsHash32::Final(a,b,c); + // Use c (and maybe b) as the hash value + // + inline static void Final( UInt32& a, UInt32& b, UInt32& c) + { + c ^= b; c -= rotl(b,14); + a ^= c; a -= rotl(c,11); + b ^= a; b -= rotl(a,25); + c ^= b; c -= rotl(b,16); + a ^= c; a -= rotl(c,4); + b ^= a; b -= rotl(a,14); + c ^= b; c -= rotl(b,24); + } + + // + // Compute2: Compute two hash values for a byte array of known length + // + static void Compute2 ( + const void *pData, // byte array of known length + Size_t uSize, // size of pData + UInt32 uSeed1, // first seed + UInt32 uSeed2, // second seed + UInt32 *uHash1, // OUT: first hash value (may not be null) + UInt32 *uHash2); // OUT: second hash value (may not be null) + + // + // Compute: Compute a hash values for a byte array of known length + // + static const UInt32 Compute ( + const void *pData, // byte array of known length + Size_t uSize, // size of pData + UInt32 uSeed = 0) // seed for hash function + { + UInt32 uHash2 = 0; + CsHash32::Compute2(pData, uSize, uSeed, uHash2, &uSeed, &uHash2); + return uSeed; + } + + // + // String: hash of string of unknown length + // + static const UInt32 String ( + _In_z_ const char *pString, // ASCII string to hash case-sensitive + UInt32 uSeed = 0) // optional seed for hash + { + UInt32 uHash2 = 0; + Size_t uSize; + + uSize = strlen(pString); + CsHash32::Compute2(pString, uSize, uSeed, uHash2, &uSeed, &uHash2); + return uSeed; + } + + // + // StringI2: Produce two case-insensitive 32-bit hashes of an ASCII string + // The results are identical to Compute2() on an uppercased string + // + static void StringI2 ( + const char *pString, // ASCII string to hash case-insensitive + Size_t uSize, // length of string (required) + UInt32 uSeed1, // first seed + UInt32 uSeed2, // second seed + UInt32 *uHash1, // OUT: first hash + UInt32 *uHash2); // OUT: second hash + + // + // StringI: case insensitive hash of string of unknown length + // + static const UInt32 StringI ( + _In_z_ const char *pString, // ASCII string to hash case-insensitive + size_t len = (size_t)-1, + UInt32 uSeed = 0) // optional seed for hash + { + UInt32 uHash2 = 0; + Size_t uSize; + + uSize = (len == (size_t)-1) ? strlen(pString) : len; + CsHash32::StringI2(pString, uSize, uSeed, uHash2, &uSeed, &uHash2); + return uSeed; + } + +}; + +class CsHash64 +{ +private: + CsHash64 (); // no NEW allowed + CsHash64 ( const CsHash64& ); // no copy allowed + CsHash64& operator= ( const CsHash64& ); // no assignment allowed + +public: + + + // + // Compute hash for a byte array of known length. + // + static const UInt64 Compute ( + const void *pData, // byte array to hash + Size_t uSize, // length of pData + UInt64 uSeed = 0 ) // seed to hash function; 0 is an OK value + { + UInt32 uHash1, uHash2; + CsHash32::Compute2( pData, uSize, (UInt32) uSeed, (UInt32) (uSeed >> 32), &uHash1, &uHash2); + return uHash1 | (((UInt64)uHash2) << 32); + } + + // + // case-insensitive hash of null terminated string + // produce the same hash as Compute on an uppercased string + // + static const UInt64 StringI ( + const char *pString, + Size_t uSize, + UInt64 uSeed) + { + UInt32 uHash1, uHash2; + CsHash32::StringI2( pString, uSize, (UInt32) uSeed, (UInt32) (uSeed >> 32), &uHash1, &uHash2); + return uHash1 | (((UInt64)uHash2) << 32); + } + +}; diff --git a/src/transform/NeuralTree.Library/inc/FeaSpecConfig.h b/src/transform/NeuralTree.Library/inc/FeaSpecConfig.h new file mode 100644 index 000000000000..ad663cb8572b --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/FeaSpecConfig.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include +#include +using namespace std; + +namespace DynamicRank{ + +class Config{ +private: + string _path; + map> _config; + Config(map>& config){ + _config= config; + }; +public: + static Config* GetRawConfiguration(string str); + bool DoesSectionExist(char* section); + bool DoesParameterExist(const char* section, const char* parameterName); + bool GetStringParameter(const char * section, const char* parameterName, string& value); + bool GetStringParameter(const char * section, const char* parameterName, char* value, size_t valueSize); + bool GetDoubleParameter(const char* section, const char* parameterName, double* value); + double GetDoubleParameter(const char* section, const char* parameterName, double defaultValue); + bool GetBoolParameter(const char* section, const char* parameterName, bool* value); + bool GetBoolParameter(const char* section, const char* parameterName, bool defaultValue); +}; +} \ No newline at end of file diff --git a/src/transform/NeuralTree.Library/inc/IFeatureMap.h b/src/transform/NeuralTree.Library/inc/IFeatureMap.h new file mode 100644 index 000000000000..3e956dfe305b --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/IFeatureMap.h @@ -0,0 +1,34 @@ +#pragma once + +#include "basic_types.h" + +namespace DynamicRank +{ + +// A Feature is defined by any of the following three parameters +// FeatureName : The Name of the Feature +// FeatureIndex : The Index of the Feature according to the FeatureMap +// FeatureName <-> FeatureIndex conversion is a combination of the above two +class IFeatureMap +{ +public: + // Destructor + virtual ~IFeatureMap() {} + + // Convert the FeatureName to the FeatureIndex. + virtual bool ObtainFeatureIndex(const char* p_featureName, + UInt32& p_featureIndex) = 0; + + // Convert the FeatureName to the FeatureIndex. + virtual bool ObtainFeatureIndex(const SIZED_STRING& p_featureName, + UInt32& p_featureIndex) = 0; + + // Convert the FeatureIndex to the FeatureName. + virtual bool GetFeatureName(UInt32 p_featureIndex, + char* p_featureName, + UInt32 p_maxNameLength) const = 0; + + // The number of features within the map. + virtual UInt32 GetNumberOfFeatures() const = 0; +}; +} diff --git a/src/transform/NeuralTree.Library/inc/INeuralNetFeatures.h b/src/transform/NeuralTree.Library/inc/INeuralNetFeatures.h new file mode 100644 index 000000000000..20893093f926 --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/INeuralNetFeatures.h @@ -0,0 +1,21 @@ +#pragma once +#include "IFeatureMap.h" + +namespace DynamicRank +{ + +// Interface to apply custom process on all features input. This is especially useful for +// on-demand feature extraction. +class INeuralNetFeatures +{ +public: + // Process a given feature index. The feature index is the value returned from a FeatureMap. + virtual void ProcessFeature(UInt32 p_featureIndex) = 0; + + // Same as above but will also pass in a vector of strings for all the segments. + // An empty vector means that the feature is part of the main L2 ranker. + virtual void ProcessFeature(UInt32 p_featureIndex, const std::vector& p_segments) = 0; +}; + +} + diff --git a/src/transform/NeuralTree.Library/inc/MigratedApi.h b/src/transform/NeuralTree.Library/inc/MigratedApi.h new file mode 100644 index 000000000000..f85d0340b31b --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/MigratedApi.h @@ -0,0 +1,19 @@ +#pragma once + +template +constexpr std::size_t countof(T const (&)[N]) noexcept +{ + return N; +} + +static inline +unsigned int rotl (const unsigned int x, int bits) +{ + const unsigned int n = ((bits % 32) + 32) % 32; + return (x << n) | (x >> (32 - n)); +} + +#ifdef _snprintf_s +#undef _snprintf_s +#endif +#define _snprintf_s(a,b,c,...) snprintf(a,b,__VA_ARGS__) diff --git a/src/transform/NeuralTree.Library/inc/NeuralInput.h b/src/transform/NeuralTree.Library/inc/NeuralInput.h new file mode 100644 index 000000000000..c671a062d065 --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/NeuralInput.h @@ -0,0 +1,633 @@ +#pragma once + +#include "basic_types.h" +#include "IFeatureMap.h" +#include "FeaSpecConfig.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "MigratedApi.h" + + +class NeuralInputTest; + + +namespace DynamicRank +{ + +struct UnionBondInput; +struct FreeForm2CodeBondData; +struct NeuralInputBondData; +struct NeuralInputUnaryBondData; +struct NeuralInputLinearBondData; +struct NeuralInputLogLinearBondData; +struct NeuralInputRationalBondData; +struct NeuralInputBucketBondData; +struct NeuralInputTanhBondData; +struct NeuralInputTanhUnaryBondData; +struct NeuralInputUseAsFloatBondData; + +// Base class representing an input value for a neural net +class NeuralInput : boost::noncopyable +{ +protected: + NeuralInput(); + + // Get the size of external memory owned by this object. + size_t GetExternalSize() const; + +public: + + // Construct from Bond. + explicit NeuralInput(const NeuralInputBondData& p_data); + + virtual ~NeuralInput(); + + // Just for special batch serializaton of freeform2. + virtual bool IsFreeForm2() const { return false; }; + + virtual void BatchSerialize(const std::vector& /*p_inputs*/, FreeForm2CodeBondData& /*p_blob*/) const {}; + + virtual void BatchUnSerialize(const std::vector& /*p_inputs*/, const FreeForm2CodeBondData& /*p_blob*/) const {}; + + // Fill the correct input name and data field. + virtual void FillBond(UnionBondInput& p_data) const = 0; + + // Fill bond structure of NueralInput class. + void FillBondData(NeuralInputBondData& p_data) const; + + virtual double Evaluate(UInt32 input[]) const = 0; + virtual double EvaluateInput(UInt32 input) const; + + // Evaluate an input for a document in all documents context. + virtual double Evaluate(UInt32** p_featureVectorArray, + UInt32 p_currentDocument, + UInt32 p_documentCount) const; + + // Get the minimum and maximum possible outputs for this input node. + virtual double GetMin() const = 0; + virtual double GetMax() const = 0; + + virtual UInt32 GetAssociatedFeature() const; + virtual void GetAllAssociatedFeatures(std::vector& associatedFeaturesList) const = 0; + + virtual double GetSlope() const; + virtual double GetIntercept() const; + + virtual bool Save(FILE *fpOutput, size_t nInputId,const IFeatureMap& p_featureMap) const; + + // Is same Input. + virtual bool Equal(const NeuralInput* p_input) const = 0; + + // Compare the internal members are equal. + // This is not virtual and just compare the base class members. + bool EqualInternal(const NeuralInput* p_input) const; + + void SetSegments(const std::vector& p_segments); + + virtual bool Train(double dblLearningRate, double outputHigh, + double outputLow, double dblOutputDelta, + UInt32 inputHigh[], UInt32 inputLow[]); + + // load Segments into m_segments + void LoadSegments(DynamicRank::Config& p_config, const char *szSection); + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + virtual size_t GetSize() const = 0; + + // Get the list of segments this input belongs to. Empty means it belongs to base ranker. + const std::vector& GetSegments() const; + +private: + // Segments to which this input applies; a size zero vector means the input is part of the main L2 ranker + boost::scoped_ptr > m_segments; + + friend class NeuralInputTest; +}; + + +// An exception that is thrown when trying to bulk-optimize a list of NeuralInputs if at least +// one of them is not compatible with bulk optimization. +class BulkOptimizationUnsupported : public std::runtime_error +{ +public: + explicit BulkOptimizationUnsupported(const std::string& p_message); +}; + + +// Base class that provides a replacement function to evaluate a set of neural inputs. +class BulkNeuralInput : boost::noncopyable +{ +protected: + BulkNeuralInput(); + +public: + + virtual ~BulkNeuralInput(); + + // Fill the correct input name and data field. + virtual void FillBond(UnionBondInput& p_data) const = 0; + + virtual bool Equal(const BulkNeuralInput* p_other) const = 0; + + // Add all the features referenced by this BulkNeuralInput object into the + // INeuralNetFetures reference. + /* virtual void AddNeuralNetFeatures(INeuralNetFeatures& p_neuralNetFeatures) const = 0; */ + + // Evaluates the neural inputs and places the result of their evaluation in the + // corresponding indices in the p_output array. + virtual void Evaluate(const UInt32 p_input[], float p_output[]) const = 0; + + // Evaluates the neural inputs in all documents context and places the result + // of their evalution in the corresponding indices in the p_output array. + virtual void Evaluate(UInt32** p_featureVectorArray, + UInt32 p_currentDocument, + UInt32 p_documentCount, + float p_output[]) const; + + // Populates p_associatedFeaturesList with all the feature indices that this + // BulkNeuralInput object requires. + virtual void GetAllAssociatedFeatures(std::vector& p_associatedFeaturesList) const = 0; + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + virtual size_t GetSize() const = 0; +}; + + +// Base class representing an input value for a neural net +// that is transformed by a function that can be expressed as a function +// with only 1 integer input +class NeuralInputUnary : public NeuralInput +{ +public: + + // Construct from Bond. + explicit NeuralInputUnary(const NeuralInputUnaryBondData& p_data); + + // Fill the correct input name and data field. + void FillBond(UnionBondInput& /*p_data*/) const {}; + + // Fill bond structure of NeuralInputUnary class. + void FillBondData(NeuralInputUnaryBondData& p_data) const; + + // Copy all members. + void CopyFrom(const NeuralInputUnary& p_neuralInputUnary); + + static bool ReadAssociatedFeature(DynamicRank::Config& p_config, + const char *szSection, + IFeatureMap& p_featureMap, + UInt32 *piFeature); + + UInt32 GetAssociatedFeature() const; + void GetAllAssociatedFeatures(std::vector& associatedFeaturesList) const; + double Evaluate(UInt32 input[]) const; + + bool Save(FILE *fpOutput, size_t nInputId,const IFeatureMap& p_featureMap) const; + + // Check objects are equal. + virtual bool Equal(const NeuralInput* p_input) const; + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + size_t GetSize() const; + + // Get feature. + UInt32 GetFeature() const; + +protected: + + // Default constructor for unit tests. + NeuralInputUnary(); + + NeuralInputUnary(int iFeature); + + NeuralInputUnary(int iFeature, IFeatureMap& p_featureMap); + + // Get the size of external memory owned by this object. + size_t GetExternalSize() const; + + // Index of this input in the input vector + UInt32 m_iFeature; +}; + + +// Linear input, using only a slope and an intercept +class NeuralInputLinear : public NeuralInputUnary +{ +protected: + double m_slope; + double m_intercept; + + // Default constructor. + NeuralInputLinear(); + + NeuralInputLinear(int id, double slope, double intercept); + + // Get the size of external memory owned by this object. + size_t GetExternalSize() const; + +public: + + // Construct from Bond. + explicit NeuralInputLinear(const NeuralInputLinearBondData& p_data); + + // Fill the correct input name and data field. + void FillBond(UnionBondInput& p_data) const; + + // Fill bond structure of NeuralInputLinear class. + void FillBondData(NeuralInputLinearBondData& p_data) const; + + // Get member variables. + double GetSlope() const; + double GetIntercept() const; + + double GetMin() const; + double GetMax() const; + double EvaluateInput(UInt32 input) const; + bool Save(FILE *fpOutput, size_t nInputId, const IFeatureMap& p_featureMap) const; + + // Same input. + virtual bool Equal(const NeuralInput* p_input) const; + + static NeuralInputLinear *Load( + DynamicRank::Config& p_config, + const char *szSection, + IFeatureMap& p_featureMap); + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + size_t GetSize() const; + +private: + + friend class NeuralInputTest; +}; + + +// Maps an input x to x/(c + x) where c is a damping factor +class NeuralInputRational : public NeuralInputUnary +{ +protected: + double m_dblDampingFactor; + NeuralInputRational(int p_id, double p_dblDampingFactor); + + // Get the size of external memory owned by this object. + size_t GetExternalSize() const; + +public: + + // Construct from Bond. + explicit NeuralInputRational(const NeuralInputRationalBondData& p_data); + + // Fill the correct input name and data field. + void FillBond(UnionBondInput& p_data) const; + + // Fill bond structure of NeuralInputRational class. + void FillBondData(NeuralInputRationalBondData& p_data) const; + + // Get member variables. + double GetDampingFactor() const; + + double GetMin() const; + double GetMax() const; + double EvaluateInput(UInt32 input) const; + bool Save(FILE *fpOutput, size_t nInputId, const IFeatureMap& p_featureMap) const; + + // Same input. + virtual bool Equal(const NeuralInput* p_input) const; + + static NeuralInputRational *Load( + DynamicRank::Config& p_config, + const char *szSection, + IFeatureMap& p_featureMap); + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + size_t GetSize() const; + +private: + + friend class NeuralInputTest; +}; + + +// LogLinear input, applying the log and then transforming +// using slope and intercept. +class NeuralInputLogLinear : public NeuralInputLinear +{ +protected: + NeuralInputLogLinear(int id, double slope, double intercept); + + // Get the size of external memory owned by this object. + size_t GetExternalSize() const; + +public: + + // Construct from Bond. + explicit NeuralInputLogLinear(const NeuralInputLogLinearBondData& p_data); + + // Fill the correct input name and data field. + void FillBond(UnionBondInput& p_data) const; + + // Fill bond structure of NeuralInputLogLinear class. + void FillBondData(NeuralInputLogLinearBondData& p_data) const; + + double EvaluateInput(UInt32 input) const; + bool Save(FILE *fpOutput, size_t nInputId, const IFeatureMap& p_featureMap) const; + + // Same input. + virtual bool Equal(const NeuralInput* p_input) const; + + + static NeuralInputLogLinear *Load( + DynamicRank::Config& p_config, + const char *szSection, + IFeatureMap& p_featureMap); + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + size_t GetSize() const; + +private: + + friend class NeuralInputTest; +}; + + +// Bucket Input, transforming the input to 0 or 1 depending on whether or input +// falls within defined bucket +class NeuralInputBucket : public NeuralInputUnary +{ +protected: + bool m_fMinInclusive; + bool m_fMaxInclusive; + + UInt32 m_nMinValue; + UInt32 m_nMaxValue; + + // Get the size of external memory owned by this object. + size_t GetExternalSize() const; + +public: + + NeuralInputBucket(int p_id, double p_min, bool p_mininclusive, double p_max, bool p_maxinclusive); + + // Construct from Bond. + explicit NeuralInputBucket(const NeuralInputBucketBondData& p_data); + + // Fill the correct input name and data field. + void FillBond(UnionBondInput& p_data) const; + + // Fill bond structure of NeuralInputBucket class. + void FillBondData(NeuralInputBucketBondData& p_data) const; + + double GetMin() const; + double GetMax() const; + + // Get member variables. + bool GetMinInclusive() const; + bool GetMaxInclusive() const; + UInt32 GetMinValue() const; + UInt32 GetMaxValue() const; + + double EvaluateInput(UInt32 input) const; + bool Save(FILE *fpOutput, size_t nInputId, const IFeatureMap& p_featureMap) const; + + // Same input. + virtual bool Equal(const NeuralInput* p_input) const; + + static NeuralInputBucket *Load( + DynamicRank::Config& p_config, + const char *szSection, + IFeatureMap& p_featureMap); + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + size_t GetSize() const; +}; + + +// A caching wrapper around an underlying input +class NeuralInputCached : public NeuralInputUnary +{ +protected: + size_t m_cacheSize; + boost::scoped_array m_resultCache; + boost::scoped_ptr m_input; + + NeuralInputCached(size_t nCacheSize, NeuralInputUnary *pChild); + + // Get the size of external memory owned by this object. + size_t GetExternalSize() const; + +public: + ~NeuralInputCached(); + + // Fill the correct input name and data field. + void FillBond(UnionBondInput& p_data) const; + + // Get the wrapped input. + const NeuralInputUnary* GetBaseInput() const; + + double GetMin() const; + double GetMax() const; + double EvaluateInput(UInt32 input) const; + bool Save(FILE *fpOutput, size_t nInputId,const IFeatureMap& p_featureMap) const; + + // Check objects are equal. + virtual bool Equal(const NeuralInput* p_input) const; + + bool Train(double dblLearningRate, double outputHigh, + double outputLow, double dblOutputDelta, + UInt32 inputHigh[], UInt32 inputLow[]); + + static NeuralInputUnary *Load(size_t nCacheSize, + NeuralInputUnary *pChild); + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + size_t GetSize() const; + +private: + + friend class NeuralInputTest; +}; + + +class NeuralInputTanh : public NeuralInput +{ +public: + + // Construct from Bond. + explicit NeuralInputTanh(const NeuralInputTanhBondData& p_data); + + // Fill the correct input name and data field. + void FillBond(UnionBondInput& p_data) const; + + // Fill bond structure of NeuralInputTanh class. + void FillBondData(NeuralInputTanhBondData& p_data) const; + + double GetMin() const; + double GetMax() const; + double Evaluate(UInt32 input[]) const; + bool Save(FILE *fpOutput, size_t nInputId, const IFeatureMap& p_featureMap) const; + bool Train(double dblLearningRate, double outputHigh, + double outputLow, double dblOutputDelta, + UInt32 inputHigh[], UInt32 inputLow[]); + + static NeuralInputTanh *Load( + DynamicRank::Config& p_config, + const char *szSection, + IFeatureMap& p_featureMap); + + // Check objects are equal. + virtual bool Equal(const NeuralInput* p_input) const; + + void GetAllAssociatedFeatures(std::vector& associatedFeaturesList) const; + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + size_t GetSize() const; + +protected: + + // Get the size of external memory owned by this object. + size_t GetExternalSize() const; + +private: + + NeuralInputTanh(); + + static const int c_maxInputs=30; + + size_t m_cInputs; + + bool m_locked; + + // Index of this input in the input vector. + UInt32 m_rgId[c_maxInputs]; + + double m_rgWeights[c_maxInputs]; + + double m_threshold; + + friend class NeuralInputTest; +}; + + +class NeuralInputTanhUnary : public NeuralInputUnary +{ +public: + + // Construct from Bond. + explicit NeuralInputTanhUnary(const NeuralInputTanhUnaryBondData& p_data); + + // Fill the correct input name and data field. + void FillBond(UnionBondInput& p_data) const; + + // Fill bond structure of NeuralInputTanhUnary class. + void FillBondData(NeuralInputTanhUnaryBondData& p_data) const; + + double GetMin() const; + double GetMax() const; + double EvaluateInput(UInt32 input) const; + bool Save(FILE *fpOutput, size_t nInputId, const IFeatureMap& p_featureMap) const; + + // Check objects are equal. + virtual bool Equal(const NeuralInput* p_input) const; + + bool Train(double dblLearningRate, double outputHigher, + double outputLower, double dblOutputDelta, + UInt32 inputHigh[], UInt32 inputLow[]); + + static NeuralInputTanhUnary *Load( + DynamicRank::Config& p_config, + const char *szSection, + IFeatureMap& p_featureMap); + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + size_t GetSize() const; + + // Get member variables. + double GetWeight() const; + double GetThreshold() const; + +protected: + + NeuralInputTanhUnary(UInt32 iFeature, double dblWeights, + double dblThreshold, bool fLocked); + + // Get the size of external memory owned by this object. + size_t GetExternalSize() const; + +private: + + bool m_fLocked; + + double m_dblWeights; + + double m_dblThreshold; + + friend class NeuralInputTest; +}; + + +class NeuralInputUseAsFloat : public NeuralInputUnary +{ +protected: + + // Default constructor. + NeuralInputUseAsFloat(); + + NeuralInputUseAsFloat(UInt32 iFeature); + + // Get the size of external memory owned by this object. + size_t GetExternalSize() const; + +public: + + // Construct from Bond. + explicit NeuralInputUseAsFloat(const NeuralInputUseAsFloatBondData& p_data); + + // Fill the correct input name and data field. + void FillBond(UnionBondInput& p_data) const; + + // Fill bond structure of NeuralInputUseAsFloat class. + void FillBondData(NeuralInputUseAsFloatBondData& p_data) const; + + // Copy all members. + void CopyFrom(const NeuralInputUseAsFloat& p_neuralInputUseAsFloat); + + double EvaluateInput(UInt32 input) const; + bool Save(FILE *fpOutput, size_t nInputId, const IFeatureMap& p_featureMap) const; + + // Check objects are equal. + virtual bool Equal(const NeuralInput* p_input) const; + + double GetMin() const; + double GetMax() const; + + + static NeuralInputUseAsFloat *Load( + DynamicRank::Config& p_config, + const char *szSection, + IFeatureMap& p_featureMap); + + // Get the size of this object, including internal and external memory + // (memory accessed through pointers or objects contain pointers e.g. std::string, std::vector, etc.). + size_t GetSize() const; + +private: + + friend class NeuralInputTest; +}; +} diff --git a/src/transform/NeuralTree.Library/inc/NeuralInputFactory.h b/src/transform/NeuralTree.Library/inc/NeuralInputFactory.h new file mode 100644 index 000000000000..325a4ebc4c21 --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/NeuralInputFactory.h @@ -0,0 +1,121 @@ +#pragma once + +#include +#include "NeuralInput.h" +#include +#include + +namespace DynamicRank +{ + +class BulkNeuralInput; +class NeuralNet; +class NeuralInput; +class IFeatureMap; +struct UnionBondInput; + +// NeuralInputFactory loads neural inputs from config files using a +// registration paradigm, in order to support different types of inputs for +// different situations. +class NeuralInputFactory : private boost::noncopyable +{ +public: + // Shortcut for configuration type. + typedef DynamicRank::Config IConfiguration; + + // Class that knows how to load a particular transform. + class Loader + { + public: + // Shortcut definition for shared pointer of this type. + typedef boost::shared_ptr Ptr; + + virtual ~Loader(); + + // Functor to create a NeuralInput given appropriate inputs. + virtual NeuralInput* operator()(IConfiguration& p_config, + const char* p_section, + IFeatureMap& p_featureMap) const = 0; + + // Use for all NeuralInputs. + virtual NeuralInput* FromBond(const UnionBondInput& p_data) const = 0; + + // Used for BulkNeuralInput. + //virtual BulkNeuralInput* FromBulkBond(const UnionBondInput& p_data) const = 0; + }; + + + // Create a neural input factory that knows how to create a variety of NeuralInput classes. + NeuralInputFactory(); + + virtual ~NeuralInputFactory(); + + typedef NeuralInput* (*LoadFunction)(IConfiguration& p_config, + const char* p_section, + IFeatureMap& p_featureMap); + + + // Add registration for a particular transform, throwing exceptions if + // registration fails, or if p_loader is NULL. + void AddTransform(const char* p_transform, LoadFunction p_loader, bool p_replace = false); + + // Add registration for a particular transform, throwing exceptions if + // registration fails, or if p_loader is NULL. + void AddTransform(const char* p_transform, Loader::Ptr p_loader, bool p_replace = false); + + NeuralInput* FromBond(const UnionBondInput& p_data) const; + + //BulkNeuralInput* FromBulkBond(const UnionBondInput& p_data) const; + + // Remove all transforms from the neural input factory. + void ClearTransforms(); + + // Load a transform, returning NULL if loading fails. + NeuralInput* Load(const char* p_transform, + IConfiguration& p_config, + const char* p_section, + IFeatureMap& p_featureMap) const; + + // Read the transform name for an input and load using Load function above. + NeuralInput* Load(IConfiguration& p_config, + int p_ID, + IFeatureMap& p_featureMap) const; + + // Small template class to adapt functions returning subtypes of + // NeuralInput to return NeuralInput. + template + static NeuralInput* LoadAdapt(IConfiguration& p_config, + const char* p_section, + IFeatureMap& p_featureMap) + { + return fun(p_config, p_section, p_featureMap); + } + +private: + + // Map of neural input transform names to loading functions. + typedef std::map TransformMap; + TransformMap m_transform; +}; + +// BulkNeuralInputFactory converts a set of NeuralInput objects into a +// BulkNeuralInput that produces the same effect as evaluating all the +// inputs separately. +class BulkNeuralInputFactory : public boost::noncopyable +{ +public: + // A mapping from NeuralInput to offset in the output array. + typedef std::pair InputAndIndex; + + // Create an instance of BulkNeuralInput from a list of NeuralInput objects and their + // offset in the output array. Throws an exception if at least one of the NeuralInput + // objects does not support bulk optimization. + virtual std::unique_ptr ConvertToBulkInput(const std::vector& p_inputs, + IFeatureMap& p_featureMap) const = 0; + +}; + +} diff --git a/src/transform/NeuralTree.Library/inc/NeuralInputFreeForm2_types.h b/src/transform/NeuralTree.Library/inc/NeuralInputFreeForm2_types.h new file mode 100644 index 000000000000..d1e5cf030daf --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/NeuralInputFreeForm2_types.h @@ -0,0 +1,232 @@ + +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Tool : bondc, Version=3.0.1, Build=bond-git.retail.0 +// Template : Microsoft.Bond.Rules.dll#Rules_BOND_CPP.tt +// File : NeuralInputFreeForm2_types.h +// +// Changes to this file may cause incorrect behavior and will be lost when +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +#pragma once + + +#include +#if BOND_MAJOR_VERSION_MIN_SUPPORTED > 3 \ + || (BOND_MAJOR_VERSION_MIN_SUPPORTED == 3 && BOND_MINOR_VERSION_MIN_SUPPORTED > 1) +#error This file was generated by an older Bond compiler which is \ + incompatible with current Bond library. Please regenerate \ + with the latest Bond compiler. +#endif + +#include +#include +#include "NeuralInput_types.h" + +namespace DynamicRank +{ + +// This is per ranker level info: it merge all freeform2 code. +struct FreeForm2CodeBondData +{ + // 1: optional vector m_offsets + std::vector m_offsets; + + // 2: optional string m_code + std::string m_code; + + FreeForm2CodeBondData() + { + } + + + // Compiler generated copy ctor OK +#ifndef BOND_NO_CXX11_DEFAULTED_FUNCTIONS + FreeForm2CodeBondData(const FreeForm2CodeBondData& /*_bond_rhs*/) = default; +#endif + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + FreeForm2CodeBondData(FreeForm2CodeBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< std::vector >::value + && std::is_nothrow_move_constructible< std::string >::value + )) + : m_offsets(std::move(_bond_rhs.m_offsets)), + m_code(std::move(_bond_rhs.m_code)) + { + } +#endif + + + template + explicit + FreeForm2CodeBondData(Allocator* _bond_allocator) + : m_offsets(*_bond_allocator), + m_code(*_bond_allocator) + { + } + + + // Compiler generated operator= OK +#ifndef BOND_NO_CXX11_DEFAULTED_FUNCTIONS + FreeForm2CodeBondData& operator=(const FreeForm2CodeBondData& _bond_rhs) = default; +#endif + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + FreeForm2CodeBondData& operator=(FreeForm2CodeBondData&& _bond_rhs) + { + FreeForm2CodeBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const FreeForm2CodeBondData& _bond_other) const + { + return true + && (m_offsets == _bond_other.m_offsets) + && (m_code == _bond_other.m_code); + } + + + bool operator!=(const FreeForm2CodeBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(FreeForm2CodeBondData& _bond_other) + { + using std::swap; + swap(m_offsets, _bond_other.m_offsets); + swap(m_code, _bond_other.m_code); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* /*_bond_name*/, const char* /*_bond_full_name*/) + { + } +}; + + +inline void swap(FreeForm2CodeBondData& _bond_left, FreeForm2CodeBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// NeuralInputFreeForm2BondData +struct NeuralInputFreeForm2BondData : public ::DynamicRank::NeuralInputBondData +{ + // 1: optional vector m_features + std::vector m_features; + + // 2: optional string m_input + // Source code. In production case we want to not save it to save space. + std::string m_input; + + NeuralInputFreeForm2BondData() { + InitMetadata("NeuralInputFreeForm2BondData", "DynamicRank.NeuralInputFreeForm2BondData"); + } + + + NeuralInputFreeForm2BondData(const NeuralInputFreeForm2BondData& _bond_rhs) + : ::DynamicRank::NeuralInputBondData(_bond_rhs), + m_features(_bond_rhs.m_features), + m_input(_bond_rhs.m_input) { + InitMetadata("NeuralInputFreeForm2BondData", "DynamicRank.NeuralInputFreeForm2BondData"); + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputFreeForm2BondData(NeuralInputFreeForm2BondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< ::DynamicRank::NeuralInputBondData >::value + && std::is_nothrow_move_constructible< std::vector >::value + && std::is_nothrow_move_constructible< std::string >::value + )) + : ::DynamicRank::NeuralInputBondData(std::move(_bond_rhs)), + m_features(std::move(_bond_rhs.m_features)), + m_input(std::move(_bond_rhs.m_input)) + { + } +#endif + + + template + explicit + NeuralInputFreeForm2BondData(Allocator* _bond_allocator) + : ::DynamicRank::NeuralInputBondData(_bond_allocator), + m_features(*_bond_allocator), + m_input(*_bond_allocator) { + InitMetadata("NeuralInputFreeForm2BondData", "DynamicRank.NeuralInputFreeForm2BondData"); + } + + + // Compiler generated operator= OK +#ifndef BOND_NO_CXX11_DEFAULTED_FUNCTIONS + NeuralInputFreeForm2BondData& operator=(const NeuralInputFreeForm2BondData& _bond_rhs) = default; +#endif + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputFreeForm2BondData& operator=(NeuralInputFreeForm2BondData&& _bond_rhs) + { + NeuralInputFreeForm2BondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const NeuralInputFreeForm2BondData& _bond_other) const + { + return true + && (static_cast(*this) == static_cast(_bond_other)) + && (m_features == _bond_other.m_features) + && (m_input == _bond_other.m_input); + } + + + bool operator!=(const NeuralInputFreeForm2BondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(NeuralInputFreeForm2BondData& _bond_other) + { + using std::swap; + ::DynamicRank::NeuralInputBondData::swap(_bond_other); + swap(m_features, _bond_other.m_features); + swap(m_input, _bond_other.m_input); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* _bond_name, const char* _bond_full_name) + { + ::DynamicRank::NeuralInputBondData::InitMetadata(_bond_name, _bond_full_name); + } +}; + + +inline void swap(NeuralInputFreeForm2BondData& _bond_left, NeuralInputFreeForm2BondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +} // namespace DynamicRank diff --git a/src/transform/NeuralTree.Library/inc/NeuralInput_types.h b/src/transform/NeuralTree.Library/inc/NeuralInput_types.h new file mode 100644 index 000000000000..f64c1e8bca58 --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/NeuralInput_types.h @@ -0,0 +1,1261 @@ + +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Tool : bondc, Version=3.0.1, Build=bond-git.retail.0 +// Template : Microsoft.Bond.Rules.dll#Rules_BOND_CPP.tt +// File : NeuralInput_types.h +// +// Changes to this file may cause incorrect behavior and will be lost when +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +#pragma once + + +#include +#if BOND_MAJOR_VERSION_MIN_SUPPORTED > 3 \ + || (BOND_MAJOR_VERSION_MIN_SUPPORTED == 3 && BOND_MINOR_VERSION_MIN_SUPPORTED > 1) +#error This file was generated by an older Bond compiler which is \ + incompatible with current Bond library. Please regenerate \ + with the latest Bond compiler. +#endif + +#include +#include +#include + +namespace DynamicRank +{ + +// NeuralInputBondData +struct NeuralInputBondData +{ + // 1: required_optional bond_meta::full_name name + std::string name; + + // 2: optional nullable> m_segments + bond::nullable > m_segments; + + NeuralInputBondData() { + InitMetadata("NeuralInputBondData", "DynamicRank.NeuralInputBondData"); + } + + + NeuralInputBondData(const NeuralInputBondData& _bond_rhs) + : name(_bond_rhs.name.get_allocator()), + m_segments(_bond_rhs.m_segments) { + InitMetadata("NeuralInputBondData", "DynamicRank.NeuralInputBondData"); + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputBondData(NeuralInputBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< std::string >::value + && std::is_nothrow_move_constructible< bond::nullable > >::value + )) + : name(std::move(_bond_rhs.name)), + m_segments(std::move(_bond_rhs.m_segments)) + { + } +#endif + + + template + explicit + NeuralInputBondData(Allocator* _bond_allocator) + : name(*_bond_allocator), + m_segments(*_bond_allocator) { + InitMetadata("NeuralInputBondData", "DynamicRank.NeuralInputBondData"); + } + + + NeuralInputBondData& operator=(const NeuralInputBondData& _bond_rhs) + { + NeuralInputBondData(_bond_rhs).swap(*this); + return *this; + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputBondData& operator=(NeuralInputBondData&& _bond_rhs) + { + NeuralInputBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const NeuralInputBondData& _bond_other) const + { + return true + // skip bond_meta::full_name field 'name' + && (m_segments == _bond_other.m_segments); + } + + + bool operator!=(const NeuralInputBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(NeuralInputBondData& _bond_other) + { + using std::swap; + // skip bond_meta::full_name field 'name' + swap(m_segments, _bond_other.m_segments); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* /*_bond_name*/, const char* _bond_full_name) + { + this->name = _bond_full_name; + } +}; + + +inline void swap(NeuralInputBondData& _bond_left, NeuralInputBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// BulkNeuralInputBondData +struct BulkNeuralInputBondData +{ + // 1: optional vector m_features + std::vector m_features; + + // 2: optional string m_code + std::string m_code; + + // 3: optional uint32 m_input + uint32_t m_input; + + BulkNeuralInputBondData() + : m_input() + { + } + + + // Compiler generated copy ctor OK +#ifndef BOND_NO_CXX11_DEFAULTED_FUNCTIONS + BulkNeuralInputBondData(const BulkNeuralInputBondData& /*_bond_rhs*/) = default; +#endif + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + BulkNeuralInputBondData(BulkNeuralInputBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< std::vector >::value + && std::is_nothrow_move_constructible< std::string >::value + && std::is_nothrow_move_constructible< uint32_t >::value + )) + : m_features(std::move(_bond_rhs.m_features)), + m_code(std::move(_bond_rhs.m_code)), + m_input(std::move(_bond_rhs.m_input)) + { + } +#endif + + + template + explicit + BulkNeuralInputBondData(Allocator* _bond_allocator) + : m_features(*_bond_allocator), + m_code(*_bond_allocator), + m_input() + { + } + + + // Compiler generated operator= OK +#ifndef BOND_NO_CXX11_DEFAULTED_FUNCTIONS + BulkNeuralInputBondData& operator=(const BulkNeuralInputBondData& _bond_rhs) = default; +#endif + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + BulkNeuralInputBondData& operator=(BulkNeuralInputBondData&& _bond_rhs) + { + BulkNeuralInputBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const BulkNeuralInputBondData& _bond_other) const + { + return true + && (m_features == _bond_other.m_features) + && (m_code == _bond_other.m_code) + && (m_input == _bond_other.m_input); + } + + + bool operator!=(const BulkNeuralInputBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(BulkNeuralInputBondData& _bond_other) + { + using std::swap; + swap(m_features, _bond_other.m_features); + swap(m_code, _bond_other.m_code); + swap(m_input, _bond_other.m_input); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* /*_bond_name*/, const char* /*_bond_full_name*/) + { + } +}; + + +inline void swap(BulkNeuralInputBondData& _bond_left, BulkNeuralInputBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// NeuralInputUnaryBondData +struct NeuralInputUnaryBondData : public ::DynamicRank::NeuralInputBondData +{ + // 1: required_optional bond_meta::full_name name + std::string name; + + // 2: optional uint32 m_iFeature + uint32_t m_iFeature; + + NeuralInputUnaryBondData() + : m_iFeature() { + InitMetadata("NeuralInputUnaryBondData", "DynamicRank.NeuralInputUnaryBondData"); + } + + + NeuralInputUnaryBondData(const NeuralInputUnaryBondData& _bond_rhs) + : ::DynamicRank::NeuralInputBondData(_bond_rhs), + name(_bond_rhs.name.get_allocator()), + m_iFeature(_bond_rhs.m_iFeature) { + InitMetadata("NeuralInputUnaryBondData", "DynamicRank.NeuralInputUnaryBondData"); + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputUnaryBondData(NeuralInputUnaryBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< ::DynamicRank::NeuralInputBondData >::value + && std::is_nothrow_move_constructible< std::string >::value + && std::is_nothrow_move_constructible< uint32_t >::value + )) + : ::DynamicRank::NeuralInputBondData(std::move(_bond_rhs)), + name(std::move(_bond_rhs.name)), + m_iFeature(std::move(_bond_rhs.m_iFeature)) + { + } +#endif + + + template + explicit + NeuralInputUnaryBondData(Allocator* _bond_allocator) + : ::DynamicRank::NeuralInputBondData(_bond_allocator), + name(*_bond_allocator), + m_iFeature() { + InitMetadata("NeuralInputUnaryBondData", "DynamicRank.NeuralInputUnaryBondData"); + } + + + NeuralInputUnaryBondData& operator=(const NeuralInputUnaryBondData& _bond_rhs) + { + NeuralInputUnaryBondData(_bond_rhs).swap(*this); + return *this; + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputUnaryBondData& operator=(NeuralInputUnaryBondData&& _bond_rhs) + { + NeuralInputUnaryBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const NeuralInputUnaryBondData& _bond_other) const + { + return true + && (static_cast(*this) == static_cast(_bond_other)) + // skip bond_meta::full_name field 'name' + && (m_iFeature == _bond_other.m_iFeature); + } + + + bool operator!=(const NeuralInputUnaryBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(NeuralInputUnaryBondData& _bond_other) + { + using std::swap; + ::DynamicRank::NeuralInputBondData::swap(_bond_other); + // skip bond_meta::full_name field 'name' + swap(m_iFeature, _bond_other.m_iFeature); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* _bond_name, const char* _bond_full_name) + { + ::DynamicRank::NeuralInputBondData::InitMetadata(_bond_name, _bond_full_name); + this->name = _bond_full_name; + } +}; + + +inline void swap(NeuralInputUnaryBondData& _bond_left, NeuralInputUnaryBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// NeuralInputLinearBondData +struct NeuralInputLinearBondData : public ::DynamicRank::NeuralInputUnaryBondData +{ + // 1: required_optional bond_meta::full_name name + std::string name; + + // 2: optional double m_slope + double m_slope; + + // 3: optional double m_intercept + double m_intercept; + + NeuralInputLinearBondData() + : m_slope(), + m_intercept() { + InitMetadata("NeuralInputLinearBondData", "DynamicRank.NeuralInputLinearBondData"); + } + + + NeuralInputLinearBondData(const NeuralInputLinearBondData& _bond_rhs) + : ::DynamicRank::NeuralInputUnaryBondData(_bond_rhs), + name(_bond_rhs.name.get_allocator()), + m_slope(_bond_rhs.m_slope), + m_intercept(_bond_rhs.m_intercept) { + InitMetadata("NeuralInputLinearBondData", "DynamicRank.NeuralInputLinearBondData"); + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputLinearBondData(NeuralInputLinearBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< ::DynamicRank::NeuralInputUnaryBondData >::value + && std::is_nothrow_move_constructible< std::string >::value + && std::is_nothrow_move_constructible< double >::value + )) + : ::DynamicRank::NeuralInputUnaryBondData(std::move(_bond_rhs)), + name(std::move(_bond_rhs.name)), + m_slope(std::move(_bond_rhs.m_slope)), + m_intercept(std::move(_bond_rhs.m_intercept)) + { + } +#endif + + + template + explicit + NeuralInputLinearBondData(Allocator* _bond_allocator) + : ::DynamicRank::NeuralInputUnaryBondData(_bond_allocator), + name(*_bond_allocator), + m_slope(), + m_intercept() { + InitMetadata("NeuralInputLinearBondData", "DynamicRank.NeuralInputLinearBondData"); + } + + + NeuralInputLinearBondData& operator=(const NeuralInputLinearBondData& _bond_rhs) + { + NeuralInputLinearBondData(_bond_rhs).swap(*this); + return *this; + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputLinearBondData& operator=(NeuralInputLinearBondData&& _bond_rhs) + { + NeuralInputLinearBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const NeuralInputLinearBondData& _bond_other) const + { + return true + && (static_cast(*this) == static_cast(_bond_other)) + // skip bond_meta::full_name field 'name' + && (m_slope == _bond_other.m_slope) + && (m_intercept == _bond_other.m_intercept); + } + + + bool operator!=(const NeuralInputLinearBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(NeuralInputLinearBondData& _bond_other) + { + using std::swap; + ::DynamicRank::NeuralInputUnaryBondData::swap(_bond_other); + // skip bond_meta::full_name field 'name' + swap(m_slope, _bond_other.m_slope); + swap(m_intercept, _bond_other.m_intercept); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* _bond_name, const char* _bond_full_name) + { + ::DynamicRank::NeuralInputUnaryBondData::InitMetadata(_bond_name, _bond_full_name); + this->name = _bond_full_name; + } +}; + + +inline void swap(NeuralInputLinearBondData& _bond_left, NeuralInputLinearBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// NeuralInputLogLinearBondData +struct NeuralInputLogLinearBondData : public ::DynamicRank::NeuralInputLinearBondData +{ + // 1: required_optional bond_meta::full_name name + std::string name; + + NeuralInputLogLinearBondData() { + InitMetadata("NeuralInputLogLinearBondData", "DynamicRank.NeuralInputLogLinearBondData"); + } + + + NeuralInputLogLinearBondData(const NeuralInputLogLinearBondData& _bond_rhs) + : ::DynamicRank::NeuralInputLinearBondData(_bond_rhs), + name(_bond_rhs.name.get_allocator()) { + InitMetadata("NeuralInputLogLinearBondData", "DynamicRank.NeuralInputLogLinearBondData"); + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputLogLinearBondData(NeuralInputLogLinearBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< ::DynamicRank::NeuralInputLinearBondData >::value + && std::is_nothrow_move_constructible< std::string >::value + )) + : ::DynamicRank::NeuralInputLinearBondData(std::move(_bond_rhs)), + name(std::move(_bond_rhs.name)) + { + } +#endif + + + template + explicit + NeuralInputLogLinearBondData(Allocator* _bond_allocator) + : ::DynamicRank::NeuralInputLinearBondData(_bond_allocator), + name(*_bond_allocator) { + InitMetadata("NeuralInputLogLinearBondData", "DynamicRank.NeuralInputLogLinearBondData"); + } + + + NeuralInputLogLinearBondData& operator=(const NeuralInputLogLinearBondData& _bond_rhs) + { + NeuralInputLogLinearBondData(_bond_rhs).swap(*this); + return *this; + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputLogLinearBondData& operator=(NeuralInputLogLinearBondData&& _bond_rhs) + { + NeuralInputLogLinearBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const NeuralInputLogLinearBondData& _bond_other) const + { + return true + && (static_cast(*this) == static_cast(_bond_other)); + // skip bond_meta::full_name field 'name' + } + + + bool operator!=(const NeuralInputLogLinearBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(NeuralInputLogLinearBondData& _bond_other) + { + using std::swap; + ::DynamicRank::NeuralInputLinearBondData::swap(_bond_other); + // skip bond_meta::full_name field 'name' + } + + + struct Schema; + + +protected: + void InitMetadata(const char* _bond_name, const char* _bond_full_name) + { + ::DynamicRank::NeuralInputLinearBondData::InitMetadata(_bond_name, _bond_full_name); + this->name = _bond_full_name; + } +}; + + +inline void swap(NeuralInputLogLinearBondData& _bond_left, NeuralInputLogLinearBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// NeuralInputRationalBondData +struct NeuralInputRationalBondData : public ::DynamicRank::NeuralInputUnaryBondData +{ + // 1: required_optional bond_meta::full_name name + std::string name; + + // 2: optional double m_dblDampingFactor + double m_dblDampingFactor; + + NeuralInputRationalBondData() + : m_dblDampingFactor() { + InitMetadata("NeuralInputRationalBondData", "DynamicRank.NeuralInputRationalBondData"); + } + + + NeuralInputRationalBondData(const NeuralInputRationalBondData& _bond_rhs) + : ::DynamicRank::NeuralInputUnaryBondData(_bond_rhs), + name(_bond_rhs.name.get_allocator()), + m_dblDampingFactor(_bond_rhs.m_dblDampingFactor) { + InitMetadata("NeuralInputRationalBondData", "DynamicRank.NeuralInputRationalBondData"); + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputRationalBondData(NeuralInputRationalBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< ::DynamicRank::NeuralInputUnaryBondData >::value + && std::is_nothrow_move_constructible< std::string >::value + && std::is_nothrow_move_constructible< double >::value + )) + : ::DynamicRank::NeuralInputUnaryBondData(std::move(_bond_rhs)), + name(std::move(_bond_rhs.name)), + m_dblDampingFactor(std::move(_bond_rhs.m_dblDampingFactor)) + { + } +#endif + + + template + explicit + NeuralInputRationalBondData(Allocator* _bond_allocator) + : ::DynamicRank::NeuralInputUnaryBondData(_bond_allocator), + name(*_bond_allocator), + m_dblDampingFactor() { + InitMetadata("NeuralInputRationalBondData", "DynamicRank.NeuralInputRationalBondData"); + } + + + NeuralInputRationalBondData& operator=(const NeuralInputRationalBondData& _bond_rhs) + { + NeuralInputRationalBondData(_bond_rhs).swap(*this); + return *this; + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputRationalBondData& operator=(NeuralInputRationalBondData&& _bond_rhs) + { + NeuralInputRationalBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const NeuralInputRationalBondData& _bond_other) const + { + return true + && (static_cast(*this) == static_cast(_bond_other)) + // skip bond_meta::full_name field 'name' + && (m_dblDampingFactor == _bond_other.m_dblDampingFactor); + } + + + bool operator!=(const NeuralInputRationalBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(NeuralInputRationalBondData& _bond_other) + { + using std::swap; + ::DynamicRank::NeuralInputUnaryBondData::swap(_bond_other); + // skip bond_meta::full_name field 'name' + swap(m_dblDampingFactor, _bond_other.m_dblDampingFactor); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* _bond_name, const char* _bond_full_name) + { + ::DynamicRank::NeuralInputUnaryBondData::InitMetadata(_bond_name, _bond_full_name); + this->name = _bond_full_name; + } +}; + + +inline void swap(NeuralInputRationalBondData& _bond_left, NeuralInputRationalBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// NeuralInputBucketBondData +struct NeuralInputBucketBondData : public ::DynamicRank::NeuralInputUnaryBondData +{ + // 1: required_optional bond_meta::full_name name + std::string name; + + // 2: optional bool m_fMinInclusive + bool m_fMinInclusive; + + // 3: optional bool m_fMaxInclusive + bool m_fMaxInclusive; + + // 4: optional uint32 m_nMinValue + uint32_t m_nMinValue; + + // 5: optional uint32 m_nMaxValue + uint32_t m_nMaxValue; + + NeuralInputBucketBondData() + : m_fMinInclusive(), + m_fMaxInclusive(), + m_nMinValue(), + m_nMaxValue() { + InitMetadata("NeuralInputBucketBondData", "DynamicRank.NeuralInputBucketBondData"); + } + + + NeuralInputBucketBondData(const NeuralInputBucketBondData& _bond_rhs) + : ::DynamicRank::NeuralInputUnaryBondData(_bond_rhs), + name(_bond_rhs.name.get_allocator()), + m_fMinInclusive(_bond_rhs.m_fMinInclusive), + m_fMaxInclusive(_bond_rhs.m_fMaxInclusive), + m_nMinValue(_bond_rhs.m_nMinValue), + m_nMaxValue(_bond_rhs.m_nMaxValue) { + InitMetadata("NeuralInputBucketBondData", "DynamicRank.NeuralInputBucketBondData"); + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputBucketBondData(NeuralInputBucketBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< ::DynamicRank::NeuralInputUnaryBondData >::value + && std::is_nothrow_move_constructible< std::string >::value + && std::is_nothrow_move_constructible< bool >::value + && std::is_nothrow_move_constructible< uint32_t >::value + )) + : ::DynamicRank::NeuralInputUnaryBondData(std::move(_bond_rhs)), + name(std::move(_bond_rhs.name)), + m_fMinInclusive(std::move(_bond_rhs.m_fMinInclusive)), + m_fMaxInclusive(std::move(_bond_rhs.m_fMaxInclusive)), + m_nMinValue(std::move(_bond_rhs.m_nMinValue)), + m_nMaxValue(std::move(_bond_rhs.m_nMaxValue)) + { + } +#endif + + + template + explicit + NeuralInputBucketBondData(Allocator* _bond_allocator) + : ::DynamicRank::NeuralInputUnaryBondData(_bond_allocator), + name(*_bond_allocator), + m_fMinInclusive(), + m_fMaxInclusive(), + m_nMinValue(), + m_nMaxValue() { + InitMetadata("NeuralInputBucketBondData", "DynamicRank.NeuralInputBucketBondData"); + } + + + NeuralInputBucketBondData& operator=(const NeuralInputBucketBondData& _bond_rhs) + { + NeuralInputBucketBondData(_bond_rhs).swap(*this); + return *this; + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputBucketBondData& operator=(NeuralInputBucketBondData&& _bond_rhs) + { + NeuralInputBucketBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const NeuralInputBucketBondData& _bond_other) const + { + return true + && (static_cast(*this) == static_cast(_bond_other)) + // skip bond_meta::full_name field 'name' + && (m_fMinInclusive == _bond_other.m_fMinInclusive) + && (m_fMaxInclusive == _bond_other.m_fMaxInclusive) + && (m_nMinValue == _bond_other.m_nMinValue) + && (m_nMaxValue == _bond_other.m_nMaxValue); + } + + + bool operator!=(const NeuralInputBucketBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(NeuralInputBucketBondData& _bond_other) + { + using std::swap; + ::DynamicRank::NeuralInputUnaryBondData::swap(_bond_other); + // skip bond_meta::full_name field 'name' + swap(m_fMinInclusive, _bond_other.m_fMinInclusive); + swap(m_fMaxInclusive, _bond_other.m_fMaxInclusive); + swap(m_nMinValue, _bond_other.m_nMinValue); + swap(m_nMaxValue, _bond_other.m_nMaxValue); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* _bond_name, const char* _bond_full_name) + { + ::DynamicRank::NeuralInputUnaryBondData::InitMetadata(_bond_name, _bond_full_name); + this->name = _bond_full_name; + } +}; + + +inline void swap(NeuralInputBucketBondData& _bond_left, NeuralInputBucketBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// NeuralInputTanhBondData +struct NeuralInputTanhBondData : public ::DynamicRank::NeuralInputBondData +{ + // 1: required_optional bond_meta::full_name name + std::string name; + + // 2: optional uint32 m_cInputs + uint32_t m_cInputs; + + // 3: optional bool m_locked + bool m_locked; + + // 4: optional double m_threshold + double m_threshold; + + // 5: optional vector m_rgId + std::vector m_rgId; + + // 6: optional vector m_rgWeights + std::vector m_rgWeights; + + NeuralInputTanhBondData() + : m_cInputs(), + m_locked(), + m_threshold() { + InitMetadata("NeuralInputTanhBondData", "DynamicRank.NeuralInputTanhBondData"); + } + + + NeuralInputTanhBondData(const NeuralInputTanhBondData& _bond_rhs) + : ::DynamicRank::NeuralInputBondData(_bond_rhs), + name(_bond_rhs.name.get_allocator()), + m_cInputs(_bond_rhs.m_cInputs), + m_locked(_bond_rhs.m_locked), + m_threshold(_bond_rhs.m_threshold), + m_rgId(_bond_rhs.m_rgId), + m_rgWeights(_bond_rhs.m_rgWeights) { + InitMetadata("NeuralInputTanhBondData", "DynamicRank.NeuralInputTanhBondData"); + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputTanhBondData(NeuralInputTanhBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< ::DynamicRank::NeuralInputBondData >::value + && std::is_nothrow_move_constructible< std::string >::value + && std::is_nothrow_move_constructible< uint32_t >::value + && std::is_nothrow_move_constructible< bool >::value + && std::is_nothrow_move_constructible< double >::value + && std::is_nothrow_move_constructible< std::vector >::value + && std::is_nothrow_move_constructible< std::vector >::value + )) + : ::DynamicRank::NeuralInputBondData(std::move(_bond_rhs)), + name(std::move(_bond_rhs.name)), + m_cInputs(std::move(_bond_rhs.m_cInputs)), + m_locked(std::move(_bond_rhs.m_locked)), + m_threshold(std::move(_bond_rhs.m_threshold)), + m_rgId(std::move(_bond_rhs.m_rgId)), + m_rgWeights(std::move(_bond_rhs.m_rgWeights)) + { + } +#endif + + + template + explicit + NeuralInputTanhBondData(Allocator* _bond_allocator) + : ::DynamicRank::NeuralInputBondData(_bond_allocator), + name(*_bond_allocator), + m_cInputs(), + m_locked(), + m_threshold(), + m_rgId(*_bond_allocator), + m_rgWeights(*_bond_allocator) { + InitMetadata("NeuralInputTanhBondData", "DynamicRank.NeuralInputTanhBondData"); + } + + + NeuralInputTanhBondData& operator=(const NeuralInputTanhBondData& _bond_rhs) + { + NeuralInputTanhBondData(_bond_rhs).swap(*this); + return *this; + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputTanhBondData& operator=(NeuralInputTanhBondData&& _bond_rhs) + { + NeuralInputTanhBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const NeuralInputTanhBondData& _bond_other) const + { + return true + && (static_cast(*this) == static_cast(_bond_other)) + // skip bond_meta::full_name field 'name' + && (m_cInputs == _bond_other.m_cInputs) + && (m_locked == _bond_other.m_locked) + && (m_threshold == _bond_other.m_threshold) + && (m_rgId == _bond_other.m_rgId) + && (m_rgWeights == _bond_other.m_rgWeights); + } + + + bool operator!=(const NeuralInputTanhBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(NeuralInputTanhBondData& _bond_other) + { + using std::swap; + ::DynamicRank::NeuralInputBondData::swap(_bond_other); + // skip bond_meta::full_name field 'name' + swap(m_cInputs, _bond_other.m_cInputs); + swap(m_locked, _bond_other.m_locked); + swap(m_threshold, _bond_other.m_threshold); + swap(m_rgId, _bond_other.m_rgId); + swap(m_rgWeights, _bond_other.m_rgWeights); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* _bond_name, const char* _bond_full_name) + { + ::DynamicRank::NeuralInputBondData::InitMetadata(_bond_name, _bond_full_name); + this->name = _bond_full_name; + } +}; + + +inline void swap(NeuralInputTanhBondData& _bond_left, NeuralInputTanhBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// NeuralInputTanhUnaryBondData +struct NeuralInputTanhUnaryBondData : public ::DynamicRank::NeuralInputUnaryBondData +{ + // 1: required_optional bond_meta::full_name name + std::string name; + + // 2: optional bool m_fLocked + bool m_fLocked; + + // 3: optional double m_dblWeights + double m_dblWeights; + + // 4: optional double m_dblThreshold + double m_dblThreshold; + + NeuralInputTanhUnaryBondData() + : m_fLocked(), + m_dblWeights(), + m_dblThreshold() { + InitMetadata("NeuralInputTanhUnaryBondData", "DynamicRank.NeuralInputTanhUnaryBondData"); + } + + + NeuralInputTanhUnaryBondData(const NeuralInputTanhUnaryBondData& _bond_rhs) + : ::DynamicRank::NeuralInputUnaryBondData(_bond_rhs), + name(_bond_rhs.name.get_allocator()), + m_fLocked(_bond_rhs.m_fLocked), + m_dblWeights(_bond_rhs.m_dblWeights), + m_dblThreshold(_bond_rhs.m_dblThreshold) { + InitMetadata("NeuralInputTanhUnaryBondData", "DynamicRank.NeuralInputTanhUnaryBondData"); + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputTanhUnaryBondData(NeuralInputTanhUnaryBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< ::DynamicRank::NeuralInputUnaryBondData >::value + && std::is_nothrow_move_constructible< std::string >::value + && std::is_nothrow_move_constructible< bool >::value + && std::is_nothrow_move_constructible< double >::value + )) + : ::DynamicRank::NeuralInputUnaryBondData(std::move(_bond_rhs)), + name(std::move(_bond_rhs.name)), + m_fLocked(std::move(_bond_rhs.m_fLocked)), + m_dblWeights(std::move(_bond_rhs.m_dblWeights)), + m_dblThreshold(std::move(_bond_rhs.m_dblThreshold)) + { + } +#endif + + + template + explicit + NeuralInputTanhUnaryBondData(Allocator* _bond_allocator) + : ::DynamicRank::NeuralInputUnaryBondData(_bond_allocator), + name(*_bond_allocator), + m_fLocked(), + m_dblWeights(), + m_dblThreshold() { + InitMetadata("NeuralInputTanhUnaryBondData", "DynamicRank.NeuralInputTanhUnaryBondData"); + } + + + NeuralInputTanhUnaryBondData& operator=(const NeuralInputTanhUnaryBondData& _bond_rhs) + { + NeuralInputTanhUnaryBondData(_bond_rhs).swap(*this); + return *this; + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputTanhUnaryBondData& operator=(NeuralInputTanhUnaryBondData&& _bond_rhs) + { + NeuralInputTanhUnaryBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const NeuralInputTanhUnaryBondData& _bond_other) const + { + return true + && (static_cast(*this) == static_cast(_bond_other)) + // skip bond_meta::full_name field 'name' + && (m_fLocked == _bond_other.m_fLocked) + && (m_dblWeights == _bond_other.m_dblWeights) + && (m_dblThreshold == _bond_other.m_dblThreshold); + } + + + bool operator!=(const NeuralInputTanhUnaryBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(NeuralInputTanhUnaryBondData& _bond_other) + { + using std::swap; + ::DynamicRank::NeuralInputUnaryBondData::swap(_bond_other); + // skip bond_meta::full_name field 'name' + swap(m_fLocked, _bond_other.m_fLocked); + swap(m_dblWeights, _bond_other.m_dblWeights); + swap(m_dblThreshold, _bond_other.m_dblThreshold); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* _bond_name, const char* _bond_full_name) + { + ::DynamicRank::NeuralInputUnaryBondData::InitMetadata(_bond_name, _bond_full_name); + this->name = _bond_full_name; + } +}; + + +inline void swap(NeuralInputTanhUnaryBondData& _bond_left, NeuralInputTanhUnaryBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// UnionNeuralInputTanhBondData +struct UnionNeuralInputTanhBondData +{ + // 1: optional bool m_cached + bool m_cached; + + // 2: optional nullable m_neuralInputTanhBondData + bond::nullable< ::DynamicRank::NeuralInputTanhBondData> m_neuralInputTanhBondData; + + // 3: optional nullable m_neuralInputTanhUnaryBondData + bond::nullable< ::DynamicRank::NeuralInputTanhUnaryBondData> m_neuralInputTanhUnaryBondData; + + UnionNeuralInputTanhBondData() + : m_cached() + { + } + + + // Compiler generated copy ctor OK +#ifndef BOND_NO_CXX11_DEFAULTED_FUNCTIONS + UnionNeuralInputTanhBondData(const UnionNeuralInputTanhBondData& /*_bond_rhs*/) = default; +#endif + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + UnionNeuralInputTanhBondData(UnionNeuralInputTanhBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< bool >::value + && std::is_nothrow_move_constructible< bond::nullable< ::DynamicRank::NeuralInputTanhBondData> >::value + && std::is_nothrow_move_constructible< bond::nullable< ::DynamicRank::NeuralInputTanhUnaryBondData> >::value + )) + : m_cached(std::move(_bond_rhs.m_cached)), + m_neuralInputTanhBondData(std::move(_bond_rhs.m_neuralInputTanhBondData)), + m_neuralInputTanhUnaryBondData(std::move(_bond_rhs.m_neuralInputTanhUnaryBondData)) + { + } +#endif + + + template + explicit + UnionNeuralInputTanhBondData(Allocator* _bond_allocator) + : m_cached(), + m_neuralInputTanhBondData(*_bond_allocator), + m_neuralInputTanhUnaryBondData(*_bond_allocator) + { + } + + + // Compiler generated operator= OK +#ifndef BOND_NO_CXX11_DEFAULTED_FUNCTIONS + UnionNeuralInputTanhBondData& operator=(const UnionNeuralInputTanhBondData& _bond_rhs) = default; +#endif + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + UnionNeuralInputTanhBondData& operator=(UnionNeuralInputTanhBondData&& _bond_rhs) + { + UnionNeuralInputTanhBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const UnionNeuralInputTanhBondData& _bond_other) const + { + return true + && (m_cached == _bond_other.m_cached) + && (m_neuralInputTanhBondData == _bond_other.m_neuralInputTanhBondData) + && (m_neuralInputTanhUnaryBondData == _bond_other.m_neuralInputTanhUnaryBondData); + } + + + bool operator!=(const UnionNeuralInputTanhBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(UnionNeuralInputTanhBondData& _bond_other) + { + using std::swap; + swap(m_cached, _bond_other.m_cached); + swap(m_neuralInputTanhBondData, _bond_other.m_neuralInputTanhBondData); + swap(m_neuralInputTanhUnaryBondData, _bond_other.m_neuralInputTanhUnaryBondData); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* /*_bond_name*/, const char* /*_bond_full_name*/) + { + } +}; + + +inline void swap(UnionNeuralInputTanhBondData& _bond_left, UnionNeuralInputTanhBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +// NeuralInputUseAsFloatBondData +struct NeuralInputUseAsFloatBondData : public ::DynamicRank::NeuralInputUnaryBondData +{ + // 1: required_optional bond_meta::full_name name + std::string name; + + NeuralInputUseAsFloatBondData() { + InitMetadata("NeuralInputUseAsFloatBondData", "DynamicRank.NeuralInputUseAsFloatBondData"); + } + + + NeuralInputUseAsFloatBondData(const NeuralInputUseAsFloatBondData& _bond_rhs) + : ::DynamicRank::NeuralInputUnaryBondData(_bond_rhs), + name(_bond_rhs.name.get_allocator()) { + InitMetadata("NeuralInputUseAsFloatBondData", "DynamicRank.NeuralInputUseAsFloatBondData"); + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputUseAsFloatBondData(NeuralInputUseAsFloatBondData&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< ::DynamicRank::NeuralInputUnaryBondData >::value + && std::is_nothrow_move_constructible< std::string >::value + )) + : ::DynamicRank::NeuralInputUnaryBondData(std::move(_bond_rhs)), + name(std::move(_bond_rhs.name)) + { + } +#endif + + + template + explicit + NeuralInputUseAsFloatBondData(Allocator* _bond_allocator) + : ::DynamicRank::NeuralInputUnaryBondData(_bond_allocator), + name(*_bond_allocator) { + InitMetadata("NeuralInputUseAsFloatBondData", "DynamicRank.NeuralInputUseAsFloatBondData"); + } + + + NeuralInputUseAsFloatBondData& operator=(const NeuralInputUseAsFloatBondData& _bond_rhs) + { + NeuralInputUseAsFloatBondData(_bond_rhs).swap(*this); + return *this; + } + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + NeuralInputUseAsFloatBondData& operator=(NeuralInputUseAsFloatBondData&& _bond_rhs) + { + NeuralInputUseAsFloatBondData(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const NeuralInputUseAsFloatBondData& _bond_other) const + { + return true + && (static_cast(*this) == static_cast(_bond_other)); + // skip bond_meta::full_name field 'name' + } + + + bool operator!=(const NeuralInputUseAsFloatBondData& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(NeuralInputUseAsFloatBondData& _bond_other) + { + using std::swap; + ::DynamicRank::NeuralInputUnaryBondData::swap(_bond_other); + // skip bond_meta::full_name field 'name' + } + + + struct Schema; + + +protected: + void InitMetadata(const char* _bond_name, const char* _bond_full_name) + { + ::DynamicRank::NeuralInputUnaryBondData::InitMetadata(_bond_name, _bond_full_name); + this->name = _bond_full_name; + } +}; + + +inline void swap(NeuralInputUseAsFloatBondData& _bond_left, NeuralInputUseAsFloatBondData& _bond_right) +{ + _bond_left.swap(_bond_right); +} + + + +} // namespace DynamicRank diff --git a/src/transform/NeuralTree.Library/inc/UnionBondInput_types.h b/src/transform/NeuralTree.Library/inc/UnionBondInput_types.h new file mode 100644 index 000000000000..dc94d12a2299 --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/UnionBondInput_types.h @@ -0,0 +1,245 @@ + +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Tool : bondc, Version=3.0.1, Build=bond-git.retail.0 +// Template : Microsoft.Bond.Rules.dll#Rules_BOND_CPP.tt +// File : UnionBondInput_types.h +// +// Changes to this file may cause incorrect behavior and will be lost when +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +#pragma once + + +#include +#if BOND_MAJOR_VERSION_MIN_SUPPORTED > 3 \ + || (BOND_MAJOR_VERSION_MIN_SUPPORTED == 3 && BOND_MINOR_VERSION_MIN_SUPPORTED > 1) +#error This file was generated by an older Bond compiler which is \ + incompatible with current Bond library. Please regenerate \ + with the latest Bond compiler. +#endif + +#include +#include +#include +#include "NeuralInput_types.h" +//#include "NeuralInputBSpline_types.h" +//#include "NeuralInputFreeForm_types.h" +//#include "NeuralInputTree_types.h" +//#include "NeuralInputFreeForm_types.h" +//#include "NeuralInputMultiple_types.h +//#include "NeuralInputBM252_types.h" +//#include "NeuralInputNgramBM25_types.h" +#include "NeuralInputFreeForm2_types.h" + +namespace DynamicRank +{ + +// UnionBondInput +struct UnionBondInput +{ + // 1: optional string m_inputType + // Used to decide which input is. linear loglinear rational bspline bsplinebasis bucket decisiontree tanh freeform floatdata aggregatedfreeform <---Note: Research this new case. Could be evil. sumbucket sumcomparisons sumbucketcomparison sumlinear sumloglinear sumgreater sumdivisor bm25f2 logbm25f2 ngrambm25f logngrambm25f perfbm25f perflogbm25f freeform2 <--- This input factory will be in dynamicranker.freeform.library. bulkinput <-- For bulky compiled input. + std::string m_inputType; + + // 2: optional nullable m_linear + bond::nullable< ::DynamicRank::NeuralInputLinearBondData> m_linear; + + // 3: optional nullable m_loglinear + bond::nullable< ::DynamicRank::NeuralInputLogLinearBondData> m_loglinear; + + // 4: optional nullable m_rational + bond::nullable< ::DynamicRank::NeuralInputRationalBondData> m_rational; + + //// 5: optional nullable m_bspline + //bond::nullable< ::DynamicRank::NeuralInputBSplineBondData> m_bspline; + + //// 6: optional nullable m_bsplinebasis + //bond::nullable< ::DynamicRank::NeuralInputBSplineBasisFunctionBondData> m_bsplinebasis; + + // 7: optional nullable m_bucket + bond::nullable< ::DynamicRank::NeuralInputBucketBondData> m_bucket; + + //// 8: optional nullable m_decisiontree + //bond::nullable< ::DynamicRank::NeuralInputTreeBondData> m_decisiontree; + + // 10: optional nullable m_tanh + bond::nullable< ::DynamicRank::UnionNeuralInputTanhBondData> m_tanh; + + //// 11: optional nullable m_freeform + //bond::nullable< ::DynamicRank::NeuralInputFreeFormBondData> m_freeform; + + // 12: optional nullable m_floatdata + bond::nullable< ::DynamicRank::NeuralInputUseAsFloatBondData> m_floatdata; + + //// 13: optional nullable m_aggregatedfreeform + //bond::nullable< ::DynamicRank::NeuralInputFreeFormBondData> m_aggregatedfreeform; + + //// 14: optional nullable m_sumbucket + //bond::nullable< ::DynamicRank::NeuralInputSumBucketBondData> m_sumbucket; + + //// 16: optional nullable m_sumcomparisons + //bond::nullable< ::DynamicRank::NeuralInputSumComparisonsBondData> m_sumcomparisons; + + //// 17: optional nullable m_sumbucketcomparisons + //bond::nullable< ::DynamicRank::NeuralInputSumBucketComparisonBondData> m_sumbucketcomparisons; + + //// 18: optional nullable m_sumlinear + //bond::nullable< ::DynamicRank::NeuralInputSumLinearBondData> m_sumlinear; + + //// 19: optional nullable m_sumloglinear + //bond::nullable< ::DynamicRank::NeuralInputSumLogLinearBondData> m_sumloglinear; + + //// 20: optional nullable m_sumgreater + //bond::nullable< ::DynamicRank::NeuralInputSumGreaterBondData> m_sumgreater; + + //// 21: optional nullable m_sumdivisor + //bond::nullable< ::DynamicRank::NeuralInputSumDivisorBondData> m_sumdivisor; + + //// 22: optional nullable m_bm25f2 + //bond::nullable< ::DynamicRank::NeuralInputLinearBM25BondData> m_bm25f2; + + //// 23: optional nullable m_logbm25f2 + //bond::nullable< ::DynamicRank::NeuralInputLogLinearBM25BondData> m_logbm25f2; + + //// 24: optional nullable m_ngrambm25f + //bond::nullable< ::DynamicRank::NeuralInputLinearNgramBM25BondData> m_ngrambm25f; + + //// 25: optional nullable m_logngrambm25f + //bond::nullable< ::DynamicRank::NeuralInputLogLinearNgramBM25BondData> m_logngrambm25f; + + //// 26: optional nullable m_perfbm25f + //bond::nullable< ::DynamicRank::NeuralInputLinearBM25BondData> m_perfbm25f; + + //// 27: optional nullable m_perflogbm25f + //bond::nullable< ::DynamicRank::NeuralInputLogLinearBM25BondData> m_perflogbm25f; + + // 28: optional nullable m_freeform2 + bond::nullable< ::DynamicRank::NeuralInputFreeForm2BondData> m_freeform2; + + // 29: optional nullable m_bulkinput + bond::nullable< ::DynamicRank::BulkNeuralInputBondData> m_bulkinput; + + UnionBondInput() + { + } + + + // Compiler generated copy ctor OK +#ifndef BOND_NO_CXX11_DEFAULTED_FUNCTIONS + UnionBondInput(const UnionBondInput& /*_bond_rhs*/) = default; +#endif + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + UnionBondInput(UnionBondInput&& _bond_rhs) BOND_NOEXCEPT_IF((true + && std::is_nothrow_move_constructible< std::string >::value + && std::is_nothrow_move_constructible< bond::nullable< ::DynamicRank::NeuralInputLinearBondData> >::value + && std::is_nothrow_move_constructible< bond::nullable< ::DynamicRank::NeuralInputLogLinearBondData> >::value + && std::is_nothrow_move_constructible< bond::nullable< ::DynamicRank::NeuralInputRationalBondData> >::value + && std::is_nothrow_move_constructible< bond::nullable< ::DynamicRank::NeuralInputBucketBondData> >::value + && std::is_nothrow_move_constructible< bond::nullable< ::DynamicRank::UnionNeuralInputTanhBondData> >::value + && std::is_nothrow_move_constructible< bond::nullable< ::DynamicRank::NeuralInputUseAsFloatBondData> >::value + && std::is_nothrow_move_constructible< bond::nullable< ::DynamicRank::NeuralInputFreeForm2BondData> >::value + && std::is_nothrow_move_constructible< bond::nullable< ::DynamicRank::BulkNeuralInputBondData> >::value + )) + : m_inputType(std::move(_bond_rhs.m_inputType)), + m_linear(std::move(_bond_rhs.m_linear)), + m_loglinear(std::move(_bond_rhs.m_loglinear)), + m_rational(std::move(_bond_rhs.m_rational)), + m_bucket(std::move(_bond_rhs.m_bucket)), + m_tanh(std::move(_bond_rhs.m_tanh)), + m_floatdata(std::move(_bond_rhs.m_floatdata)), + m_freeform2(std::move(_bond_rhs.m_freeform2)), + m_bulkinput(std::move(_bond_rhs.m_bulkinput)) + { + } +#endif + + + template + explicit + UnionBondInput(Allocator* _bond_allocator) + : m_inputType(*_bond_allocator), + m_linear(*_bond_allocator), + m_loglinear(*_bond_allocator), + m_rational(*_bond_allocator), + m_bucket(*_bond_allocator), + m_tanh(*_bond_allocator), + m_floatdata(*_bond_allocator), + m_freeform2(*_bond_allocator), + m_bulkinput(*_bond_allocator) + { + } + + + // Compiler generated operator= OK +#ifndef BOND_NO_CXX11_DEFAULTED_FUNCTIONS + UnionBondInput& operator=(const UnionBondInput& _bond_rhs) = default; +#endif + + +#ifndef BOND_NO_CXX11_RVALUE_REFERENCES + UnionBondInput& operator=(UnionBondInput&& _bond_rhs) + { + UnionBondInput(std::move(_bond_rhs)).swap(*this); + return *this; + } +#endif + + + bool operator==(const UnionBondInput& _bond_other) const + { + return true + && (m_inputType == _bond_other.m_inputType) + && (m_linear == _bond_other.m_linear) + && (m_loglinear == _bond_other.m_loglinear) + && (m_rational == _bond_other.m_rational) + && (m_bucket == _bond_other.m_bucket) + && (m_tanh == _bond_other.m_tanh) + && (m_floatdata == _bond_other.m_floatdata) + && (m_freeform2 == _bond_other.m_freeform2) + && (m_bulkinput == _bond_other.m_bulkinput); + } + + + bool operator!=(const UnionBondInput& _bond_other) const + { + return !(*this == _bond_other); + } + + + void swap(UnionBondInput& _bond_other) + { + using std::swap; + swap(m_inputType, _bond_other.m_inputType); + swap(m_linear, _bond_other.m_linear); + swap(m_loglinear, _bond_other.m_loglinear); + swap(m_rational, _bond_other.m_rational); + swap(m_bucket, _bond_other.m_bucket); + swap(m_tanh, _bond_other.m_tanh); + swap(m_floatdata, _bond_other.m_floatdata); + swap(m_freeform2, _bond_other.m_freeform2); + swap(m_bulkinput, _bond_other.m_bulkinput); + } + + + struct Schema; + + +protected: + void InitMetadata(const char* /*_bond_name*/, const char* /*_bond_full_name*/) + { + } +}; + + +inline void swap(UnionBondInput& _bond_left, UnionBondInput& _bond_right) +{ + _bond_left.swap(_bond_right); +} +} // namespace DynamicRank diff --git a/src/transform/NeuralTree.Library/inc/basic_types.h b/src/transform/NeuralTree.Library/inc/basic_types.h new file mode 100644 index 000000000000..54770d7a5285 --- /dev/null +++ b/src/transform/NeuralTree.Library/inc/basic_types.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include + +#ifdef _In_z_ +#undef _In_z_ +#endif +#define _In_z_ + +typedef uint32_t UInt32; +typedef uint8_t UInt8; +typedef uint64_t UInt64; + +typedef int64_t Int64; +typedef int32_t Int32; +typedef unsigned long DWORD; + +typedef size_t Size_t; + +#define MAX_UINT32 ((UInt32)-1) + +struct SIZED_STRING +{ + union + { + const UInt8 *pbData; + const char *pcData; + }; + size_t cbData; +}; + +#define _TRUNCATE ((size_t)-1) + +// A utility class for creating temporary SIZED_STRINGs +class CStackSizedString : public SIZED_STRING +{ +public: + // Null-terminated input + CStackSizedString(_In_z_ const char *szValue) + { +#if defined(_MSC_VER) && _MSC_VER >= 1910 + // Suppression: error C26490: Don't use reinterpret_cast. + // reinterpreting the bits of the char* to the UInt* (both 8 bits per item) is necessary here + [[gsl::suppress(type.1)]] +#endif + pbData = reinterpret_cast(szValue); + cbData = strlen(szValue); + } + + // Name/size pair + CStackSizedString( + const UInt8 *pbValue, + size_t cbValue) + { + pbData = pbValue; + cbData = cbValue; + } + + // Name/size pair + CStackSizedString( + const char *pcValue, + size_t cbValue) + { + pcData = pcValue; + cbData = cbValue; + } + +private: + // prevent heap allocation + void *operator new(size_t); +}; + +#define UINT_MAX 0xffffffffu + + +// convenience macros to make SIZED_STRINGs more usable, particularly with the FEX Document class and FexSprintf +// expand a SIZED_STRING: +#define SIZED_STR(sizedstr) (char*)sizedstr.pbData, sizedstr.cbData +// reverse, for printing to FexSprintf: +#define SIZED_STR_REV(sizedstr) sizedstr.cbData, sizedstr.pcData +// References, for loading with document->GetField( SIZED_STR_REF(str) ); +#define SIZED_STR_REF(sizedstr) &sizedstr.pcData, &sizedstr.cbData +// Convert a SIZED_STRING to std::basic_string +#define SIZED_STR_STL(sizedstr) (sizedstr.cbData>0 ? std::string(SIZED_STR(sizedstr)) : std::string()) + +typedef int BOOL; + +#ifndef FALSE +#define FALSE 0 +#endif + +#ifndef TRUE +#define TRUE 1 +#endif diff --git a/src/transform/NeuralTree.Library/src/CMakeLists.txt b/src/transform/NeuralTree.Library/src/CMakeLists.txt new file mode 100644 index 000000000000..499a46d144b1 --- /dev/null +++ b/src/transform/NeuralTree.Library/src/CMakeLists.txt @@ -0,0 +1,35 @@ +set(PROJECT_NAME DRNeuralTreeLibrary) + +set(Headers + ../inc/basic_types.h + ../inc/CsHash.h + ../inc/FeaSpecConfig.h + ../inc/IFeatureMap.h + ../inc/INeuralNetFeatures.h + ../inc/MigratedApi.h + ../inc/NeuralInput_types.h + ../inc/NeuralInput.h + ../inc/NeuralInputFactory.h + ../inc/NeuralInputFreeForm2_types.h + ../inc/UnionBondInput_types.h +) +source_group("Headers" FILES ${Headers}) + +set(Sources + "CsHash.cpp" + "FeaSpecConfig.cpp" + "NeuralInput.cpp" + "NeuralInputFactory.cpp" +) +source_group("Sources" FILES ${Sources}) + +set(ALL_FILES + ${Headers} + ${Sources} +) + +add_library(${PROJECT_NAME} STATIC ${ALL_FILES}) + +target_include_directories(${PROJECT_NAME} PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/../inc" +) diff --git a/src/transform/NeuralTree.Library/src/CsHash.cpp b/src/transform/NeuralTree.Library/src/CsHash.cpp new file mode 100644 index 000000000000..a2a343de1650 --- /dev/null +++ b/src/transform/NeuralTree.Library/src/CsHash.cpp @@ -0,0 +1,135 @@ +#include "CsHash.h" + +/* + * Copied and modified from http://burtleburtle.net/bob/c/lookup3.c, + * where it is Public Domain. + */ + +#define UPPER(c) ((UInt32)((((c) >= 'a') && ((c) <= 'z')) ? (c) - ('a' - 'A') : (c))) + + +// Compute2: Compute two hashes of an array of bytes. +// The first hash is slightly better mixed than the second hash. +void CsHash32::Compute2 ( + const void *pData, // byte array to hash; may be null if uSize==0 + Size_t uSize, // length of pData + UInt32 uSeed1, // first seed + UInt32 uSeed2, // second seed + UInt32 *uHash1, // OUT: first hash (may not be null) + UInt32 *uHash2) // OUT: second hash (may not be null) +{ + UInt32 a,b,c; + const UInt32 *k = (const UInt32 *)pData; // read 32-bit chunks + const UInt8 *k1; + + // Set up the internal state + a = b = c = 0xdeadbeef + ((UInt32)uSize) + uSeed1; + c += uSeed2; + + // all but last block: aligned reads and affect 32 bits of (a,b,c) + while (uSize > 12) + { + a += k[0]; + b += k[1]; + c += k[2]; + Mix(a,b,c); + uSize -= 12; + k += 3; + } + + // handle the last (probably partial) block + k1 = (const UInt8 *)k; + switch(uSize) + { + case 12: c+=k[2]; b+=k[1]; a+=k[0]; break; + case 11: c+=((UInt32)k1[10])<<16; // fall through + case 10: c+=((UInt32)k1[9])<<8; // fall through + case 9 : c+=(UInt32)k1[8]; // fall through + case 8 : b+=k[1]; a+=k[0]; break; + case 7 : b+=((UInt32)k1[6])<<16; // fall through + case 6 : b+=((UInt32)k1[5])<<8; // fall through + case 5 : b+=((UInt32)k1[4]); // fall through + case 4 : a+=k[0]; break; + case 3 : a+=((UInt32)k1[2])<<16; // fall through + case 2 : a+=((UInt32)k1[1])<<8; // fall through + case 1 : a+=k1[0]; break; + case 0 : + *uHash1 = c; + *uHash2 = b; + return; + } + + Final(a,b,c); + *uHash1 = c; + *uHash2 = b; + return; +} + + +// Hash a string of unknown length case insensitive. I can't just call +// Compute() without allocating a copy of the string, which could have +// complications because there's no max length for strings. +void CsHash32::StringI2 ( + const char *pString, + Size_t uSize, + UInt32 uSeed1, + UInt32 uSeed2, + UInt32 *uHash1, + UInt32 *uHash2) +{ + UInt32 a,b,c; + const UInt8 *k; + + k = (const UInt8 *) pString; + + // Set up the internal state + a = b = c = 0xdeadbeef + ((UInt32)uSize) + uSeed1; + c += uSeed2; + + // all but the last block: affect some 32 bits of (a,b,c) + while (uSize > 12) + { + a += UPPER(k[0]); + a += UPPER(k[1])<<8; + a += UPPER(k[2])<<16; + a += UPPER(k[3])<<24; + b += UPPER(k[4]); + b += UPPER(k[5])<<8; + b += UPPER(k[6])<<16; + b += UPPER(k[7])<<24; + c += UPPER(k[8]); + c += UPPER(k[9])<<8; + c += UPPER(k[10])<<16; + c += UPPER(k[11])<<24; + Mix(a,b,c); + uSize -= 12; + k += 12; + } + + // last block: affect all 32 bits of (c) + switch(uSize) // all the case statements fall through + { + case 12: c+=UPPER(k[11])<<24; + case 11: c+=UPPER(k[10])<<16; + case 10: c+=UPPER(k[9])<<8; + case 9 : c+=UPPER(k[8]); + case 8 : b+=UPPER(k[7])<<24; + case 7 : b+=UPPER(k[6])<<16; + case 6 : b+=UPPER(k[5])<<8; + case 5 : b+=UPPER(k[4]); + case 4 : a+=UPPER(k[3])<<24; + case 3 : a+=UPPER(k[2])<<16; + case 2 : a+=UPPER(k[1])<<8; + case 1 : a+=UPPER(k[0]); + break; + case 0 : + *uHash1 = c; + *uHash2 = b; + return; + } + + Final(a,b,c); + *uHash1 = c; + *uHash2 = b; + return; +} diff --git a/src/transform/NeuralTree.Library/src/FeaSpecConfig.cpp b/src/transform/NeuralTree.Library/src/FeaSpecConfig.cpp new file mode 100644 index 000000000000..fd729e8b79bc --- /dev/null +++ b/src/transform/NeuralTree.Library/src/FeaSpecConfig.cpp @@ -0,0 +1,144 @@ +#include "FeaSpecConfig.h" +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace boost; +using namespace DynamicRank; +using namespace LightGBM; + + +Config* Config::GetRawConfiguration(string str){ + regex sectionPattern("\\[(Input:\\d+)\\]"); + regex linePattern("(Line\\d+)\\=([\\s\\S]*)"); + regex equalPattern("(.*)\\=([\\s\\S]*)"); + smatch result; + string sectionId; + map sectionContent; + sectionContent.clear(); + map> config; + size_t pos = 0; + string line, delimiter = "\n"; + while ((pos = str.find(delimiter)) != string::npos) { + line = str.substr(0, pos); + trim(line); + if (!line.empty()) { + if (regex_match(line, result, sectionPattern)) { + if (!sectionContent.empty()) { + config.insert(map>::value_type(sectionId, sectionContent)); + } + sectionId = result[1]; + sectionContent.clear(); + } + else if (regex_match(line, result, linePattern)) { + sectionContent.insert(map::value_type(result[1], result[2])); + } + else if (regex_match(line, result, equalPattern)) { + sectionContent.insert(map::value_type(result[1], result[2])); + } + else { + Log::Warning("Cannot resolve pattern '%s'. Ignore it.", line.c_str()); + } + } + str.erase(0, pos + delimiter.length()); + } + // Insert last section. + config.insert(map>::value_type(sectionId, sectionContent)); + return new Config(config); +} + + +bool Config::DoesSectionExist(char* section) { + string sectionNameStr(section); + return _config.count(sectionNameStr) > 0; +} + + +bool Config::DoesParameterExist(const char* section, const char* parameterName) { + string sectionNameStr(section); + string paramNameStr(parameterName); + return _config[sectionNameStr].count(parameterName) > 0; +} + + +bool Config::GetStringParameter(const char* section, const char* parameterName, string& value) { + string sectionNameStr(section); + string parameterNameStr(parameterName); + map sectionContent = _config[sectionNameStr]; + if (sectionContent.count(parameterName) == 0) { + return false; + } + else { + value = sectionContent[parameterNameStr]; + return true; + } +} + + +bool Config::GetStringParameter(const char* section, const char* parameterName, char* value, size_t valueSize) { + string sectionNameStr(section); + string parameterNameStr(parameterName); + map sectionContent = _config[sectionNameStr]; + if (sectionContent.count(parameterName) == 0) { + return false; + } + else { + string valueStr = sectionContent[parameterNameStr]; + if (valueStr.length() > valueSize) + return false; + else { + strcpy(value, valueStr.data()); + return true; + } + + } +} + + +bool Config::GetDoubleParameter(const char* section, const char* parameterName, double* value) { + string sectionNameStr(section); + string parameterNameStr(parameterName); + map sectionContent = _config[sectionNameStr]; + if (sectionContent.count(parameterName) == 0) { + return false; + } + else { + *value = atof(sectionContent[parameterNameStr].c_str()); + return true; + } +} + + +double Config::GetDoubleParameter(const char* section, const char* parameterName, double defaultValue) { + double value; + if (Config::GetDoubleParameter(section, parameterName, &value)) + return value; + else + return defaultValue; +} + + +bool Config::GetBoolParameter(const char* section, const char* parameterName, bool* value) { + string sectionNameStr(section); + string parameterNameStr(parameterName); + map sectionContent = _config[sectionNameStr]; + if (sectionContent.count(parameterName) == 0) { + return false; + } + else { + *value = boost::lexical_cast(sectionContent[parameterNameStr]); + return true; + } +} + +bool Config::GetBoolParameter(const char* section, const char* parameterName, bool defaultValue) { + bool value; + if (Config::GetBoolParameter(section, parameterName, &value)) + return value; + else + return defaultValue; +} diff --git a/src/transform/NeuralTree.Library/src/NeuralInput.cpp b/src/transform/NeuralTree.Library/src/NeuralInput.cpp new file mode 100644 index 000000000000..0340b26e2d2f --- /dev/null +++ b/src/transform/NeuralTree.Library/src/NeuralInput.cpp @@ -0,0 +1,1944 @@ +#include +#include +#include "NeuralInput.h" +#include +#include "UnionBondInput_types.h" +#include "MigratedApi.h" +#include + +using namespace LightGBM; +using namespace DynamicRank; + +const static std::vector c_empty; + +// Use caching for the results of neural inputs. Remove this line to not use +// cached results. +#define INPUT_CACHE + + +NeuralInput::NeuralInput() +{ +} + + +NeuralInput::NeuralInput(const NeuralInputBondData& p_data) +{ + if (p_data.m_segments.hasvalue()) + { + m_segments.reset(new std::vector(p_data.m_segments.value())); + } +} + + +NeuralInput::~NeuralInput() +{ +} + + +void +NeuralInput::FillBondData(NeuralInputBondData& p_data) const +{ + if (m_segments != nullptr) + { + p_data.m_segments.set(*m_segments); + } +} + + +double +NeuralInput::EvaluateInput(UInt32 /*input*/) const +{ + // This virtual method should never be called + // There should be a version in the derived class + return 0.0; +} + + +double +NeuralInput::Evaluate(UInt32** p_featureVectorArray, + UInt32 p_currentDocument, + UInt32 p_documentCount) const +{ + // If derived class needs list based evaluation, it should have a version. + // Else call freeform evaluate function. + if (p_currentDocument >= p_documentCount) + { + return 0.0; + } + return Evaluate(p_featureVectorArray[p_currentDocument]); +} + + +UInt32 +NeuralInput::GetAssociatedFeature() const +{ + return static_cast(-1); +} + +double +NeuralInput::GetSlope() const +{ + return 1.0; +} + +double +NeuralInput::GetIntercept() const +{ + return 0; +} + +void +NeuralInput::LoadSegments(DynamicRank::Config& p_config, + const char* szSection) +{ + char szSegments[1024]; + if (!p_config.GetStringParameter( + szSection, + "Segments", + szSegments, + sizeof(szSegments))) + { + // Sectors not specified; will leave m_segments as a length zero vector + return; + } + + std::string tmp(szSegments); + std::vector segmentNames; + boost::split(segmentNames, tmp, boost::is_any_of(", ")); + for (std::vector::iterator it = segmentNames.begin(); it != segmentNames.end(); ++it) + { + if (!it->empty()) + { + boost::algorithm::to_lower(*it); + if (!m_segments) + { + m_segments.reset(new std::vector()); + } + m_segments->push_back(*it); + } + } +} + + +void +NeuralInput::SetSegments(const std::vector& p_segments) +{ + if (p_segments.empty()) + { + m_segments.reset(); + } + else + { + m_segments.reset(new std::vector(p_segments)); + } +} + + +bool +NeuralInput::Save(FILE* fpOutput, + size_t nInputId, + const IFeatureMap& /*p_featureMap*/) const +{ + // Format for an input in the config file is + // [Input:index] + // Name=... + // Transform={linear, bucket, loglinear, etc.} + // (transform-specific config) + + // All we can do here is write out the section name and the + // feature name + fprintf(fpOutput, "\n[Input:%Iu]\n", nInputId); + + // Print out the segments, if applicable + if (m_segments && m_segments->size() != 0) + { + fprintf(fpOutput, "Segments="); + char separatingComma[2] = {0, 0}; + + for (std::vector::const_iterator iter = m_segments->begin(); + iter != m_segments->end(); + ++iter) + { + fprintf(fpOutput, "%s%s", separatingComma, (*iter).c_str()); + separatingComma[0] = ','; + } + fprintf(fpOutput, "\n"); + } + + return true; +} + + +bool +NeuralInput::EqualInternal(const NeuralInput* p_input) const +{ + if (!p_input) + { + return false; + } + + if ((m_segments.get() == nullptr) != (p_input->m_segments.get() == nullptr)) + { + return false; + } + + if (m_segments.get() != nullptr && p_input->m_segments.get() != nullptr && *m_segments != *p_input->m_segments) + { + return false; + } + return true; +} + + +bool +NeuralInput::Train(double /*dblLearningRate*/, + double /*outputHigh*/, + double /*outputLow*/, + double /*dblOutputDelta*/, + UInt32 /*inputHigh*/[], + UInt32 /*inputLow*/[]) +{ + // Default implementation does nothing. + return false; +} + + +// Do a 'set COMPUTE_BOOST_STATIC_ASSERT=1' command, then build in order to get error output that contains the struct size. +#if defined(COMPUTE_BOOST_STATIC_ASSERT) +template struct Sizeof_NeuralInput; +Sizeof_NeuralInput sizeof_NeuralInput; +#endif + +// Use static assert to force the user to consider the external size and serialization if a new member is added. +// For each new member please add it to the template serialize and generate all the bin files for all the nets in searchgold. +// If the new member has an external size, then the size has to be reported within this function. +#ifdef DEBUG +BOOST_STATIC_ASSERT(sizeof(NeuralInput) == 16); +#else +BOOST_STATIC_ASSERT(sizeof(NeuralInput) == 16); +#endif + + +size_t +NeuralInput::GetExternalSize() const +{ + size_t s = 0; + if (m_segments) + { + const std::vector& v = *m_segments; + s += sizeof(std::string) * v.capacity(); + for (size_t i = 0; i < v.size(); ++i) + { + s += v[i].capacity(); + } + } + return s; +} + + +const std::vector& +NeuralInput::GetSegments() const +{ + if (m_segments) + return *m_segments; + else + return c_empty; +} + + +BulkNeuralInput::BulkNeuralInput() +{ +} + + +BulkNeuralInput::~BulkNeuralInput() +{ +} + + +void BulkNeuralInput::Evaluate(UInt32** p_featureVectorArray, + UInt32 p_currentDocument, + UInt32 p_documentCount, + float p_output[]) const +{ + if (p_currentDocument >= p_documentCount) + { + return; + } + Evaluate(p_featureVectorArray[p_currentDocument], p_output); +} + + +bool +NeuralInputUnary::ReadAssociatedFeature(DynamicRank::Config& p_config, + const char* szSection, + IFeatureMap& p_featureMap, + UInt32* piFeature) +{ + *piFeature = static_cast(-1); + + if (szSection == NULL) + { + return false; + } + + char szName[256]; + if (p_config.GetStringParameter(szSection, "Name", szName, sizeof(szName))) + { + if (!p_featureMap.ObtainFeatureIndex(szName, *piFeature)) + { + Log::Warning("DR:ReadAssociatedFeature: Could not find index for feature name: %s in section: %s", szName, szSection); + return false; // input error, invalid feature name + } + return true; + } + Log::Warning("DR:ReadAssociatedFeature: Could not find 'Name' of the feature for section: %s", szSection); + return false; // feature is not specified +} + + +UInt32 +NeuralInputUnary::GetAssociatedFeature() const +{ + return m_iFeature; +} + + +void +NeuralInputUnary::GetAllAssociatedFeatures(std::vector& associatedFeaturesList) const +{ + associatedFeaturesList.push_back(m_iFeature); +} + + +NeuralInputUnary::NeuralInputUnary() : NeuralInput(), m_iFeature(0) +{ +} + + +NeuralInputUnary::NeuralInputUnary(int p_feature) : NeuralInput(), m_iFeature(p_feature) +{ +} + + +NeuralInputUnary::NeuralInputUnary(const NeuralInputUnaryBondData& p_data) + : NeuralInput(p_data), + m_iFeature(p_data.m_iFeature) +{ +} + + +void +NeuralInputUnary::FillBondData(NeuralInputUnaryBondData& p_data) const +{ + // Fill base class. + NeuralInput::FillBondData(p_data); + p_data.m_iFeature = m_iFeature; +} + + +bool +NeuralInputUnary::Save(FILE* fpOutput, + size_t nInputId, + const IFeatureMap& p_featureMap) const +{ + NeuralInput::Save(fpOutput, nInputId, p_featureMap); + + // Create a stack buffer for the feature name + char rgInput[1024]; + if (!p_featureMap.GetFeatureName(m_iFeature, rgInput, 1024)) + { + return false; + } + + fprintf(fpOutput, "Name=%s\n", rgInput); + + return true; +} + + +void +NeuralInputUnary::CopyFrom(const NeuralInputUnary& p_neuralInputUnary) +{ + m_iFeature = p_neuralInputUnary.m_iFeature; +} + + +bool +NeuralInputUnary::Equal(const NeuralInput* p_input) const +{ + if (!NeuralInput::EqualInternal(p_input)) + { + return false; + } + + const NeuralInputUnary* other = dynamic_cast(p_input); + if (!other) + { + return false; + } + if (m_iFeature != other->m_iFeature) + { + return false; + } + return true; +} + + +double +NeuralInputUnary::Evaluate(UInt32 input[]) const +{ + return EvaluateInput(input[m_iFeature]); +} + + +UInt32 +NeuralInputUnary::GetFeature() const +{ + return m_iFeature; +} + + +size_t +NeuralInputUnary::GetSize() const +{ + return sizeof(*this) + GetExternalSize(); +} + + +// Do a 'set COMPUTE_BOOST_STATIC_ASSERT=1' command, then build in order to get error output that contains the struct size. +#if defined(COMPUTE_BOOST_STATIC_ASSERT) +template struct Sizeof_NeuralInputUnary; +Sizeof_NeuralInputUnary sizeof_NeuralInputUnary; +#endif + +// Use static assert to force the user to consider the external size and serialization if a new member is added. +// For each new member please add it to the template serialize and generate all the bin files for all the nets in searchgold. +// If the new member has an external size, then the size has to be reported within this function. +#ifdef DEBUG +BOOST_STATIC_ASSERT(sizeof(NeuralInputUnary) - sizeof(NeuralInput) == 8); +#else +BOOST_STATIC_ASSERT(sizeof(NeuralInputUnary) - sizeof(NeuralInput) == 8); +#endif + + +size_t +NeuralInputUnary::GetExternalSize() const +{ + return NeuralInput::GetExternalSize(); +} + + +NeuralInputLinear* +NeuralInputLinear::Load(DynamicRank::Config& p_config, + const char* szSection, + IFeatureMap& p_featureMap) +{ + UInt32 iFeature; + + if (!NeuralInputUnary::ReadAssociatedFeature(p_config,szSection, p_featureMap, &iFeature)) + { + return NULL; + } + + double slope = 0.0; + double intercept = 0.0; + if ((!p_config.GetDoubleParameter(szSection, "Slope", &slope)) || + (!p_config.GetDoubleParameter(szSection, "Intercept", &intercept))) + { + Log::Warning("NeuralInputLinear::Load: Slope or Intercept not provided (double value) for section: %s", szSection); + return NULL; + } + + NeuralInputLinear* pNeuralInputLinear = new NeuralInputLinear(iFeature, slope, intercept); + return pNeuralInputLinear; +} + + +bool +NeuralInputLinear::Equal(const NeuralInput* p_input) const +{ + if (!NeuralInputUnary::Equal(p_input)) + { + return false; + } + + const NeuralInputLinear* other = dynamic_cast(p_input); + if (!other) + { + return false; + } + if (m_slope != other->m_slope + || m_intercept != other->m_intercept + || m_iFeature != other->m_iFeature) + { + return false; + } + + return true; +} + + +NeuralInputLinear::NeuralInputLinear() : NeuralInputUnary(0), m_slope(0.0), m_intercept(0.0) +{ +} + + +NeuralInputLinear::NeuralInputLinear( + int id, + double slope, + double intercept) + : NeuralInputUnary(id), m_slope(slope), m_intercept(intercept) +{ +} + + +NeuralInputLinear::NeuralInputLinear(const NeuralInputLinearBondData& p_data) + : NeuralInputUnary(p_data), + m_slope(p_data.m_slope), + m_intercept(p_data.m_intercept) +{ +} + + +void +NeuralInputLinear::FillBond(UnionBondInput& p_data) const +{ + p_data.m_inputType = "linear"; + // Fill bond struct of the class. + NeuralInputLinearBondData data; + FillBondData(data); + p_data.m_linear.set(data); +} + + +void +NeuralInputLinear::FillBondData(NeuralInputLinearBondData& p_data) const +{ + // Fill base class. + NeuralInputUnary::FillBondData(p_data); + p_data.m_slope = m_slope; + p_data.m_intercept = m_intercept; +} + + +double +NeuralInputLinear::GetMax() const +{ + // The output of this node is 0 mean, 1 standard deviation. For the max + // value of the output, we assume that it is 1 standard deviation away from + // the mean. + return 1; +} + + +double +NeuralInputLinear::GetMin() const +{ + // The output of this node is 0 mean, 1 standard deviation. For the min + // value of the output, we assume that it is 1 standard deviation away from + // the mean. Further, since the min input to the node is 0, we cannot have + // a value that is less than m_intercept. + return ((m_intercept > -1) ? m_intercept : -1); +} + + +double +NeuralInputLinear::GetSlope() const +{ + return m_slope; +} + + +double +NeuralInputLinear::GetIntercept() const +{ + return m_intercept; +} + + +double +NeuralInputLinear::EvaluateInput(UInt32 val) const +{ + return ((double)val * m_slope) + m_intercept; +} + + +bool +NeuralInputLinear::Save(FILE* fpOutput, + size_t nInputId, + const IFeatureMap& p_featureMap) const +{ + // The base class can write out the generic portion + NeuralInputUnary::Save(fpOutput, nInputId, p_featureMap); + fprintf(fpOutput, "Transform=linear\n"); + fprintf(fpOutput, "Slope=%lg\n", m_slope); + fprintf(fpOutput, "Intercept=%lg\n", m_intercept); + + return true; +} + + +size_t +NeuralInputLinear::GetSize() const +{ + return sizeof(*this) + GetExternalSize(); +} + + +// Do a 'set COMPUTE_BOOST_STATIC_ASSERT=1' command, then build in order to get error output that contains the struct size. +#if defined(COMPUTE_BOOST_STATIC_ASSERT) +template struct Sizeof_NeuralInputLinear; +Sizeof_NeuralInputLinear sizeof_NeuralInputLinear; +#endif + +// Use static assert to force the user to consider the external size and serialization if a new member is added. +// For each new member please add it to the template serialize and generate all the bin files for all the nets in searchgold. +// If the new member has an external size, then the size has to be reported within this function. +#ifdef DEBUG +BOOST_STATIC_ASSERT(sizeof(NeuralInputLinear) - sizeof(NeuralInputUnary) == 16); +#else +BOOST_STATIC_ASSERT(sizeof(NeuralInputLinear) - sizeof(NeuralInputUnary) == 16); +#endif + + +size_t +NeuralInputLinear::GetExternalSize() const +{ + return NeuralInputUnary::GetExternalSize(); +} + + +NeuralInputLogLinear* +NeuralInputLogLinear::Load(DynamicRank::Config& p_config, + const char* szSection, + IFeatureMap& p_featureMap) +{ + UInt32 iFeature; + + if (!NeuralInputUnary::ReadAssociatedFeature(p_config,szSection, p_featureMap, &iFeature)) + { + return NULL; + } + + double slope = 0.0; + double intercept = 0.0; + if ((!p_config.GetDoubleParameter(szSection, "Slope", &slope)) || + (!p_config.GetDoubleParameter(szSection, "Intercept", &intercept))) + { + return NULL; + } + + NeuralInputLogLinear *pNeuralInputLogLinear = new NeuralInputLogLinear(iFeature, slope, intercept); + return pNeuralInputLogLinear; +} + + +NeuralInputLogLinear::NeuralInputLogLinear(int id, + double slope, + double intercept) + : NeuralInputLinear(id, slope, intercept) +{ +} + + +NeuralInputLogLinear::NeuralInputLogLinear(const NeuralInputLogLinearBondData& p_data) + : NeuralInputLinear(p_data) +{ +} + + +void +NeuralInputLogLinear::FillBond(UnionBondInput& p_data) const +{ + p_data.m_inputType = "loglinear"; + // Fill bond struct of the class. + NeuralInputLogLinearBondData data; + FillBondData(data); + p_data.m_loglinear.set(data); +} + + +void +NeuralInputLogLinear::FillBondData(NeuralInputLogLinearBondData& p_data) const +{ + // Fill base class. + NeuralInputLinear::FillBondData(p_data); +} + + +bool +NeuralInputLogLinear::Equal(const NeuralInput* p_input) const +{ + if (!NeuralInputLinear::Equal(p_input)) + { + return false; + } + + if (!dynamic_cast(p_input)) + { + return false; + } + + return true; +} + + +double +NeuralInputLogLinear::EvaluateInput(UInt32 val) const +{ + return (log((double)(val)+1) * m_slope) + m_intercept; +} + + +bool +NeuralInputLogLinear::Save(FILE *fpOutput, + size_t nInputId, + const IFeatureMap& p_featureMap) const +{ + // The base class can write out the generic portion + NeuralInputUnary::Save(fpOutput, nInputId, p_featureMap); + fprintf(fpOutput, "Transform=loglinear\n"); + fprintf(fpOutput, "Slope=%lg\n", m_slope); + fprintf(fpOutput, "Intercept=%lg\n", m_intercept); + + return true; +} + + +size_t +NeuralInputLogLinear::GetSize() const +{ + return sizeof(*this) + GetExternalSize(); +} + + +// Do a 'set COMPUTE_BOOST_STATIC_ASSERT=1' command, then build in order to get error output that contains the struct size. +#if defined(COMPUTE_BOOST_STATIC_ASSERT) +template struct Sizeof_NeuralInputLogLinear; +Sizeof_NeuralInputLogLinear sizeof_NeuralInputLogLinear; +#endif + +// Use static assert to force the user to consider the external size and serialization if a new member is added. +// For each new member please add it to the template serialize and generate all the bin files for all the nets in searchgold. +// If the new member has an external size, then the size has to be reported within this function. +#ifdef DEBUG +BOOST_STATIC_ASSERT(sizeof(NeuralInputLogLinear) - sizeof(NeuralInputLinear) == 0); +#else +BOOST_STATIC_ASSERT(sizeof(NeuralInputLogLinear) - sizeof(NeuralInputLinear) == 0); +#endif + + +size_t +NeuralInputLogLinear::GetExternalSize() const +{ + return NeuralInputLinear::GetExternalSize(); +} + + +NeuralInputBucket::NeuralInputBucket(int p_id, + double p_min, + bool p_mininclusive, + double p_max, + bool p_maxinclusive) + : NeuralInputUnary(p_id) +{ + m_fMinInclusive = p_mininclusive; + m_fMaxInclusive = p_maxinclusive; + if (p_min < 0) + { + m_nMinValue = 0; + } + else + { + m_nMinValue = static_cast(p_min); + if (!p_mininclusive) + { + m_nMinValue++; + } + } + + m_nMaxValue = static_cast(p_max); + if (p_maxinclusive) + { + m_nMaxValue++; + } +} + + +NeuralInputBucket::NeuralInputBucket(const NeuralInputBucketBondData& p_data) + : NeuralInputUnary(p_data), + m_fMinInclusive(p_data.m_fMinInclusive), + m_fMaxInclusive(p_data.m_fMaxInclusive), + m_nMinValue(p_data.m_nMinValue), + m_nMaxValue(p_data.m_nMaxValue) +{ +} + + +void +NeuralInputBucket::FillBond(UnionBondInput& p_data) const +{ + p_data.m_inputType = "bucket"; + // Fill bond struct of the class. + NeuralInputBucketBondData data; + FillBondData(data); + p_data.m_bucket.set(data); +} + + +void +NeuralInputBucket::FillBondData(NeuralInputBucketBondData& p_data) const +{ + // Fill base class. + NeuralInputUnary::FillBondData(p_data); + p_data.m_fMinInclusive = m_fMinInclusive; + p_data.m_fMaxInclusive = m_fMaxInclusive; + p_data.m_nMinValue = m_nMinValue; + p_data.m_nMaxValue = m_nMaxValue; +} + + +bool +NeuralInputBucket::Equal(const NeuralInput* p_input) const +{ + if (!NeuralInputUnary::Equal(p_input)) + { + return false; + } + + const NeuralInputBucket* other = dynamic_cast(p_input); + if (!other) + { + return false; + } + if (m_fMinInclusive != other->m_fMinInclusive + || m_fMaxInclusive != other->m_fMaxInclusive + || m_nMinValue != other->m_nMinValue + || m_nMaxValue != other->m_nMaxValue + || m_iFeature != other->m_iFeature) + { + return false; + } + + return true; +} + + +bool +NeuralInputBucket::GetMinInclusive() const +{ + return m_fMinInclusive; +} + + +bool +NeuralInputBucket::GetMaxInclusive() const +{ + return m_fMaxInclusive; +} + + +UInt32 +NeuralInputBucket::GetMinValue() const +{ + return m_nMinValue; +} + + +UInt32 +NeuralInputBucket::GetMaxValue() const +{ + return m_nMaxValue; +} + + +double +NeuralInputBucket::GetMax() const +{ + // The output of this node is 0 or 1. So the max value is 1. + return 1; +} + + +double +NeuralInputBucket::GetMin() const +{ + // The output of this node is 0 or 1. So the max value is 0. + return 0; +} + + +double NeuralInputBucket::EvaluateInput(UInt32 val) const +{ + if (m_nMinValue <= val && val < m_nMaxValue) + { + return 1.0; + } + + return 0.0; +} + + +bool +NeuralInputBucket::Save(FILE* fpOutput, + size_t nInputId, + const IFeatureMap& p_featureMap) const +{ + // The base class can write out the generic portion + NeuralInputUnary::Save(fpOutput, nInputId, p_featureMap); + fprintf(fpOutput, "Transform=bucket\n"); + double dblMinValue = (double)m_nMinValue; + double dblMaxValue = (double)m_nMaxValue; + + if (!m_fMinInclusive) + { + // Min by default is inclusive, so if we have to write out a + // non-exclusive form we have to subtract. + dblMinValue -= 1.0; + } + + if (m_fMaxInclusive) + { + // Max by default is exclusive, so we have to write out the + // exclusive variant + dblMaxValue -= 1.0; + } + + fprintf(fpOutput, "MinValue=%lf\n", dblMinValue); + fprintf(fpOutput, "MaxValue=%lf\n", dblMaxValue); + fprintf(fpOutput, "MinInclusive=%s\n", m_fMinInclusive ? "true" : "false"); + fprintf(fpOutput, "MaxInclusive=%s\n", m_fMaxInclusive ? "true" : "false"); + + return true; +} + + +NeuralInputBucket* +NeuralInputBucket::Load(DynamicRank::Config& p_config, + const char* szSection, + IFeatureMap& p_featureMap) +{ + UInt32 iFeature; + + if (!NeuralInputUnary::ReadAssociatedFeature( + p_config,szSection, p_featureMap, &iFeature)) + { + return NULL; + } + + double min = 0.0; + double max = 0.0; + bool minincl = false; + bool maxincl = false; + if ((!p_config.GetDoubleParameter(szSection, "MinValue", &min)) || + (!p_config.GetDoubleParameter(szSection, "MaxValue", &max)) || + (!p_config.GetBoolParameter(szSection, "MinInclusive", &minincl)) || + (!p_config.GetBoolParameter(szSection, "MaxInclusive", &maxincl))) + { + return NULL; + } + + NeuralInputBucket *pNeuralInputBucket = + new NeuralInputBucket(iFeature, min, minincl, max, maxincl); + if( !pNeuralInputBucket ) + { + return NULL; + } + + return pNeuralInputBucket; +} + + +size_t +NeuralInputBucket::GetSize() const +{ + return sizeof(*this) + GetExternalSize(); +} + + +// Do a 'set COMPUTE_BOOST_STATIC_ASSERT=1' command, then build in order to get error output that contains the struct size. +#if defined(COMPUTE_BOOST_STATIC_ASSERT) +template struct Sizeof_NeuralInputBucket; +Sizeof_NeuralInputBucket sizeof_NeuralInputBucket; +#endif + +size_t +NeuralInputBucket::GetExternalSize() const +{ + return NeuralInputUnary::GetExternalSize(); +} + + +NeuralInputRational* +NeuralInputRational::Load(DynamicRank::Config& p_config, + const char* szSection, + IFeatureMap& p_featureMap) +{ + UInt32 iFeature; + + if (!NeuralInputUnary::ReadAssociatedFeature(p_config,szSection, p_featureMap, &iFeature)) + { + return NULL; + } + + double dblDampingFactor = p_config.GetDoubleParameter(szSection, "DampingFactor", 0.0); + if (dblDampingFactor <= 0.0) + { + return NULL; + } + + NeuralInputRational* pNeuralInputRational = new NeuralInputRational(iFeature, dblDampingFactor); + return pNeuralInputRational; +} + + +NeuralInputRational::NeuralInputRational(int p_id, + double p_dblDampingFactor) + : NeuralInputUnary(p_id), + // Take the absolute value in order to make the field + // nicely defined (no poles). + m_dblDampingFactor(fabs(p_dblDampingFactor)) +{ +} + + +NeuralInputRational::NeuralInputRational(const NeuralInputRationalBondData& p_data) + : NeuralInputUnary(p_data), + m_dblDampingFactor(p_data.m_dblDampingFactor) +{ +} + + +void +NeuralInputRational::FillBond(UnionBondInput& p_data) const +{ + p_data.m_inputType = "rational"; + // Fill bond struct of the class. + NeuralInputRationalBondData data; + FillBondData(data); + p_data.m_rational.set(data); +} + + +void +NeuralInputRational::FillBondData(NeuralInputRationalBondData& p_data) const +{ + // Fill base class. + NeuralInputUnary::FillBondData(p_data); + p_data.m_dblDampingFactor = m_dblDampingFactor; +} + + +bool +NeuralInputRational::Equal(const NeuralInput* p_input) const +{ + if (!NeuralInputUnary::Equal(p_input)) + { + return false; + } + + const NeuralInputRational* other = dynamic_cast(p_input); + if (!other) + { + return false; + } + if (m_dblDampingFactor != other->m_dblDampingFactor) + { + return false; + } + + return true; +} + + +double +NeuralInputRational::EvaluateInput(UInt32 input) const +{ + double dblOutput = (double)input / ((double)input + m_dblDampingFactor); + + return dblOutput; + +} + + +double +NeuralInputRational::GetDampingFactor() const +{ + return m_dblDampingFactor; +} + + +double +NeuralInputRational::GetMin() const +{ + return 0.0; +} + + +double +NeuralInputRational::GetMax() const +{ + return 1.0; +} + + +bool +NeuralInputRational::Save(FILE* fpOutput, + size_t nInputId, + const IFeatureMap& p_featureMap) const +{ + // The base class can write out the generic portion + NeuralInputUnary::Save(fpOutput, nInputId, p_featureMap); + fprintf(fpOutput, "Transform=rational\n"); + fprintf(fpOutput, "DampingFactor=%lg\n", m_dblDampingFactor); + + return true; +} + + +size_t +NeuralInputRational::GetSize() const +{ + return sizeof(*this) + GetExternalSize(); +} + + +// Do a 'set COMPUTE_BOOST_STATIC_ASSERT=1' command, then build in order to get error output that contains the struct size. +#if defined(COMPUTE_BOOST_STATIC_ASSERT) +template struct Sizeof_NeuralInputRational; +Sizeof_NeuralInputRational sizeof_NeuralInputRational; +#endif + +// Use static assert to force the user to consider the external size and serialization if a new member is added. +// For each new member please add it to the template serialize and generate all the bin files for all the nets in searchgold. +// If the new member has an external size, then the size has to be reported within this function. +#ifdef DEBUG +BOOST_STATIC_ASSERT(sizeof(NeuralInputRational) - sizeof(NeuralInputUnary) == 8); +#else +BOOST_STATIC_ASSERT(sizeof(NeuralInputRational) - sizeof(NeuralInputUnary) == 8); +#endif + + +size_t +NeuralInputRational::GetExternalSize() const +{ + return NeuralInputUnary::GetExternalSize(); +} + + +NeuralInputCached::~NeuralInputCached() +{ +} + + +NeuralInputCached::NeuralInputCached(size_t nCacheSize, + NeuralInputUnary* pChild) + : NeuralInputUnary(pChild->GetAssociatedFeature()) +{ + SetSegments(pChild->GetSegments()); + + m_input.reset(pChild); + m_resultCache.reset(new double[nCacheSize]); + + if (!m_resultCache.get()) + { + m_cacheSize = 0; // not enough memory to cache results + return; + } + m_cacheSize = nCacheSize; + for (UInt32 i = 0; i < nCacheSize ; i++) + { + m_resultCache[i] = m_input->EvaluateInput(i); + } +} + + +void +NeuralInputCached::FillBond(UnionBondInput& p_data) const +{ + m_input->FillBond(p_data); +} + + +bool +NeuralInputCached::Equal(const NeuralInput* p_input) const +{ + if (!NeuralInputUnary::Equal(p_input)) + { + return false; + } + + const NeuralInputCached* other = dynamic_cast(p_input); + if (!other) + { + return false; + } + if (m_cacheSize != other->m_cacheSize + || !m_input->Equal(other->m_input.get())) + { + return false; + } + + return true; +} + + +double +NeuralInputCached::EvaluateInput(UInt32 val) const +{ + if (val < m_cacheSize) + { + return m_resultCache[val]; + } + + return m_input->EvaluateInput(val); +} + + +double +NeuralInputCached::GetMax() const +{ + return m_input->GetMax(); +} + + +double +NeuralInputCached::GetMin() const +{ + return m_input->GetMin(); +} + + +bool +NeuralInputCached::Save(FILE* fpOutput, + size_t nInputId, + const IFeatureMap& p_featureMap) const +{ + NeuralInputUnary* input = const_cast(m_input.get()); + input->SetSegments(GetSegments()); + + return m_input->Save(fpOutput, nInputId, p_featureMap); +} + + +bool +NeuralInputCached::Train(double dblLearningRate, double outputHigh, + double outputLow, double dblOutputDelta, + UInt32 inputHigh[], UInt32 inputLow[]) +{ + bool ret; + + // Zero out cache since training can change the transform. + ret = m_input->Train(dblLearningRate, outputHigh, outputLow, + dblOutputDelta, inputHigh, inputLow); + if (ret) + m_cacheSize = 0; + return ret; +} + + +NeuralInputUnary* +NeuralInputCached::Load(size_t nCacheSize, + NeuralInputUnary *pChild) +{ +#ifdef INPUT_CACHE + + if (pChild) + { + return new NeuralInputCached(nCacheSize, pChild); + } +#endif + + return pChild; +} + + +const NeuralInputUnary* +NeuralInputCached::GetBaseInput() const +{ + return m_input.get(); +} + + +size_t +NeuralInputCached::GetSize() const +{ + return sizeof(*this) + GetExternalSize(); +} + + +// Do a 'set COMPUTE_BOOST_STATIC_ASSERT=1' command, then build in order to get error output that contains the struct size. +#if defined(COMPUTE_BOOST_STATIC_ASSERT) +template struct Sizeof_NeuralInputCached; +Sizeof_NeuralInputCached sizeof_NeuralInputCached; +#endif + +// Use static assert to force the user to consider the external size and serialization if a new member is added. +// For each new member please add it to the template serialize and generate all the bin files for all the nets in searchgold. +// If the new member has an external size, then the size has to be reported within this function. +#ifdef DEBUG +BOOST_STATIC_ASSERT(sizeof(NeuralInputCached) - sizeof(NeuralInputUnary) == 24); +#else +BOOST_STATIC_ASSERT(sizeof(NeuralInputCached) - sizeof(NeuralInputUnary) == 24); +#endif + + +size_t +NeuralInputCached::GetExternalSize() const +{ + size_t externalSize = 0; + + if (m_cacheSize && m_resultCache.get()) + { + externalSize += m_cacheSize * sizeof(m_resultCache[0]); + } + + externalSize += m_input->GetSize(); + + externalSize += NeuralInputUnary::GetExternalSize(); + + return externalSize; +} + + +NeuralInputTanh::NeuralInputTanh() + : m_cInputs(0), + m_locked(false), + m_threshold(0.0) +{ + memset(m_rgId, 0, c_maxInputs * sizeof(m_rgId[0])); + memset(m_rgWeights, 0, c_maxInputs * sizeof(m_rgWeights[0])); +} + + +NeuralInputTanh::NeuralInputTanh(const NeuralInputTanhBondData& p_data) + : NeuralInput(p_data), + m_cInputs(p_data.m_cInputs), + m_locked(p_data.m_locked), + m_threshold(p_data.m_threshold) +{ + for (size_t i = 0; i < p_data.m_rgId.size(); ++i) + { + m_rgId[i] = p_data.m_rgId[i]; + } + + for (size_t i = 0; i < p_data.m_rgWeights.size(); ++i) + { + m_rgWeights[i] = p_data.m_rgWeights[i]; + } +} + + +void +NeuralInputTanh::FillBond(UnionBondInput& p_data) const +{ + p_data.m_inputType = "tanh"; + // Fill bond struct of the class. + NeuralInputTanhBondData data; + FillBondData(data); + UnionNeuralInputTanhBondData unionData; + unionData.m_cached = false; + unionData.m_neuralInputTanhBondData.set(data); + p_data.m_tanh.set(unionData); +} + + +void +NeuralInputTanh::FillBondData(NeuralInputTanhBondData& p_data) const +{ + // Fill base class. + NeuralInput::FillBondData(p_data); + p_data.m_cInputs = static_cast(m_cInputs); + p_data.m_locked = m_locked; + p_data.m_threshold = m_threshold; + for (size_t i = 0; i < c_maxInputs; i++) + { + p_data.m_rgId.push_back(m_rgId[i]); + p_data.m_rgWeights.push_back(m_rgWeights[i]); + } +} + + +bool +NeuralInputTanh::Equal(const NeuralInput* p_input) const +{ + if (!NeuralInput::EqualInternal(p_input)) + { + return false; + } + + const NeuralInputTanh* other = dynamic_cast(p_input); + if (!other) + { + return false; + } + if (m_cInputs != other->m_cInputs + || m_locked != other->m_locked + || m_threshold != other->m_threshold + || memcmp(m_rgId, other->m_rgId, c_maxInputs * sizeof(m_rgId[0])) + || memcmp(m_rgWeights, other->m_rgWeights, c_maxInputs * sizeof(m_rgWeights[0]))) + { + return false; + } + + return true; +} + + +NeuralInputTanh* +NeuralInputTanh::Load(DynamicRank::Config& p_config, + const char* szSection, + IFeatureMap& p_featureMap) +{ + char buff[256]; + + // This transform is different than the others because it may have + // "Name" defined and still have multiple inputs... + if (szSection == NULL) + { + return NULL; + } + + NeuralInputTanh *result = new NeuralInputTanh(); + result->m_cInputs = 0; + result->m_locked = p_config.GetBoolParameter(szSection, "locked", false); + result->m_threshold = p_config.GetDoubleParameter(szSection, "Threshold", 0.0); + + int inputs = 0; + bool featurePresentInFeatureMap = true; + while (inputs < result->c_maxInputs) + { + if (inputs == 0) + { + // strcpy_s(buff, sizeof(buff), "Name"); + strcpy(buff, "Name"); + } + else + { + if (_snprintf_s(buff, sizeof(buff), _TRUNCATE, "Name:%d", inputs + 1) == -1) + { + return NULL; + } + } + if (!p_config.GetStringParameter(szSection, buff, buff, sizeof(buff))) + { + break; + } + + // continue to read configuration even if its not present in the feature + // map, it will be needed for cloning the net + if (!p_featureMap.ObtainFeatureIndex(buff, result->m_rgId[inputs])) + { + featurePresentInFeatureMap = false; + } + + if (inputs == 0) + { + strcpy(buff, "Weight"); + } + else + { + if (_snprintf_s(buff, sizeof(buff), _TRUNCATE, "Weight:%d", inputs + 1) == -1) + { + return NULL; + } + } + + result->m_rgWeights[inputs] = p_config.GetDoubleParameter(szSection, buff, 0.01); + + // Unless we increment input count, the value we stored in the array + // m_rgWeights wont matter. If this variable is set to false once, + // it will be false for all the subsequent iterations of the loop + // there cannot be name:3 present without name:2 in the featuremap + // this is as per the old behavior as of Dec 15, 2008. + if (featurePresentInFeatureMap) + { + result->m_cInputs++; + } + ++inputs; + } + + return result; +} + + +void +NeuralInputTanh::GetAllAssociatedFeatures(std::vector& associatedFeaturesList) const +{ + for( size_t i = 0; i < m_cInputs; ++i ) + { + associatedFeaturesList.push_back(m_rgId[i]); + } +} + + +double +NeuralInputTanh::GetMin() const +{ + return -1.0; +} + + +double +NeuralInputTanh::GetMax() const +{ + return 1.0; +} + + +double +NeuralInputTanh::Evaluate(UInt32 input[]) const +{ + double sum = m_threshold; + for (size_t i = 0; i < m_cInputs; ++i) + { + sum += log((double)input[m_rgId[i]]+1) * m_rgWeights[i]; + } + return tanh(sum); +} + + +bool +NeuralInputTanh::Save(FILE *fpOutput, + size_t nInputId, + const IFeatureMap& p_featureMap) const +{ + // All we can do here is write out the section name and the + // feature name + NeuralInput::Save (fpOutput, nInputId, p_featureMap); + + fprintf(fpOutput, "transform=tanh\n"); + + // Create a stack buffer for the feature name + char rgInput[1024]; + fprintf(fpOutput, "Threshold=%lg\n", m_threshold); + fprintf(fpOutput, "Locked=%s\n", m_locked?"TRUE":"FALSE"); + + for (size_t i=0; i