Skip to content

Commit

Permalink
Fix PR #68 test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
lcorcodilos committed Jun 15, 2021
1 parent 590b64b commit 9ca9055
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
7 changes: 3 additions & 4 deletions TIMBER/Analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""

from TIMBER.CollectionOrganizer import CollectionOrganizer
from TIMBER.Utilities.CollectionGen import BuildCollectionDict, GetKeyValForBranch, StructDef, StructObj
from TIMBER.Tools.Common import GenerateHash, GetHistBinningTuple, CompileCpp, ConcatCols, GetStandardFlags, ExecuteCmd
from clang import cindex
from collections import OrderedDict
Expand Down Expand Up @@ -1041,9 +1040,9 @@ def DrawTemplates(self,hGroup,saveLocation,projection='X',projectionArgs=(),file
down.SetLineColor(ROOT.kBlue)

leg = ROOT.TLegend(0.7,0.7,0.9,0.9)
leg.AddEntry(nominal.GetValue(),'Nominal','lf')
leg.AddEntry(up.GetValue(),'Up','l')
leg.AddEntry(down.GetValue(),'Down','l')
leg.AddEntry(nominal,'Nominal','lf')
leg.AddEntry(up,'Up','l')
leg.AddEntry(down,'Down','l')

up.Draw('same hist')
down.Draw('same hist')
Expand Down
14 changes: 10 additions & 4 deletions TIMBER/CollectionOrganizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,20 @@ def _parsetype(self, t):
nRVecs = len(re.findall('ROOT::VecOps::RVec<',t))
if nRVecs == 0:
collType = t
isVect = False
else:
isVect = True
collType = t.strip()
collType = re.sub('ROOT::VecOps::RVec<','',collType,count=1)
collType = re.sub('>','',collType,count=1)
collType += ' &'
if 'Bool_t' in collType:
collType = collType.replace('Bool_t&','std::_Bit_reference')
collType = collType.replace('Bool_t &','Bool_t&').replace('Bool_t&','std::_Bit_reference')

if collType == ' &':
collType = ''

return collType
return collType, isVect

def AddCollection(self, c):
'''Add a collection to tracking.
Expand Down Expand Up @@ -91,20 +93,22 @@ def AddBranch(self, b, btype=''):
'''
collname = b.split('_')[0]
varname = '_'.join(b.split('_')[1:])
typeStr = self._parsetype(btype)
typeStr, isVect = self._parsetype(btype)

if typeStr == False or varname == '' or 'n'+collname not in self._baseBranches:
matches = [m for m in self._otherBranches.keys() if (m.startswith(collname) and '_'.join(m.split('_')[1:]) != '')]
if len(matches) == 0:
self._otherBranches[b] = {
'type': typeStr,
'isVect': isVect,
'alias': False
}
else:
if varname != '':
self.AddCollection(collname)
self._collectionDict[collname][varname] = {
'type': typeStr,
'isVect': isVect,
'alias': False
}
for match in matches:
Expand All @@ -114,6 +118,7 @@ def AddBranch(self, b, btype=''):
self.AddCollection(collname)
self._collectionDict[collname][varname] = {
'type': typeStr,
'isVect': isVect,
'alias': False
}

Expand Down Expand Up @@ -160,7 +165,8 @@ def BuildCppCollection(self,collection,node,silent=True):
newNode = node
attributes = []
for aname in self.GetCollectionAttributes(collection):
attributes.append('%s %s'%(self._collectionDict[collection][aname]['type'], aname))
if self._collectionDict[collection][aname]['isVect']:
attributes.append('%s %s'%(self._collectionDict[collection][aname]['type'], aname))

if collection+'s' not in self._builtCollections:
self._builtCollections.append(collection+'s')
Expand Down
2 changes: 1 addition & 1 deletion TIMBER/Utilities/CollectionGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def GetKeyValForBranch(rdf, bname, includeType=True):
collType = str(rdf.GetColumnType(bname)).replace('ROOT::VecOps::RVec<','')
if collType.endswith('>'): collType = collType[:-1]
collType += '&'
if 'Bool_t' in collType: collType = collType.replace('Bool_t&','std::_Bit_reference')
if 'Bool_t' in collType: collType = collType.replace('Bool_t &','Bool_t&').replace('Bool_t&','std::_Bit_reference')
if includeType:
out = (collname, collType+' '+varname)
else:
Expand Down
6 changes: 2 additions & 4 deletions test/test_Analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,14 @@ def test_snapshot(self):
# assert True

def test_Correction(self):
c = Correction('testWeight','test/test_weight.cc')
c = Correction('test_weight','test/test_weight.cc')
self.a.Define('Jet_pt0','Jet_pt[0]')
self.a.AddCorrection(c,{'pt':'Jet_pt0'})
self.a.MakeWeightCols()
htemplate = ROOT.TH1F('th1','',100,0,1000)
hgroup = self.a.MakeTemplateHistos(htemplate,'Jet_pt0')
print ('TESTING GetWeightName: %s'%(self.a.GetWeightName(c,'up','')))
print ([hgroup[h].GetName() for h in hgroup.keys()])
assert self.a.GetWeightName(c,'up','') == 'weight__test_weight_up'
self.a.DrawTemplates(hgroup, './')
pass

def test_CommonVars(self):
assert sorted(self.a.CommonVars(["Muon","Tau"])) == sorted(['phi', 'pt', 'charge', 'eta', 'mass', 'genPartIdx', 'jetIdx'])
Expand Down

0 comments on commit 9ca9055

Please sign in to comment.