Skip to content

Commit 07e7e61

Browse files
committed
Add tests for LSE & Softmax functions
Updated TestUMathsCatSnippets with unit tests for each function. Regenerated UMathsCatSnippets using CodeSnip to include the LSE and Softmax routines.
1 parent 98f4bdd commit 07e7e61

File tree

2 files changed

+121
-2
lines changed

2 files changed

+121
-2
lines changed

tests/Cat-Maths/TestUMathsCatSnippets.pas

+65-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ TestMathsCatSnippets = class(TTestCase)
2323
procedure TestPowNZN_EOverflow;
2424
procedure TestDigitPowerSum_EOverflow;
2525
procedure TestDigitPowerSum_EArgumentException;
26+
procedure TestLSE_EArgumentException;
2627
function EqualArrays(const Left, Right: TBytes): Boolean;
2728
function ReverseArray(const A: TBytes): TBytes;
2829
published
@@ -52,7 +53,7 @@ TestMathsCatSnippets = class(TTestCase)
5253
procedure TestMaxOfArray_Integer;
5354
procedure TestMaxOfArray_Int64;
5455
procedure TestMaxOfArray_Single;
55-
procedure TestMaxOfArray_Double;
56+
procedure TestMaxOfArray_Double; // required by LSE
5657
procedure TestMaxOfArray_Extended;
5758
procedure TestPowNZN; // required by DigitPowerSum
5859
procedure TestPowNZZ;
@@ -83,6 +84,8 @@ TestMathsCatSnippets = class(TTestCase)
8384
procedure TestDigitPowerSum; // required by IsNarcissistic
8485
procedure TestIsPalindromic;
8586
procedure TestIsNarcissistic;
87+
procedure TestLSE; // required by SoftMax
88+
procedure TestSoftMax;
8689
end;
8790

8891
implementation
@@ -753,6 +756,31 @@ procedure TestMathsCatSnippets.TestLCD;
753756
CheckEquals(9, LCD(-9, -9), 'LCD(-9, -9)');
754757
end;
755758

759+
procedure TestMathsCatSnippets.TestLSE;
760+
const
761+
Fudge = 0.000001;
762+
A1: array [1..7] of Double = (-35.0, 20.78, 42.56, -27.8, 41.576, 0.0, 57.945);
763+
A2: array [1..7] of Double = (-35.0, 20.78, 42.56, -27.8, 41.576, 0.0, 20.78);
764+
A5: array [1..3] of Double = (-430.0, -399.83, -300.00);
765+
A6: array [1..10] of Double = (-12.0, 4.0, -6.0, 11.0, 10.0, 3.0, -3.0, 9.0, -8.0, 7.0);
766+
begin
767+
// Hand calculated
768+
CheckTrue(SameValue(57.945000285961067157769252279369, LSE(A1)), '#1');
769+
// Calculated using http://mycalcsolutions.com/calculator?mathematics;stat_prob;softmax
770+
CheckTrue(SameValue(42.87759, LSE(A2), Fudge), '#2');
771+
CheckTrue(SameValue(-35.0, LSE([-35.0]), Fudge), '#3');
772+
CheckTrue(SameValue(0.0, LSE([0.0]), Fudge), '#4');
773+
CheckTrue(SameValue(-300.0, LSE(A5), Fudge), '#5');
774+
CheckTrue(SameValue(11.420537, LSE(A6), Fudge), '#6');
775+
// Check empty array exception
776+
CheckException(TestLSE_EArgumentException, EArgumentException, 'EArgumentException');
777+
end;
778+
779+
procedure TestMathsCatSnippets.TestLSE_EArgumentException;
780+
begin
781+
LSE([]);
782+
end;
783+
756784
procedure TestMathsCatSnippets.TestMaxOfArray_Double;
757785
var
758786
A: TDoubleDynArray;
@@ -1203,6 +1231,42 @@ procedure TestMathsCatSnippets.TestResizeRect_B;
12031231
CheckEquals(-4, RectHeight(R), '3: RectHeight');
12041232
end;
12051233

1234+
procedure TestMathsCatSnippets.TestSoftMax;
1235+
1236+
function ArraysEqual(const Left, Right: array of Double): Boolean;
1237+
const
1238+
Fudge = 0.000001;
1239+
var
1240+
Idx: Integer;
1241+
begin
1242+
Result := True;
1243+
if Length(Left) <> Length(Right) then
1244+
Exit(False);
1245+
for Idx := Low(Left) to High(Left) do
1246+
if not SameValue(Left[Idx], Right[Idx], Fudge) then
1247+
Exit(False);
1248+
end;
1249+
const
1250+
A1: array [1..7] of Double = (-35.0, 20.78, 42.56, -27.8, 41.576, 0.0, 57.945);
1251+
E1: array [1..7] of Double = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0);
1252+
A2: array [1..7] of Double = (-35.0, 20.78, 42.56, -27.8, 41.576, 0.0, 20.78);
1253+
E2: array [1..7] of Double = (0.0, 0.0, 0.727901, 0.0, 0.272099, 0.0, 0.0);
1254+
A5: array [1..3] of Double = (-430.0, -399.83, -300.0);
1255+
E5: array [1..3] of Double = (0.0, 0.0, 1.0);
1256+
A6: array [1..10] of Double = (-12.0, 4.0, -6.0, 11.0, 10.0, 3.0, -3.0, 9.0, -8.0, 7.0);
1257+
E6: array [1..10] of Double = (0.0, 0.000599, 0.0, 0.656694, 0.241584, 0.00022, 0.000001, 0.088874, 0, 0.012028);
1258+
A7: array [1..3] of Double = (1430.0, 1430.83, 1440.47);
1259+
E7: array [1..3] of Double = (0.000028, 0.000065, 0.999907);
1260+
begin
1261+
CheckTrue(ArraysEqual(E1, SoftMax(A1)), '#1');
1262+
CheckTrue(ArraysEqual(E2, SoftMax(A2)), '#2');
1263+
CheckTrue(ArraysEqual([1.0], SoftMax([-35.0])), '#3');
1264+
CheckTrue(ArraysEqual([1.0], SoftMax([0.0])), '#4');
1265+
CheckTrue(ArraysEqual(E5, SoftMax(A5)), '#6');
1266+
CheckTrue(ArraysEqual(E6, SoftMax(A6)), '#6');
1267+
CheckTrue(ArraysEqual(E7, SoftMax(A7)), '#7');
1268+
end;
1269+
12061270
procedure TestMathsCatSnippets.TestStretchRect_A;
12071271
var
12081272
R0, R1, R2: TRect;

tests/Cat-Maths/UMathsCatSnippets.pas

+56-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* The unit is copyright © 2005-2024 by Peter Johnson & Contributors and is
77
* licensed under the MIT License (https://opensource.org/licenses/MIT).
88
*
9-
* Generated on : Thu, 09 Jan 2025 15:04:31 GMT.
9+
* Generated on : Fri, 10 Jan 2025 09:52:09 GMT.
1010
* Generated by : DelphiDabbler CodeSnip Release 4.24.0.
1111
*
1212
* The latest version of CodeSnip is available from the CodeSnip GitHub project
@@ -227,6 +227,13 @@ function IsRectNormal(const R: Windows.TRect): Boolean;
227227
}
228228
function LCD(A, B: Integer): Integer;
229229

230+
{
231+
Returns the logarithm of the sum of the exponentials of the given array of
232+
floating pointing point numbers.
233+
An EArgumentException exception is raised if the array is empty.
234+
}
235+
function LSE(const A: array of Double): Double;
236+
230237
{
231238
Returns the maximum value contained in the given array of double precision
232239
floating point values.
@@ -488,6 +495,15 @@ function SignOfInt(const Value: Int64): Integer;
488495
}
489496
procedure SimplifyFraction(var Num, Denom: Int64);
490497

498+
{
499+
Applies the softmax function to each element of floating point array A and
500+
normalizes them into a probability distribution proportional to the
501+
exponentials of the elements of A. The normalised values are returned as an
502+
array of the same size as A.
503+
An EArgumentException exception is raised if A is empty.
504+
}
505+
function SoftMax(const A: array of Double): Types.TDoubleDynArray;
506+
491507
{
492508
Stretches rectangle R by the given scaling factors and returns the result.
493509
The rectangle's width is scaled by ScalingX and its height by ScalingY.
@@ -1198,6 +1214,27 @@ function LCD(A, B: Integer): Integer;
11981214
Result := Abs((A * B)) div GCD(A, B);
11991215
end;
12001216

1217+
{
1218+
Returns the logarithm of the sum of the exponentials of the given array of
1219+
floating pointing point numbers.
1220+
An EArgumentException exception is raised if the array is empty.
1221+
}
1222+
function LSE(const A: array of Double): Double;
1223+
var
1224+
MaxElem: Double;
1225+
Elem: Double;
1226+
Sum: Double;
1227+
begin
1228+
if System.Length(A) = 0 then
1229+
raise SysUtils.EArgumentException.Create('Empty array');
1230+
// Using the centering "trick": see https://rpubs.com/FJRubio/LSE
1231+
MaxElem := MaxOfArray(A);
1232+
Sum := 0.0;
1233+
for Elem in A do
1234+
Sum := Sum + System.Exp(Elem - MaxElem);
1235+
Result := System.Ln(Sum) + MaxElem;
1236+
end;
1237+
12011238
{
12021239
Returns the maximum value contained in the given array of double precision
12031240
floating point values.
@@ -1950,6 +1987,24 @@ procedure SimplifyFraction(var Num, Denom: Int64);
19501987
Denom := Denom div CommonFactor;
19511988
end;
19521989

1990+
{
1991+
Applies the softmax function to each element of floating point array A and
1992+
normalizes them into a probability distribution proportional to the
1993+
exponentials of the elements of A. The normalised values are returned as an
1994+
array of the same size as A.
1995+
An EArgumentException exception is raised if A is empty.
1996+
}
1997+
function SoftMax(const A: array of Double): Types.TDoubleDynArray;
1998+
var
1999+
LSEOfA: Double;
2000+
Idx: Integer;
2001+
begin
2002+
LSEOfA := LSE(A); // raise EArgumentException if A is empty
2003+
System.SetLength(Result, System.Length(A));
2004+
for Idx := 0 to Pred(System.Length(A)) do
2005+
Result[Idx] := System.Exp(A[Idx] - LSEOfA);
2006+
end;
2007+
19532008
{
19542009
Stretches rectangle R by the given scaling factors and returns the result.
19552010
The rectangle's width is scaled by ScalingX and its height by ScalingY.

0 commit comments

Comments
 (0)