From a1a48da55a29d1c40a5cabce831eb802f9de5612 Mon Sep 17 00:00:00 2001 From: "Randall C. O'Reilly" Date: Thu, 24 Oct 2024 03:11:51 -0700 Subject: [PATCH] stroop SOA working -- lots of moving parts on that one. --- ch9/stroop/stroop.go | 118 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 100 insertions(+), 18 deletions(-) diff --git a/ch9/stroop/stroop.go b/ch9/stroop/stroop.go index 41ac930..c195dd4 100644 --- a/ch9/stroop/stroop.go +++ b/ch9/stroop/stroop.go @@ -11,12 +11,15 @@ package main import ( "embed" "math" + "reflect" + "strings" "cogentcore.org/core/base/errors" "cogentcore.org/core/base/randx" "cogentcore.org/core/core" "cogentcore.org/core/icons" "cogentcore.org/core/math32" + "cogentcore.org/core/math32/minmax" "cogentcore.org/core/tensor/table" "cogentcore.org/core/tree" "github.com/emer/emergent/v2/econfig" @@ -119,7 +122,7 @@ var ParamSets = params.Sets{ "Layer.Act.Dt.VmTau": "30", }}, }, - "SOATraining": { + "SOATesting": { {Sel: "Layer", Desc: "no decay", Params: params.Params{ "Layer.Act.Init.Decay": "0", @@ -363,12 +366,33 @@ func (ss *Sim) CycleThresholdStop() { if ss.Context.Mode == etime.Train { return } - cyc := ss.Loops.Stacks[etime.Test].Loops[etime.Cycle] + mode := ss.Context.Mode + cycl := ss.Loops.Stacks[mode].Loops[etime.Cycle] + cyc := cycl.Counter.Cur out := ss.Net.LayerByName("Output") outact := out.Pools[0].Inhib.Act.Max - if outact > 0.51 { - ss.Stats.SetFloat("RT", float64(cyc.Counter.Cur)) - cyc.SkipToMax() + if mode == etime.Validate { + tbl := ss.SOA + trl := ss.Loops.Stacks[mode].Loops[etime.Trial].Counter.Cur + // soa := int(tbl.Float("SOA", trl)) + mxc := int(tbl.Float("MaxCycles", trl)) + islate := strings.Contains(ss.Stats.String("TrialName"), "latestim") + if islate { + if outact > 0.51 { + ss.Stats.SetFloat("RT", float64(cyc)) + cycl.SkipToMax() + } + } else { + if cyc > mxc { + cycl.SkipToMax() + } + } + // fmt.Println(trl, tnm, soa, mxc) + } else { + if outact > 0.51 { + ss.Stats.SetFloat("RT", float64(cyc)) + cycl.SkipToMax() + } } } @@ -389,6 +413,11 @@ func (ss *Sim) ConfigLoops() { AddTime(etime.Trial, ss.Test.Rows). AddTime(etime.Cycle, 200) + man.AddStack(etime.Validate). + AddTime(etime.Epoch, 1). + AddTime(etime.Trial, ss.SOA.Rows). + AddTime(etime.Cycle, 200) + leabra.LooperStdPhases(man, &ss.Context, ss.Net, 75, 99) // plus phase timing leabra.LooperSimCycleAndLearn(man, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code @@ -397,7 +426,7 @@ func (ss *Sim) ConfigLoops() { stack.Loops[etime.Trial].OnStart.Add("ApplyInputs", func() { ss.ApplyInputs() }) - if m == etime.Test { + if m == etime.Test || m == etime.Validate { stack.Loops[etime.Cycle].Main.Add("CycleThresholdStop", func() { ss.CycleThresholdStop() }) @@ -406,15 +435,6 @@ func (ss *Sim) ConfigLoops() { man.GetLoop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun) - // Add Testing - trainEpoch := man.GetLoop(etime.Train, etime.Epoch) - trainEpoch.OnStart.Add("TestAtInterval", func() { - if (ss.Config.TestInterval > 0) && ((trainEpoch.Counter.Cur+1)%ss.Config.TestInterval == 0) { - // Note the +1 so that it doesn't occur at the 0th timestep. - ss.TestAll() - } - }) - ///////////////////////////////////////////// // Logging @@ -456,7 +476,13 @@ func (ss *Sim) ApplyInputs() { // ss.Stats.SetString("TrialName", evi.(*env.FreqTable).String()) } else { out.Type = leabra.CompareLayer - ss.Stats.SetString("TrialName", evi.(*env.FixedTable).TrialName.Cur) + trlnm := evi.(*env.FixedTable).TrialName.Cur + ss.Stats.SetString("TrialName", trlnm) + if ctx.Mode == etime.Validate { + if !strings.Contains(trlnm, "latestim") || strings.Contains(trlnm, "Both") { + ss.Net.InitActs() + } + } } for _, lnm := range lays { ly := ss.Net.LayerByName(lnm) @@ -498,7 +524,9 @@ func (ss *Sim) TestAll() { func (ss *Sim) InitStats() { ss.Stats.SetFloat("SSE", 0.0) ss.Stats.SetFloat("RT", math.NaN()) + ss.Stats.SetFloat("SOA", 0) ss.Stats.SetString("TrialName", "") + ss.Stats.SetString("GroupName", "") ss.Logs.InitErrStats() // inits TrlErr, FirstZero, LastZero, NZero } @@ -512,8 +540,16 @@ func (ss *Sim) StatCounters() { trnEpc := ss.Loops.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur ss.Stats.SetInt("Epoch", trnEpc) trl := ss.Stats.Int("Trial") - if ctx.Mode != etime.Train { + if ctx.Mode == etime.Test { trl = trl % 3 + } else if ctx.Mode == etime.Validate { + trl := ss.Loops.Stacks[mode].Loops[etime.Trial].Counter.Cur + soa := ss.SOA.Float("SOA", trl) + ss.Stats.SetFloat("SOA", soa) + gp := trl / 22 + conds := []string{"Color_Conf", "Color_Cong", "Word_Conf", "Word_Cong"} + ss.Stats.SetString("GroupName", conds[gp]) + trl = trl % 22 } ss.Stats.SetInt("Trial", trl) ss.Stats.SetInt("Cycle", int(ctx.Cycle)) @@ -554,6 +590,7 @@ func (ss *Sim) ConfigLogs() { ss.Logs.AddCounterItems(etime.Run, etime.Epoch, etime.Trial, etime.Cycle) ss.Logs.AddStatStringItem(etime.AllModes, etime.AllTimes, "RunName") ss.Logs.AddStatStringItem(etime.AllModes, etime.Trial, "TrialName") + ss.Logs.AddStatStringItem(etime.Validate, etime.Trial, "GroupName") ss.Logs.AddStatAggItem("SSE", etime.Run, etime.Epoch, etime.Trial) ss.Logs.AddStatAggItem("AvgSSE", etime.Run, etime.Epoch, etime.Trial) @@ -561,6 +598,16 @@ func (ss *Sim) ConfigLogs() { ss.Logs.AddStatAggItem("RT", etime.Epoch, etime.Trial) + ss.Logs.AddItem(&elog.Item{ + Name: "SOA", + Type: reflect.Float64, + FixMax: true, + Range: minmax.F32{Max: 220}, + Write: elog.WriteMap{ + etime.Scope(etime.Validate, etime.Trial): func(ctx *elog.Context) { + ctx.SetFloat64(ss.Stats.Float("SOA")) + }}}) + ss.Logs.AddPerTrlMSec("PerTrlMSec", etime.Run, etime.Epoch, etime.Trial) ss.Logs.PlotItems("RT") @@ -570,7 +617,9 @@ func (ss *Sim) ConfigLogs() { // don't plot certain combinations we don't use ss.Logs.NoPlot(etime.Train, etime.Cycle) ss.Logs.NoPlot(etime.Test, etime.Cycle) + ss.Logs.NoPlot(etime.Test, etime.Epoch) ss.Logs.NoPlot(etime.Test, etime.Run) + ss.Logs.SetMetaScope(etime.Scope(etime.Validate, etime.Trial), "Plot", "true") ss.Logs.SetMeta(etime.Train, etime.Run, "LegendCol", "RunName") ss.Logs.SetMeta(etime.Test, etime.Trial, "Points", "true") @@ -587,6 +636,19 @@ func (ss *Sim) ConfigLogs() { ss.Logs.SetMeta(etime.Test, etime.Trial, "TrialName:On", "+") ss.Logs.SetMeta(etime.Train, etime.Epoch, "PctErr:On", "+") + + ss.Logs.SetMeta(etime.Validate, etime.Trial, "XAxis", "SOA") + ss.Logs.SetMeta(etime.Validate, etime.Trial, "Legend", "GroupName") + ss.Logs.SetMeta(etime.Validate, etime.Trial, "Points", "true") + + ss.Logs.SetMeta(etime.Validate, etime.Trial, "SOA:Min", "-22") + ss.Logs.SetMeta(etime.Validate, etime.Trial, "SOA:Max", "22") + ss.Logs.SetMeta(etime.Validate, etime.Trial, "SOA:FixMin", "true") + ss.Logs.SetMeta(etime.Validate, etime.Trial, "SOA:FixMax", "true") + ss.Logs.SetMeta(etime.Validate, etime.Trial, "RT:FixMin", "true") + ss.Logs.SetMeta(etime.Validate, etime.Trial, "RT:FixMax", "true") + ss.Logs.SetMeta(etime.Validate, etime.Trial, "RT:Min", "0") + ss.Logs.SetMeta(etime.Validate, etime.Trial, "RT:Max", "250") } // Log is the main logging function, handles special things for different scopes @@ -605,6 +667,11 @@ func (ss *Sim) Log(mode etime.Modes, time etime.Times) { case time == etime.Cycle: return case time == etime.Trial: + if mode == etime.Validate { + if !strings.Contains(ss.Stats.String("TrialName"), "latestim") { + return + } + } ss.TrialStats() ss.StatCounters() } @@ -648,7 +715,7 @@ func (ss *Sim) MakeToolbar(p *tree.Plan) { }, }) - ss.GUI.AddLooperCtrl(p, ss.Loops, []etime.Modes{etime.Train, etime.Test}) + ss.GUI.AddLooperCtrl(p, ss.Loops, []etime.Modes{etime.Train, etime.Test, etime.Validate}) ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Test Init", Icon: icons.Update, Tooltip: "Initialize testing to start over -- if Test Step doesn't work, then do this.", @@ -656,10 +723,25 @@ func (ss *Sim) MakeToolbar(p *tree.Plan) { Func: func() { ss.Params.SetAllSheet("Testing") ss.SetPFCParams() + ev := ss.Envs.ByMode(etime.Validate) + ev.Init(0) ss.Loops.ResetCountersByMode(etime.Test) }, }) + ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "SOA Init", Icon: icons.Update, + Tooltip: "Initialize SOA testing to start over -- if Test Step doesn't work, then do this.", + Active: egui.ActiveStopped, + Func: func() { + ss.Params.SetAllSheet("Testing") + ss.Params.SetAllSheet("SOATesting") + ss.SetPFCParams() + ev := ss.Envs.ByMode(etime.Validate) + ev.Init(0) + ss.Loops.ResetCountersByMode(etime.Validate) + }, + }) + //////////////////////////////////////////////// tree.Add(p, func(w *core.Separator) {}) ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Reset RunLog",