Skip to content

Commit

Permalink
stroop SOA working -- lots of moving parts on that one.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoreilly committed Oct 24, 2024
1 parent 2d9e34e commit a1a48da
Showing 1 changed file with 100 additions and 18 deletions.
118 changes: 100 additions & 18 deletions ch9/stroop/stroop.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ package main
import (
"embed"
"math"
"reflect"
"strings"

"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/randx"
"cogentcore.org/core/core"
"cogentcore.org/core/icons"
"cogentcore.org/core/math32"
"cogentcore.org/core/math32/minmax"
"cogentcore.org/core/tensor/table"
"cogentcore.org/core/tree"
"github.com/emer/emergent/v2/econfig"
Expand Down Expand Up @@ -119,7 +122,7 @@ var ParamSets = params.Sets{
"Layer.Act.Dt.VmTau": "30",
}},
},
"SOATraining": {
"SOATesting": {
{Sel: "Layer", Desc: "no decay",
Params: params.Params{
"Layer.Act.Init.Decay": "0",
Expand Down Expand Up @@ -363,12 +366,33 @@ func (ss *Sim) CycleThresholdStop() {
if ss.Context.Mode == etime.Train {
return
}
cyc := ss.Loops.Stacks[etime.Test].Loops[etime.Cycle]
mode := ss.Context.Mode
cycl := ss.Loops.Stacks[mode].Loops[etime.Cycle]
cyc := cycl.Counter.Cur
out := ss.Net.LayerByName("Output")
outact := out.Pools[0].Inhib.Act.Max
if outact > 0.51 {
ss.Stats.SetFloat("RT", float64(cyc.Counter.Cur))
cyc.SkipToMax()
if mode == etime.Validate {
tbl := ss.SOA
trl := ss.Loops.Stacks[mode].Loops[etime.Trial].Counter.Cur
// soa := int(tbl.Float("SOA", trl))
mxc := int(tbl.Float("MaxCycles", trl))
islate := strings.Contains(ss.Stats.String("TrialName"), "latestim")
if islate {
if outact > 0.51 {
ss.Stats.SetFloat("RT", float64(cyc))
cycl.SkipToMax()
}
} else {
if cyc > mxc {
cycl.SkipToMax()
}
}
// fmt.Println(trl, tnm, soa, mxc)
} else {
if outact > 0.51 {
ss.Stats.SetFloat("RT", float64(cyc))
cycl.SkipToMax()
}
}
}

Expand All @@ -389,6 +413,11 @@ func (ss *Sim) ConfigLoops() {
AddTime(etime.Trial, ss.Test.Rows).
AddTime(etime.Cycle, 200)

man.AddStack(etime.Validate).
AddTime(etime.Epoch, 1).
AddTime(etime.Trial, ss.SOA.Rows).
AddTime(etime.Cycle, 200)

leabra.LooperStdPhases(man, &ss.Context, ss.Net, 75, 99) // plus phase timing
leabra.LooperSimCycleAndLearn(man, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code

Expand All @@ -397,7 +426,7 @@ func (ss *Sim) ConfigLoops() {
stack.Loops[etime.Trial].OnStart.Add("ApplyInputs", func() {
ss.ApplyInputs()
})
if m == etime.Test {
if m == etime.Test || m == etime.Validate {
stack.Loops[etime.Cycle].Main.Add("CycleThresholdStop", func() {
ss.CycleThresholdStop()
})
Expand All @@ -406,15 +435,6 @@ func (ss *Sim) ConfigLoops() {

man.GetLoop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun)

// Add Testing
trainEpoch := man.GetLoop(etime.Train, etime.Epoch)
trainEpoch.OnStart.Add("TestAtInterval", func() {
if (ss.Config.TestInterval > 0) && ((trainEpoch.Counter.Cur+1)%ss.Config.TestInterval == 0) {
// Note the +1 so that it doesn't occur at the 0th timestep.
ss.TestAll()
}
})

/////////////////////////////////////////////
// Logging

Expand Down Expand Up @@ -456,7 +476,13 @@ func (ss *Sim) ApplyInputs() {
// ss.Stats.SetString("TrialName", evi.(*env.FreqTable).String())
} else {
out.Type = leabra.CompareLayer
ss.Stats.SetString("TrialName", evi.(*env.FixedTable).TrialName.Cur)
trlnm := evi.(*env.FixedTable).TrialName.Cur
ss.Stats.SetString("TrialName", trlnm)
if ctx.Mode == etime.Validate {
if !strings.Contains(trlnm, "latestim") || strings.Contains(trlnm, "Both") {
ss.Net.InitActs()
}
}
}
for _, lnm := range lays {
ly := ss.Net.LayerByName(lnm)
Expand Down Expand Up @@ -498,7 +524,9 @@ func (ss *Sim) TestAll() {
func (ss *Sim) InitStats() {
ss.Stats.SetFloat("SSE", 0.0)
ss.Stats.SetFloat("RT", math.NaN())
ss.Stats.SetFloat("SOA", 0)
ss.Stats.SetString("TrialName", "")
ss.Stats.SetString("GroupName", "")
ss.Logs.InitErrStats() // inits TrlErr, FirstZero, LastZero, NZero
}

Expand All @@ -512,8 +540,16 @@ func (ss *Sim) StatCounters() {
trnEpc := ss.Loops.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur
ss.Stats.SetInt("Epoch", trnEpc)
trl := ss.Stats.Int("Trial")
if ctx.Mode != etime.Train {
if ctx.Mode == etime.Test {
trl = trl % 3
} else if ctx.Mode == etime.Validate {
trl := ss.Loops.Stacks[mode].Loops[etime.Trial].Counter.Cur
soa := ss.SOA.Float("SOA", trl)
ss.Stats.SetFloat("SOA", soa)
gp := trl / 22
conds := []string{"Color_Conf", "Color_Cong", "Word_Conf", "Word_Cong"}
ss.Stats.SetString("GroupName", conds[gp])
trl = trl % 22
}
ss.Stats.SetInt("Trial", trl)
ss.Stats.SetInt("Cycle", int(ctx.Cycle))
Expand Down Expand Up @@ -554,13 +590,24 @@ func (ss *Sim) ConfigLogs() {
ss.Logs.AddCounterItems(etime.Run, etime.Epoch, etime.Trial, etime.Cycle)
ss.Logs.AddStatStringItem(etime.AllModes, etime.AllTimes, "RunName")
ss.Logs.AddStatStringItem(etime.AllModes, etime.Trial, "TrialName")
ss.Logs.AddStatStringItem(etime.Validate, etime.Trial, "GroupName")

ss.Logs.AddStatAggItem("SSE", etime.Run, etime.Epoch, etime.Trial)
ss.Logs.AddStatAggItem("AvgSSE", etime.Run, etime.Epoch, etime.Trial)
ss.Logs.AddErrStatAggItems("TrlErr", etime.Run, etime.Epoch, etime.Trial)

ss.Logs.AddStatAggItem("RT", etime.Epoch, etime.Trial)

ss.Logs.AddItem(&elog.Item{
Name: "SOA",
Type: reflect.Float64,
FixMax: true,
Range: minmax.F32{Max: 220},
Write: elog.WriteMap{
etime.Scope(etime.Validate, etime.Trial): func(ctx *elog.Context) {
ctx.SetFloat64(ss.Stats.Float("SOA"))
}}})

ss.Logs.AddPerTrlMSec("PerTrlMSec", etime.Run, etime.Epoch, etime.Trial)

ss.Logs.PlotItems("RT")
Expand All @@ -570,7 +617,9 @@ func (ss *Sim) ConfigLogs() {
// don't plot certain combinations we don't use
ss.Logs.NoPlot(etime.Train, etime.Cycle)
ss.Logs.NoPlot(etime.Test, etime.Cycle)
ss.Logs.NoPlot(etime.Test, etime.Epoch)
ss.Logs.NoPlot(etime.Test, etime.Run)
ss.Logs.SetMetaScope(etime.Scope(etime.Validate, etime.Trial), "Plot", "true")
ss.Logs.SetMeta(etime.Train, etime.Run, "LegendCol", "RunName")

ss.Logs.SetMeta(etime.Test, etime.Trial, "Points", "true")
Expand All @@ -587,6 +636,19 @@ func (ss *Sim) ConfigLogs() {
ss.Logs.SetMeta(etime.Test, etime.Trial, "TrialName:On", "+")

ss.Logs.SetMeta(etime.Train, etime.Epoch, "PctErr:On", "+")

ss.Logs.SetMeta(etime.Validate, etime.Trial, "XAxis", "SOA")
ss.Logs.SetMeta(etime.Validate, etime.Trial, "Legend", "GroupName")
ss.Logs.SetMeta(etime.Validate, etime.Trial, "Points", "true")

ss.Logs.SetMeta(etime.Validate, etime.Trial, "SOA:Min", "-22")
ss.Logs.SetMeta(etime.Validate, etime.Trial, "SOA:Max", "22")
ss.Logs.SetMeta(etime.Validate, etime.Trial, "SOA:FixMin", "true")
ss.Logs.SetMeta(etime.Validate, etime.Trial, "SOA:FixMax", "true")
ss.Logs.SetMeta(etime.Validate, etime.Trial, "RT:FixMin", "true")
ss.Logs.SetMeta(etime.Validate, etime.Trial, "RT:FixMax", "true")
ss.Logs.SetMeta(etime.Validate, etime.Trial, "RT:Min", "0")
ss.Logs.SetMeta(etime.Validate, etime.Trial, "RT:Max", "250")
}

// Log is the main logging function, handles special things for different scopes
Expand All @@ -605,6 +667,11 @@ func (ss *Sim) Log(mode etime.Modes, time etime.Times) {
case time == etime.Cycle:
return
case time == etime.Trial:
if mode == etime.Validate {
if !strings.Contains(ss.Stats.String("TrialName"), "latestim") {
return
}
}
ss.TrialStats()
ss.StatCounters()
}
Expand Down Expand Up @@ -648,18 +715,33 @@ func (ss *Sim) MakeToolbar(p *tree.Plan) {
},
})

ss.GUI.AddLooperCtrl(p, ss.Loops, []etime.Modes{etime.Train, etime.Test})
ss.GUI.AddLooperCtrl(p, ss.Loops, []etime.Modes{etime.Train, etime.Test, etime.Validate})

ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Test Init", Icon: icons.Update,
Tooltip: "Initialize testing to start over -- if Test Step doesn't work, then do this.",
Active: egui.ActiveStopped,
Func: func() {
ss.Params.SetAllSheet("Testing")
ss.SetPFCParams()
ev := ss.Envs.ByMode(etime.Validate)
ev.Init(0)
ss.Loops.ResetCountersByMode(etime.Test)
},
})

ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "SOA Init", Icon: icons.Update,
Tooltip: "Initialize SOA testing to start over -- if Test Step doesn't work, then do this.",
Active: egui.ActiveStopped,
Func: func() {
ss.Params.SetAllSheet("Testing")
ss.Params.SetAllSheet("SOATesting")
ss.SetPFCParams()
ev := ss.Envs.ByMode(etime.Validate)
ev.Init(0)
ss.Loops.ResetCountersByMode(etime.Validate)
},
})

////////////////////////////////////////////////
tree.Add(p, func(w *core.Separator) {})
ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Reset RunLog",
Expand Down

0 comments on commit a1a48da

Please sign in to comment.