From b2ff4e91fd5c7be44711bf60ede81ff7501c2572 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Wed, 24 Mar 2021 13:02:01 +0000 Subject: [PATCH 01/65] Updated TreeStat's taxon set panel to be more like BEAUti --- src/dr/app/treestat/CharactersPanel.java | 11 +- src/dr/app/treestat/StatisticsPanel.java | 6 +- src/dr/app/treestat/TaxonSetsPanel.java | 968 +++++++++++++++++------ src/dr/app/treestat/TreeStatData.java | 15 +- src/dr/app/treestat/TreeStatFrame.java | 7 +- 5 files changed, 757 insertions(+), 250 deletions(-) diff --git a/src/dr/app/treestat/CharactersPanel.java b/src/dr/app/treestat/CharactersPanel.java index f2481ad486..7351797ea9 100644 --- a/src/dr/app/treestat/CharactersPanel.java +++ b/src/dr/app/treestat/CharactersPanel.java @@ -34,6 +34,7 @@ import java.util.ArrayList; import dr.app.gui.table.TableSorter; +import dr.evolution.util.Taxon; import jam.table.TableRenderer; import jam.framework.Exportable; @@ -408,7 +409,10 @@ public void actionPerformed(ActionEvent ae) { int saved1 = charactersTable.getSelectedRow(); int saved2 = statesTable.getSelectedRow(); int[] rows = excludedTaxaTable.getSelectedRows(); - ArrayList exclList = new ArrayList(treeStatData.allTaxa); + ArrayList exclList = new ArrayList<>(); + for (Taxon taxon : treeStatData.allTaxa) { + exclList.add(taxon.getId()); + } exclList.removeAll(selectedState.taxa); for (int row : rows) { selectedState.taxa.add(exclList.get(row)); @@ -556,7 +560,10 @@ public Object getValueAt(int row, int col) { if (included) { return selectedState.taxa.get(row); } else { - ArrayList exclList = new ArrayList(treeStatData.allTaxa); + ArrayList exclList = new ArrayList<>(); + for (Taxon taxon : treeStatData.allTaxa) { + exclList.add(taxon.getId()); + } exclList.removeAll(selectedState.taxa); return exclList.get(row); } diff --git a/src/dr/app/treestat/StatisticsPanel.java b/src/dr/app/treestat/StatisticsPanel.java index b421f7f0a0..4a80eb508d 100644 --- a/src/dr/app/treestat/StatisticsPanel.java +++ b/src/dr/app/treestat/StatisticsPanel.java @@ -326,7 +326,7 @@ public TreeSummaryStatistic createStatistic(TreeSummaryStatistic.Factory factory if (factory.allowsTaxonList()) { - for (Object taxonSet : treeStatData.taxonSets) { + for (Object taxonSet : treeStatData.taxonSets.values()) { taxonSetCombo.addItem(taxonSet); } @@ -413,9 +413,7 @@ public void itemStateChanged(ItemEvent e) { taxa.setId(t.name); //Iterator iter = t.taxa.iterator(); for (Object aTaxa : t.taxa) { - String id = (String) aTaxa; - Taxon taxon = new Taxon(id); - taxa.addTaxon(taxon); + taxa.addTaxon((Taxon) aTaxa); } statistic.setTaxonList(taxa); } else { diff --git a/src/dr/app/treestat/TaxonSetsPanel.java b/src/dr/app/treestat/TaxonSetsPanel.java index 2912e7a130..85c05b5a1e 100644 --- a/src/dr/app/treestat/TaxonSetsPanel.java +++ b/src/dr/app/treestat/TaxonSetsPanel.java @@ -1,5 +1,5 @@ /* - * TaxonSetsPanel.java + * TaxonSetPanel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * @@ -25,185 +25,734 @@ package dr.app.treestat; +import dr.app.beauti.ComboBoxRenderer; +import dr.app.beauti.util.PanelUtils; +import dr.app.gui.table.DateCellEditor; +import dr.evolution.util.Taxa; +import dr.evolution.util.Taxon; +import dr.evolution.util.TaxonList; import jam.framework.Exportable; +import jam.panels.ActionPanel; import jam.table.TableRenderer; +import jam.util.IconUtils; +import java.awt.*; +import java.awt.event.*; +import java.util.ArrayList; +import java.util.Collections; import javax.swing.*; +import javax.swing.border.TitledBorder; +import javax.swing.event.DocumentEvent; +import javax.swing.event.DocumentListener; import javax.swing.event.ListSelectionEvent; import javax.swing.event.ListSelectionListener; +import javax.swing.plaf.BorderUIResource; import javax.swing.table.AbstractTableModel; -import java.awt.*; -import java.awt.event.ActionEvent; -import java.util.ArrayList; - - +import javax.swing.table.JTableHeader; +import javax.swing.table.TableColumn; +import javax.swing.table.TableColumnModel; +import java.util.List; + +/** + * @author Andrew Rambaut + * @author Alexei Drummond + * @version $Id: TaxaPanel.java,v 1.1 2006/09/05 13:29:34 rambaut Exp $ + */ public class TaxonSetsPanel extends JPanel implements Exportable { + private static final long serialVersionUID = -3138832889782090814L; + + protected String TAXA; + protected String TAXON; - /** - * - */ - private static final long serialVersionUID = -9013475414423166476L; TreeStatFrame frame = null; TreeStatData treeStatData = null; TreeStatData.TaxonSet selectedTaxonSet = null; - JScrollPane scrollPane1 = null; - JTable taxonSetsTable = null; - TaxonSetsTableModel taxonSetsTableModel = null; + // private TaxonList taxa = null; + protected JTable taxonSetsTable = null; + private TableColumnModel tableColumnModel; + protected TaxonSetsTableModel taxonSetsTableModel = new TaxonSetsTableModel(); + ComboBoxRenderer comboBoxRenderer = new ComboBoxRenderer(); - JScrollPane scrollPane2 = null; - JTable excludedTaxaTable = null; - TaxaTableModel excludedTaxaTableModel = null; + protected JPanel taxonSetEditingPanel = null; - JScrollPane scrollPane3 = null; - JTable includedTaxaTable = null; - TaxaTableModel includedTaxaTableModel = null; + protected TreeStatData.TaxonSet currentTaxonSet = null; - public TaxonSetsPanel(TreeStatFrame frame, TreeStatData treeStatData) { + protected final List includedTaxa = new ArrayList<>(); + protected final List excludedTaxa = new ArrayList<>(); - this.frame = frame; - setOpaque(false); + private JTextField excludedTaxaSearchField = new JTextField(); + + protected JTable excludedTaxaTable = null; + protected TaxaTableModel excludedTaxaTableModel = null; + private JLabel excludedTaxaLabel = new JLabel(); + protected JComboBox excludedTaxonSetsComboBox = null; + protected boolean excludedSelectionChanging = false; + + private JTextField includedTaxaSearchField = new JTextField(); + + protected JTable includedTaxaTable = null; + protected TaxaTableModel includedTaxaTableModel = null; + private JLabel includedTaxaLabel = new JLabel(); + protected JComboBox includedTaxonSetsComboBox = null; + protected boolean includedSelectionChanging = false; + + public TaxonSetsPanel(TreeStatFrame parent, TreeStatData treeStatData) { + + this.frame = parent; this.treeStatData = treeStatData; + setText(false); - Icon addIcon = null, removeIcon = null, includeIcon = null, excludeIcon = null; - try { - addIcon = new ImageIcon(dr.app.util.Utils.getImage(this, "images/add.png")); - removeIcon = new ImageIcon(dr.app.util.Utils.getImage(this, "images/minus.png")); - includeIcon = new ImageIcon(dr.app.util.Utils.getImage(this, "images/include.png")); - excludeIcon = new ImageIcon(dr.app.util.Utils.getImage(this, "images/exclude.png")); - } catch (Exception e) { - // do nothing - } + // Taxon Sets + initTaxonSetsTable(taxonSetsTableModel); + + initTableColumn(); - JPanel buttonPanel = new JPanel(new BorderLayout()); - buttonPanel.setOpaque(false); - JButton importButton = new JButton(frame.getImportAction()); - importButton.setFocusable(false); - importButton.putClientProperty("JButton.buttonType", "textured"); - importButton.setMargin(new Insets(4,4,4,4)); - buttonPanel.add(importButton, BorderLayout.WEST); - buttonPanel.add(new JLabel(" To define taxon sets, first import a list of taxa (i.e., from the trees to be analysed)"), BorderLayout.SOUTH); + initPanel(addTaxonSetAction, removeTaxonSetAction); + } - // Taxon Sets - taxonSetsTableModel = new TaxonSetsTableModel(); - TableSorter sorter = new TableSorter(taxonSetsTableModel); - taxonSetsTable = new JTable(sorter); - sorter.addTableModelListener(taxonSetsTable); + protected void setText(boolean useStarBEAST) { + if (useStarBEAST) { + TAXA = "Species"; + TAXON = "Species set"; + } else { + TAXA = "Taxa"; + TAXON = "Taxon set"; + } + } - taxonSetsTable.getColumnModel().getColumn(0).setCellRenderer( - new TableRenderer(SwingConstants.LEFT, new Insets(0, 4, 0, 4))); + protected void initPanel(Action addTaxonSetAction, Action removeTaxonSetAction) { + JScrollPane scrollPane1 = new JScrollPane(taxonSetsTable, + JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, + JScrollPane.HORIZONTAL_SCROLLBAR_NEVER); - taxonSetsTable.getSelectionModel().addListSelectionListener(new ListSelectionListener() { - public void valueChanged(ListSelectionEvent evt) { taxonSetsTableSelectionChanged(); } - }); + ActionPanel actionPanel1 = new ActionPanel(false); + actionPanel1.setAddAction(addTaxonSetAction); + actionPanel1.setRemoveAction(removeTaxonSetAction); - scrollPane1 = new JScrollPane(taxonSetsTable, - JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, - JScrollPane.HORIZONTAL_SCROLLBAR_NEVER); + addTaxonSetAction.setEnabled(false); + removeTaxonSetAction.setEnabled(false); - JPanel buttonPanel1 = createAddRemoveButtonPanel(addTaxonSetAction, addIcon, "Create a new taxon set", - removeTaxonSetAction, removeIcon, "Remove a taxon set", - javax.swing.BoxLayout.X_AXIS); + JPanel controlPanel1 = new JPanel(new FlowLayout(FlowLayout.LEFT)); + controlPanel1.add(actionPanel1); // Excluded Taxon List excludedTaxaTableModel = new TaxaTableModel(false); - sorter = new TableSorter(excludedTaxaTableModel); - excludedTaxaTable = new JTable(sorter); - sorter.addTableModelListener(excludedTaxaTable); + excludedTaxaTable = new JTable(excludedTaxaTableModel); excludedTaxaTable.getColumnModel().getColumn(0).setCellRenderer( - new TableRenderer(SwingConstants.LEFT, new Insets(0, 4, 0, 4))); + new TableRenderer(SwingConstants.LEFT, new Insets(0, 4, 0, 4))); + excludedTaxaTable.getColumnModel().getColumn(0).setMinWidth(20); excludedTaxaTable.getSelectionModel().addListSelectionListener(new ListSelectionListener() { - public void valueChanged(ListSelectionEvent evt) { excludedTaxaTableSelectionChanged(); } + public void valueChanged(ListSelectionEvent evt) { + excludedTaxaTableSelectionChanged(); + } }); - scrollPane2 = new JScrollPane(excludedTaxaTable, - JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, - JScrollPane.HORIZONTAL_SCROLLBAR_NEVER); + JScrollPane scrollPane2 = new JScrollPane(excludedTaxaTable, + JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, + JScrollPane.HORIZONTAL_SCROLLBAR_NEVER); + + includedTaxonSetsComboBox = new JComboBox(new String[]{TAXON.toLowerCase() + "..."}); + excludedTaxonSetsComboBox = new JComboBox(new String[]{TAXON.toLowerCase() + "..."}); - JPanel buttonPanel2 = createAddRemoveButtonPanel(includeTaxonAction, includeIcon, "Include selected taxa in the taxon set", - excludeTaxonAction, excludeIcon, "Exclude selected taxa from the taxon set", - javax.swing.BoxLayout.Y_AXIS); + includedTaxaLabel.setText(""); + excludedTaxaLabel.setText(""); - // Included Taxon List + Box panel1 = new Box(BoxLayout.X_AXIS); + panel1.add(new JLabel("Select: ")); + panel1.setOpaque(false); + excludedTaxonSetsComboBox.setOpaque(false); + panel1.add(excludedTaxonSetsComboBox); + + // Included Taxon List includedTaxaTableModel = new TaxaTableModel(true); - sorter = new TableSorter(includedTaxaTableModel); - includedTaxaTable = new JTable(sorter); - sorter.addTableModelListener(includedTaxaTable); + includedTaxaTable = new JTable(includedTaxaTableModel); includedTaxaTable.getColumnModel().getColumn(0).setCellRenderer( - new TableRenderer(SwingConstants.LEFT, new Insets(0, 4, 0, 4))); + new TableRenderer(SwingConstants.LEFT, new Insets(0, 4, 0, 4))); + includedTaxaTable.getColumnModel().getColumn(0).setMinWidth(20); includedTaxaTable.getSelectionModel().addListSelectionListener(new ListSelectionListener() { - public void valueChanged(ListSelectionEvent evt) { includedTaxaTableSelectionChanged(); } + public void valueChanged(ListSelectionEvent evt) { + includedTaxaTableSelectionChanged(); + } }); + includedTaxaTable.doLayout(); + + JScrollPane scrollPane3 = new JScrollPane(includedTaxaTable, + JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, + JScrollPane.HORIZONTAL_SCROLLBAR_NEVER); + + Box panel2 = new Box(BoxLayout.X_AXIS); + panel2.add(new JLabel("Select: ")); + panel2.setOpaque(false); + includedTaxonSetsComboBox.setOpaque(false); + panel2.add(includedTaxonSetsComboBox); + + Icon includeIcon = null, excludeIcon = null; + try { + includeIcon = new ImageIcon(IconUtils.getImage(TreeStatApp.class, "images/include.png")); + excludeIcon = new ImageIcon(IconUtils.getImage(TreeStatApp.class, "images/exclude.png")); + } catch (Exception e) { + // do nothing + } - scrollPane3 = new JScrollPane(includedTaxaTable, - JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, - JScrollPane.HORIZONTAL_SCROLLBAR_NEVER); + JPanel buttonPanel = createAddRemoveButtonPanel(includeTaxonAction, includeIcon, "Include selected " + + TAXA.toLowerCase() + " in the " + TAXON.toLowerCase(), + excludeTaxonAction, excludeIcon, "Exclude selected " + TAXA.toLowerCase() + + " from the " + TAXON.toLowerCase(), BoxLayout.Y_AXIS); + + taxonSetEditingPanel = new JPanel(); + taxonSetEditingPanel.setBorder(BorderFactory.createTitledBorder("")); + taxonSetEditingPanel.setOpaque(false); + taxonSetEditingPanel.setLayout(new GridBagLayout()); + + excludedTaxaSearchField.setColumns(12); +// excludedTaxaSearchField.putClientProperty("JTextField.variant", "search"); + excludedTaxaSearchField.putClientProperty("Quaqua.TextField.style","search"); + excludedTaxaSearchField.putClientProperty("Quaqua.TextField.sizeVariant","small"); + includedTaxaSearchField.setColumns(12); +// includedTaxaSearchField.putClientProperty("JTextField.variant", "search"); + includedTaxaSearchField.putClientProperty("Quaqua.TextField.style","search"); + includedTaxaSearchField.putClientProperty("Quaqua.TextField.sizeVariant","small"); - JPanel panel = new JPanel(); - panel.setOpaque(false); - panel.setLayout(new GridBagLayout()); GridBagConstraints c = new GridBagConstraints(); - c.weightx = 0.333333; - c.weighty = 1; - c.fill = GridBagConstraints.BOTH; - c.anchor = GridBagConstraints.CENTER; - c.insets = new Insets(6,6,6,6); c.gridx = 0; c.gridy = 0; - panel.add(scrollPane1, c); - - c.weightx = 0.333333; + c.weightx = 0.5; c.weighty = 0; - c.fill = GridBagConstraints.NONE; - c.anchor = GridBagConstraints.LINE_START; - c.insets = new Insets(0,6,6,6); + c.fill = GridBagConstraints.HORIZONTAL; + c.anchor = GridBagConstraints.CENTER; + c.insets = new Insets(3, 6, 3, 0); + taxonSetEditingPanel.add(excludedTaxaSearchField, c); + c.gridx = 0; c.gridy = 1; - panel.add(buttonPanel1, c); - - c.weightx = 0.333333; + c.weightx = 0.5; c.weighty = 1; c.fill = GridBagConstraints.BOTH; c.anchor = GridBagConstraints.CENTER; - c.insets = new Insets(6,6,6,6); + c.insets = new Insets(0, 6, 0, 0); + taxonSetEditingPanel.add(scrollPane2, c); + + c.gridx = 0; + c.gridy = 2; + c.weightx = 0.5; + c.weighty = 0; + c.fill = GridBagConstraints.HORIZONTAL; + c.anchor = GridBagConstraints.CENTER; + c.insets = new Insets(0, 6, 3, 0); + taxonSetEditingPanel.add(excludedTaxaLabel, c); + + c.gridx = 0; + c.gridy = 3; + c.weightx = 0.5; + c.weighty = 0; + c.fill = GridBagConstraints.HORIZONTAL; + c.anchor = GridBagConstraints.CENTER; + c.insets = new Insets(0, 6, 3, 0); + taxonSetEditingPanel.add(panel1, c); + c.gridx = 1; c.gridy = 0; - panel.add(scrollPane2, c); - - c.weightx = 0.0; + c.weightx = 0; + c.weighty = 1; + c.gridheight = 4; c.fill = GridBagConstraints.NONE; c.anchor = GridBagConstraints.CENTER; - c.insets = new Insets(0,0,0,0); + c.insets = new Insets(12, 2, 12, 4); + taxonSetEditingPanel.add(buttonPanel, c); + c.gridx = 2; c.gridy = 0; - panel.add(buttonPanel2, c); + c.weightx = 0.5; + c.weighty = 0; + c.gridheight = 1; + c.fill = GridBagConstraints.HORIZONTAL; + c.anchor = GridBagConstraints.CENTER; + c.insets = new Insets(3, 0, 3, 6); + taxonSetEditingPanel.add(includedTaxaSearchField, c); - c.weightx = 0.333333; + c.gridx = 2; + c.gridy = 1; + c.weightx = 0.5; + c.weighty = 1; c.fill = GridBagConstraints.BOTH; c.anchor = GridBagConstraints.CENTER; - c.insets = new Insets(6,6,6,6); - c.gridx = 3; + c.insets = new Insets(0, 0, 0, 6); + taxonSetEditingPanel.add(scrollPane3, c); + + c.gridx = 2; + c.gridy = 2; + c.weightx = 0.5; + c.weighty = 0; + c.fill = GridBagConstraints.HORIZONTAL; + c.anchor = GridBagConstraints.CENTER; + c.insets = new Insets(0, 0, 3, 6); + taxonSetEditingPanel.add(includedTaxaLabel, c); + + c.gridx = 2; + c.gridy = 3; + c.weightx = 0.5; + c.weighty = 0; + c.fill = GridBagConstraints.HORIZONTAL; + c.anchor = GridBagConstraints.CENTER; + c.insets = new Insets(0, 0, 3, 6); + taxonSetEditingPanel.add(panel2, c); + + JPanel panel3 = new JPanel(); + panel3.setOpaque(false); + panel3.setLayout(new GridBagLayout()); + c = new GridBagConstraints(); + + c.gridx = 0; + c.gridy = 0; + c.weightx = 0.5; + c.weighty = 1; + c.fill = GridBagConstraints.BOTH; + c.anchor = GridBagConstraints.CENTER; + c.insets = new Insets(0, 0, 2, 12); + panel3.add(scrollPane1, c); + + c.gridx = 0; + c.gridy = 1; + c.weightx = 0; + c.weighty = 0; + c.fill = GridBagConstraints.NONE; + c.anchor = GridBagConstraints.WEST; + c.insets = new Insets(2, 0, 0, 12); + panel3.add(actionPanel1, c); + + c.gridx = 1; c.gridy = 0; - panel.add(scrollPane3, c); + c.weightx = 0.5; + c.weighty = 1; + c.fill = GridBagConstraints.BOTH; + c.anchor = GridBagConstraints.CENTER; + c.insets = new Insets(0, 0, 0, 0); + panel3.add(taxonSetEditingPanel, c); + + setOpaque(false); + setBorder(new BorderUIResource.EmptyBorderUIResource(new Insets(12, 12, 12, 12))); + setLayout(new BorderLayout(0, 0)); + add(panel3, BorderLayout.CENTER); + +// taxonSetsTable.addMouseListener(new MouseAdapter() { +// public void mouseClicked(MouseEvent e) { +// if (e.getClickCount() == 2) { +// JTable target = (JTable)e.getSource(); +// int row = target.getSelectedRow(); +// taxonSetsTableDoubleClicked(row); +// } +// } +// }); + + includedTaxaSearchField.getDocument().addDocumentListener(new DocumentListener() { + public void changedUpdate(DocumentEvent e) { + selectIncludedTaxa(includedTaxaSearchField.getText()); + } + + public void removeUpdate(DocumentEvent e) { + selectIncludedTaxa(includedTaxaSearchField.getText()); + } + + public void insertUpdate(DocumentEvent e) { + selectIncludedTaxa(includedTaxaSearchField.getText()); + } + } + ); + excludedTaxaSearchField.getDocument().addDocumentListener(new DocumentListener() { + public void changedUpdate(DocumentEvent e) { + selectExcludedTaxa(excludedTaxaSearchField.getText()); + } + + public void removeUpdate(DocumentEvent e) { + selectExcludedTaxa(excludedTaxaSearchField.getText()); + } + + public void insertUpdate(DocumentEvent e) { + selectExcludedTaxa(excludedTaxaSearchField.getText()); + } + } + ); + + + includedTaxaTable.addMouseListener(new MouseAdapter() { + public void mouseClicked(MouseEvent e) { + if (e.getClickCount() == 2) { + includeSelectedTaxa(); + } + } + }); + excludedTaxaTable.addMouseListener(new MouseAdapter() { + public void mouseClicked(MouseEvent e) { + if (e.getClickCount() == 2) { + excludeSelectedTaxa(); + } + } + }); + + includedTaxaTable.addFocusListener(new FocusAdapter() { + public void focusGained(FocusEvent focusEvent) { + excludedTaxaTable.clearSelection(); + } + }); + excludedTaxaTable.addFocusListener(new FocusAdapter() { + public void focusGained(FocusEvent focusEvent) { + includedTaxaTable.clearSelection(); + } + }); + + includedTaxaTable.getSelectionModel().addListSelectionListener(new ListSelectionListener() { + public void valueChanged(ListSelectionEvent e) { + if (!includedSelectionChanging) { + if (includedTaxonSetsComboBox.getSelectedIndex() != 0) { + includedTaxonSetsComboBox.setSelectedIndex(0); + } + includedTaxaSearchField.setText(""); + } + } + }); + includedTaxonSetsComboBox.addItemListener(new ItemListener() { + public void itemStateChanged(ItemEvent e) { + includedSelectionChanging = true; + includedTaxaTable.clearSelection(); + if (includedTaxonSetsComboBox.getSelectedIndex() > 0) { + String taxaName = includedTaxonSetsComboBox.getSelectedItem().toString(); + if (!taxaName.endsWith("...")) { + TreeStatData.TaxonSet taxonSet = treeStatData.taxonSets.get(taxaName); + if (taxonSet != null) { + for (int i = 0; i < taxonSet.taxa.getTaxonCount(); i++) { + Taxon taxon = taxonSet.taxa.getTaxon(i); + int index = includedTaxa.indexOf(taxon); + includedTaxaTable.getSelectionModel().addSelectionInterval(index, index); + + } + } + } + } + includedSelectionChanging = false; + } + }); + + excludedTaxaTable.getSelectionModel().addListSelectionListener(new ListSelectionListener() { + public void valueChanged(ListSelectionEvent e) { + if (!excludedSelectionChanging) { + if (excludedTaxonSetsComboBox.getSelectedIndex() != 0) { + excludedTaxonSetsComboBox.setSelectedIndex(0); + } + excludedTaxaSearchField.setText(""); + } + + } + }); + excludedTaxonSetsComboBox.addItemListener(new ItemListener() { + public void itemStateChanged(ItemEvent e) { + excludedSelectionChanging = true; + excludedTaxaTable.clearSelection(); + if (excludedTaxonSetsComboBox.getSelectedIndex() > 0) { + String taxaName = excludedTaxonSetsComboBox.getSelectedItem().toString(); + if (!taxaName.endsWith("...")) { + TreeStatData.TaxonSet taxonSet = treeStatData.taxonSets.get(taxaName); + if (taxonSet != null) { + for (int i = 0; i < taxonSet.taxa.getTaxonCount(); i++) { + Taxon taxon = taxonSet.taxa.getTaxon(i); + int index = excludedTaxa.indexOf(taxon); + excludedTaxaTable.getSelectionModel().addSelectionInterval(index, index); + + } + } + } + } + excludedSelectionChanging = false; + } + }); + + includedTaxaTable.doLayout(); + excludedTaxaTable.doLayout(); + } + + private void selectIncludedTaxa(String text) { + includedSelectionChanging = true; + includedTaxaTable.clearSelection(); + int index = 0; + for (Taxon taxon : includedTaxa) { + if (taxon.getId().contains(text)) { + includedTaxaTable.getSelectionModel().addSelectionInterval(index, index); + } + index ++; + + } + includedSelectionChanging = false; + } + + private void selectExcludedTaxa(String text) { + excludedSelectionChanging = true; + excludedTaxaTable.clearSelection(); + int index = 0; + for (Taxon taxon : excludedTaxa) { + if (taxon.getId().contains(text)) { + excludedTaxaTable.getSelectionModel().addSelectionInterval(index, index); + } + index ++; + + } + excludedSelectionChanging = false; + + } + + protected void initTableColumn() { + tableColumnModel = taxonSetsTable.getColumnModel(); + TableColumn tableColumn = tableColumnModel.getColumn(0); + tableColumn.setCellRenderer(new TableRenderer(SwingConstants.LEFT, new Insets(0, 4, 0, 4))); + tableColumn.setMinWidth(20); + } + + protected void initTaxonSetsTable(AbstractTableModel tableModel) { + taxonSetsTable = new JTable(tableModel); + taxonSetsTable.getSelectionModel().setSelectionMode(ListSelectionModel.SINGLE_SELECTION); + + taxonSetsTable.getSelectionModel().addListSelectionListener(new ListSelectionListener() { + public void valueChanged(ListSelectionEvent evt) { + taxonSetsTableSelectionChanged(); + } + }); + taxonSetsTable.doLayout(); + } + + protected void taxonSetChanged() { + currentTaxonSet.taxa.removeAllTaxa(); + for (Taxon anIncludedTaxa : includedTaxa) { + currentTaxonSet.taxa.addTaxon(anIncludedTaxa); + } + + setupTaxonSetsComboBoxes(); + + includedTaxaLabel.setText("" + includedTaxa.size() + " taxa included"); + excludedTaxaLabel.setText("" + excludedTaxa.size() + " taxa excluded"); + + frame.setDirty(); + } + + protected void resetPanel() { +// if (!treeStatData.hasData() || treeStatData.taxonSets == null || treeStatData.taxonSets.size() < 1) { +// setCurrentTaxonSet(null); +// } + } + + public JComponent getExportableComponent() { + return taxonSetsTable; + } + + private void taxonSetsTableSelectionChanged() { + if (taxonSetsTable.getSelectedRowCount() == 0) { + selectedTaxonSet = null; + removeTaxonSetAction.setEnabled(false); + } else { + String name = treeStatData.taxonSetNames.get(taxonSetsTable.getSelectedRow()); + selectedTaxonSet = treeStatData.taxonSets.get(name); + removeTaxonSetAction.setEnabled(true); + } + setCurrentTaxonSet(selectedTaxonSet); + includedTaxaTableModel.fireTableDataChanged(); + excludedTaxaTableModel.fireTableDataChanged(); + } + +// private void taxonSetsTableDoubleClicked(int row) { +// currentTaxonSet = (Taxa)taxonSets.get(row); +// +// Collections.sort(taxonSets); +// taxonSetsTableModel.fireTableDataChanged(); +// +// setCurrentTaxonSet(currentTaxonSet); +// +// int sel = taxonSets.indexOf(currentTaxonSet); +// taxonSetsTable.setRowSelectionInterval(sel, sel); +// } + + Action addTaxonSetAction = new AbstractAction("+") { - setLayout(new BorderLayout()); - setBorder(BorderFactory.createEmptyBorder(6,6,6,6)); - add(buttonPanel, BorderLayout.NORTH); - add(panel, BorderLayout.CENTER); + private static final long serialVersionUID = 20273987098143413L; + public void actionPerformed(ActionEvent ae) { + int taxonSetCount = treeStatData.taxonSets.size(); + TreeStatData.TaxonSet taxonSet = new TreeStatData.TaxonSet(); + taxonSet.name = "untitled" + taxonSetCount; + taxonSet.taxa = new Taxa(); + treeStatData.taxonSetNames.add(taxonSet.name); + treeStatData.taxonSets.put(taxonSet.name, taxonSet); + dataChanged(); + + int sel = treeStatData.taxonSets.size() - 1; + taxonSetsTable.setRowSelectionInterval(sel, sel); + + currentTaxonSet = taxonSet; + + taxonSetChanged(); + + taxonSetsTableModel.fireTableDataChanged(); + } + }; + + Action removeTaxonSetAction = new AbstractAction("-") { + + private static final long serialVersionUID = 6077578872870122265L; + + public void actionPerformed(ActionEvent ae) { + int row = taxonSetsTable.getSelectedRow(); + if (row != -1) { + String name = treeStatData.taxonSetNames.get(row); + TreeStatData.TaxonSet taxonSet = treeStatData.taxonSets.remove(name); + } + taxonSetChanged(); + + taxonSetsTableModel.fireTableDataChanged(); + + if (row >= treeStatData.taxonSets.size()) { + row = treeStatData.taxonSets.size() - 1; + } + if (row >= 0) { + taxonSetsTable.setRowSelectionInterval(row, row); + } else { + setCurrentTaxonSet(null); + } + } + }; + + protected void setCurrentTaxonSet(TreeStatData.TaxonSet taxonSet) { + + this.currentTaxonSet = taxonSet; + + includedTaxa.clear(); + excludedTaxa.clear(); + + if (currentTaxonSet != null) { + for (Taxon taxon : taxonSet.taxa) { + includedTaxa.add(taxon); + } + Collections.sort(includedTaxa); + + for (Taxon taxon : treeStatData.allTaxa) { + excludedTaxa.add(taxon); + } + excludedTaxa.removeAll(includedTaxa); + Collections.sort(excludedTaxa); + } + + setTaxonSetTitle(); + + setupTaxonSetsComboBoxes(); + + includedTaxaTableModel.fireTableDataChanged(); + excludedTaxaTableModel.fireTableDataChanged(); + } + + protected void setTaxonSetTitle() { + + if (currentTaxonSet == null) { + taxonSetEditingPanel.setBorder(BorderFactory.createTitledBorder("")); + taxonSetEditingPanel.setEnabled(false); + } else { + taxonSetEditingPanel.setEnabled(true); + taxonSetEditingPanel.setBorder(new TitledBorder(null, TAXON + ": " + currentTaxonSet.name, TitledBorder.DEFAULT_JUSTIFICATION, TitledBorder.ABOVE_TOP)); + } + } + + + protected void setupTaxonSetsComboBoxes() { + setupTaxonSetsComboBox(excludedTaxonSetsComboBox, excludedTaxa); + excludedTaxonSetsComboBox.setSelectedIndex(0); + setupTaxonSetsComboBox(includedTaxonSetsComboBox, includedTaxa); + includedTaxonSetsComboBox.setSelectedIndex(0); + } + + protected void setupTaxonSetsComboBox(JComboBox comboBox, List availableTaxa) { + comboBox.removeAllItems(); + + comboBox.addItem(TAXON.toLowerCase() + "..."); + for (TreeStatData.TaxonSet taxonSet : treeStatData.taxonSets.values()) { + // AR - as these comboboxes are just intended to be handy ways of selecting taxa, I have removed + // these requirements (it was just confusing why they weren't in the lists. +// if (taxa != currentTaxonSet) { +// if (isCompatible(taxa, availableTaxa)) { + comboBox.addItem(taxonSet.name); // have to add String, otherwise it will throw Exception to cast "taxa..." into Taxa +// } +// } + } + } + + /** + * The table on the left side of panel + */ + protected class TaxonSetsTableModel extends AbstractTableModel { + private static final long serialVersionUID = 3318461381525023153L; + + String[] columnNames = {"Taxon Set"}; + + public TaxonSetsTableModel() { + } + + public int getColumnCount() { + return columnNames.length; + } + + public String getColumnName(int column) { + return columnNames[column]; + } + + public int getRowCount() { + if (treeStatData == null) return 0; + return treeStatData.taxonSets.size(); + } + + public Object getValueAt(int rowIndex, int columnIndex) { + String name = treeStatData.taxonSetNames.get(rowIndex); + TreeStatData.TaxonSet taxonSet = treeStatData.taxonSets.get(name); + switch (columnIndex) { + case 0: + return taxonSet.name; + default: + throw new IllegalArgumentException("unknown column, " + columnIndex); + } + } + + public void setValueAt(Object aValue, int rowIndex, int columnIndex) { +// Taxa taxonSet = treeStatData.taxonSets.get(rowIndex); + TreeStatData.TaxonSet taxonSet = treeStatData.taxonSets.get(rowIndex); + switch (columnIndex) { + case 0: + taxonSet.name=aValue.toString(); + setTaxonSetTitle(); + break; + + default: + throw new IllegalArgumentException("unknown column, " + columnIndex); + } + } + + public boolean isCellEditable(int row, int col) { + return true; + } + + public Class getColumnClass(int columnIndex) { + switch (columnIndex) { + case 0: + return String.class; + default: + throw new IllegalArgumentException("unknown column, " + columnIndex); + } + } } - JPanel createAddRemoveButtonPanel(Action addAction, Icon addIcon, String addToolTip, - Action removeAction, Icon removeIcon, String removeToolTip, int axis) { + protected JPanel createAddRemoveButtonPanel(Action addAction, Icon addIcon, String addToolTip, + Action removeAction, Icon removeIcon, String removeToolTip, int axis) { - JPanel buttonPanel = new JPanel(); + JPanel buttonPanel = new JPanel(); buttonPanel.setLayout(new BoxLayout(buttonPanel, axis)); buttonPanel.setOpaque(false); JButton addButton = new JButton(addAction); @@ -211,8 +760,9 @@ JPanel createAddRemoveButtonPanel(Action addAction, Icon addIcon, String addTool addButton.setIcon(addIcon); addButton.setText(null); } - addButton.setToolTipText(addToolTip); - addButton.putClientProperty("JButton.buttonType", "textured"); + addButton.setToolTipText(addToolTip); + addButton.putClientProperty("JButton.buttonType", "roundRect"); + // addButton.putClientProperty("JButton.buttonType", "toolbar"); addButton.setOpaque(false); addAction.setEnabled(false); @@ -221,13 +771,14 @@ JPanel createAddRemoveButtonPanel(Action addAction, Icon addIcon, String addTool removeButton.setIcon(removeIcon); removeButton.setText(null); } - removeButton.setToolTipText(removeToolTip); - removeButton.putClientProperty("JButton.buttonType", "textured"); + removeButton.setToolTipText(removeToolTip); + removeButton.putClientProperty("JButton.buttonType", "roundRect"); +// removeButton.putClientProperty("JButton.buttonType", "toolbar"); removeButton.setOpaque(false); removeAction.setEnabled(false); buttonPanel.add(addButton); - buttonPanel.add(new JToolBar.Separator(new Dimension(6,6))); + buttonPanel.add(new JToolBar.Separator(new Dimension(6, 6))); buttonPanel.add(removeButton); return buttonPanel; @@ -242,18 +793,6 @@ public void dataChanged() { excludedTaxaTableModel.fireTableDataChanged(); } - private void taxonSetsTableSelectionChanged() { - if (taxonSetsTable.getSelectedRowCount() == 0) { - selectedTaxonSet = null; - removeTaxonSetAction.setEnabled(false); - } else { - selectedTaxonSet = treeStatData.taxonSets.get(taxonSetsTable.getSelectedRow()); - removeTaxonSetAction.setEnabled(true); - } - includedTaxaTableModel.fireTableDataChanged(); - excludedTaxaTableModel.fireTableDataChanged(); - } - private void excludedTaxaTableSelectionChanged() { if (excludedTaxaTable.getSelectedRowCount() == 0) { includeTaxonAction.setEnabled(false); @@ -270,128 +809,84 @@ private void includedTaxaTableSelectionChanged() { } } - public JComponent getExportableComponent() { - return this; - } + private void includeSelectedTaxa() { + int[] rows = excludedTaxaTable.getSelectedRows(); - Action addTaxonSetAction = new AbstractAction("+") { + List transfer = new ArrayList(); - /** - * - */ - private static final long serialVersionUID = 1831933175582860833L; - - public void actionPerformed(ActionEvent ae) { - TreeStatData.TaxonSet taxonSet = new TreeStatData.TaxonSet(); - taxonSet.name = "untitled"; - taxonSet.taxa = new ArrayList(); - treeStatData.taxonSets.add(taxonSet); - dataChanged(); + for (int r : rows) { + transfer.add(excludedTaxa.get(r)); + } - int sel = treeStatData.taxonSets.size() - 1; - taxonSetsTable.setRowSelectionInterval(sel, sel); - } - }; + includedTaxa.addAll(transfer); + Collections.sort(includedTaxa); - Action removeTaxonSetAction = new AbstractAction("-") { + excludedTaxa.removeAll(includedTaxa); - /** - * - */ - private static final long serialVersionUID = -8662527333546044639L; + includedTaxaTableModel.fireTableDataChanged(); + excludedTaxaTableModel.fireTableDataChanged(); - public void actionPerformed(ActionEvent ae) { - int saved = taxonSetsTable.getSelectedRow(); - int row = taxonSetsTable.getSelectedRow(); - if (row != -1) { - treeStatData.taxonSets.remove(row); - } - dataChanged(); - if (saved >= treeStatData.taxonSets.size()) saved = treeStatData.taxonSets.size() - 1; - taxonSetsTable.setRowSelectionInterval(saved, saved); - } - }; + includedTaxaTable.getSelectionModel().clearSelection(); + for (Taxon taxon : transfer) { + int row = includedTaxa.indexOf(taxon); + includedTaxaTable.getSelectionModel().addSelectionInterval(row, row); + } - Action includeTaxonAction = new AbstractAction("->") { - /** - * - */ - private static final long serialVersionUID = -1875904513948242608L; + taxonSetChanged(); + } - public void actionPerformed(ActionEvent ae) { - int saved = taxonSetsTable.getSelectedRow(); - int[] rows = excludedTaxaTable.getSelectedRows(); - ArrayList exclList = new ArrayList(treeStatData.allTaxa); - exclList.removeAll(selectedTaxonSet.taxa); - for (int row : rows) { - selectedTaxonSet.taxa.add(exclList.get(row)); - } - dataChanged(); - taxonSetsTable.setRowSelectionInterval(saved, saved); - } - }; + private void excludeSelectedTaxa() { + int[] rows = includedTaxaTable.getSelectedRows(); - Action excludeTaxonAction = new AbstractAction("<-") { + List transfer = new ArrayList(); - /** - * - */ - private static final long serialVersionUID = 4523480086490780822L; + for (int r : rows) { + transfer.add(includedTaxa.get(r)); + } - public void actionPerformed(ActionEvent ae) { - int saved = taxonSetsTable.getSelectedRow(); - int[] rows = includedTaxaTable.getSelectedRows(); - for (int i = rows.length - 1; i >= 0 ; i--) { - selectedTaxonSet.taxa.remove(rows[i]); - } - dataChanged(); - taxonSetsTable.setRowSelectionInterval(saved, saved); - } - }; + excludedTaxa.addAll(transfer); + Collections.sort(excludedTaxa); - class TaxonSetsTableModel extends AbstractTableModel { + includedTaxa.removeAll(excludedTaxa); - /** - * - */ - private static final long serialVersionUID = 219223813257870207L; + includedTaxaTableModel.fireTableDataChanged(); + excludedTaxaTableModel.fireTableDataChanged(); - public TaxonSetsTableModel() { + excludedTaxaTable.getSelectionModel().clearSelection(); + for (Taxon taxon : transfer) { + int row = excludedTaxa.indexOf(taxon); + excludedTaxaTable.getSelectionModel().addSelectionInterval(row, row); } - public int getColumnCount() { - return 1; - } + taxonSetChanged(); + } - public int getRowCount() { - return treeStatData.taxonSets.size(); - } + Action includeTaxonAction = new AbstractAction("->") { + /** + * + */ + private static final long serialVersionUID = 7510299673661594128L; - public Object getValueAt(int row, int col) { - return (treeStatData.taxonSets.get(row)).name; + public void actionPerformed(ActionEvent ae) { + includeSelectedTaxa(); } + }; - public void setValueAt(Object value, int row, int col) { - (treeStatData.taxonSets.get(row)).name = (String)value; - } + Action excludeTaxonAction = new AbstractAction("<-") { - public boolean isCellEditable(int row, int col) { - return true; - } + /** + * + */ + private static final long serialVersionUID = 449692708602410206L; - public String getColumnName(int column) { - return "Taxon Sets"; + public void actionPerformed(ActionEvent ae) { + excludeSelectedTaxa(); } + }; - public Class getColumnClass(int c) {return getValueAt(0, c).getClass();} - } - - class TaxaTableModel extends AbstractTableModel { + class TaxaTableModel extends AbstractTableModel { - /** - * - */ - private static final long serialVersionUID = 1559408662356843275L; + private static final long serialVersionUID = -8027482229525938010L; boolean included; public TaxaTableModel(boolean included) { @@ -403,35 +898,36 @@ public int getColumnCount() { } public int getRowCount() { - if (selectedTaxonSet == null) return 0; + if (currentTaxonSet == null) return 0; if (included) { - return selectedTaxonSet.taxa.size(); + return includedTaxa.size(); } else { - return treeStatData.allTaxa.size() - selectedTaxonSet.taxa.size(); + return excludedTaxa.size(); } } public Object getValueAt(int row, int col) { if (included) { - return selectedTaxonSet.taxa.get(row); + return includedTaxa.get(row).getId(); } else { - ArrayList exclList = new ArrayList(treeStatData.allTaxa); - exclList.removeAll(selectedTaxonSet.taxa); - return exclList.get(row); + return excludedTaxa.get(row).getId(); } } - public boolean isCellEditable(int row, int col) { - return false; - } + public boolean isCellEditable(int row, int col) { + return false; + } public String getColumnName(int column) { - if (included) return "Included Taxa"; - else return "Excluded Taxa"; + if (included) return "Included " + TAXA; + else return "Excluded " + TAXA; } - public Class getColumnClass(int c) {return getValueAt(0, c).getClass();} + public Class getColumnClass(int c) { + return getValueAt(0, c).getClass(); + } } + } diff --git a/src/dr/app/treestat/TreeStatData.java b/src/dr/app/treestat/TreeStatData.java index cdc817df03..87a4d7436f 100644 --- a/src/dr/app/treestat/TreeStatData.java +++ b/src/dr/app/treestat/TreeStatData.java @@ -26,13 +26,13 @@ package dr.app.treestat; import dr.app.treestat.statistics.SummaryStatisticDescription; +import dr.evolution.util.Taxa; +import dr.evolution.util.Taxon; +import dr.evolution.util.TaxonList; import org.jdom.Document; import org.jdom.Element; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.util.*; public class TreeStatData { public static final String version = "1.0"; @@ -41,14 +41,15 @@ public TreeStatData() { } // Data options - public Set allTaxa = new HashSet(); - public List taxonSets = new ArrayList(); + public Set allTaxa = new HashSet<>(); + public List taxonSetNames = new ArrayList<>(); + public Map taxonSets = new HashMap<>(); public List characters = new ArrayList(); public List statistics = new ArrayList(); public static class TaxonSet { String name; - List taxa; + Taxa taxa; public String toString() { return name; } } diff --git a/src/dr/app/treestat/TreeStatFrame.java b/src/dr/app/treestat/TreeStatFrame.java index 47fb7f296c..1c26277676 100644 --- a/src/dr/app/treestat/TreeStatFrame.java +++ b/src/dr/app/treestat/TreeStatFrame.java @@ -26,6 +26,7 @@ package dr.app.treestat; import dr.evolution.tree.TreeUtils; +import dr.evolution.util.Taxon; import jam.framework.Application; import jam.framework.DocumentFrame; import jam.util.IconUtils; @@ -34,6 +35,7 @@ import javax.swing.plaf.BorderUIResource; import java.awt.*; import java.io.*; +import java.util.HashSet; import dr.evolution.io.Importer; import dr.evolution.io.NexusImporter; @@ -190,7 +192,10 @@ protected void importFromFile(File file) throws IOException, Importer.ImportExce tree = importer.importTree(null); } - treeStatData.allTaxa = TreeUtils.getLeafSet(tree); + treeStatData.allTaxa = new HashSet(); + for (String taxonName : TreeUtils.getLeafSet(tree)) { + treeStatData.allTaxa.add(new Taxon(taxonName)); + } statusLabel.setText(Integer.toString(treeStatData.allTaxa.size()) + " taxa loaded."); reader.close(); From f8fc902553e540422d0bda1016cab5ab28161fb7 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Mon, 3 May 2021 08:58:26 +0100 Subject: [PATCH 02/65] Update to use VAqua Mac L&F instead of Quaqua --- build.xml | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/build.xml b/build.xml index 2db36273c4..6f8d00c2c4 100644 --- a/build.xml +++ b/build.xml @@ -255,6 +255,26 @@ + + + + + + + + + + + + + + + + + + + + @@ -262,6 +282,11 @@ + + + + + From ba10290b28512e3da8f4e787a072e6c9775c856c Mon Sep 17 00:00:00 2001 From: xji3 Date: Wed, 15 Jun 2022 14:09:56 -0500 Subject: [PATCH 03/65] Bug fix: Use LinkedHashSet in Model sets to reserve insertion order for CheckPoint save and load --- src/dr/app/checkpoint/BeastCheckpointer.java | 9 +++++++-- src/dr/inference/model/Model.java | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/dr/app/checkpoint/BeastCheckpointer.java b/src/dr/app/checkpoint/BeastCheckpointer.java index 9d8472c366..ac26de775d 100644 --- a/src/dr/app/checkpoint/BeastCheckpointer.java +++ b/src/dr/app/checkpoint/BeastCheckpointer.java @@ -286,7 +286,9 @@ protected boolean writeStateToFile(File file, long state, double lnL, MarkovChai //check up front if there are any TreeParameterModel objects for (Model model : Model.CONNECTED_MODEL_SET) { if (model instanceof TreeParameterModel) { - //System.out.println("\nDetected TreeParameterModel: " + ((TreeParameterModel) model).toString()); + if (DEBUG) { + System.out.println("\nSave TreeParameterModel: " + model.getClass().getSimpleName()); + } traitModels.add((TreeParameterModel) model); } } @@ -503,7 +505,7 @@ protected long readStateFromFile(File file, MarkovChain markovChain, double[] ln // load the tree models last as we get the node heights from the tree (not the parameters which // which may not be associated with the right node - Set expectedTreeModelNames = new HashSet(); + Set expectedTreeModelNames = new LinkedHashSet<>(); //store list of TreeModels for debugging purposes ArrayList treeModelList = new ArrayList(); @@ -527,6 +529,9 @@ protected long readStateFromFile(File file, MarkovChain markovChain, double[] ln //first add all TreeParameterModels to a list if (model instanceof TreeParameterModel) { + if (DEBUG) { + System.out.println("\nLoad TreeParameterModel: " + model.getClass().getSimpleName()); + } traitModels.add((TreeParameterModel)model); } diff --git a/src/dr/inference/model/Model.java b/src/dr/inference/model/Model.java index b67ac53839..96f37d305c 100644 --- a/src/dr/inference/model/Model.java +++ b/src/dr/inference/model/Model.java @@ -170,8 +170,8 @@ public int getListenerCount() { // set to store all created models - final static Set FULL_MODEL_SET = new HashSet(); - final static Set CONNECTED_MODEL_SET = new HashSet(); + final static Set FULL_MODEL_SET = new LinkedHashSet(); + final static Set CONNECTED_MODEL_SET = new LinkedHashSet(); } From a26af153d4f9210d4f919ab051cc7e2037a4a9f5 Mon Sep 17 00:00:00 2001 From: "Marc A. Suchard" Date: Mon, 20 Jun 2022 13:23:42 -0700 Subject: [PATCH 04/65] update sampled trait example --- .../continuous/RacRABV_homogeneous.xml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/Phylogeography/continuous/RacRABV_homogeneous.xml b/examples/Phylogeography/continuous/RacRABV_homogeneous.xml index 1f7ab84149..ec33bac5c1 100644 --- a/examples/Phylogeography/continuous/RacRABV_homogeneous.xml +++ b/examples/Phylogeography/continuous/RacRABV_homogeneous.xml @@ -442,7 +442,7 @@ - + @@ -461,7 +461,7 @@ - + @@ -486,7 +486,7 @@ - + @@ -501,7 +501,7 @@ - + @@ -579,6 +579,7 @@ + + From 43757f5a8e47fba214ccd11322526ad9bc5a1f60 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Sun, 26 Jun 2022 17:46:53 -0700 Subject: [PATCH 05/65] Adding Include/Exclude to TreeIntervalsParser --- .../coalescent/TreeIntervalsParser.java | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/dr/evomodelxml/coalescent/TreeIntervalsParser.java b/src/dr/evomodelxml/coalescent/TreeIntervalsParser.java index 1019224114..cfa118a803 100644 --- a/src/dr/evomodelxml/coalescent/TreeIntervalsParser.java +++ b/src/dr/evomodelxml/coalescent/TreeIntervalsParser.java @@ -2,14 +2,19 @@ import dr.evolution.tree.Tree; import dr.evolution.tree.TreeUtils; +import dr.evolution.util.TaxonList; import dr.evomodel.coalescent.TreeIntervals; import dr.evomodel.tree.TreeModel; import dr.xml.*; +import java.util.ArrayList; +import java.util.List; + public class TreeIntervalsParser extends AbstractXMLObjectParser{ public static final String TREE_INTERVALS = "treeIntervals"; - public static final String TREE = "tree"; + public static final String INCLUDE = "include"; + public static final String EXCLUDE = "exclude"; public String getParserName() { return TREE_INTERVALS; @@ -19,8 +24,23 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { Tree tree = (Tree) xo.getChild(Tree.class); + TaxonList includeSubtree = null; + + if (xo.hasChildNamed(INCLUDE)) { + includeSubtree = (TaxonList) xo.getElementFirstChild(INCLUDE); + } + + List excludeSubtrees = new ArrayList<>(); + + if (xo.hasChildNamed(EXCLUDE)) { + XMLObject cxo = xo.getChild(EXCLUDE); + for (int i = 0; i < cxo.getChildCount(); i++) { + excludeSubtrees.add((TaxonList) cxo.getChild(i)); + } + } + try { - return new TreeIntervals(tree, null, null); + return new TreeIntervals(tree, includeSubtree, excludeSubtrees); } catch (TreeUtils.MissingTaxonException mte) { throw new XMLParseException("Taxon, " + mte + ", in " + getParserName() + " was not found in the tree."); } From fac4164a343bb05ebd2f8b48462e10668162907a Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Sat, 16 Jul 2022 19:27:40 +0200 Subject: [PATCH 06/65] small bug fix for GSS - better late than never --- .../MarginalLikelihoodEstimationGenerator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java b/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java index 572e4628c0..9d997ecb0e 100644 --- a/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java +++ b/src/dr/app/beauti/components/marginalLikelihoodEstimation/MarginalLikelihoodEstimationGenerator.java @@ -993,7 +993,7 @@ public void writeMLE(XMLWriter writer, MarginalLikelihoodEstimationOptions optio new Attribute.Default(PriorParsers.MEAN, "" + Math.log(2)), new Attribute.Default(PriorParsers.OFFSET, "" + 0.0) }); - writer.writeIDref("statistic", model.getPrefix() + ClockType.LOCAL_CLOCK + ".changes"); + writer.writeIDref("statistic", model.getPrefix() + "rateChanges"); writer.writeCloseTag(PriorParsers.POISSON_PRIOR); writer.writeOpenTag(PriorParsers.GAMMA_PRIOR, From 51a7c2d82e74527d1ee591347d41e8afdcacef03 Mon Sep 17 00:00:00 2001 From: "Marc A. Suchard" Date: Tue, 23 Aug 2022 09:53:45 -0700 Subject: [PATCH 07/65] untested patch for Mac build on newer OSs that do not find BEAGLE --- build.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.xml b/build.xml index 2db36273c4..4e77206ef7 100644 --- a/build.xml +++ b/build.xml @@ -662,7 +662,7 @@ useJavaXKey="true" icon="${common_dir}/icons/beast.icns" jvmversion="${jvm_version}" - vmoptions="-Xmx2048M" + vmoptions="-Xmx2048M -Djava.library.path=/usr/local/lib" arguments="-window -working -options" highresolutioncapable="true" version="${version}" From cbf48b303ff9e3905a9ce5680faaf280dc4e2765 Mon Sep 17 00:00:00 2001 From: Marc Suchard Date: Wed, 21 Sep 2022 10:14:28 -0700 Subject: [PATCH 08/65] manual merge Xiang fix to checking BEAGLE versions --- .../BeagleFunctionality.java | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/BeagleFunctionality.java b/src/dr/evomodel/treedatalikelihood/BeagleFunctionality.java index 67d41c1e7e..485eab14f0 100644 --- a/src/dr/evomodel/treedatalikelihood/BeagleFunctionality.java +++ b/src/dr/evomodel/treedatalikelihood/BeagleFunctionality.java @@ -10,25 +10,36 @@ */ public class BeagleFunctionality { + private static boolean checkGTEVersion(int[] versionNumbers){ + int[] beagleVersionNumbers = BeagleInfo.getVersionNumbers(); + if (versionNumbers.length == 0 || beagleVersionNumbers.length == 0) + return false; + for (int i = 0; i < versionNumbers.length && i < beagleVersionNumbers.length; i++){ + if (beagleVersionNumbers[i] > versionNumbers[i]) + return true; + if (beagleVersionNumbers[i] < versionNumbers[i]) + return false; + } + return true; + } + public static boolean IS_THREAD_COUNT_COMPATIBLE() { - int[] versionNumbers = BeagleInfo.getVersionNumbers(); - return versionNumbers.length != 0 && versionNumbers[0] >= 3 && versionNumbers[1] >= 1; + return checkGTEVersion(new int[]{3,1}); } public static boolean IS_ODD_STATE_SSE_FIXED() { // SSE for odd state counts fixed in BEAGLE 3.1.3 - int[] versionNumbers = BeagleInfo.getVersionNumbers(); - return versionNumbers.length != 0 && versionNumbers[0] >= 3 && versionNumbers[1] >= 1 && versionNumbers[2] >= 3; + return checkGTEVersion(new int[]{3,1,3}); } static boolean IS_PRE_ORDER_SUPPORTED() { int[] versionNumbers = BeagleInfo.getVersionNumbers(); - return versionNumbers.length != 0 && versionNumbers[0] >= 3 && versionNumbers[1] >= 2; + return checkGTEVersion(new int[]{3,2}); } static boolean IS_MULTI_PARTITION_COMPATIBLE() { int[] versionNumbers = BeagleInfo.getVersionNumbers(); - return versionNumbers.length != 0 && versionNumbers[0] >= 3; + return checkGTEVersion(new int[]{3}); } public static List parseSystemPropertyIntegerArray(String propertyName) { From dd61c4e35262459860e5a4fb06cc9f7b0883a028 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Thu, 20 Oct 2022 22:42:41 +0100 Subject: [PATCH 09/65] Fix some logic in the parser --- src/dr/evomodelxml/coalescent/CoalescentSimulatorParser.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/evomodelxml/coalescent/CoalescentSimulatorParser.java b/src/dr/evomodelxml/coalescent/CoalescentSimulatorParser.java index 25988a54f3..28fe5d5233 100644 --- a/src/dr/evomodelxml/coalescent/CoalescentSimulatorParser.java +++ b/src/dr/evomodelxml/coalescent/CoalescentSimulatorParser.java @@ -82,7 +82,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (taxonLists.size() == 0) { if (subtrees.size() == 1) { return subtrees.get(0); - } if (constraintsTree==null){ + } else if (subtrees.size() == 0 && constraintsTree==null) { throw new XMLParseException("Expected at least one taxonList or two subtrees or a constraints tree in " + getParserName() + " element."); } From 2b9a06bba8de299e9e9c5b5bd656431d5cecd96e Mon Sep 17 00:00:00 2001 From: "Marc A. Suchard" Date: Tue, 6 Dec 2022 09:21:57 -0800 Subject: [PATCH 10/65] allow relative weights at root for Markov-modulated model --- .../MarkovModulatedFrequencyModel.java | 30 +++++++++++++++++-- .../MarkovModulatedSubstitutionModel.java | 7 +++-- ...arkovModulatedSubstitutionModelParser.java | 11 +++++-- 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/src/dr/evomodel/substmodel/MarkovModulatedFrequencyModel.java b/src/dr/evomodel/substmodel/MarkovModulatedFrequencyModel.java index b44bf4b97c..87035f728d 100644 --- a/src/dr/evomodel/substmodel/MarkovModulatedFrequencyModel.java +++ b/src/dr/evomodel/substmodel/MarkovModulatedFrequencyModel.java @@ -47,7 +47,8 @@ public class MarkovModulatedFrequencyModel extends FrequencyModel { - MarkovModulatedFrequencyModel(String name, List freqModels, Parameter switchingRates) { + MarkovModulatedFrequencyModel(String name, List freqModels, Parameter switchingRates, + Parameter relativeWeight) { super(name); this.freqModels = freqModels; int freqCount = 0; @@ -72,6 +73,12 @@ public class MarkovModulatedFrequencyModel extends FrequencyModel { DoubleMatrix2D d = new DenseDoubleMatrix2D(numBaseModel, numBaseModel); d.set(0, 0, 1.0); + + this.relativeWeight = relativeWeight; + if (relativeWeight != null) { + addVariable(relativeWeight); + } + checkRelativeWeight(); } public void setFrequency(int i, double value) { @@ -81,7 +88,13 @@ public void setFrequency(int i, double value) { public double getFrequency(int index) { int whichModel = index / stateCount; int whichState = index % stateCount; - double relativeFreq = freqModels.get(whichModel).getFrequency(whichState) / numBaseModel; + double relativeFreq = freqModels.get(whichModel).getFrequency(whichState); + + if (relativeWeight != null) { + relativeFreq *= relativeWeight.getParameterValue(whichModel); + } else { + relativeFreq /= numBaseModel; + } // Scale by stationary distribution over hidden classes if (numBaseModel > 1) { @@ -168,6 +181,18 @@ protected void handleVariableChangedEvent(Variable variable, int index, Paramete } } + private void checkRelativeWeight() { + if (relativeWeight != null) { + double sum = 0.0; + for (double x : relativeWeight.getParameterValues()) { + sum += x; + } + if (sum != 1.0) { + throw new IllegalArgumentException("Relative weights must sum to 1.0"); + } + } + } + protected void handleModelChangedEvent(Model model, Object object, int index) { fireModelChanged(); } @@ -186,6 +211,7 @@ public Parameter getFrequencyParameter() { private final int totalFreqCount; private final int stateCount; private final Parameter switchingRates; + private final Parameter relativeWeight; private double[] baseStationaryDistribution; private double[] storedBaseStationaryDistribution; diff --git a/src/dr/evomodel/substmodel/MarkovModulatedSubstitutionModel.java b/src/dr/evomodel/substmodel/MarkovModulatedSubstitutionModel.java index bbb031cfb3..51f215a2be 100644 --- a/src/dr/evomodel/substmodel/MarkovModulatedSubstitutionModel.java +++ b/src/dr/evomodel/substmodel/MarkovModulatedSubstitutionModel.java @@ -76,7 +76,7 @@ public MarkovModulatedSubstitutionModel(String name, Parameter switchingRates, DataType dataType, EigenSystem eigenSystem) { - this(name, baseModels, switchingRates, dataType, eigenSystem, null, false, null); + this(name, baseModels, switchingRates, dataType, eigenSystem, null, false, null, null); } public MarkovModulatedSubstitutionModel(String name, @@ -86,7 +86,8 @@ public MarkovModulatedSubstitutionModel(String name, EigenSystem eigenSystem, Parameter rateScalar, boolean geometricRates, - SiteRateModel gammaRateModel) { + SiteRateModel gammaRateModel, + Parameter relativeWeights) { // super(name, dataType, null, eigenSystem); super(name, dataType, null, null); @@ -121,7 +122,7 @@ public MarkovModulatedSubstitutionModel(String name, } // This constructor also checks that all models have the same base stateCount - freqModel = new MarkovModulatedFrequencyModel("mm", freqModels, switchingRates); + freqModel = new MarkovModulatedFrequencyModel("mm", freqModels, switchingRates, relativeWeights); addModel(freqModel); if (stateCount != stateSizes) { diff --git a/src/dr/evomodelxml/substmodel/MarkovModulatedSubstitutionModelParser.java b/src/dr/evomodelxml/substmodel/MarkovModulatedSubstitutionModelParser.java index 2f557bfad2..14bb794ac8 100644 --- a/src/dr/evomodelxml/substmodel/MarkovModulatedSubstitutionModelParser.java +++ b/src/dr/evomodelxml/substmodel/MarkovModulatedSubstitutionModelParser.java @@ -49,6 +49,7 @@ public class MarkovModulatedSubstitutionModelParser extends AbstractXMLObjectPar public static final String RATE_SCALAR = "rateScalar"; public static final String GEOMETRIC_RATES = "geometricRates"; public static final String RENORMALIZE = "renormalize"; + private static final String RELATIVE_WEIGHTS = "relativeWeights"; public String getParserName() { return MARKOV_MODULATED_MODEL; @@ -85,8 +86,13 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } } + Parameter relativeWeights = null; + if (xo.hasChildNamed(RELATIVE_WEIGHTS)) { + relativeWeights = (Parameter) xo.getElementFirstChild(RELATIVE_WEIGHTS); + } + MarkovModulatedSubstitutionModel mmsm = new MarkovModulatedSubstitutionModel(xo.getId(), substModels, switchingRates, dataType, null, - rateScalar, geometricRates, siteRateModel); + rateScalar, geometricRates, siteRateModel, relativeWeights); if (xo.getAttribute(RENORMALIZE, false)) { mmsm.setNormalization(true); @@ -124,7 +130,8 @@ public XMLSyntaxRule[] getSyntaxRules() { AttributeRule.newBooleanRule(RENORMALIZE, true), new ElementRule(RATE_SCALAR, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true), - + new ElementRule(RELATIVE_WEIGHTS, + new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true), new ElementRule(SiteRateModel.class, true), }; } From c1320154ecb285fb32ba7eeac4d6f19f87576cfb Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 21:46:03 +0100 Subject: [PATCH 11/65] adding HMC operator set type --- src/dr/app/beauti/types/OperatorSetType.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dr/app/beauti/types/OperatorSetType.java b/src/dr/app/beauti/types/OperatorSetType.java index 9b09683ba1..6ddf6fac82 100644 --- a/src/dr/app/beauti/types/OperatorSetType.java +++ b/src/dr/app/beauti/types/OperatorSetType.java @@ -34,7 +34,8 @@ public enum OperatorSetType { FIXED_TREE_TOPOLOGY("fixed tree topology"), NEW_TREE_MIX("new tree operator mix"), ADAPTIVE_MULTIVARIATE("adaptive multivariate"), - CUSTOM("custom operator mix"); + CUSTOM("custom operator mix"), + HMC("Hamiltonian Monte Carlo"); OperatorSetType(String displayName) { this.displayName = displayName; From 57aacb77dfb2bad3ad2113a2ecf73a489c2c3457 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 21:47:34 +0100 Subject: [PATCH 12/65] adding skygrid with HMC operator type --- src/dr/app/beauti/types/OperatorType.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dr/app/beauti/types/OperatorType.java b/src/dr/app/beauti/types/OperatorType.java index c9455f3657..70df357186 100644 --- a/src/dr/app/beauti/types/OperatorType.java +++ b/src/dr/app/beauti/types/OperatorType.java @@ -66,6 +66,7 @@ public enum OperatorType { WIDE_EXCHANGE("wideExchange"), GMRF_GIBBS_OPERATOR("gmrfGibbsOperator"), SKY_GRID_GIBBS_OPERATOR("gmrfGibbsOperator"), + SKY_GRID_HMC_OPERATOR("gmrfHMCOperator"), // PRECISION_GMRF_OPERATOR("precisionGMRFOperator"), WILSON_BALDING("wilsonBalding"); From 280a6f267704793316bd0b0b21cc178ebc52f32e Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 21:49:44 +0100 Subject: [PATCH 13/65] adding HMC operator type to JComboBox --- src/dr/app/beauti/operatorspanel/OperatorsPanel.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dr/app/beauti/operatorspanel/OperatorsPanel.java b/src/dr/app/beauti/operatorspanel/OperatorsPanel.java index 44efaeb6a0..12ac35d0c3 100644 --- a/src/dr/app/beauti/operatorspanel/OperatorsPanel.java +++ b/src/dr/app/beauti/operatorspanel/OperatorsPanel.java @@ -66,7 +66,8 @@ public class OperatorsPanel extends BeautiPanel implements Exportable { OperatorSetType.FIXED_TREE_TOPOLOGY, OperatorSetType.NEW_TREE_MIX, OperatorSetType.ADAPTIVE_MULTIVARIATE, - OperatorSetType.CUSTOM + OperatorSetType.CUSTOM, + OperatorSetType.HMC }); public List operators = new ArrayList(); From 5eae8610d2a3e041f7f6629cb8e4cb372255ee50 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 21:50:14 +0100 Subject: [PATCH 14/65] adding to do for future development --- src/dr/app/beauti/options/ModelOptions.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/dr/app/beauti/options/ModelOptions.java b/src/dr/app/beauti/options/ModelOptions.java index e0bd1d9525..98c9aa2cc7 100644 --- a/src/dr/app/beauti/options/ModelOptions.java +++ b/src/dr/app/beauti/options/ModelOptions.java @@ -137,6 +137,9 @@ public Parameter createParameterUniformPrior(String name, String description, Pr .initial(initial).uniformLower(uniformLower).uniformUpper(uniformUpper).build(parameters); } + //TODO think about having a createParameter method with a String priorID argument + //public Parameter createParameterGammaPrior(String name, String priorID, String description, ...) { + public Parameter createParameterGammaPrior(String name, String description, PriorScaleType scaleType, double initial, double shape, double scale, boolean priorFixed) { return createParameterGammaPrior(name, description, scaleType, initial, shape, scale, priorFixed, false); From 3eb5ee80dbbf4a1487827fface9b69637e9f1c20 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 21:52:04 +0100 Subject: [PATCH 15/65] adding to do for future development --- src/dr/app/beauti/options/Parameter.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dr/app/beauti/options/Parameter.java b/src/dr/app/beauti/options/Parameter.java index 47a6f4a737..42efb411d3 100644 --- a/src/dr/app/beauti/options/Parameter.java +++ b/src/dr/app/beauti/options/Parameter.java @@ -49,10 +49,13 @@ public class Parameter implements Serializable { private boolean meanInRealSpace = false; - // Required para + // Required parameters private String baseName; private final String description; + //TODO think about have an optional (could be public) parameter priorID to idref to the prior later on + //private String priorID; + private int dimensionWeight = 1; private final List subParameters = new ArrayList(); From 79f5eb87d6b423b2f94551c44ce2fe29eafb436a Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 21:54:02 +0100 Subject: [PATCH 16/65] writing IDref to previously defined parameter priors --- .../generator/ParameterPriorGenerator.java | 41 +++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/src/dr/app/beauti/generator/ParameterPriorGenerator.java b/src/dr/app/beauti/generator/ParameterPriorGenerator.java index 2e3fe0cd4c..3602f555f7 100644 --- a/src/dr/app/beauti/generator/ParameterPriorGenerator.java +++ b/src/dr/app/beauti/generator/ParameterPriorGenerator.java @@ -32,7 +32,6 @@ import dr.app.beauti.util.XMLWriter; import dr.evolution.util.Taxa; import dr.evomodel.tree.DefaultTreeModel; -import dr.evomodel.tree.TreeModel; import dr.evomodelxml.tree.CTMCScalePriorParser; import dr.evomodelxml.tree.MonophylyStatisticParser; import dr.inference.model.ParameterParser; @@ -41,6 +40,7 @@ import dr.inferencexml.model.OneOnXPriorParser; import dr.util.Attribute; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -51,8 +51,14 @@ */ public class ParameterPriorGenerator extends Generator { + //map parameters to prior IDs, for use with HMC + private HashMap mapParameterToPrior; + public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] components) { super(options, components); + //TODO don't like this being here, but will see how things pan out as more HMC approaches are added + mapParameterToPrior = new HashMap(); + mapParameterToPrior.put("skygrid.precision", "skygrid.precision.prior"); } /** @@ -63,7 +69,6 @@ public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] compone public void writeParameterPriors(XMLWriter writer) { boolean first = true; - for (Map.Entry taxaBooleanEntry : options.taxonSetsMono.entrySet()) { if (taxaBooleanEntry.getValue()) { if (first) { @@ -80,7 +85,6 @@ public void writeParameterPriors(XMLWriter writer) { List parameters = options.selectParameters(); - for (Parameter parameter : parameters) { if (!(parameter.priorType == PriorType.NONE_TREE_PRIOR || parameter.priorType == PriorType.NONE_FIXED || @@ -112,6 +116,18 @@ private void writeCachedParameterPrior(Parameter parameter, XMLWriter writer) { * @param writer the writer */ public void writeParameterPrior(Parameter parameter, XMLWriter writer) { + + //if models need to have a prior defined before the priors block + /*if (reservedParameters.contains(parameter.getName())) { + + return; + }*/ + + if (mapParameterToPrior.keySet().contains(parameter.getName())) { + writePriorIdref(writer, parameter, mapParameterToPrior.get(parameter.getName())); + return; + } + if (parameter.priorType == PriorType.NONE_FIXED) { return; } @@ -349,4 +365,23 @@ private void writeParameterIdref(XMLWriter writer, Parameter parameter) { } } + private void writePriorIdref(XMLWriter writer, Parameter parameter, String priorID) { + switch (parameter.priorType) { + case GAMMA_PRIOR: + writer.writeIDref(PriorParsers.GAMMA_PRIOR, priorID); + break; + case LOGNORMAL_PRIOR: + writer.writeIDref(PriorParsers.LOG_NORMAL_PRIOR, priorID); + break; + case EXPONENTIAL_PRIOR: + writer.writeIDref(PriorParsers.EXPONENTIAL_PRIOR, priorID); + break; + case INVERSE_GAMMA_PRIOR: + writer.writeIDref(PriorParsers.INVGAMMA_PRIOR_CORRECT, priorID); + break; + default: + throw new IllegalArgumentException("Unknown or invalid prior defined on " + parameter.getName()); + } + } + } From 3135daa290dfc3cc2a5cebfad1f86b9e2d037937 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 21:55:50 +0100 Subject: [PATCH 17/65] adding HMC operator type to define tree transition kernels to use --- src/dr/app/beauti/options/PartitionTreeModel.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/dr/app/beauti/options/PartitionTreeModel.java b/src/dr/app/beauti/options/PartitionTreeModel.java index c5d9df77b2..26e8ea1cfd 100644 --- a/src/dr/app/beauti/options/PartitionTreeModel.java +++ b/src/dr/app/beauti/options/PartitionTreeModel.java @@ -173,6 +173,7 @@ public List selectOperators(List operators) { boolean branchesInUse = false; boolean newTreeOperatorsInUse = false; boolean adaptiveMultivariateInUse = false; + boolean HMCinUse = false; // if not a fixed tree then sample tree space if (options.operatorSetType == OperatorSetType.DEFAULT) { @@ -185,6 +186,8 @@ public List selectOperators(List operators) { } else if (options.operatorSetType == OperatorSetType.ADAPTIVE_MULTIVARIATE) { newTreeOperatorsInUse = true; adaptiveMultivariateInUse = true; + } else if (options.operatorSetType == OperatorSetType.HMC) { + HMCinUse = true; } else { throw new IllegalArgumentException("Unknown operator set type"); } @@ -199,6 +202,7 @@ public List selectOperators(List operators) { getOperator("subtreeLeap").setUsed(newTreeOperatorsInUse); getOperator("FHSPR").setUsed(newTreeOperatorsInUse); + } return operators; } From 535338856868a8ceb7ca88fdd1de81c150422816 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 21:58:03 +0100 Subject: [PATCH 18/65] create skygrid HMC operator and fix a few old function arguments --- .../app/beauti/options/PartitionTreePrior.java | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/dr/app/beauti/options/PartitionTreePrior.java b/src/dr/app/beauti/options/PartitionTreePrior.java index 5facb872dd..65261a61ce 100644 --- a/src/dr/app/beauti/options/PartitionTreePrior.java +++ b/src/dr/app/beauti/options/PartitionTreePrior.java @@ -271,9 +271,11 @@ public void initModelParametersAndOpererators() { "demographic.indicators", OperatorType.SCALE_WITH_INDICATORS, 0.5, 2 * demoWeights); createOperatorUsing2Parameters("gmrfGibbsOperator", "gmrfGibbsOperator", "Gibbs sampler for GMRF Skyride", "skyride.logPopSize", "skyride.precision", OperatorType.GMRF_GIBBS_OPERATOR, 2, 2); - createOperatorUsing2Parameters("gmrfSkyGridGibbsOperator", "gmrfGibbsOperator", "Gibbs sampler for Bayesian SkyGrid", "skygrid.logPopSize", + createOperatorUsing2Parameters("gmrfSkyGridGibbsOperator", "skygrid.logPopSize", "Gibbs sampler for Bayesian SkyGrid", "skygrid.logPopSize", "skygrid.precision", OperatorType.SKY_GRID_GIBBS_OPERATOR, 1.0, 2); - createScaleOperator("skygrid.precision", "description", 0.75, 1.0); + createScaleOperator("skygrid.precision", "skygrid precision", 0.75, 1.0); + createOperatorUsing2Parameters("gmrfSkyGridHMCOperator", "Multiple", "HMC transition kernel for Bayesian SkyGrid", "skygrid.logPopSize", + "skygrid.precision", OperatorType.SKY_GRID_HMC_OPERATOR, 1.0, 2); createScaleOperator("yule.birthRate", demoTuning, demoWeights); @@ -341,10 +343,10 @@ public List selectParameters(List params) { params.add(getParameter("demographic.populationSizeChanges")); params.add(getParameter("demographic.populationMean")); } else if (nodeHeightPrior == TreePriorType.GMRF_SKYRIDE) { -// params.add(getParameter("skyride.popSize")); // force user to use GMRF, not allowed to change + // params.add(getParameter("skyride.popSize")); // force user to use GMRF prior, not allowed to change params.add(getParameter("skyride.precision")); } else if (nodeHeightPrior == TreePriorType.SKYGRID) { -// params.add(getParameter("skyride.popSize")); // force user to use GMRF, not allowed to change + // params.add(getParameter("skygrid.logPopSize")); // force user to use GMRF prior, not allowed to change params.add(getParameter("skygrid.precision")); } else if (nodeHeightPrior == TreePriorType.YULE || nodeHeightPrior == TreePriorType.YULE_CALIBRATION) { params.add(getParameter("yule.birthRate")); @@ -414,8 +416,12 @@ public List selectOperators(List ops) { } else if (nodeHeightPrior == TreePriorType.GMRF_SKYRIDE) { ops.add(getOperator("gmrfGibbsOperator")); } else if (nodeHeightPrior == TreePriorType.SKYGRID) { - ops.add(getOperator("gmrfSkyGridGibbsOperator")); - ops.add(getOperator("skygrid.precision")); + if (options.operatorSetType == OperatorSetType.HMC) { + ops.add(getOperator("gmrfSkyGridHMCOperator")); + } else { + ops.add(getOperator("gmrfSkyGridGibbsOperator")); + ops.add(getOperator("skygrid.precision")); + } } else if (nodeHeightPrior == TreePriorType.EXTENDED_SKYLINE) { ops.add(getOperator("demographic.populationMean")); ops.add(getOperator("demographic.popSize")); From afdfca8340011c65fd103724bfdc05e0713f9e3e Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 22:07:02 +0100 Subject: [PATCH 19/65] XML generation for skygrid gradient --- .../beauti/generator/TreePriorGenerator.java | 79 ++++++++++++++++++- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/src/dr/app/beauti/generator/TreePriorGenerator.java b/src/dr/app/beauti/generator/TreePriorGenerator.java index 571ec569ce..87e7c4810c 100644 --- a/src/dr/app/beauti/generator/TreePriorGenerator.java +++ b/src/dr/app/beauti/generator/TreePriorGenerator.java @@ -27,12 +27,14 @@ import dr.app.beauti.components.ComponentFactory; import dr.app.beauti.options.*; +import dr.app.beauti.types.OperatorSetType; import dr.app.beauti.types.StartingTreeType; import dr.app.beauti.types.TreePriorParameterizationType; import dr.app.beauti.types.TreePriorType; import dr.app.beauti.util.XMLWriter; import dr.evolution.util.Taxa; import dr.evolution.util.Units; +import dr.evomodel.coalescent.GMRFSkyrideGradient; import dr.evomodel.tree.DefaultTreeModel; import dr.evomodel.tree.TreeModel; import dr.evomodelxml.CSVExporterParser; @@ -45,11 +47,16 @@ import dr.evoxml.TaxaParser; import dr.inference.distribution.ExponentialDistributionModel; import dr.inference.distribution.ExponentialMarkovModel; +import dr.inference.distribution.GammaDistributionModel; +import dr.inference.model.CompoundParameter; import dr.inference.model.ParameterParser; -import dr.inferencexml.distribution.DistributionModelParser; -import dr.inferencexml.distribution.ExponentialMarkovModelParser; -import dr.inferencexml.distribution.MixedDistributionLikelihoodParser; +import dr.inferencexml.distribution.*; +import dr.inferencexml.hmc.CompoundGradientParser; +import dr.inferencexml.hmc.GradientWrapperParser; +import dr.inferencexml.hmc.JointGradientParser; +import dr.inferencexml.model.CompoundParameterParser; import dr.inferencexml.model.SumStatisticParser; +import dr.math.distributions.GammaDistribution; import dr.util.Attribute; import dr.xml.XMLParser; @@ -728,6 +735,72 @@ void writeMultiLociTreePriors(PartitionTreePrior prior, XMLWriter writer) { writer.writeCloseTag(GMRFSkyrideLikelihoodParser.SKYGRID_LIKELIHOOD); + //writing the gamma prior here so will need to prevent another one from being written in the priors block + //key use: using HMC on the skygrid parameters + writer.writeOpenTag(PriorParsers.GAMMA_PRIOR, + new Attribute[]{ + new Attribute.Default(XMLParser.ID, "skygrid.precision.prior"), + new Attribute.Default(GammaDistributionModelParser.SHAPE, 0.001), + new Attribute.Default(GammaDistributionModelParser.SCALE, 1000.0), + new Attribute.Default(GammaDistributionModelParser.OFFSET, 0.0) + } + ); + writer.writeIDref(ParameterParser.PARAMETER, "skygrid.precision"); + writer.writeCloseTag(PriorParsers.GAMMA_PRIOR); + + //add gradient information to XML file in case of an HMC transition kernel mix + if (options.operatorSetType == OperatorSetType.HMC) { + + writer.writeOpenTag(GMRFSkyrideGradientParser.NAME, + new Attribute[]{ + new Attribute.Default(XMLParser.ID, "gmrfGradientPop"), + new Attribute.Default(GMRFSkyrideGradientParser.WRT_PARAMETER, "logPopulationSizes") + } + ); + writer.writeIDref(GMRFSkyrideLikelihoodParser.SKYGRID_LIKELIHOOD, "skygrid"); + writer.writeCloseTag(GMRFSkyrideGradientParser.NAME); + + writer.writeOpenTag(CompoundParameterParser.COMPOUND_PARAMETER, + new Attribute[]{ + new Attribute.Default(XMLParser.ID, "skygrid.parameters") + } + ); + writer.writeIDref(ParameterParser.PARAMETER, "skygrid.precision"); + writer.writeIDref(ParameterParser.PARAMETER, "skygrid.logPopSize"); + writer.writeCloseTag(CompoundParameterParser.COMPOUND_PARAMETER); + + writer.writeOpenTag(GMRFSkyrideGradientParser.NAME, + new Attribute[]{ + new Attribute.Default(XMLParser.ID, "gmrfGradientPrec"), + new Attribute.Default(GMRFSkyrideGradientParser.WRT_PARAMETER, "precision") + } + ); + writer.writeIDref(GMRFSkyrideLikelihoodParser.SKYGRID_LIKELIHOOD, "skygrid"); + writer.writeCloseTag(GMRFSkyrideGradientParser.NAME); + + writer.writeOpenTag(JointGradientParser.SUM_DERIVATIVE2, + new Attribute[]{ + new Attribute.Default(XMLParser.ID, "joint.skygrid.precision") + } + ); + writer.writeIDref(GMRFSkyrideGradientParser.NAME, "gmrfGradientPrec"); + writer.writeOpenTag(GradientWrapperParser.NAME); + writer.writeIDref(PriorParsers.GAMMA_PRIOR, "skygrid.precision.prior"); + writer.writeIDref(ParameterParser.PARAMETER, "skygrid.precision"); + writer.writeCloseTag(GradientWrapperParser.NAME); + writer.writeCloseTag(JointGradientParser.SUM_DERIVATIVE2); + + writer.writeOpenTag(CompoundGradientParser.SUM_DERIVATIVE2, + new Attribute[]{ + new Attribute.Default(XMLParser.ID, "full.skygrid.gradient") + } + ); + writer.writeIDref(JointGradientParser.SUM_DERIVATIVE2, "joint.skygrid.precision"); + writer.writeIDref(GMRFSkyrideGradientParser.NAME, "gmrfGradientPop"); + writer.writeCloseTag(CompoundGradientParser.SUM_DERIVATIVE2); + + } + } else if (prior.getNodeHeightPrior() == TreePriorType.EXTENDED_SKYLINE) { final String tagName = VariableDemographicModelParser.MODEL_NAME; From 80554c27568e0bd6118e4bccb10467cb4f75c069 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 22:07:35 +0100 Subject: [PATCH 20/65] XML generation for skygrid HMC transition kernel --- .../beauti/generator/OperatorsGenerator.java | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/dr/app/beauti/generator/OperatorsGenerator.java b/src/dr/app/beauti/generator/OperatorsGenerator.java index 74b551a777..b4dd45ad5a 100644 --- a/src/dr/app/beauti/generator/OperatorsGenerator.java +++ b/src/dr/app/beauti/generator/OperatorsGenerator.java @@ -36,13 +36,17 @@ import dr.evomodelxml.coalescent.operators.GMRFSkyrideBlockUpdateOperatorParser; import dr.evomodelxml.coalescent.operators.SampleNonActiveGibbsOperatorParser; import dr.evomodelxml.operators.*; +import dr.inference.model.CompoundParameter; import dr.inference.model.ParameterParser; import dr.inference.operators.AdaptableVarianceMultivariateNormalOperator; import dr.inference.operators.OperatorSchedule; import dr.inference.operators.RandomWalkOperator; import dr.inference.operators.RateBitExchangeOperator; +import dr.inferencexml.SignTransformParser; +import dr.inferencexml.hmc.CompoundGradientParser; import dr.inferencexml.model.CompoundParameterParser; import dr.inferencexml.operators.*; +import dr.inferencexml.operators.hmc.HamiltonianMonteCarloOperatorParser; import dr.oldevomodel.substmodel.AbstractSubstitutionModel; import dr.oldevomodelxml.substmodel.GeneralSubstitutionModelParser; import dr.util.Attribute; @@ -218,6 +222,9 @@ private void writeOperator(Operator operator, XMLWriter writer) { case SKY_GRID_GIBBS_OPERATOR: writeSkyGridGibbsOperator(operator, prefix, writer); break; + case SKY_GRID_HMC_OPERATOR: + writeSkyGridHMCOperator(operator, prefix, writer); + break; case ADAPTIVE_MULTIVARIATE: writeAdaptiveMultivariateOperator(operator, writer); break; @@ -497,6 +504,34 @@ private void writeSkyGridGibbsOperator(Operator operator, String treePriorPrefix writer.writeCloseTag(GMRFSkyrideBlockUpdateOperatorParser.GRID_BLOCK_UPDATE_OPERATOR); } + private void writeSkyGridHMCOperator(Operator operator, String treePriorPrefix, XMLWriter writer) { + writer.writeOpenTag( + HamiltonianMonteCarloOperatorParser.HMC_OPERATOR, + new Attribute[]{ + getWeightAttribute(operator.getWeight()), + new Attribute.Default(HamiltonianMonteCarloOperatorParser.N_STEPS, 50), + new Attribute.Default(HamiltonianMonteCarloOperatorParser.STEP_SIZE, 1E-2), + new Attribute.Default(HamiltonianMonteCarloOperatorParser.MODE, "vanilla"), + new Attribute.Default(HamiltonianMonteCarloOperatorParser.GRADIENT_CHECK_COUNT, 0), + new Attribute.Default(HamiltonianMonteCarloOperatorParser.GRADIENT_CHECK_TOLERANCE, 1E-1), + new Attribute.Default(HamiltonianMonteCarloOperatorParser.PRECONDITIONING, "none"), + new Attribute.Default(HamiltonianMonteCarloOperatorParser.PRECONDITIONING_UPDATE_FREQUENCY, 100) + } + ); + writer.writeIDref(CompoundGradientParser.SUM_DERIVATIVE2, treePriorPrefix + "full.skygrid.gradient"); + writer.writeIDref(CompoundParameterParser.COMPOUND_PARAMETER, treePriorPrefix + "skygrid.parameters"); + writer.writeOpenTag( + SignTransformParser.NAME, + new Attribute[]{ + new Attribute.Default(TransformParsers.START, 1), + new Attribute.Default(TransformParsers.END, 1) + } + ); + writer.writeIDref(CompoundParameterParser.COMPOUND_PARAMETER, treePriorPrefix + "skygrid.parameters"); + writer.writeCloseTag(SignTransformParser.NAME); + writer.writeCloseTag(HamiltonianMonteCarloOperatorParser.HMC_OPERATOR); + } + private void writeGMRFGibbsOperator(Operator operator, String treePriorPrefix, XMLWriter writer) { writer.writeOpenTag( GMRFSkyrideBlockUpdateOperatorParser.BLOCK_UPDATE_OPERATOR, From 0f375ea7d0a2a3c09721d03fd7443c9423065e61 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 23:10:12 +0100 Subject: [PATCH 21/65] altering access for parsing purposes --- .../hmc/HamiltonianMonteCarloOperatorParser.java | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/dr/inferencexml/operators/hmc/HamiltonianMonteCarloOperatorParser.java b/src/dr/inferencexml/operators/hmc/HamiltonianMonteCarloOperatorParser.java index 5baf66f2ae..262c573ffe 100644 --- a/src/dr/inferencexml/operators/hmc/HamiltonianMonteCarloOperatorParser.java +++ b/src/dr/inferencexml/operators/hmc/HamiltonianMonteCarloOperatorParser.java @@ -44,19 +44,22 @@ public class HamiltonianMonteCarloOperatorParser extends AbstractXMLObjectParser { - private final static String HMC_OPERATOR = "hamiltonianMonteCarloOperator"; - private final static String N_STEPS = "nSteps"; - private final static String STEP_SIZE = "stepSize"; + public final static String HMC_OPERATOR = "hamiltonianMonteCarloOperator"; + public final static String N_STEPS = "nSteps"; + public final static String STEP_SIZE = "stepSize"; private final static String RANDOM_STEP_FRACTION = "randomStepCountFraction"; - private final static String PRECONDITIONING = "preconditioning"; + public final static String PRECONDITIONING = "preconditioning"; private final static String PRECONDITIONER = "preconditioner"; - private final static String GRADIENT_CHECK_COUNT = "gradientCheckCount"; + public final static String GRADIENT_CHECK_COUNT = "gradientCheckCount"; public final static String GRADIENT_CHECK_TOLERANCE = "gradientCheckTolerance"; private final static String MAX_ITERATIONS = "checkStepSizeMaxIterations"; private final static String REDUCTION_FACTOR = "checkStepSizeReductionFactor"; private final static String TARGET_ACCEPTANCE_PROBABILITY = "targetAcceptanceProbability"; private final static String INSTABILITY_HANDLER = "instabilityHandler"; private final static String MASK = "mask"; + //these are in the Skygrid+HMC XML files but were not (yet) defined here + public final static String MODE = "mode"; + public final static String PRECONDITIONING_UPDATE_FREQUENCY = "preconditioningUpdateFrequency"; @Override public String getParserName() { From d1a257858637fe6a2eca4ec3ca7997aca84218bb Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 23:10:43 +0100 Subject: [PATCH 22/65] altering access for parsing purposes --- src/dr/evomodelxml/coalescent/GMRFSkyrideGradientParser.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dr/evomodelxml/coalescent/GMRFSkyrideGradientParser.java b/src/dr/evomodelxml/coalescent/GMRFSkyrideGradientParser.java index d02519e65e..e8abb55726 100644 --- a/src/dr/evomodelxml/coalescent/GMRFSkyrideGradientParser.java +++ b/src/dr/evomodelxml/coalescent/GMRFSkyrideGradientParser.java @@ -41,8 +41,8 @@ */ public class GMRFSkyrideGradientParser extends AbstractXMLObjectParser { - private static final String NAME = "gmrfSkyrideGradient"; - private static final String WRT_PARAMETER = "wrtParameter"; + public static final String NAME = "gmrfSkyrideGradient"; + public static final String WRT_PARAMETER = "wrtParameter"; private static final String COALESCENT_INTERVAL = "coalescentInterval"; private static final String NODE_HEIGHTS = "nodeHeights"; From dc4f12cdbabc282a06e7b298f0b7f3e4fcfee73f Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Fri, 13 Jan 2023 23:11:14 +0100 Subject: [PATCH 23/65] altering access for parsing purposes --- src/dr/inferencexml/hmc/JointGradientParser.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/inferencexml/hmc/JointGradientParser.java b/src/dr/inferencexml/hmc/JointGradientParser.java index 434e47b6a7..2bc9f0d835 100644 --- a/src/dr/inferencexml/hmc/JointGradientParser.java +++ b/src/dr/inferencexml/hmc/JointGradientParser.java @@ -40,7 +40,7 @@ public class JointGradientParser extends AbstractXMLObjectParser { private final static String SUM_DERIVATIVE = "sumDerivative"; - private final static String SUM_DERIVATIVE2 = "jointGradient"; + public final static String SUM_DERIVATIVE2 = "jointGradient"; @Override public String getParserName() { From 9b0358ad0e0774ef79bdd6735b358ceb8fb7facd Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Tue, 28 Mar 2023 08:51:48 +0100 Subject: [PATCH 24/65] Renaming MMCC to HIPSTR --- src/dr/app/tools/TreeAnnotator.java | 30 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/dr/app/tools/TreeAnnotator.java b/src/dr/app/tools/TreeAnnotator.java index bdf767d386..eab14e18d7 100644 --- a/src/dr/app/tools/TreeAnnotator.java +++ b/src/dr/app/tools/TreeAnnotator.java @@ -72,7 +72,7 @@ public class TreeAnnotator { enum Target { MAX_CLADE_CREDIBILITY("Maximum clade credibility tree"), - MAX_MARGINAL_CLADE_CREDIBILITY("Maximum marginal clade credibilities"), + HIPSTR("Highest independent posterior subtree reconstruction (HIPSTR)"), USER_TARGET_TREE("User target tree"); String desc; @@ -267,9 +267,9 @@ public TreeAnnotator(final int burninTrees, targetTree = new FlexibleTree(getMCCTree(burnin, cladeSystem, inputFileName)); break; } - case MAX_MARGINAL_CLADE_CREDIBILITY: { + case HIPSTR: { progressStream.println("Finding maximum marginal credibility tree..."); - targetTree = new FlexibleTree(getMMCCTree(cladeSystem)); + targetTree = new FlexibleTree(getHIPSTRTree(cladeSystem)); break; } default: throw new IllegalArgumentException("Unknown targetOption"); @@ -416,15 +416,15 @@ private Tree getMCCTree(int burnin, CladeSystem cladeSystem, String inputFileNam return bestTree; } - private Tree getMMCCTree(CladeSystem cladeSystem) { + private Tree getHIPSTRTree(CladeSystem cladeSystem) { CladeSystem.Clade rootClade = cladeSystem.getRootClade(); credibilityCache.clear(); - double score = findMMCCTree(cladeSystem, rootClade); + double score = findHIPSTRTree(cladeSystem, rootClade); - SimpleTree tree = new SimpleTree(buildMCCTree(cladeSystem, rootClade)); + SimpleTree tree = new SimpleTree(buildHIPSTRTree(cladeSystem, rootClade)); progressStream.println(); progressStream.println("Highest Log Marginal Clade Credibility: " + score); @@ -435,7 +435,7 @@ private Tree getMMCCTree(CladeSystem cladeSystem) { private Map credibilityCache = new HashMap<>(); - private double findMMCCTree(CladeSystem cladeSystem, CladeSystem.Clade clade) { + private double findHIPSTRTree(CladeSystem cladeSystem, CladeSystem.Clade clade) { double logCredibility = Math.log(clade.credibility); @@ -451,7 +451,7 @@ private double findMMCCTree(CladeSystem cladeSystem, CladeSystem.Clade clade) { double leftLogCredibility = credibilityCache.getOrDefault(left, Double.NaN); if (Double.isNaN(leftLogCredibility)) { - leftLogCredibility = findMMCCTree(cladeSystem, left); + leftLogCredibility = findHIPSTRTree(cladeSystem, left); credibilityCache.put(left, leftLogCredibility); } CladeSystem.Clade right = cladeSystem.getCladeMap().get(subClade.snd); @@ -460,7 +460,7 @@ private double findMMCCTree(CladeSystem cladeSystem, CladeSystem.Clade clade) { } double rightLogCredibility = credibilityCache.getOrDefault(right, Double.NaN); if (Double.isNaN(rightLogCredibility)) { - rightLogCredibility = findMMCCTree(cladeSystem, right); + rightLogCredibility = findHIPSTRTree(cladeSystem, right); credibilityCache.put(right, rightLogCredibility); } @@ -479,13 +479,13 @@ private double findMMCCTree(CladeSystem cladeSystem, CladeSystem.Clade clade) { return logCredibility; } - private SimpleNode buildMCCTree(CladeSystem cladeSystem, CladeSystem.Clade clade) { + private SimpleNode buildHIPSTRTree(CladeSystem cladeSystem, CladeSystem.Clade clade) { SimpleNode newNode = new SimpleNode(); if (clade.size == 1) { newNode.setTaxon(clade.taxon); } else { - newNode.addChild(buildMCCTree(cladeSystem, clade.bestLeft)); - newNode.addChild(buildMCCTree(cladeSystem, clade.bestRight)); + newNode.addChild(buildHIPSTRTree(cladeSystem, clade.bestLeft)); + newNode.addChild(buildHIPSTRTree(cladeSystem, clade.bestRight)); } return newNode; } @@ -1551,7 +1551,7 @@ public static void main(String[] args) throws IOException { Arguments arguments = new Arguments( new Arguments.Option[]{ - new Arguments.StringOption("type", new String[] { "mcc", "mmcc" }, false, "an option of 'mcc' or 'mmcc'"), + new Arguments.StringOption("type", new String[] { "mcc", "hipstr" }, false, "an option of 'mcc' or 'hipstr'"), new Arguments.StringOption("heights", new String[]{"keep", "median", "mean", "ca"}, false, "an option of 'keep', 'median', 'mean' or 'ca' (default)"), new Arguments.LongOption("burnin", "the number of states to be considered as 'burn-in'"), @@ -1630,8 +1630,8 @@ public static void main(String[] args) throws IOException { } Target target = Target.MAX_CLADE_CREDIBILITY; - if (arguments.hasOption("type") && arguments.getStringOption("type").equalsIgnoreCase("MMCC")) { - target = Target.MAX_MARGINAL_CLADE_CREDIBILITY; + if (arguments.hasOption("type") && arguments.getStringOption("type").equalsIgnoreCase("HIPSTR")) { + target = Target.HIPSTR; } if (arguments.hasOption("target")) { From cf0da572b23275115af2a56277785ee5b700d774 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Tue, 28 Mar 2023 18:35:03 +0100 Subject: [PATCH 25/65] Renaming MMCC to HIPSTR --- src/dr/app/tools/TreeAnnotator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/app/tools/TreeAnnotator.java b/src/dr/app/tools/TreeAnnotator.java index eab14e18d7..5f3c9e3482 100644 --- a/src/dr/app/tools/TreeAnnotator.java +++ b/src/dr/app/tools/TreeAnnotator.java @@ -268,7 +268,7 @@ public TreeAnnotator(final int burninTrees, break; } case HIPSTR: { - progressStream.println("Finding maximum marginal credibility tree..."); + progressStream.println("Finding highest independent posterior subtree reconstruction (HIPSTR) tree..."); targetTree = new FlexibleTree(getHIPSTRTree(cladeSystem)); break; } From 4ebdca2d0ae13aa114dd330e2f1914ee1767a3a0 Mon Sep 17 00:00:00 2001 From: Gabe Hassler Date: Tue, 28 Mar 2023 16:25:45 -0700 Subject: [PATCH 26/65] fixes ant compile error when default encoding is ASCII --- build.xml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/build.xml b/build.xml index 33a0d6affd..923ccc8224 100644 --- a/build.xml +++ b/build.xml @@ -74,7 +74,8 @@ fork="true" includeantruntime="false" memoryinitialsize="256m" - memorymaximumsize="1024m"> + memorymaximumsize="1024m" + encoding="UTF-8"> From be768635def3398f82a87a07de79d84967e25098 Mon Sep 17 00:00:00 2001 From: GuyBaele Date: Thu, 13 Apr 2023 14:52:57 +0200 Subject: [PATCH 27/65] avoiding NPE in debug code and providing more debug output --- src/dr/app/checkpoint/BeastCheckpointer.java | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/dr/app/checkpoint/BeastCheckpointer.java b/src/dr/app/checkpoint/BeastCheckpointer.java index ac26de775d..d171fc8b13 100644 --- a/src/dr/app/checkpoint/BeastCheckpointer.java +++ b/src/dr/app/checkpoint/BeastCheckpointer.java @@ -525,6 +525,10 @@ protected long readStateFromFile(File file, MarkovChain markovChain, double[] ln } System.out.println(); } + } else { + if (DEBUG) { + System.out.println("Not a TreeModel: " + model.getModelName()); + } } //first add all TreeParameterModels to a list @@ -533,6 +537,9 @@ protected long readStateFromFile(File file, MarkovChain markovChain, double[] ln System.out.println("\nLoad TreeParameterModel: " + model.getClass().getSimpleName()); } traitModels.add((TreeParameterModel)model); + if (DEBUG) { + System.out.println("TreeParameterModel: " + model.getModelName()); + } } } @@ -590,6 +597,9 @@ protected long readStateFromFile(File file, MarkovChain markovChain, double[] ln int edgeCount = Integer.parseInt(fields[0]); if (DEBUG) { System.out.println("edge count = " + edgeCount); + System.out.println("model: " + model.getId()); + System.out.println("linkedModels size = " + linkedModels.size()); + System.out.println(linkedModels.get(model.getId())); } //create data matrix of doubles to store information from list of TreeParameterModels @@ -662,7 +672,7 @@ protected long readStateFromFile(File file, MarkovChain markovChain, double[] ln if (DEBUG) { System.out.println("\nDouble checking:"); for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) { - if (parameter.getParameterName().equals("branchRates.categories.rootNodeNumber")) { + if (parameter.getParameterName() != null && parameter.getParameterName().equals("branchRates.categories.rootNodeNumber")) { System.out.println(parameter.getParameterName() + ": " + parameter.getParameterValue(0)); } } From 00ba04046010469c9a297cbd7806e38f08af3780 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Fri, 21 Apr 2023 11:49:53 +0200 Subject: [PATCH 28/65] Refactored SumParameter so it can take a mixture of statistics and parameters --- src/dr/evomodel/coalescent/TreeIntervals.java | 1 + src/dr/inference/model/SumParameter.java | 35 +++++++++++-------- .../model/SumParameterParser.java | 28 ++++++++------- 3 files changed, 37 insertions(+), 27 deletions(-) diff --git a/src/dr/evomodel/coalescent/TreeIntervals.java b/src/dr/evomodel/coalescent/TreeIntervals.java index 233304034f..32d9833d6e 100644 --- a/src/dr/evomodel/coalescent/TreeIntervals.java +++ b/src/dr/evomodel/coalescent/TreeIntervals.java @@ -216,6 +216,7 @@ private void collectTimes(Tree tree, NodeRef node, Set excludeNodesBelo } if (!include || tree.isExternal(child)) { + // the mrca of the clade below that is being excluded becomes a sampling event intervals.addSampleEvent(tree.getNodeHeight(child)); } else { collectTimes(tree, child, excludeNodesBelow, intervals); diff --git a/src/dr/inference/model/SumParameter.java b/src/dr/inference/model/SumParameter.java index 1c39ec349a..62621da89b 100644 --- a/src/dr/inference/model/SumParameter.java +++ b/src/dr/inference/model/SumParameter.java @@ -25,6 +25,7 @@ package dr.inference.model; +import java.util.ArrayList; import java.util.List; /** @@ -32,11 +33,14 @@ */ public class SumParameter extends Parameter.Abstract implements VariableListener { - public SumParameter(List parameterList) { - this.parameterList = parameterList; - dimension = parameterList.size() == 1 ? 1 : parameterList.get(0).getDimension();; - for (Parameter p : parameterList) { - p.addVariableListener(this); + public SumParameter(List statisticList) { + this.statisticList = statisticList; + dimension = statisticList.size() == 1 ? 1 : statisticList.get(0).getDimension(); + for (Statistic s : statisticList) { + if (s instanceof Parameter) { + parameterList.add(((Parameter) s)); + ((Parameter) s).addVariableListener(this); + } } } @@ -73,15 +77,15 @@ protected void adoptValues(Parameter source) { public double getParameterValue(int dim) { double value = 0; - if (dimension == 1) { - value = parameterList.get(0).getParameterValue(0); - for (int i = 1; i < parameterList.get(0).getDimension(); i++) { - value += parameterList.get(0).getParameterValue(i); + if (statisticList.size() == 1) { + value = statisticList.get(0).getStatisticValue(0); + for (int i = 1; i < statisticList.get(0).getDimension(); i++) { + value += statisticList.get(0).getStatisticValue(i); } } else { - value = parameterList.get(0).getParameterValue(dim); - for (int i = 1; i < parameterList.size(); i++) { - value += parameterList.get(i).getParameterValue(dim); + value = statisticList.get(0).getStatisticValue(dim); + for (int i = 1; i < statisticList.size(); i++) { + value += statisticList.get(i).getStatisticValue(dim); } } return value; @@ -102,8 +106,8 @@ public void setParameterValueNotifyChangedAll(int dim, double value){ public String getParameterName() { if (getId() == null) { StringBuilder sb = new StringBuilder("sum"); - for (Parameter p : parameterList) { - sb.append(".").append(p.getId()); + for (Statistic s : statisticList) { + sb.append(".").append(s.getId()); } setId(sb.toString()); } @@ -134,7 +138,8 @@ public void variableChangedEvent(Variable variable, int index, ChangeType type) fireParameterChangedEvent(index,type); } - private final List parameterList; + private final List statisticList; + private final List parameterList = new ArrayList<>(); private final int dimension; private Bounds bounds = null; } diff --git a/src/dr/inferencexml/model/SumParameterParser.java b/src/dr/inferencexml/model/SumParameterParser.java index ee10d75146..7a9ea283e1 100644 --- a/src/dr/inferencexml/model/SumParameterParser.java +++ b/src/dr/inferencexml/model/SumParameterParser.java @@ -25,6 +25,7 @@ package dr.inferencexml.model; +import dr.inference.model.Statistic; import dr.inference.model.SumParameter; import dr.inference.model.Parameter; import dr.xml.*; @@ -41,30 +42,33 @@ public class SumParameterParser extends AbstractXMLObjectParser { public Object parseXMLObject(XMLObject xo) throws XMLParseException { - List paramList = new ArrayList(); + List statisticList = new ArrayList<>(); int dim = -1; for (int i = 0; i < xo.getChildCount(); ++i) { - Parameter parameter = (Parameter) xo.getChild(i); + Statistic s = (Statistic) xo.getChild(i); if (dim == -1) { - dim = parameter.getDimension(); + dim = s.getDimension(); } else { - if (parameter.getDimension() != dim) { - throw new XMLParseException("All parameters in sum '" + xo.getId() + "' must be the same length"); + if (s.getDimension() != dim) { + throw new XMLParseException("All statistics/parameters in sum '" + xo.getId() + "' must be the same length"); } } - paramList.add(parameter); + statisticList.add(s); } - boolean sumAll = xo.getBooleanAttribute(SUM_ALL); + boolean sumAll = statisticList.size() == 1; + if (xo.hasAttribute(SUM_ALL)) { + sumAll = xo.getBooleanAttribute(SUM_ALL); + } - if (sumAll && paramList.size() > 1) { + if (sumAll && statisticList.size() > 1) { throw new XMLParseException("To sum all the elements, only one parameter should be given"); } - if (!sumAll && paramList.size() < 2) { + if (!sumAll && statisticList.size() < 2) { throw new XMLParseException("For an element-wise sum, more than one parameter should be given"); } - return new SumParameter(paramList); + return new SumParameter(statisticList); } public XMLSyntaxRule[] getSyntaxRules() { @@ -72,11 +76,11 @@ public XMLSyntaxRule[] getSyntaxRules() { } private final XMLSyntaxRule[] rules = { - new ElementRule(Parameter.class,1,Integer.MAX_VALUE), + new ElementRule(Statistic.class,1,Integer.MAX_VALUE), }; public String getParserDescription() { - return "A element-wise sum of parameters."; + return "A element-wise sum of statistics or parameters."; } public Class getReturnType() { From 48b75569bbf531fd740a4809f1fb782b80206939 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Fri, 21 Apr 2023 12:29:22 +0200 Subject: [PATCH 29/65] TreeDataLikelihoodParser now takes the base SiteRateModel rather than specifically a GammaSiteRateModel --- .../treedatalikelihood/TreeDataLikelihoodParser.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 69ed137c2c..73b4082e13 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -267,7 +267,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { patternList = (PatternList) cxo.getChild(PatternList.class); patternLists.add(patternList); - GammaSiteRateModel siteRateModel = (GammaSiteRateModel) cxo.getChild(GammaSiteRateModel.class); + SiteRateModel siteRateModel = (SiteRateModel) cxo.getChild(SiteRateModel.class); siteRateModels.add(siteRateModel); FrequencyModel rootFreqModel = (FrequencyModel) xo.getChild(FrequencyModel.class); @@ -275,9 +275,6 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { BranchModel branchModel = (BranchModel) cxo.getChild(BranchModel.class); if (branchModel == null) { SubstitutionModel substitutionModel = (SubstitutionModel) xo.getChild(SubstitutionModel.class); - if (substitutionModel == null) { - substitutionModel = siteRateModel.getSubstitutionModel(); - } if (substitutionModel == null) { throw new XMLParseException("No substitution model available for partition " + k + " in DataTreeLikelihood: "+xo.getId()); } From b12bd6de0142c5d43652e18ead4e690e540c0558 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Fri, 21 Apr 2023 14:54:08 +0200 Subject: [PATCH 30/65] Adding an additional parameter which alters the gamma category widths. --- .../siteratemodel/GammaSiteRateModel.java | 96 ++++++++++--------- .../siteratemodel/GammaSiteModelParser.java | 28 +++++- .../TreeDataLikelihoodParser.java | 2 +- 3 files changed, 78 insertions(+), 48 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 93239e766c..d109b69be7 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -43,6 +43,10 @@ */ public class GammaSiteRateModel extends AbstractModel implements SiteRateModel, Citable { + public enum CategoryWidthType { + FASTEST, + GEOMETRIC + }; public GammaSiteRateModel(String name) { this( name, @@ -50,7 +54,7 @@ public GammaSiteRateModel(String name) { 1.0, null, 0, - null); + null, null, null); } public GammaSiteRateModel(String name, double alpha, int categoryCount) { @@ -59,25 +63,28 @@ public GammaSiteRateModel(String name, double alpha, int categoryCount) { 1.0, new Parameter.Default(alpha), categoryCount, - null); + null, null, null); } - public GammaSiteRateModel(String name, double pInvar) { + public GammaSiteRateModel(String name, double alpha, int categoryCount, double pInvar) { this( name, null, 1.0, - null, - 0, - new Parameter.Default(pInvar)); + new Parameter.Default(alpha), + categoryCount, + new Parameter.Default(pInvar), + null, null); } - public GammaSiteRateModel(String name, double alpha, int categoryCount, double pInvar) { + public GammaSiteRateModel(String name, double alpha, int categoryCount, double pInvar, double catWidth, CategoryWidthType categoryWidthType) { this( name, null, 1.0, new Parameter.Default(alpha), categoryCount, - new Parameter.Default(pInvar)); + new Parameter.Default(pInvar), + new Parameter.Default(catWidth), + categoryWidthType); } public GammaSiteRateModel( @@ -85,7 +92,7 @@ public GammaSiteRateModel( Parameter nuParameter, Parameter shapeParameter, int gammaCategoryCount, Parameter invarParameter) { - this(name, nuParameter, 1.0, shapeParameter, gammaCategoryCount, invarParameter); + this(name, nuParameter, 1.0, shapeParameter, gammaCategoryCount, invarParameter, null, null); } /** @@ -97,7 +104,9 @@ public GammaSiteRateModel( Parameter nuParameter, double muWeight, Parameter shapeParameter, int gammaCategoryCount, - Parameter invarParameter) { + Parameter invarParameter, + Parameter categoryWidthParameter, + CategoryWidthType categoryWidthType) { super(name); @@ -130,6 +139,15 @@ public GammaSiteRateModel( invarParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, 1)); } + this.categoryWidthParameter = categoryWidthParameter; + this.categoryWidthType = categoryWidthType; + if (categoryWidthParameter != null) { + this.categoryCount += 1; + + addVariable(categoryWidthParameter); + categoryWidthParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); + } + categoryRates = new double[this.categoryCount]; categoryProportions = new double[this.categoryCount]; @@ -165,35 +183,8 @@ public final double getAlpha() { return shapeParameter.getParameterValue(0); } - - public Parameter getAlphaParameter() { - return shapeParameter; - } - - public Parameter getPInvParameter() { - return invarParameter; - } - - public Parameter setRelativeRateParameter() { - return nuParameter; - } - - public void setAlphaParameter(Parameter parameter) { - if (shapeParameter != null) removeVariable(shapeParameter); - shapeParameter = parameter; - if (shapeParameter != null) addVariable(shapeParameter); - } - - public void setPInvParameter(Parameter parameter) { - if (invarParameter != null) removeVariable(invarParameter); - invarParameter = parameter; - if (invarParameter != null) addVariable(invarParameter); - } - - public void setRelativeRateParameter(Parameter parameter) { - if (nuParameter != null) removeVariable(nuParameter); - nuParameter = parameter; - if (nuParameter != null) addVariable(nuParameter); + public void setRelativeRateParameter(Parameter nu) { + this.nuParameter = nu; } // ***************************************************************** @@ -272,7 +263,9 @@ private void calculateCategoryRates() { if (shapeParameter != null) { final double a = shapeParameter.getParameterValue(0); + final double f = (1.0 + categoryWidthParameter.getParameterValue(0)); double mean = 0.0; + double sum = 0.0; final int gammaCatCount = categoryCount - cat; for (int i = 0; i < gammaCatCount; i++) { @@ -285,14 +278,22 @@ private void calculateCategoryRates() { mean += categoryRates[i + cat]; - categoryProportions[i + cat] = propVariable / gammaCatCount; + if (categoryWidthParameter != null && categoryWidthType == CategoryWidthType.GEOMETRIC && i > 0) { + categoryProportions[i + cat] = categoryProportions[i + cat - 1] * f; + } else if (categoryWidthParameter != null && categoryWidthType == CategoryWidthType.FASTEST && + i == (gammaCatCount - 1)) { + categoryProportions[i + cat] = f; + } else { + categoryProportions[i + cat] = 1.0; + } + sum += categoryProportions[i + cat]; } mean = (propVariable * mean) / gammaCatCount; for (int i = 0; i < gammaCatCount; i++) { - categoryRates[i + cat] /= mean; + categoryProportions[i + cat] /= sum; } } else { categoryRates[cat] = 1.0 / propVariable; @@ -309,10 +310,6 @@ private void calculateCategoryRates() { ratesKnown = true; } - public boolean hasInvariantSites() { - return invarParameter != null; - } - // ***************************************************************** // Interface ModelComponent // ***************************************************************** @@ -327,6 +324,8 @@ protected final void handleVariableChangedEvent(Variable variable, int index, Pa ratesKnown = false; } else if (variable == invarParameter) { ratesKnown = false; + } else if (variable == categoryWidthParameter) { + ratesKnown = false; } else if (variable == nuParameter) { ratesKnown = false; // MAS: I changed this because the rate parameter can affect the categories if the parameter is in siteModel and not clockModel } else { @@ -384,6 +383,10 @@ public double getStatisticValue(int dim) { */ private Parameter invarParameter; + private Parameter categoryWidthParameter; + + private CategoryWidthType categoryWidthType = null; + private boolean ratesKnown; private int categoryCount; @@ -417,7 +420,7 @@ public String getDescription() { } public List getCitations() { - if (getAlphaParameter() != null) { + if (shapeParameter != null) { return Collections.singletonList(CITATION); } else { return Collections.emptyList(); @@ -437,4 +440,5 @@ public List getCitations() { ); private SubstitutionModel substitutionModel; + } \ No newline at end of file diff --git a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java index 0b890f0449..721f9d564c 100644 --- a/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java +++ b/src/dr/evomodelxml/siteratemodel/GammaSiteModelParser.java @@ -55,6 +55,10 @@ public class GammaSiteModelParser extends AbstractXMLObjectParser { public static final String GAMMA_SHAPE = "gammaShape"; public static final String GAMMA_CATEGORIES = "gammaCategories"; public static final String PROPORTION_INVARIANT = "proportionInvariant"; + public static final String CATEGORY_WIDTH = "categoryWidth"; + public static final String TYPE = "type"; + public static final String FASTEST = "fastest"; + public static final String GEOMETRIC = "geometric"; public String getParserName() { return SITE_MODEL; @@ -102,13 +106,30 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { msg += "\n initial proportion of invariant sites = " + invarParam.getParameterValue(0); } + Parameter categoryWidthParameter = null; + GammaSiteRateModel.CategoryWidthType type = null; + if (xo.hasChildNamed(CATEGORY_WIDTH)) { + categoryWidthParameter = (Parameter) xo.getElementFirstChild(CATEGORY_WIDTH); + String typeString = xo.getChild(CATEGORY_WIDTH).getStringAttribute(TYPE); + try { + type = GammaSiteRateModel.CategoryWidthType.valueOf(typeString.toUpperCase()); + if (type == GammaSiteRateModel.CategoryWidthType.FASTEST) { + msg += "\n initial proportion of fastest sites = " + categoryWidthParameter.getParameterValue(0); + } else { + msg += "\n initial factor for increasing category width = " + categoryWidthParameter.getParameterValue(0); + } + } catch (IllegalArgumentException eae) { + throw new XMLParseException("Unknown category width type: " + typeString); + } + } + if (msg.length() > 0) { Logger.getLogger("dr.evomodel").info("\nCreating site rate model: " + msg); } else { Logger.getLogger("dr.evomodel").info("\nCreating site rate model."); } - GammaSiteRateModel siteRateModel = new GammaSiteRateModel(SITE_MODEL, muParam, muWeight, shapeParam, catCount, invarParam); + GammaSiteRateModel siteRateModel = new GammaSiteRateModel(SITE_MODEL, muParam, muWeight, shapeParam, catCount, invarParam, categoryWidthParameter, type); if (xo.hasChildNamed(SUBSTITUTION_MODEL)) { @@ -167,6 +188,11 @@ public XMLSyntaxRule[] getSyntaxRules() { new ElementRule(PROPORTION_INVARIANT, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) + }, true), + + new ElementRule(CATEGORY_WIDTH, new XMLSyntaxRule[]{ + AttributeRule.newStringRule(TYPE, false), + new ElementRule(Parameter.class) }, true) }; diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 73b4082e13..40b763f0c9 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -274,7 +274,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { BranchModel branchModel = (BranchModel) cxo.getChild(BranchModel.class); if (branchModel == null) { - SubstitutionModel substitutionModel = (SubstitutionModel) xo.getChild(SubstitutionModel.class); + SubstitutionModel substitutionModel = (SubstitutionModel) cxo.getChild(SubstitutionModel.class); if (substitutionModel == null) { throw new XMLParseException("No substitution model available for partition " + k + " in DataTreeLikelihood: "+xo.getId()); } From 2f614920fc43b7b2fe0d02199cda5289c2bd327d Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Fri, 21 Apr 2023 15:53:00 +0200 Subject: [PATCH 31/65] Restoring some back compatibility to having substitution model in the SiteRateModel --- src/dr/evomodel/siteratemodel/GammaSiteRateModel.java | 5 ++--- .../treedatalikelihood/TreeDataLikelihoodParser.java | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index d109b69be7..9cbcc42d27 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -263,7 +263,6 @@ private void calculateCategoryRates() { if (shapeParameter != null) { final double a = shapeParameter.getParameterValue(0); - final double f = (1.0 + categoryWidthParameter.getParameterValue(0)); double mean = 0.0; double sum = 0.0; final int gammaCatCount = categoryCount - cat; @@ -279,10 +278,10 @@ private void calculateCategoryRates() { mean += categoryRates[i + cat]; if (categoryWidthParameter != null && categoryWidthType == CategoryWidthType.GEOMETRIC && i > 0) { - categoryProportions[i + cat] = categoryProportions[i + cat - 1] * f; + categoryProportions[i + cat] = categoryProportions[i + cat - 1] * (1.0 + categoryWidthParameter.getParameterValue(0)); } else if (categoryWidthParameter != null && categoryWidthType == CategoryWidthType.FASTEST && i == (gammaCatCount - 1)) { - categoryProportions[i + cat] = f; + categoryProportions[i + cat] = (1.0 + categoryWidthParameter.getParameterValue(0)); } else { categoryProportions[i + cat] = 1.0; } diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 40b763f0c9..53f51ee387 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -275,6 +275,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { BranchModel branchModel = (BranchModel) cxo.getChild(BranchModel.class); if (branchModel == null) { SubstitutionModel substitutionModel = (SubstitutionModel) cxo.getChild(SubstitutionModel.class); + if (substitutionModel == null && siteRateModel instanceof GammaSiteRateModel) { + substitutionModel = ((GammaSiteRateModel)siteRateModel).getSubstitutionModel(); + } if (substitutionModel == null) { throw new XMLParseException("No substitution model available for partition " + k + " in DataTreeLikelihood: "+xo.getId()); } From b6b06408f3680897f5159af901107e91266a2fb3 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Sat, 22 Apr 2023 09:20:03 +0200 Subject: [PATCH 32/65] Erroneously incrementing category count --- src/dr/evomodel/siteratemodel/GammaSiteRateModel.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java index 9cbcc42d27..61e5d954f1 100644 --- a/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java +++ b/src/dr/evomodel/siteratemodel/GammaSiteRateModel.java @@ -142,8 +142,6 @@ public GammaSiteRateModel( this.categoryWidthParameter = categoryWidthParameter; this.categoryWidthType = categoryWidthType; if (categoryWidthParameter != null) { - this.categoryCount += 1; - addVariable(categoryWidthParameter); categoryWidthParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, 1)); } From 0fbeb98f97378fcaf5d2d4254ccae77491c93f85 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Sat, 22 Apr 2023 09:20:39 +0200 Subject: [PATCH 33/65] Adding getters for results --- src/dr/math/GeneralisedGaussLaguerreQuadrature.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/dr/math/GeneralisedGaussLaguerreQuadrature.java b/src/dr/math/GeneralisedGaussLaguerreQuadrature.java index 9cbfc830da..05f4a60e27 100644 --- a/src/dr/math/GeneralisedGaussLaguerreQuadrature.java +++ b/src/dr/math/GeneralisedGaussLaguerreQuadrature.java @@ -168,5 +168,12 @@ public double logIntegrate(UnivariateFunction f, double min){ } + public double[] getAbscissae() { + return abscissae; + } + + public double[] getCoefficients() { + return coefficients; + } } From cf320940b8b2e2c69a22554949ea1b440e7d8528 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 24 Apr 2023 15:09:53 +0100 Subject: [PATCH 34/65] Negative state numbers are invalid (and a few comments to fix) --- src/dr/evolution/alignment/SitePatterns.java | 6 +++--- src/dr/evolution/datatype/GeneralDataType.java | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dr/evolution/alignment/SitePatterns.java b/src/dr/evolution/alignment/SitePatterns.java index dcf4757bd9..ae92699450 100644 --- a/src/dr/evolution/alignment/SitePatterns.java +++ b/src/dr/evolution/alignment/SitePatterns.java @@ -420,7 +420,7 @@ private int addPattern(int[] pattern, int weight, double[][] uncertainty) { } /** - * @return true if the pattern is invariant + * @return true if the pattern contains a gap state */ private boolean isGapped(int[] pattern) { int len = pattern.length; @@ -434,7 +434,7 @@ private boolean isGapped(int[] pattern) { } /** - * @return true if the pattern is invariant + * @return true if the pattern contains an ambiguous state */ private boolean isAmbiguous(int[] pattern) { int len = pattern.length; @@ -448,7 +448,7 @@ private boolean isAmbiguous(int[] pattern) { } /** - * @return true if the pattern is invariant + * @return true if the pattern contains an unknown state */ private boolean isUnknown(int[] pattern) { int len = pattern.length; diff --git a/src/dr/evolution/datatype/GeneralDataType.java b/src/dr/evolution/datatype/GeneralDataType.java index cb502ae6f1..c0206f5f0f 100644 --- a/src/dr/evolution/datatype/GeneralDataType.java +++ b/src/dr/evolution/datatype/GeneralDataType.java @@ -199,7 +199,7 @@ public boolean[] getStateSet(int state) { boolean[] stateSet = new boolean[stateCount]; - if (state < states.size()) { + if (state >= 0 && state < states.size()) { State s = states.get(state); for (int i = 0; i < stateCount; i++) { From f5ee6b5f0a518d925503190714db6c99774c00b0 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 26 Apr 2023 13:31:29 +0100 Subject: [PATCH 35/65] Sequences/Alignments now detect if they have states that are not valid for the DataType --- src/dr/evolution/alignment/SimpleAlignment.java | 4 ++-- src/dr/evolution/datatype/GeneralDataType.java | 8 +++++++- src/dr/evolution/sequence/Sequence.java | 2 +- src/dr/evolution/sequence/UncertainSequence.java | 4 +++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/dr/evolution/alignment/SimpleAlignment.java b/src/dr/evolution/alignment/SimpleAlignment.java index 4db4f967c3..5176f96c7d 100644 --- a/src/dr/evolution/alignment/SimpleAlignment.java +++ b/src/dr/evolution/alignment/SimpleAlignment.java @@ -188,7 +188,7 @@ public void addSequence(Sequence sequence) { throw new IllegalArgumentException("Sequence's dataType does not match the alignment's"); } - int invalidCharAt = sequence.getInvalidChar(); + int invalidCharAt = sequence.getInvalidChar(dataType); if (invalidCharAt >= 0) throw new IllegalArgumentException("Sequence of " + sequence.getTaxon().getId() + " contains invalid char \'" + sequence.getChar(invalidCharAt) + "\' at index " + invalidCharAt); @@ -214,7 +214,7 @@ public void insertSequence(int position, Sequence sequence) { throw new IllegalArgumentException("Sequence's dataType does not match the alignment's"); } - int invalidCharAt = sequence.getInvalidChar(); + int invalidCharAt = sequence.getInvalidChar(dataType); if (invalidCharAt >= 0) throw new IllegalArgumentException("Sequence of " + sequence.getTaxon().getId() + " contains invalid char \'" + sequence.getChar(invalidCharAt) + "\' at index " + invalidCharAt); diff --git a/src/dr/evolution/datatype/GeneralDataType.java b/src/dr/evolution/datatype/GeneralDataType.java index c0206f5f0f..fe62b64341 100644 --- a/src/dr/evolution/datatype/GeneralDataType.java +++ b/src/dr/evolution/datatype/GeneralDataType.java @@ -127,7 +127,13 @@ public void addAmbiguity(String code, String[] ambiguousStates) { @Override public char[] getValidChars() { - return null; + char[] validChars = new char[stateMap.size()]; + int i = 0; + for (String state : stateMap.keySet()) { + validChars[i] = state.charAt(0); + i++; + } + return validChars; } /** diff --git a/src/dr/evolution/sequence/Sequence.java b/src/dr/evolution/sequence/Sequence.java index 9b8f630849..31371c4208 100644 --- a/src/dr/evolution/sequence/Sequence.java +++ b/src/dr/evolution/sequence/Sequence.java @@ -132,7 +132,7 @@ public void getChars(int srcBegin, int srcEnd, char[] dst, int dstBegin) { /** * search invalid character in the sequence by given data type, and return its index */ - public int getInvalidChar() { + public int getInvalidChar(DataType dataType) { final char[] validChars = dataType.getValidChars(); if (validChars != null) { String validString = new String(validChars); diff --git a/src/dr/evolution/sequence/UncertainSequence.java b/src/dr/evolution/sequence/UncertainSequence.java index 57de33b1a3..3adb6051dd 100644 --- a/src/dr/evolution/sequence/UncertainSequence.java +++ b/src/dr/evolution/sequence/UncertainSequence.java @@ -1,5 +1,7 @@ package dr.evolution.sequence; +import dr.evolution.datatype.DataType; + import java.util.ArrayList; import java.util.List; import java.util.StringTokenizer; @@ -50,7 +52,7 @@ public void setState(int index, int state) { } @Override - public int getInvalidChar() { + public int getInvalidChar(DataType dataType) { checkParsed(); From bde848ef0d5bcf2e157f907bcd7820b3a75e034a Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 26 Apr 2023 14:00:32 +0100 Subject: [PATCH 36/65] added some caching of validChars and stateSets for GeneralDataType --- .../evolution/datatype/GeneralDataType.java | 52 +++++++++++-------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/src/dr/evolution/datatype/GeneralDataType.java b/src/dr/evolution/datatype/GeneralDataType.java index fe62b64341..a30441b5f8 100644 --- a/src/dr/evolution/datatype/GeneralDataType.java +++ b/src/dr/evolution/datatype/GeneralDataType.java @@ -127,15 +127,19 @@ public void addAmbiguity(String code, String[] ambiguousStates) { @Override public char[] getValidChars() { - char[] validChars = new char[stateMap.size()]; - int i = 0; - for (String state : stateMap.keySet()) { - validChars[i] = state.charAt(0); - i++; + if (validChars == null) { + validChars = new char[stateMap.size()]; + int i = 0; + for (String state : stateMap.keySet()) { + validChars[i] = state.charAt(0); + i++; + } } return validChars; } + private char[] validChars = null; + /** * Get state corresponding to a code * @@ -203,28 +207,34 @@ public int[] getStates(int state) { */ public boolean[] getStateSet(int state) { - boolean[] stateSet = new boolean[stateCount]; - - if (state >= 0 && state < states.size()) { - State s = states.get(state); - - for (int i = 0; i < stateCount; i++) { - stateSet[i] = false; - } - for (int i = 0, n = s.ambiguities.length; i < n; i++) { - stateSet[s.ambiguities[i]] = true; - } - } else if (state == states.size()) { - for (int i = 0; i < stateCount; i++) { - stateSet[i] = true; + boolean[] stateSet = stateSetMap.get(state); + + if (stateSet == null) { + stateSet = new boolean[stateCount]; + if (state >= 0 && state < states.size()) { + State s = states.get(state); + + for (int i = 0; i < stateCount; i++) { + stateSet[i] = false; + } + for (int i = 0, n = s.ambiguities.length; i < n; i++) { + stateSet[s.ambiguities[i]] = true; + } + } else if (state == states.size()) { + for (int i = 0; i < stateCount; i++) { + stateSet[i] = true; + } + } else { + throw new IllegalArgumentException("invalid state index"); } - } else { - throw new IllegalArgumentException("invalid state index"); + stateSetMap.put(state, stateSet); } return stateSet; } + private Map stateSetMap = new HashMap<>(); + /** * description of data type * From 4ce04e488383b54a4f687a8750c471b4b9e597e9 Mon Sep 17 00:00:00 2001 From: rambaut Date: Wed, 26 Apr 2023 14:47:25 +0100 Subject: [PATCH 37/65] AgeStatistic that changes a height statistic into an absolute age --- src/dr/app/beast/release_parsers.properties | 1 + src/dr/evomodel/tree/AgeStatistic.java | 83 +++++++++++++++++++ .../evomodelxml/tree/AgeStatisticParser.java | 82 ++++++++++++++++++ 3 files changed, 166 insertions(+) create mode 100644 src/dr/evomodel/tree/AgeStatistic.java create mode 100644 src/dr/evomodelxml/tree/AgeStatisticParser.java diff --git a/src/dr/app/beast/release_parsers.properties b/src/dr/app/beast/release_parsers.properties index 6dccdaf482..ba5a7ceeb3 100644 --- a/src/dr/app/beast/release_parsers.properties +++ b/src/dr/app/beast/release_parsers.properties @@ -287,6 +287,7 @@ dr.evomodelxml.tree.TreeLengthStatisticParser dr.evomodelxml.tree.NodeHeightsStatisticParser dr.evomodelxml.tree.TreeShapeStatisticParser dr.evomodelxml.tree.TMRCAStatisticParser +dr.evomodelxml.tree.AgeStatisticParser dr.evomodelxml.tree.MRCATraitStatisticParser dr.evomodelxml.tree.AncestralTraitParser dr.evomodelxml.tree.ExternalLengthStatisticParser diff --git a/src/dr/evomodel/tree/AgeStatistic.java b/src/dr/evomodel/tree/AgeStatistic.java new file mode 100644 index 0000000000..dc5b7685e3 --- /dev/null +++ b/src/dr/evomodel/tree/AgeStatistic.java @@ -0,0 +1,83 @@ +/* + * TMRCAStatistic.java + * + * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodel.tree; + +import dr.evolution.tree.NodeRef; +import dr.evolution.tree.Tree; +import dr.evolution.tree.TreeUtils; +import dr.evolution.util.Taxon; +import dr.evolution.util.TaxonList; +import dr.inference.model.Statistic; + +import java.util.Set; + +/** + * A statistic that tracks the time of MRCA of a set of taxa + * + * @author Alexei Drummond + * @author Andrew Rambaut + * @version $Id: TMRCAStatistic.java,v 1.21 2005/07/11 14:06:25 rambaut Exp $ + */ +public class AgeStatistic extends Statistic.Abstract { + + public AgeStatistic(String name, Statistic heightStatistic) { + super(name); + this.heightStatistic = heightStatistic; + if (Taxon.getMostRecentDate() != null) { + isBackwards = Taxon.getMostRecentDate().isBackwards(); + mostRecentTipTime = Taxon.getMostRecentDate().getAbsoluteTimeValue(); + } else { + // give node heights or taxa don't have dates + mostRecentTipTime = Double.NaN; + isBackwards = false; + } + } + + public int getDimension() { + return 1; + } + + /** + * @return the height of the MRCA node. + */ + public double getStatisticValue(int dim) { + + if (!Double.isNaN(mostRecentTipTime)) { + if (isBackwards) { + return mostRecentTipTime + heightStatistic.getStatisticValue(dim); + } else { + return mostRecentTipTime - heightStatistic.getStatisticValue(dim); + } + } else { + return Double.NaN; + } + } + + private Statistic heightStatistic = null; + private final double mostRecentTipTime; + private final boolean isBackwards; + +} diff --git a/src/dr/evomodelxml/tree/AgeStatisticParser.java b/src/dr/evomodelxml/tree/AgeStatisticParser.java new file mode 100644 index 0000000000..1df12c7887 --- /dev/null +++ b/src/dr/evomodelxml/tree/AgeStatisticParser.java @@ -0,0 +1,82 @@ +/* + * TMRCAStatisticParser.java + * + * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodelxml.tree; + +import dr.evolution.tree.Tree; +import dr.evolution.tree.TreeUtils; +import dr.evolution.util.Taxa; +import dr.evolution.util.TaxonList; +import dr.evomodel.tree.AgeStatistic; +import dr.evomodel.tree.TMRCAStatistic; +import dr.inference.model.Statistic; +import dr.xml.*; + +/** + * + * Converts a statistic or parameter into absolute age: + * + * + * + * + * @author Andrew Rambaut + */ +public class AgeStatisticParser extends AbstractXMLObjectParser { + + public static final String AGE_STATISTIC = "ageStatistic"; + + + public String getParserName() { + return AGE_STATISTIC; + } + + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + + String name = xo.getAttribute(Statistic.NAME, xo.getId()); + Statistic heightStatistic = (Statistic) xo.getChild(Statistic.class); + return new AgeStatistic(name, heightStatistic); + } + + //************************************************************************ + // AbstractXMLObjectParser implementation + //************************************************************************ + + public String getParserDescription() { + return "A statistic that converts a height statistic into an absolute age using the date of the most recent tip. "; + } + + public Class getReturnType() { + return AgeStatistic.class; + } + + public XMLSyntaxRule[] getSyntaxRules() { + return rules; + } + + private final XMLSyntaxRule[] rules = { + new ElementRule(Statistic.class) + }; + +} From 0ee7a86263c53d7ad25c0acbbf26861fa120557d Mon Sep 17 00:00:00 2001 From: rambaut Date: Thu, 27 Apr 2023 18:06:15 +0100 Subject: [PATCH 38/65] Fix for issue #151 - show trace files in legend --- src/dr/inference/trace/LogFileTraces.java | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/dr/inference/trace/LogFileTraces.java b/src/dr/inference/trace/LogFileTraces.java index e832ce9dc2..817d29cbab 100644 --- a/src/dr/inference/trace/LogFileTraces.java +++ b/src/dr/inference/trace/LogFileTraces.java @@ -40,7 +40,8 @@ public class LogFileTraces extends AbstractTraceList { public LogFileTraces(String name, File file) { - this.name = name; + // trim off the extension if present. + this.name = name.toUpperCase().endsWith(".LOG") ? name.substring(0, name.length() - 4) : name; this.file = file; System.out.println("Loading log " + file.getAbsolutePath() + " ..."); } @@ -52,6 +53,13 @@ public String getName() { return name; } + /** + * @return the path of this traceset + */ + public String getFileName() { + return file.getName(); + } + /** * @return the path of this traceset */ From fd90f5b8ee6241975220d39f9ce9257fc55d1a1c Mon Sep 17 00:00:00 2001 From: rambaut Date: Thu, 27 Apr 2023 18:06:34 +0100 Subject: [PATCH 39/65] Update packaging --- .../mac/universalJavaApplicationStub | 238 ++++++++++++++---- 1 file changed, 186 insertions(+), 52 deletions(-) diff --git a/packaging_tools/mac/universalJavaApplicationStub b/packaging_tools/mac/universalJavaApplicationStub index e512c58513..43cb6837a0 100755 --- a/packaging_tools/mac/universalJavaApplicationStub +++ b/packaging_tools/mac/universalJavaApplicationStub @@ -11,14 +11,14 @@ # # # @author Tobias Fischer # # @url https://github.com/tofi86/universalJavaApplicationStub # -# @date 2018-07-29 # -# @version 3.0.3 # +# @date 2023-02-04 # +# @version 3.3.0 # # # ################################################################################## # # # The MIT License (MIT) # # # -# Copyright (c) 2014-2018 Tobias Fischer # +# Copyright (c) 2014-2023 Tobias Fischer # # # # Permission is hereby granted, free of charge, to any person obtaining a copy # # of this software and associated documentation files (the "Software"), to deal # @@ -166,6 +166,8 @@ if [ $exitcode -eq 0 ]; then JavaFolder="${AppleJavaFolder}" ResourcesFolder="${AppleResourcesFolder}" + # set expandable variables + APP_ROOT="${AppPackageFolder}" APP_PACKAGE="${AppPackageFolder}" JAVAROOT="${AppleJavaFolder}" USER_HOME="$HOME" @@ -180,7 +182,7 @@ if [ $exitcode -eq 0 ]; then # AppPackageRoot is the standard WorkingDirectory when the script is started WorkingDirectory="${AppPackageRoot}" fi - # expand variables $APP_PACKAGE, $JAVAROOT, $USER_HOME + # expand variables $APP_PACKAGE, $APP_ROOT, $JAVAROOT, $USER_HOME WorkingDirectory=$(eval echo "${WorkingDirectory}") @@ -203,7 +205,7 @@ if [ $exitcode -eq 0 ]; then else JVMClassPath=${JVMClassPath_RAW} fi - # expand variables $APP_PACKAGE, $JAVAROOT, $USER_HOME + # expand variables $APP_PACKAGE, $APP_ROOT, $JAVAROOT, $USER_HOME JVMClassPath=$(eval echo "${JVMClassPath}") # read the JVM Options in either Array or String style @@ -213,6 +215,8 @@ if [ $exitcode -eq 0 ]; then else JVMDefaultOptions=${JVMDefaultOptions_RAW} fi + # expand variables $APP_PACKAGE, $APP_ROOT, $JAVAROOT, $USER_HOME (#84) + JVMDefaultOptions=$(eval echo "${JVMDefaultOptions}") # read StartOnMainThread and add as -XstartOnFirstThread JVMStartOnMainThread=$(plist_get_java ':StartOnMainThread') @@ -220,9 +224,14 @@ if [ $exitcode -eq 0 ]; then JVMDefaultOptions+=" -XstartOnFirstThread" fi - # read the JVM Arguments as an array and retain spaces + # read the JVM Arguments in either Array or String style (#76) and retain spaces IFS=$'\t\n' - MainArgs=($(xargs -n1 <<<$(plist_get_java ':Arguments'))) + MainArgs_RAW=$(plist_get_java ':Arguments' | xargs) + if [[ $MainArgs_RAW == *Array* ]] ; then + MainArgs=($(xargs -n1 <<<$(plist_get_java ':Arguments' | tr -d '\n' | sed -E 's/Array \{ *(.*) *\}/\1/g' | sed 's/ */ /g'))) + else + MainArgs=($(xargs -n1 <<<$(plist_get_java ':Arguments'))) + fi unset IFS # post processing of the array follows further below... @@ -240,7 +249,11 @@ else ResourcesFolder="${OracleResourcesFolder}" WorkingDirectory="${OracleJavaFolder}" + # set expandable variables APP_ROOT="${AppPackageFolder}" + APP_PACKAGE="${AppPackageFolder}" + JAVAROOT="${OracleJavaFolder}" + USER_HOME="$HOME" # read the MainClass name JVMMainClass="$(plist_get ':JVMMainClassName')" @@ -258,12 +271,12 @@ else JVMClassPath_RAW=$(plist_get ':JVMClassPath') if [[ $JVMClassPath_RAW == *Array* ]] ; then JVMClassPath=.$(plist_get ':JVMClassPath' | grep " " | sed 's/^ */:/g' | tr -d '\n' | xargs) - # expand variables $APP_PACKAGE, $JAVAROOT, $USER_HOME + # expand variables $APP_PACKAGE, $APP_ROOT, $JAVAROOT, $USER_HOME JVMClassPath=$(eval echo "${JVMClassPath}") elif [[ ! -z ${JVMClassPath_RAW} ]] ; then JVMClassPath=${JVMClassPath_RAW} - # expand variables $APP_PACKAGE, $JAVAROOT, $USER_HOME + # expand variables $APP_PACKAGE, $APP_ROOT, $JAVAROOT, $USER_HOME JVMClassPath=$(eval echo "${JVMClassPath}") else @@ -272,8 +285,11 @@ else # Do NOT expand the default 'AppName.app/Contents/Java/*' classpath (#42) fi - # read the JVM Default Options + # read the JVM Default Options by parsing the :JVMDefaultOptions + # and pulling all values starting with a dash (-) JVMDefaultOptions=$(plist_get ':JVMDefaultOptions' | grep -o " \-.*" | tr -d '\n' | xargs) + # expand variables $APP_PACKAGE, $APP_ROOT, $JAVAROOT, $USER_HOME (#99) + JVMDefaultOptions=$(eval echo "${JVMDefaultOptions}") # read the Main Arguments from JVMArguments key as an array and retain spaces (see #46 for naming details) IFS=$'\t\n' @@ -287,6 +303,18 @@ else fi +# (#75) check for undefined icons or icon names without .icns extension and prepare +# an osascript statement for those cases when the icon can be shown in the dialog +DialogWithIcon="" +if [ ! -z ${CFBundleIconFile} ]; then + if [[ ${CFBundleIconFile} == *.icns ]] && [[ -f "${ResourcesFolder}/${CFBundleIconFile}" ]] ; then + DialogWithIcon=" with icon path to resource \"${CFBundleIconFile}\" in bundle (path to me)" + elif [[ ${CFBundleIconFile} != *.icns ]] && [[ -f "${ResourcesFolder}/${CFBundleIconFile}.icns" ]] ; then + CFBundleIconFile+=".icns" + DialogWithIcon=" with icon path to resource \"${CFBundleIconFile}\" in bundle (path to me)" + fi +fi + # JVMVersion: post processing and optional splitting if [[ ${JVMVersion} == *";"* ]]; then @@ -297,14 +325,14 @@ fi stub_logger "[JavaRequirement] JVM minimum version: ${JVMVersion}" stub_logger "[JavaRequirement] JVM maximum version: ${JVMMaxVersion}" -# MainArgs: replace occurences of $APP_ROOT with its content +# MainArgs: expand variables $APP_PACKAGE, $APP_ROOT, $JAVAROOT, $USER_HOME MainArgsArr=() for i in "${MainArgs[@]}" do MainArgsArr+=("$(eval echo "$i")") done -# JVMOptions: replace occurences of $APP_ROOT with its content +# JVMOptions: expand variables $APP_PACKAGE, $APP_ROOT, $JAVAROOT, $USER_HOME JVMOptionsArr=() for i in "${JVMOptions[@]}" do @@ -315,14 +343,41 @@ done # internationalized messages ############################################ -LANG=$(defaults read -g AppleLocale) -stub_logger "[Language] $LANG" +# supported languages / available translations +stubLanguages=("de" "en" "es" "fr" "pt-BR" "zh") + +# read user preferred languages as defined in macOS System Preferences (#101) +stub_logger '[LanguageSearch] Checking preferred languages in macOS System Preferences...' +appleLanguages=($(defaults read -g AppleLanguages | grep '\s"' | tr -d ',' | xargs)) +stub_logger "[LanguageSearch] ... found [${appleLanguages[*]}]" + +language="" +for i in "${appleLanguages[@]}" +do + langValue="${i%-*}" + if [[ " ${stubLanguages[*]} " =~ " ${i} " ]]; then + stub_logger "[LanguageSearch] ... selected '$i' as the default language for the launcher stub" + language=${i} + break + elif [[ " ${stubLanguages[*]} " =~ " ${langValue} " ]]; then + stub_logger "[LanguageSearch] ... selected '$langValue' (from '$i') as the default language for the launcher stub" + language=${langValue} + break + fi +done +if [ -z "${language}" ]; then + language="en" + stub_logger "[LanguageSearch] ... selected fallback 'en' as the default language for the launcher stub" +fi +stub_logger "[Language] $language" + -# French localization -if [[ $LANG == fr* ]] ; then +case "${language}" in +# French +fr) MSG_ERROR_LAUNCHING="ERREUR au lancement de '${CFBundleName}'." MSG_MISSING_MAINCLASS="'MainClass' n'est pas spécifié.\nL'application Java ne peut pas être lancée." - MSG_JVMVERSION_REQ_INVALID="La syntaxe de la version Java demandée est invalide: %s\nVeuillez contacter le développeur de l'application." + MSG_JVMVERSION_REQ_INVALID="La syntaxe de la version de Java demandée est invalide: %s\nVeuillez contacter le développeur de l'application." MSG_NO_SUITABLE_JAVA="La version de Java installée sur votre système ne convient pas.\nCe programme nécessite Java %s" MSG_JAVA_VERSION_OR_LATER="ou ultérieur" MSG_JAVA_VERSION_LATEST="(dernière mise à jour)" @@ -330,10 +385,12 @@ if [[ $LANG == fr* ]] ; then MSG_NO_SUITABLE_JAVA_CHECK="Merci de bien vouloir installer la version de Java requise." MSG_INSTALL_JAVA="Java doit être installé sur votre système.\nRendez-vous sur java.com et suivez les instructions d'installation..." MSG_LATER="Plus tard" - MSG_VISIT_JAVA_DOT_COM="Visiter java.com" + MSG_VISIT_JAVA_DOT_COM="Java by Oracle" + MSG_VISIT_ADOPTIUM="Java by Adoptium" + ;; -# German localization -elif [[ $LANG == de* ]] ; then +# German +de) MSG_ERROR_LAUNCHING="FEHLER beim Starten von '${CFBundleName}'." MSG_MISSING_MAINCLASS="Die 'MainClass' ist nicht spezifiziert!\nDie Java-Anwendung kann nicht gestartet werden!" MSG_JVMVERSION_REQ_INVALID="Die Syntax der angeforderten Java-Version ist ungültig: %s\nBitte kontaktieren Sie den Entwickler der App." @@ -344,10 +401,12 @@ elif [[ $LANG == de* ]] ; then MSG_NO_SUITABLE_JAVA_CHECK="Stellen Sie sicher, dass die angeforderte Java-Version installiert ist." MSG_INSTALL_JAVA="Auf Ihrem System muss die 'Java'-Software installiert sein.\nBesuchen Sie java.com für weitere Installationshinweise." MSG_LATER="Später" - MSG_VISIT_JAVA_DOT_COM="java.com öffnen" + MSG_VISIT_JAVA_DOT_COM="Java von Oracle" + MSG_VISIT_ADOPTIUM="Java von Adoptium" + ;; -# Simplifyed Chinese localization -elif [[ $LANG == zh* ]] ; then +# Simplified Chinese +zh) MSG_ERROR_LAUNCHING="无法启动 '${CFBundleName}'." MSG_MISSING_MAINCLASS="没有指定 'MainClass'!\nJava程序无法启动!" MSG_JVMVERSION_REQ_INVALID="Java版本参数语法错误: %s\n请联系该应用的开发者。" @@ -358,10 +417,44 @@ elif [[ $LANG == zh* ]] ; then MSG_NO_SUITABLE_JAVA_CHECK="请确保系统中安装了所需的Java版本" MSG_INSTALL_JAVA="你需要在Mac中安装Java运行环境!\n访问 java.com 了解如何安装。" MSG_LATER="稍后" - MSG_VISIT_JAVA_DOT_COM="访问 java.com" - -# English default localization -else + MSG_VISIT_JAVA_DOT_COM="Java by Oracle" + MSG_VISIT_ADOPTIUM="Java by Adoptium" + ;; + +# Spanish +es) + MSG_ERROR_LAUNCHING="ERROR iniciando '${CFBundleName}'." + MSG_MISSING_MAINCLASS="¡'MainClass' no especificada!\n¡La aplicación Java no puede iniciarse!" + MSG_JVMVERSION_REQ_INVALID="La sintaxis de la versión Java requerida no es válida: %s\nPor favor, contacte con el desarrollador de la aplicación." + MSG_NO_SUITABLE_JAVA="¡No se encontró una versión de Java adecuada en su sistema!\nEste programa requiere Java %s" + MSG_JAVA_VERSION_OR_LATER="o posterior" + MSG_JAVA_VERSION_LATEST="(ultima actualización)" + MSG_JAVA_VERSION_MAX="superior a %s" + MSG_NO_SUITABLE_JAVA_CHECK="Asegúrese de instalar la versión Java requerida." + MSG_INSTALL_JAVA="¡Necesita tener JAVA instalado en su Mac!\nVisite java.com para consultar las instrucciones para su instalación..." + MSG_LATER="Más tarde" + MSG_VISIT_JAVA_DOT_COM="Java de Oracle" + MSG_VISIT_ADOPTIUM="Java de Adoptium" + ;; + +# Brazilian Portuguese +pt-BR) + MSG_ERROR_LAUNCHING="ERRO iniciando '${CFBundleName}'." + MSG_MISSING_MAINCLASS="'MainClass' não foi definida!\nA aplicação java não poderá ser iniciada!" + MSG_JVMVERSION_REQ_INVALID="A sintaxe da versão Java requerida não é valida: %s\nPor favor contacte o desenvolvedor dessa aplicação." + MSG_NO_SUITABLE_JAVA="Não foi encontrado uma versão Java compatível no seu sistema!\nEsta aplicação precisa do Java %s" + MSG_JAVA_VERSION_OR_LATER="ou maior" + MSG_JAVA_VERSION_LATEST="(última atualização)" + MSG_JAVA_VERSION_MAX="máxima %s" + MSG_NO_SUITABLE_JAVA_CHECK="Verifique se instalou a versão Java necessária." + MSG_INSTALL_JAVA="Você precisa instalar o JAVA no seu Mac!\nPor favor, visite java.com para instruções de instalação..." + MSG_LATER="Depois" + MSG_VISIT_JAVA_DOT_COM="Java por Oracle" + MSG_VISIT_ADOPTIUM="Java por Adoptium" + ;; + +# English | default +en|*) MSG_ERROR_LAUNCHING="ERROR launching '${CFBundleName}'." MSG_MISSING_MAINCLASS="'MainClass' isn't specified!\nJava application cannot be started!" MSG_JVMVERSION_REQ_INVALID="The syntax of the required Java version is invalid: %s\nPlease contact the App developer." @@ -372,8 +465,10 @@ else MSG_NO_SUITABLE_JAVA_CHECK="Make sure you install the required Java version." MSG_INSTALL_JAVA="You need to have JAVA installed on your Mac!\nVisit java.com for installation instructions..." MSG_LATER="Later" - MSG_VISIT_JAVA_DOT_COM="Visit java.com" -fi + MSG_VISIT_JAVA_DOT_COM="Java by Oracle" + MSG_VISIT_ADOPTIUM="Java by Adoptium" + ;; +esac @@ -456,7 +551,7 @@ function get_comparable_java_version() { ################################################################################ function is_valid_requirement_pattern() { local java_req=$1 - java8pattern='1\.[4-8](\.0)?(\.0_[0-9]+)?[*+]?' + java8pattern='1\.[4-8](\.[0-9]+)?(\.0_[0-9]+)?[*+]?' java9pattern='(9|1[0-9])(-ea|[*+]|(\.[0-9]+){1,2}[*+]?)?' # test matches either old Java versioning scheme (up to 1.8) or new scheme (starting with 9) if [[ ${java_req} =~ ^(${java8pattern}|${java9pattern})$ ]]; then @@ -489,20 +584,28 @@ if [ -n "$JAVA_HOME" ] ; then if [[ $JAVA_HOME == /* ]] ; then # if "$JAVA_HOME" starts with a Slash it's an absolute path JAVACMD="$JAVA_HOME/bin/java" + stub_logger "[JavaSearch] ... parsing JAVA_HOME as absolute path to the executable '$JAVACMD'" else # otherwise it's a relative path to "$AppPackageFolder" JAVACMD="$AppPackageFolder/$JAVA_HOME/bin/java" + stub_logger "[JavaSearch] ... parsing JAVA_HOME as relative path inside the App bundle to the executable '$JAVACMD'" fi JAVACMD_version=$(get_comparable_java_version $(get_java_version_from_cmd "${JAVACMD}")) else - stub_logger "[JavaSearch] ... didn't found JAVA_HOME" + stub_logger "[JavaSearch] ... haven't found JAVA_HOME" fi # check for any other or a specific Java version # also if $JAVA_HOME exists but isn't executable if [ -z "${JAVACMD}" ] || [ ! -x "${JAVACMD}" ] ; then - stub_logger "[JavaSearch] Checking for JavaVirtualMachines on the system ..." + + # add a warning in the syslog if JAVA_HOME is not executable or not found (#100) + if [ -n "$JAVA_HOME" ] ; then + stub_logger "[JavaSearch] ... but no 'java' executable was found at the JAVA_HOME location!" + fi + + stub_logger "[JavaSearch] Searching for JavaVirtualMachines on the system ..." # reset variables JAVACMD="" JAVACMD_version="" @@ -513,7 +616,7 @@ if [ -z "${JAVACMD}" ] || [ ! -x "${JAVACMD}" ] ; then # log exit cause stub_logger "[EXIT 4] ${MSG_JVMVERSION_REQ_INVALID_EXPANDED}" # display error message with AppleScript - osascript -e "tell application \"System Events\" to display dialog \"${MSG_ERROR_LAUNCHING}\n\n${MSG_JVMVERSION_REQ_INVALID_EXPANDED}\" with title \"${CFBundleName}\" buttons {\" OK \"} default button 1 with icon path to resource \"${CFBundleIconFile}\" in bundle (path to me)" + osascript -e "tell application \"System Events\" to display dialog \"${MSG_ERROR_LAUNCHING}\n\n${MSG_JVMVERSION_REQ_INVALID_EXPANDED}\" with title \"${CFBundleName}\" buttons {\" OK \"} default button 1${DialogWithIcon}" # exit with error exit 4 fi @@ -523,7 +626,7 @@ if [ -z "${JAVACMD}" ] || [ ! -x "${JAVACMD}" ] ; then # log exit cause stub_logger "[EXIT 5] ${MSG_JVMVERSION_REQ_INVALID_EXPANDED}" # display error message with AppleScript - osascript -e "tell application \"System Events\" to display dialog \"${MSG_ERROR_LAUNCHING}\n\n${MSG_JVMVERSION_REQ_INVALID_EXPANDED}\" with title \"${CFBundleName}\" buttons {\" OK \"} default button 1 with icon path to resource \"${CFBundleIconFile}\" in bundle (path to me)" + osascript -e "tell application \"System Events\" to display dialog \"${MSG_ERROR_LAUNCHING}\n\n${MSG_JVMVERSION_REQ_INVALID_EXPANDED}\" with title \"${CFBundleName}\" buttons {\" OK \"} default button 1${DialogWithIcon}" # exit with error exit 5 fi @@ -531,15 +634,41 @@ if [ -z "${JAVACMD}" ] || [ ! -x "${JAVACMD}" ] ; then # find installed JavaVirtualMachines (JDK + JRE) allJVMs=() - # read JDK's from '/usr/libexec/java_home -V' command - while read -r line; do - version=$(echo $line | awk -F $',' '{print $1;}') - path=$(echo $line | awk -F $'" ' '{print $2;}') - path+="/bin/java" - allJVMs+=("$version:$path") - done < <(/usr/libexec/java_home -V 2>&1 | grep '^[[:space:]]') - # unset while loop variables - unset version path + + # read JDK's from '/usr/libexec/java_home --xml' command with PlistBuddy and a custom Dict iterator + # idea: https://stackoverflow.com/a/14085460/1128689 and https://scriptingosx.com/2018/07/parsing-dscl-output-in-scripts/ + javaXml=$(/usr/libexec/java_home --xml) + javaCounter=$(/usr/libexec/PlistBuddy -c "Print" /dev/stdin <<< $javaXml | grep "Dict" | wc -l | tr -d ' ') + + # iterate over all Dict entries + # but only if there are any JVMs at all (#93) + if [ "$javaCounter" -gt "0" ] ; then + for idx in $(seq 0 $((javaCounter - 1))) + do + version=$(/usr/libexec/PlistBuddy -c "print :$idx:JVMVersion" /dev/stdin <<< $javaXml) + path=$(/usr/libexec/PlistBuddy -c "print :$idx:JVMHomePath" /dev/stdin <<< $javaXml) + path+="/bin/java" + allJVMs+=("$version:$path") + done + # unset for loop variables + unset version path + fi + + # add SDKMAN! java versions (#95) + if [ -d ~/.sdkman/candidates/java/ ] ; then + for sdkjdk in ~/.sdkman/candidates/java/*/ + do + if [[ ${sdkjdk} =~ /current/$ ]] ; then + continue + fi + + sdkjdkcmd="${sdkjdk}bin/java" + version=$(get_java_version_from_cmd "${sdkjdkcmd}") + allJVMs+=("$version:$sdkjdkcmd") + done + # unset for loop variables + unset version + fi # add Apple JRE if available if [ -x "${apple_jre_plugin}" ] ; then @@ -559,6 +688,9 @@ if [ -z "${JAVACMD}" ] || [ ! -x "${JAVACMD}" ] ; then # determine JVMs matching the min/max version requirement + + stub_logger "[JavaSearch] Filtering the result list for JVMs matching the min/max version requirement ..." + minC=$(get_comparable_java_version ${JVMVersion}) maxC=$(get_comparable_java_version ${JVMMaxVersion}) matchingJVMs=() @@ -643,7 +775,7 @@ if [ -z "${JAVACMD}" ] || [ ! -x "${JAVACMD}" ] ; then # debug output for i in "${matchingJVMs[@]}" do - stub_logger "[JavaSearch] ... ... matches all requirements: $i" + stub_logger "[JavaSearch] ... matches all requirements: $i" done @@ -700,9 +832,10 @@ if [ -z "${JAVACMD}" ] || [ ! -x "${JAVACMD}" ] ; then stub_logger "[EXIT 3] ${MSG_NO_SUITABLE_JAVA_EXPANDED}" # display error message with AppleScript - osascript -e "tell application \"System Events\" to display dialog \"${MSG_ERROR_LAUNCHING}\n\n${MSG_NO_SUITABLE_JAVA_EXPANDED}\n${MSG_NO_SUITABLE_JAVA_CHECK}\" with title \"${CFBundleName}\" buttons {\" OK \", \"${MSG_VISIT_JAVA_DOT_COM}\"} default button \"${MSG_VISIT_JAVA_DOT_COM}\" with icon path to resource \"${CFBundleIconFile}\" in bundle (path to me)" \ + osascript -e "tell application \"System Events\" to display dialog \"${MSG_ERROR_LAUNCHING}\n\n${MSG_NO_SUITABLE_JAVA_EXPANDED}\n${MSG_NO_SUITABLE_JAVA_CHECK}\" with title \"${CFBundleName}\" buttons {\" OK \", \"${MSG_VISIT_JAVA_DOT_COM}\", \"${MSG_VISIT_ADOPTIUM}\"} default button 1${DialogWithIcon}" \ -e "set response to button returned of the result" \ - -e "if response is \"${MSG_VISIT_JAVA_DOT_COM}\" then open location \"http://java.com\"" + -e "if response is \"${MSG_VISIT_JAVA_DOT_COM}\" then open location \"https://www.java.com/download/\"" \ + -e "if response is \"${MSG_VISIT_ADOPTIUM}\" then open location \"https://adoptium.net/releases.html\"" # exit with error exit 3 @@ -710,9 +843,10 @@ if [ -z "${JAVACMD}" ] || [ ! -x "${JAVACMD}" ] ; then # log exit cause stub_logger "[EXIT 1] ${MSG_ERROR_LAUNCHING}" # display error message with AppleScript - osascript -e "tell application \"System Events\" to display dialog \"${MSG_ERROR_LAUNCHING}\n\n${MSG_INSTALL_JAVA}\" with title \"${CFBundleName}\" buttons {\"${MSG_LATER}\", \"${MSG_VISIT_JAVA_DOT_COM}\"} default button \"${MSG_VISIT_JAVA_DOT_COM}\" with icon path to resource \"${CFBundleIconFile}\" in bundle (path to me)" \ + osascript -e "tell application \"System Events\" to display dialog \"${MSG_ERROR_LAUNCHING}\n\n${MSG_INSTALL_JAVA}\" with title \"${CFBundleName}\" buttons {\"${MSG_LATER}\", \"${MSG_VISIT_JAVA_DOT_COM}\", \"${MSG_VISIT_ADOPTIUM}\"} default button 1${DialogWithIcon}" \ -e "set response to button returned of the result" \ - -e "if response is \"${MSG_VISIT_JAVA_DOT_COM}\" then open location \"http://java.com\"" + -e "if response is \"${MSG_VISIT_JAVA_DOT_COM}\" then open location \"https://www.java.com/download/\"" \ + -e "if response is \"${MSG_VISIT_ADOPTIUM}\" then open location \"https://adoptium.net/releases.html\"" # exit with error exit 1 fi @@ -727,7 +861,7 @@ if [ -z "${JVMMainClass}" ]; then # log exit cause stub_logger "[EXIT 2] ${MSG_MISSING_MAINCLASS}" # display error message with AppleScript - osascript -e "tell application \"System Events\" to display dialog \"${MSG_ERROR_LAUNCHING}\n\n${MSG_MISSING_MAINCLASS}\" with title \"${CFBundleName}\" buttons {\" OK \"} default button 1 with icon path to resource \"${CFBundleIconFile}\" in bundle (path to me)" + osascript -e "tell application \"System Events\" to display dialog \"${MSG_ERROR_LAUNCHING}\n\n${MSG_MISSING_MAINCLASS}\" with title \"${CFBundleName}\" buttons {\" OK \"} default button 1${DialogWithIcon}" # exit with error exit 2 fi @@ -761,13 +895,13 @@ stub_logger "[WorkingDirectory] ${WorkingDirectory}" # - main class # - main class arguments # - passthrough arguments from Terminal or Drag'n'Drop to Finder icon -stub_logger "[Exec] \"$JAVACMD\" -cp \"${JVMClassPath}\" -splash:\"${ResourcesFolder}/${JVMSplashFile}\" -Xdock:icon=\"${ResourcesFolder}/${CFBundleIconFile}\" -Xdock:name=\"${CFBundleName}\" ${JVMOptionsArr:+$(printf "'%s' " "${JVMOptionsArr[@]}") }${JVMDefaultOptions:+$JVMDefaultOptions }${JVMMainClass}${MainArgsArr:+ $(printf "'%s' " "${MainArgsArr[@]}")}${ArgsPassthru:+ $(printf "'%s' " "${ArgsPassthru[@]}")}" +stub_logger "[Exec] \"$JAVACMD\" -cp \"${JVMClassPath}\" ${JVMSplashFile:+ -splash:\"${ResourcesFolder}/${JVMSplashFile}\"} -Xdock:icon=\"${ResourcesFolder}/${CFBundleIconFile}\" -Xdock:name=\"${CFBundleName}\" ${JVMOptionsArr:+$(printf "'%s' " "${JVMOptionsArr[@]}") }${JVMDefaultOptions:+$JVMDefaultOptions }${JVMMainClass}${MainArgsArr:+ $(printf "'%s' " "${MainArgsArr[@]}")}${ArgsPassthru:+ $(printf "'%s' " "${ArgsPassthru[@]}")}" exec "${JAVACMD}" \ -cp "${JVMClassPath}" \ - -splash:"${ResourcesFolder}/${JVMSplashFile}" \ + ${JVMSplashFile:+ -splash:"${ResourcesFolder}/${JVMSplashFile}"} \ -Xdock:icon="${ResourcesFolder}/${CFBundleIconFile}" \ -Xdock:name="${CFBundleName}" \ - ${JVMOptions:+"${JVMOptions[@]}" }\ + ${JVMOptionsArr:+"${JVMOptionsArr[@]}" }\ ${JVMDefaultOptions:+$JVMDefaultOptions }\ "${JVMMainClass}"\ ${MainArgsArr:+ "${MainArgsArr[@]}"}\ From 4db8d8e70e1c8315ac375b0ecb0233801df7f34e Mon Sep 17 00:00:00 2001 From: Kuoi Date: Fri, 28 Apr 2023 23:14:46 +0800 Subject: [PATCH 40/65] fix: typo --- build_coalsim.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_coalsim.xml b/build_coalsim.xml index 8a66005daf..db081e2209 100644 --- a/build_coalsim.xml +++ b/build_coalsim.xml @@ -8,7 +8,7 @@ - 2 + From 9f0138e48104b1097dc73c72047cff285cfb4703 Mon Sep 17 00:00:00 2001 From: rambaut Date: Thu, 18 May 2023 15:13:05 +0100 Subject: [PATCH 41/65] Updating comment --- src/dr/evomodel/tree/AgeStatistic.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodel/tree/AgeStatistic.java b/src/dr/evomodel/tree/AgeStatistic.java index dc5b7685e3..0bf892c579 100644 --- a/src/dr/evomodel/tree/AgeStatistic.java +++ b/src/dr/evomodel/tree/AgeStatistic.java @@ -35,11 +35,10 @@ import java.util.Set; /** - * A statistic that tracks the time of MRCA of a set of taxa + * A statistic that calculates the age (absolute time) from a height * - * @author Alexei Drummond * @author Andrew Rambaut - * @version $Id: TMRCAStatistic.java,v 1.21 2005/07/11 14:06:25 rambaut Exp $ + * @version $Id: $ */ public class AgeStatistic extends Statistic.Abstract { From 422dc0812b1e6efa9d4cbde2a1d0caca6daaefa9 Mon Sep 17 00:00:00 2001 From: rambaut Date: Thu, 18 May 2023 15:16:02 +0100 Subject: [PATCH 42/65] Updating comment --- src/dr/app/beast/release_parsers.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/app/beast/release_parsers.properties b/src/dr/app/beast/release_parsers.properties index ba5a7ceeb3..2561497c49 100644 --- a/src/dr/app/beast/release_parsers.properties +++ b/src/dr/app/beast/release_parsers.properties @@ -287,7 +287,7 @@ dr.evomodelxml.tree.TreeLengthStatisticParser dr.evomodelxml.tree.NodeHeightsStatisticParser dr.evomodelxml.tree.TreeShapeStatisticParser dr.evomodelxml.tree.TMRCAStatisticParser -dr.evomodelxml.tree.AgeStatisticParser +dr.evomodelxml.tree.AgeStatisticParser dr.evomodelxml.tree.MRCATraitStatisticParser dr.evomodelxml.tree.AncestralTraitParser dr.evomodelxml.tree.ExternalLengthStatisticParser From 8a10723b9d2db3752cdfbc425347a66121377ab2 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Thu, 18 May 2023 15:30:05 +0100 Subject: [PATCH 43/65] errant space in a parser list --- src/dr/app/beast/release_parsers.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dr/app/beast/release_parsers.properties b/src/dr/app/beast/release_parsers.properties index 2561497c49..ba5a7ceeb3 100644 --- a/src/dr/app/beast/release_parsers.properties +++ b/src/dr/app/beast/release_parsers.properties @@ -287,7 +287,7 @@ dr.evomodelxml.tree.TreeLengthStatisticParser dr.evomodelxml.tree.NodeHeightsStatisticParser dr.evomodelxml.tree.TreeShapeStatisticParser dr.evomodelxml.tree.TMRCAStatisticParser -dr.evomodelxml.tree.AgeStatisticParser +dr.evomodelxml.tree.AgeStatisticParser dr.evomodelxml.tree.MRCATraitStatisticParser dr.evomodelxml.tree.AncestralTraitParser dr.evomodelxml.tree.ExternalLengthStatisticParser From 72af2af77714b5fcf3c2d8aefa837f53174f1fa3 Mon Sep 17 00:00:00 2001 From: Andrew Rambaut Date: Thu, 18 May 2023 15:34:24 +0100 Subject: [PATCH 44/65] Updating version info --- src/dr/app/beast/BeastVersion.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dr/app/beast/BeastVersion.java b/src/dr/app/beast/BeastVersion.java index 82e4613f34..c74ef47379 100644 --- a/src/dr/app/beast/BeastVersion.java +++ b/src/dr/app/beast/BeastVersion.java @@ -55,12 +55,12 @@ public class BeastVersion implements Version, Citable { */ private static final String VERSION = "1.10.5"; - private static final String DATE_STRING = "2002-2019"; + private static final String DATE_STRING = "2002-2023"; private static final boolean IS_PRERELEASE = true; // this is now being manually updated since the move to GitHub. 7 digits of GitHub hash. - private static final String REVISION = "23570d1"; + private static final String REVISION = "8a10723"; public String getVersion() { return VERSION; From 779e3a12e9f0b817811dfada2fa58a0a2d94dfc9 Mon Sep 17 00:00:00 2001 From: rambaut Date: Fri, 2 Jun 2023 16:26:55 +0100 Subject: [PATCH 45/65] Added the ability for a local clock clade to have the transition time on the stem branch as a parameter. Should be bounded at 0 and 1. --- .../branchratemodel/LocalClockModel.java | 39 ++++++++++++------- .../LocalClockModelParser.java | 32 ++++++++++----- 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/src/dr/evomodel/branchratemodel/LocalClockModel.java b/src/dr/evomodel/branchratemodel/LocalClockModel.java index ecd2dddf97..1cd1856621 100644 --- a/src/dr/evomodel/branchratemodel/LocalClockModel.java +++ b/src/dr/evomodel/branchratemodel/LocalClockModel.java @@ -113,20 +113,26 @@ public void addExternalBranchClock(TaxonList taxonList, BranchRateModel branchRa addModel(branchRates); } - public void addCladeClock(TaxonList taxonList, Parameter rateParameter, boolean isRelativeRate, double stemProportion, boolean excludeClade) throws TreeUtils.MissingTaxonException { + public void addCladeClock(TaxonList taxonList, Parameter rateParameter, boolean isRelativeRate, Parameter stemParameter, boolean excludeClade) throws TreeUtils.MissingTaxonException { Set tips = TreeUtils.getTipsForTaxa(treeModel, taxonList); BitSet tipBitSet = TreeUtils.getTipsBitSetForTaxa(treeModel, taxonList); - LocalClock clock = new LocalClock(rateParameter, isRelativeRate, tips, stemProportion, excludeClade); + LocalClock clock = new LocalClock(rateParameter, isRelativeRate, tips, stemParameter, excludeClade); localCladeClocks.put(tipBitSet, clock); addVariable(rateParameter); + if (stemParameter != null) { + addVariable(stemParameter); + } } - public void addCladeClock(TaxonList taxonList, BranchRateModel branchRates, boolean isRelativeRate, double stemProportion, boolean excludeClade) throws TreeUtils.MissingTaxonException { + public void addCladeClock(TaxonList taxonList, BranchRateModel branchRates, boolean isRelativeRate, Parameter stemParameter, boolean excludeClade) throws TreeUtils.MissingTaxonException { Set tips = TreeUtils.getTipsForTaxa(treeModel, taxonList); BitSet tipBitSet = TreeUtils.getTipsBitSetForTaxa(treeModel, taxonList); - LocalClock clock = new LocalClock(branchRates, isRelativeRate, tips, stemProportion, excludeClade); + LocalClock clock = new LocalClock(branchRates, isRelativeRate, tips, stemParameter, excludeClade); localCladeClocks.put(tipBitSet, clock); addModel(branchRates); + if (stemParameter != null) { + addVariable(stemParameter); + } } public void addTrunkClock(TaxonList taxonList, Parameter rateParameter, Parameter indexParameter, boolean isRelativeRate) throws TreeUtils.MissingTaxonException { @@ -389,7 +395,7 @@ private class LocalClock { this.tips = tipSet; this.tipList = null; this.type = type; - this.stemProportion = 1.0; + this.stemParameter = null; this.excludeClade = true; } @@ -401,11 +407,11 @@ private class LocalClock { this.tips = tipSet; this.tipList = null; this.type = type; - this.stemProportion = 1.0; + this.stemParameter = null; this.excludeClade = true; } - LocalClock(Parameter rateParameter, boolean isRelativeRate, Set tips, double stemProportion, boolean excludeClade) { + LocalClock(Parameter rateParameter, boolean isRelativeRate, Set tips, Parameter stemParameter, boolean excludeClade) { this.rateParameter = rateParameter; this.branchRates = null; this.indexParameter = null; @@ -413,11 +419,11 @@ private class LocalClock { this.tips = tips; this.tipList = null; this.type = ClockType.CLADE; - this.stemProportion = stemProportion; + this.stemParameter = stemParameter; this.excludeClade = excludeClade; } - LocalClock(BranchRateModel branchRates, boolean isRelativeRate, Set tips, double stemProportion, boolean excludeClade) { + LocalClock(BranchRateModel branchRates, boolean isRelativeRate, Set tips, Parameter stemParameter, boolean excludeClade) { this.rateParameter = null; this.branchRates = branchRates; this.indexParameter = null; @@ -425,7 +431,7 @@ private class LocalClock { this.tips = tips; this.tipList = null; this.type = ClockType.CLADE; - this.stemProportion = stemProportion; + this.stemParameter = stemParameter; this.excludeClade = excludeClade; } @@ -437,7 +443,7 @@ private class LocalClock { this.tips = null; this.tipList = tipList; this.type = type; - this.stemProportion = 1.0; + this.stemParameter = null; this.excludeClade = true; } @@ -449,12 +455,17 @@ private class LocalClock { this.tips = null; this.tipList = tipList; this.type = type; - this.stemProportion = 1.0; + this.stemParameter = null; this.excludeClade = true; } double getStemProportion() { - return this.stemProportion; + if (stemParameter != null) { + return stemParameter.getParameterValue(0); + } else { + // if no parameter then default to 0 stem + return 0.0; + } } boolean excludeClade() { @@ -484,7 +495,7 @@ boolean isRelativeRate() { private final Set tips; private final List tipList; private final ClockType type; - private final double stemProportion; + private final Parameter stemParameter; private final boolean excludeClade; } diff --git a/src/dr/evomodelxml/branchratemodel/LocalClockModelParser.java b/src/dr/evomodelxml/branchratemodel/LocalClockModelParser.java index 164865a01b..06bda961b3 100644 --- a/src/dr/evomodelxml/branchratemodel/LocalClockModelParser.java +++ b/src/dr/evomodelxml/branchratemodel/LocalClockModelParser.java @@ -93,16 +93,24 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } boolean excludeClade = false; - double stemProportion = 0.0; - + Parameter stemParameter = null; + if (xoc.hasAttribute(INCLUDE_STEM)) { // if includeStem=true then assume it is the whole stem - stemProportion = xoc.getBooleanAttribute(INCLUDE_STEM) ? 1.0 : 0.0; + stemParameter = new Parameter.Default(xoc.getBooleanAttribute(INCLUDE_STEM) ? 1.0 : 0.0); } if (xoc.hasAttribute(STEM_PROPORTION)) { - stemProportion = xoc.getDoubleAttribute(STEM_PROPORTION); - if (stemProportion < 0.0 || stemProportion > 1.0) { + double stemValue = xoc.getDoubleAttribute(STEM_PROPORTION); + if (stemValue < 0.0 || stemValue > 1.0) { + throw new XMLParseException("A stem proportion should be between 0, 1"); + } + stemParameter = new Parameter.Default(stemValue); + } + + if (xoc.hasChildNamed(STEM_PROPORTION)) { + stemParameter = (Parameter) xoc.getElementFirstChild(STEM_PROPORTION); + if (stemParameter.getParameterValue(0) < 0.0 || stemParameter.getParameterValue(0) > 1.0) { throw new XMLParseException("A stem proportion should be between 0, 1"); } } @@ -113,9 +121,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { try { if (branchRates != null) { - localClockModel.addCladeClock(taxonList, branchRates, relative, stemProportion, excludeClade); + localClockModel.addCladeClock(taxonList, branchRates, relative, stemParameter, excludeClade); } else { - localClockModel.addCladeClock(taxonList, rateParameter, relative, stemProportion, excludeClade); + localClockModel.addCladeClock(taxonList, rateParameter, relative, stemParameter, excludeClade); } } catch (TreeUtils.MissingTaxonException mte) { @@ -204,8 +212,14 @@ public XMLSyntaxRule[] getSyntaxRules() { new ElementRule(CLADE, new XMLSyntaxRule[]{ AttributeRule.newBooleanRule(RELATIVE, true), - AttributeRule.newBooleanRule(INCLUDE_STEM, true, "determines whether or not the stem branch above this clade is included in the siteModel (default false)."), - AttributeRule.newDoubleRule(STEM_PROPORTION, true, "proportion of stem to include in clade rate (default 0)."), + new XORRule( + new XMLSyntaxRule[]{ + AttributeRule.newBooleanRule(INCLUDE_STEM, true, "determines whether or not the stem branch above this clade is included in the siteModel (default false)."), + AttributeRule.newDoubleRule(STEM_PROPORTION, true, "proportion of stem to include in clade rate (default 0)."), + new ElementRule(STEM_PROPORTION, Parameter.class, "A parameter for the proportion of stem to include in clade rate (0 - 1)", false) + }, + true + ), AttributeRule.newBooleanRule(EXCLUDE_CLADE, true, "determines whether to exclude actual branches of the clade from the siteModel (default false)."), new ElementRule(Taxa.class, "A set of taxa which defines a clade to apply a different site model to"), new XORRule( From 0b892650824425b0a87c99eca5f71a30f7994fd8 Mon Sep 17 00:00:00 2001 From: rambaut Date: Fri, 9 Jun 2023 13:35:55 +0100 Subject: [PATCH 46/65] Putting getSequences into SequenceList interface. --- src/dr/evolution/alignment/GapStrippedAlignment.java | 4 ++++ src/dr/evolution/alignment/WrappedAlignment.java | 5 +++++ src/dr/evolution/sequence/SequenceList.java | 6 ++++++ src/dr/evolution/sequence/Sequences.java | 9 +++++---- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/dr/evolution/alignment/GapStrippedAlignment.java b/src/dr/evolution/alignment/GapStrippedAlignment.java index ac3218ca51..9dc75e4790 100644 --- a/src/dr/evolution/alignment/GapStrippedAlignment.java +++ b/src/dr/evolution/alignment/GapStrippedAlignment.java @@ -103,6 +103,10 @@ public final Object getSequenceAttribute(int index, String name) { throw new UnsupportedOperationException(); } + public List getSequences() { + return alignment.getSequences(); + } + public final int getTaxonCount() { return alignment.getTaxonCount(); } diff --git a/src/dr/evolution/alignment/WrappedAlignment.java b/src/dr/evolution/alignment/WrappedAlignment.java index a08237ee74..7ebd65662f 100644 --- a/src/dr/evolution/alignment/WrappedAlignment.java +++ b/src/dr/evolution/alignment/WrappedAlignment.java @@ -33,6 +33,7 @@ import dr.evolution.util.Taxon; import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -181,6 +182,10 @@ public Object getSequenceAttribute(int index, String name) { return alignment.getSequenceAttribute(index, name); } + public List getSequences() { + return alignment.getSequences(); + } + /** * @return a count of the number of taxa in the list. */ diff --git a/src/dr/evolution/sequence/SequenceList.java b/src/dr/evolution/sequence/SequenceList.java index ae96c39fc0..d842bba878 100644 --- a/src/dr/evolution/sequence/SequenceList.java +++ b/src/dr/evolution/sequence/SequenceList.java @@ -27,6 +27,8 @@ import dr.evolution.util.TaxonList; +import java.util.List; + /** * Interface for a list of sequences. * @@ -61,5 +63,9 @@ public interface SequenceList extends TaxonList { */ public Object getSequenceAttribute(int index, String name); + /** + * @return an immutable iterable List of sequences + */ + public List getSequences(); } diff --git a/src/dr/evolution/sequence/Sequences.java b/src/dr/evolution/sequence/Sequences.java index e792e87006..e4115841d2 100644 --- a/src/dr/evolution/sequence/Sequences.java +++ b/src/dr/evolution/sequence/Sequences.java @@ -29,10 +29,7 @@ import dr.util.Attributable; import dr.util.Identifiable; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Vector; +import java.util.*; /** * Class for storing sequences. @@ -88,6 +85,10 @@ public Object getSequenceAttribute(int index, String name) { return sequence.getAttribute(name); } + public List getSequences() { + return Collections.unmodifiableList(sequences); + } + // ************************************************************** // TaxonList IMPLEMENTATION // ************************************************************** From bb9c7eb5598e956f465909e8bc81465d19edd98d Mon Sep 17 00:00:00 2001 From: rambaut Date: Fri, 9 Jun 2023 13:38:13 +0100 Subject: [PATCH 47/65] Implemented issue #1152 - Dummy check for unaligned sequences in alignment --- src/dr/app/beauti/util/BEAUTiImporter.java | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/dr/app/beauti/util/BEAUTiImporter.java b/src/dr/app/beauti/util/BEAUTiImporter.java index 3e6e362eff..f6d9062a38 100644 --- a/src/dr/app/beauti/util/BEAUTiImporter.java +++ b/src/dr/app/beauti/util/BEAUTiImporter.java @@ -41,6 +41,7 @@ import dr.evolution.io.NexusImporter; import dr.evolution.io.NexusImporter.MissingBlockException; import dr.evolution.io.NexusImporter.NexusBlock; +import dr.evolution.sequence.Sequence; import dr.evolution.tree.Tree; import dr.evolution.util.Taxa; import dr.evolution.util.Taxon; @@ -560,6 +561,18 @@ private void setData(String fileName, TaxonList taxonList, Alignment alignment, options.fileNameStem = fileNameStem; } + // check the alignment before adding it... + if (alignment.getSiteCount() == 0) { + // sequences are different lengths + throw new ImportException("This alignment is of zero length"); + } + for (Sequence seq : alignment.getSequences()) { + if (seq.getLength() != alignment.getSiteCount()) { + // sequences are different lengths + throw new ImportException("The sequences in the alignment file are of different lengths - BEAST requires aligned sequences"); + } + } + addTaxonList(taxonList); addAlignment(alignment, charSets, model, fileName, fileNameStem); From 6d19a2c19f70a206a32239fe08e29eb4265b3bc2 Mon Sep 17 00:00:00 2001 From: rambaut Date: Fri, 9 Jun 2023 13:40:46 +0100 Subject: [PATCH 48/65] Cleaning up import error message dialog box --- src/dr/app/beauti/BeautiFrame.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dr/app/beauti/BeautiFrame.java b/src/dr/app/beauti/BeautiFrame.java index 05e581cd58..e3f7d138c7 100644 --- a/src/dr/app/beauti/BeautiFrame.java +++ b/src/dr/app/beauti/BeautiFrame.java @@ -509,25 +509,25 @@ private void importFiles(File[] files) { // JOptionPane.showMessageDialog(this, "Unable to open file: File not found", // "Unable to open file", JOptionPane.ERROR_MESSAGE); } catch (IOException ioe) { - JOptionPane.showMessageDialog(this, "File I/O Error unable to read file: " + ioe.getMessage(), + JOptionPane.showMessageDialog(this, "File I/O Error unable to read file:\n " + ioe.getMessage(), "Unable to read file", JOptionPane.ERROR_MESSAGE); ioe.printStackTrace(); // there may be other files in the list so don't return // return; } catch (MissingBlockException ex) { - JOptionPane.showMessageDialog(this, "TAXON, DATA or CHARACTERS block is missing in Nexus file: " + ex, + JOptionPane.showMessageDialog(this, "TAXON, DATA or CHARACTERS block is missing in Nexus file:\n " + ex.getMessage(), "Missing Block in Nexus File", JOptionPane.ERROR_MESSAGE); ex.printStackTrace(); } catch (ImportException ime) { - JOptionPane.showMessageDialog(this, "Error parsing imported file: " + ime, + JOptionPane.showMessageDialog(this, "Error parsing imported file:\n " + ime.getMessage(), "Error reading file", JOptionPane.ERROR_MESSAGE); ime.printStackTrace(); } catch (JDOMException jde) { - JOptionPane.showMessageDialog(this, "Error parsing imported file: " + jde, + JOptionPane.showMessageDialog(this, "Error parsing imported file:\n " + jde.getMessage(), "Error reading file", JOptionPane.ERROR_MESSAGE); jde.printStackTrace(); From 8d136c08849505f13e455b280ccdf24d0b63516f Mon Sep 17 00:00:00 2001 From: Plemey Date: Mon, 19 Jun 2023 12:05:23 +0200 Subject: [PATCH 49/65] enhanced summary capacity for GetNSCountsFromTrees --- src/dr/app/tools/GetNSCountsFromTrees.java | 317 +++++++++++++++++---- 1 file changed, 269 insertions(+), 48 deletions(-) diff --git a/src/dr/app/tools/GetNSCountsFromTrees.java b/src/dr/app/tools/GetNSCountsFromTrees.java index 5486216bd0..b745c60f6a 100644 --- a/src/dr/app/tools/GetNSCountsFromTrees.java +++ b/src/dr/app/tools/GetNSCountsFromTrees.java @@ -8,7 +8,6 @@ import dr.evolution.io.TreeImporter; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; -import dr.evolution.tree.TreeUtils; import dr.inference.trace.TraceException; import dr.util.Version; @@ -20,7 +19,6 @@ */ public class GetNSCountsFromTrees { - //TODO: allow to add parent and descendent node discrete state private final static Version version = new BeastVersion(); public static final String BURNIN = "burnin"; public static final String totalcN = "N"; @@ -33,7 +31,7 @@ public class GetNSCountsFromTrees { public static final String BRANCHINFO = "branchInfo"; public static final String[] falseTrue = {"false", "true"}; public static final String BRANCHSET = "branchSet"; -// public static final String CLADETAXA = "cladeTaxa"; + // public static final String CLADETAXA = "cladeTaxa"; public static final String INCLUDECLADES = "includeClades"; public static final String CLADESTEM = "cladeStem"; public static final String EXCLUDECLADESTEM = "excludeCladeStem"; @@ -44,11 +42,18 @@ public class GetNSCountsFromTrees { public static final String CODONSITELIST = "codonSiteList"; public static final String MRSD = "mrsd"; public static final String EXCLUDECLADES = "excludeClades"; - - public GetNSCountsFromTrees(int burnin, String inputFileName, String outputFileName, boolean branchInfo, - BranchSet branchSet, List inclusionSets, boolean cladeStem, boolean zeroBranches, boolean summary, - int sites, double mrsd, List exclusionSets, boolean excludeCladeStems, - double[] siteList) throws IOException { + public static final String PREFIX = "prefix"; + public static final String HEIGHT_ABOVE = "heightAbove"; + public static final String HEIGHT_BELOW = "heightBelow"; + public static final String TIME_BEFORE = "timeBefore"; + public static final String TIME_AFTER = "timeAfter"; + public static final String DISCRETE_TRAIT_STATE_SET = "discreteTraitStateSet"; + public static final String DISCRETE_TRAIT_ATTRIBUTE = "discreteTraitAttribute"; + + public GetNSCountsFromTrees (int burnin, String inputFileName, String outputFileName, boolean branchInfo, + BranchSet branchSet, List inclusionSets, boolean cladeStem, boolean zeroBranches, + boolean summary, int sites, double mrsd, List exclusionSets, boolean excludeCladeStems, + double[] siteList, String prefix, double heightAbove, double heightBelow, String discreteTrait, String[] stateSet) throws IOException { File inputFile = new File(inputFileName); if (inputFile.isFile()){ @@ -59,17 +64,42 @@ public GetNSCountsFromTrees(int burnin, String inputFileName, String outputFileN } this.branchInfo = branchInfo; + // in case you would like to ignore branches without N or S substitutions (can't remember why I thought this could be useful) this.zeroBranches = zeroBranches; this.summary = summary; this.mrsd = mrsd; this.cladeStem = cladeStem; this.excludeCladeStems = excludeCladeStems; + this.discreteTrait = discreteTrait; this.sites = sites; if (siteList!=null){ this.sites = (siteList.length)*3; - progressStream.println("number of sites set based on site list provided"); + progressStream.println("Sites set based on site list provided (dN/dS may not be adequate as uN and uS does not account for site list)"); + } + + if (heightAbove > 0 || heightBelow < Double.MAX_VALUE) { + timeConstraints = true; + } else { + timeConstraints = false; + } + + if (prefix!=null){ + prefixTotalcN = prefix+"."+totalcN; + prefixTotalcS = prefix+"."+totalcS; + prefixTotaluN = prefix+"."+totaluN; + prefixTotaluS = prefix+"."+totaluS; + prefixHistoryN = prefix+"."+historyN; + prefixHistoryS = prefix+"."+historyS; + } else { + prefixTotalcN = totalcN; + prefixTotalcS = totalcS; + prefixTotaluN = totaluN; + prefixTotaluS = totaluS; + prefixHistoryN = historyN; + prefixHistoryS = historyS; } + // resultsStream = System.out; if (outputFileName != null) { @@ -83,11 +113,12 @@ public GetNSCountsFromTrees(int burnin, String inputFileName, String outputFileN resultsStream = new PrintStream(new File(inputFileName+".NSout.txt")); } - analyze(inputFile, burnin, branchSet, inclusionSets, exclusionSets, siteList); + analyze(inputFile, burnin, branchSet, inclusionSets, exclusionSets, siteList, stateSet, heightAbove, heightBelow); } - private void analyze(File inputFile, int burnin, BranchSet branchSet, List inclusionSets, List exclusionSets, double[] siteList){ + private void analyze(File inputFile, int burnin, BranchSet branchSet, List inclusionSets, List exclusionSets, + double[] siteList, String[] stateSet, double heightAbove, double heightBelow){ if (summary) { resultsStream.print("tree"+SEP+"cN"+SEP+"uN"+SEP+"cS"+SEP+"uS"+SEP+"cNrate"+SEP+"cSrate"+SEP+"dN/dS"+"\n"); @@ -121,7 +152,7 @@ private void analyze(File inputFile, int burnin, BranchSet branchSet, List Tree tree = importer.importNextTree(); if(count>=burnin) { - getNSCounts(tree, treeUsed, branchSet, inclusionSets, exclusionSets, siteList); + getNSCounts(tree, treeUsed, branchSet, inclusionSets, exclusionSets, siteList, stateSet, heightAbove, heightBelow); treeUsed ++; } count++; @@ -143,7 +174,8 @@ private void analyze(File inputFile, int burnin, BranchSet branchSet, List } } - private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List inclusionSets, List exclusionSets, double[] siteList){ + private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List inclusionSets, List exclusionSets, + double[] siteList, String[] stateSet, double heightAbove, double heightBelow){ int count = 0; double cN = 0; double uN = 0; @@ -155,33 +187,91 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List NodeRef node = tree.getNode(x); if (!tree.isRoot(node)){ count ++; - if (nodeToConsider(tree, node, branchSet, inclusionSets, exclusionSets)){ + if (nodeToConsider(tree, node, branchSet, inclusionSets, exclusionSets, stateSet, heightAbove, heightBelow)){ double branchLength = tree.getBranchLength(node); + if (timeConstraints){ + double nodeHeight = tree.getNodeHeight(node); + double parentNodeHeight = tree.getNodeHeight(tree.getParent(node)); + double upperHeight = parentNodeHeight; + double lowerHeight = nodeHeight; + if (heightBelow < Double.MAX_VALUE) { + if (parentNodeHeight > heightBelow) { + upperHeight = heightBelow; + } + } + if (heightAbove > 0){ + if(nodeHeight < heightAbove) { + lowerHeight = heightAbove; + } + } + if ((upperHeight - lowerHeight)>0){ + branchLength = upperHeight - lowerHeight; + } + } length += branchLength; - Object totalNObject = tree.getNodeAttribute(node, totalcN); - Object totalSObject = tree.getNodeAttribute(node, totalcS); + + Object totalNObject = tree.getNodeAttribute(node, prefixTotalcN); + Object totalSObject = tree.getNodeAttribute(node, prefixTotalcS); + //in case N and S would not be annotated if (totalNObject!=null && totalSObject!=null) { - double totalN = (Double) totalNObject; + //System.out.println("Hallo"); + double totalN = (Double) totalNObject; double totalS = (Double) totalSObject; - double totaluNObject = (Double) tree.getNodeAttribute(node, totaluN); - double totaluSObject = (Double) tree.getNodeAttribute(node, totaluS); - if(siteList==null){ + double totaluN = (Double) tree.getNodeAttribute(node, prefixTotaluN); + double totaluS = (Double) tree.getNodeAttribute(node, prefixTotaluS); + if(siteList==null && !timeConstraints){ cN += totalN; cS += totalS; } - uN += totaluNObject; - uS += totaluSObject; + uN += totaluN; + uS += totaluS; if (totalN > 0) { - Object[] allNObject = (Object[]) tree.getNodeAttribute(node, historyN); + Object[] allNObject = (Object[]) tree.getNodeAttribute(node, prefixHistoryN); for (int a = 0; a < allNObject.length; a++) { Object[] singleNObject = (Object[]) allNObject[a]; boolean proceedAgain = false; if(siteList==null) { - proceedAgain = true; + if (!timeConstraints){ + proceedAgain = true; + } else { + boolean timeCompatible = true; + if (heightBelow < Double.MAX_VALUE){ + if((Double) singleNObject[1] > heightBelow){ + timeCompatible = false; + } + } + if (heightAbove > 0){ + if((Double) singleNObject[1] < heightAbove) { + timeCompatible = false; + } + } + if (timeCompatible){ + proceedAgain = true; + cN ++; + } + } } else { if (inSiteList((Integer)singleNObject[0], siteList)) { - proceedAgain = true; - cN ++; + if (!timeConstraints){ + proceedAgain = true; + cN ++; + } else { + boolean timeCompatible = true; + if (heightBelow < Double.MAX_VALUE){ + if((Double) singleNObject[1] > heightBelow){ + timeCompatible = false; + } + } + if (heightAbove > 0){ + if((Double) singleNObject[1] < heightAbove) { + timeCompatible = false; + } + } + if (timeCompatible){ + proceedAgain = true; + cN ++; + } + } } } if(!summary && proceedAgain){ @@ -194,7 +284,7 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List } resultsStream.print(singleNObject[2] + SEP + singleNObject[3] + SEP); if (branchInfo) { - resultsStream.print(branchLength + SEP + totalN + SEP + totaluNObject + "\n"); + resultsStream.print(branchLength + SEP + totalN + SEP + totaluN + "\n"); } else { resultsStream.print("\n"); } @@ -203,16 +293,52 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List } } if (totalS > 0) { - Object[] allSObject = (Object[]) tree.getNodeAttribute(node, historyS); + Object[] allSObject = (Object[]) tree.getNodeAttribute(node, prefixHistoryS); for (int a = 0; a < allSObject.length; a++) { Object[] singleSObject = (Object[]) allSObject[a]; boolean proceedAgain = false; if(siteList==null) { - proceedAgain = true; + if (!timeConstraints){ + proceedAgain = true; + } else { + boolean timeCompatible = true; + if (heightBelow < Double.MAX_VALUE){ + if((Double) singleSObject[1] > heightBelow){ + timeCompatible = false; + } + } + if (heightAbove > 0){ + if((Double) singleSObject[1] < heightAbove) { + timeCompatible = false; + } + } + if (timeCompatible){ + proceedAgain = true; + cS ++; + } + } } else { if (inSiteList((Integer)singleSObject[0], siteList)) { - proceedAgain = true; - cS ++; + if (!timeConstraints){ + proceedAgain = true; + cS ++; + } else { + boolean timeCompatible = true; + if (heightBelow < Double.MAX_VALUE){ + if((Double) singleSObject[1] > heightBelow){ + timeCompatible = false; + } + } + if (heightAbove > 0){ + if((Double) singleSObject[1] < heightAbove) { + timeCompatible = false; + } + } + if (timeCompatible){ + proceedAgain = true; + cS ++; + } + } } } if(!summary && proceedAgain){ @@ -225,7 +351,7 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List } resultsStream.print(singleSObject[2] + SEP + singleSObject[3] + SEP); if (branchInfo) { - resultsStream.print(branchLength + SEP + totalS + SEP + totaluSObject + "\n"); + resultsStream.print(branchLength + SEP + totalS + SEP + totaluS + "\n"); } else { resultsStream.print("\n"); } @@ -233,6 +359,7 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List } } + //in case you do not wish to print out zero branch lengths if ((totalN+totalS)==0){ if (zeroBranches) { if (!summary){ @@ -249,7 +376,9 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List length -= branchLength; } } - //tree,branch,S/N,position,oriState,destState,time,branchLength,totalBranchcN,totalcS,totaluN,totaluS,summaryTine + } else { + System.err.println("No N or S annotations?"); + System.exit(-1); } } } @@ -260,7 +389,8 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List } } - private boolean nodeToConsider(Tree tree, NodeRef node, BranchSet branchSet, List inclusionSets, List exclusionSets){ + private boolean nodeToConsider(Tree tree, NodeRef node, BranchSet branchSet, List inclusionSets, List exclusionSets, + String[] stateSet, double heightAbove, double heightBelow){ boolean nodeToConsider = false; if (branchSet == BranchSet.ALL) { nodeToConsider = true; @@ -297,6 +427,35 @@ private boolean nodeToConsider(Tree tree, NodeRef node, BranchSet branchSet, Lis } } } + + // making sure it is a branch that maintains a specified state = node and parent state are a specified state. + if(nodeToConsider && stateSet!=null) { + nodeToConsider = false; + String nodeState = ((String)tree.getNodeAttribute(node, discreteTrait)).replaceAll("\"",""); + String parentNodeState = ((String)tree.getNodeAttribute(tree.getParent(node), discreteTrait)).replaceAll("\"","");; + if (nodeState.equals(parentNodeState)){ +// System.out.println(nodeState); + for (String state: stateSet){ + if (state.equalsIgnoreCase(nodeState)){ + nodeToConsider = true; + break; + } + } + } + } + + //partially accounting for specified time constraints + if (nodeToConsider && timeConstraints) { + double nodeHeight = tree.getNodeHeight(node); + double parentNodeHeight = tree.getNodeHeight(tree.getParent(node)); + if(parentNodeHeight < heightAbove){ + nodeToConsider = false; + } + if(nodeHeight > heightBelow){ + nodeToConsider = false; + } + } + return nodeToConsider; } @@ -337,7 +496,7 @@ private static boolean onBackbone(Tree tree, NodeRef node, Set targetSet) { if (tree.isExternal(node)) return false; - Set leafSet = TreeUtils.getDescendantLeaves(tree, node); + Set leafSet = Tree.Utils.getDescendantLeaves(tree, node); int size = leafSet.size(); leafSet.retainAll(targetSet); @@ -347,7 +506,7 @@ private static boolean onBackbone(Tree tree, NodeRef node, Set targetSet) { // if all leaves below are in target then check just above. if (leafSet.size() == size) { - Set superLeafSet = TreeUtils.getDescendantLeaves(tree, tree.getParent(node)); + Set superLeafSet = Tree.Utils.getDescendantLeaves(tree, tree.getParent(node)); superLeafSet.removeAll(targetSet); // the branch is on ancestral path if the super tree has some non-targets in it @@ -359,7 +518,7 @@ private static boolean onBackbone(Tree tree, NodeRef node, Set targetSet) { } private static boolean inClade(Tree tree, NodeRef node, Set targetSet, boolean includeStem) { - Set leafSet = TreeUtils.getDescendantLeaves(tree, node); + Set leafSet = Tree.Utils.getDescendantLeaves(tree, node); leafSet.removeAll(targetSet); @@ -368,7 +527,7 @@ private static boolean inClade(Tree tree, NodeRef node, Set targetSet, boolean i if (includeStem){ return true; } else { - Set parentLeafSet = TreeUtils.getDescendantLeaves(tree, tree.getParent(node)); + Set parentLeafSet = Tree.Utils.getDescendantLeaves(tree, tree.getParent(node)); parentLeafSet.removeAll(targetSet); if (parentLeafSet.size() == 0){ return true; @@ -382,12 +541,12 @@ private static boolean inClade(Tree tree, NodeRef node, Set targetSet, boolean i } private static boolean isMRCAnode(Tree tree, NodeRef node, Set targetSet) { - NodeRef mrca = TreeUtils.getCommonAncestorNode(tree, targetSet); + NodeRef mrca = Tree.Utils.getCommonAncestorNode(tree, targetSet); if (node.equals(mrca)){ return true; } else { return false; - } + } } private boolean branchInfo; @@ -397,9 +556,18 @@ private static boolean isMRCAnode(Tree tree, NodeRef node, Set targetSet) { private double mrsd; private boolean cladeStem; private boolean excludeCladeStems; + private String discreteTrait; + private boolean timeConstraints; private static PrintStream progressStream = System.err; private PrintStream resultsStream; + public String prefixTotalcN; + public String prefixTotalcS; + public String prefixTotaluN; + public String prefixTotaluS; + public String prefixHistoryN; + public String prefixHistoryS; + enum BranchSet { ALL, INT, @@ -501,6 +669,15 @@ public static void main(String[] args) throws IOException, TraceException { new Arguments.StringOption(EXCLUDECLADES, "clade exclusion files", "specifies files with taxa that define clades to be excluded"), new Arguments.IntegerOption(SITESUM, "the number of nucleotide sites to summarize rates in per site per time unit [default = 1]"), new Arguments.StringOption(CODONSITELIST, "list of sites", "sites for which the summary is restricted to"), + new Arguments.StringOption(PREFIX, "annotation prefix", "specifies a prefix that is used for the annotations (e.g. to distinguish partition-specific annotations"), + + new Arguments.RealOption(HEIGHT_ABOVE, "specifies a boundary above which the height must be to be included in the summary [default=none]"), + new Arguments.RealOption(HEIGHT_BELOW, "specifies a boundary below which the height must be to be included in the summary [default=none]"), + new Arguments.RealOption(TIME_BEFORE, "specifies a boundary before which the time must be to be included in the summary [default=none]"), + new Arguments.RealOption(TIME_AFTER, "specifies a boundary after which the time must be to be included in the summary [default=none]"), + + new Arguments.StringOption(DISCRETE_TRAIT_ATTRIBUTE, "discrete trait attribute", "specifies the string used as attribute for the discrete trait state"), + new Arguments.StringOption(DISCRETE_TRAIT_STATE_SET, "discrete trait state set", "specifies which discrete trait states that branches need to maintain to be considered for the summary"), new Arguments.Option("help", "option to print this message") }); @@ -539,9 +716,9 @@ public static void main(String[] args) throws IOException, TraceException { if (set == set.BACKBONE) { if (arguments.hasOption(BACKBONETAXA)) { String[] fileList = parseVariableLengthStringArray(arguments.getStringOption(BACKBONETAXA)); - for (String singleSet: fileList){ + for (String singleSet : fileList) { inclusionSets.add(getTargetSet(singleSet)); - progressStream.println("getting target set for backbone inclusion: "+singleSet); + progressStream.println("getting target set for backbone inclusion: " + singleSet); } } else { progressStream.println("you want to get summaries for (a) backbone(s), but no files with taxa to define it are provided??"); @@ -556,9 +733,9 @@ public static void main(String[] args) throws IOException, TraceException { if (arguments.hasOption(INCLUDECLADES)) { String[] fileList = parseVariableLengthStringArray(arguments.getStringOption(INCLUDECLADES)); - for (String singleSet: fileList){ + for (String singleSet : fileList) { inclusionSets.add(getTargetSet(singleSet)); - progressStream.println("getting target set for clade inclusion: "+singleSet); + progressStream.println("getting target set for clade inclusion: " + singleSet); } } else { progressStream.println("you want to get summaries for one or more clades, but no files with taxa to define it are provided??"); @@ -568,7 +745,7 @@ public static void main(String[] args) throws IOException, TraceException { if (set == set.SINGLEBRANCH) { if (arguments.hasOption(INCLUDECLADES)) { String[] fileList = parseVariableLengthStringArray(arguments.getStringOption(INCLUDECLADES)); - if (fileList.length > 1){ + if (fileList.length > 1) { progressStream.println("more than one clade set is specified for a summary of a single branch??"); System.exit(-1); } else { @@ -619,9 +796,9 @@ public static void main(String[] args) throws IOException, TraceException { List exclusionSets = new ArrayList(); if (arguments.hasOption(EXCLUDECLADES)) { String[] fileList = parseVariableLengthStringArray(arguments.getStringOption(EXCLUDECLADES)); - for (String singleSet: fileList){ + for (String singleSet : fileList) { exclusionSets.add(getTargetSet(singleSet)); - progressStream.println("getting target set for clade exclusion: "+singleSet); + progressStream.println("getting target set for clade exclusion: " + singleSet); } } @@ -631,6 +808,48 @@ public static void main(String[] args) throws IOException, TraceException { progressStream.println("site list provided: note that dN/dS will not be accurately estimated because the neutral expectation to get dN/dS (uN and uS) is for all sites along a branch."); } + String prefix = arguments.getStringOption(PREFIX); + + double heightAbove = 0; + if (arguments.hasOption(HEIGHT_ABOVE)) { + heightAbove = arguments.getRealOption(HEIGHT_ABOVE); + } else if (arguments.hasOption(TIME_BEFORE)) { + if (mrsd > 0){ + heightAbove = mrsd - arguments.getRealOption(TIME_BEFORE); + } else { + System.err.println(TIME_BEFORE +" is specified but no mrsd (>0)??"); + System.exit(-1); + } + } + + double heightBelow = Double.MAX_VALUE; + if (arguments.hasOption(HEIGHT_BELOW)) { + heightBelow = arguments.getRealOption(HEIGHT_BELOW); + } else if (arguments.hasOption(TIME_AFTER)) { + if (mrsd > 0){ + heightBelow = mrsd - arguments.getRealOption(TIME_AFTER); + } else { + System.err.println(TIME_AFTER +" is specified but no mrsd (>0)??"); + System.exit(-1); + } + } + + String discreteTrait = null; + if (arguments.hasOption(DISCRETE_TRAIT_ATTRIBUTE)) { + discreteTrait = arguments.getStringOption(DISCRETE_TRAIT_ATTRIBUTE); + progressStream.println("discrete trait attribute provided is "+discreteTrait); + } + + String[] stateSet = null; + if (arguments.hasOption(DISCRETE_TRAIT_STATE_SET)) { + stateSet = parseVariableLengthStringArray(arguments.getStringOption(DISCRETE_TRAIT_STATE_SET)); + if (discreteTrait==null){ + System.err.println("stateSet provided nut no discrete trait attribute provided??"); + System.exit(-1); + } + progressStream.println("discrete trait state set provided for summary; first state in set = "+stateSet[0]); + } + String inputFileName = null; String outputFileName = null; @@ -661,7 +880,9 @@ public static void main(String[] args) throws IOException, TraceException { burnin = Integer.parseInt(br.readLine()); } - new GetNSCountsFromTrees(burnin, inputFileName, outputFileName, branchInfo, set, inclusionSets, cladeStem, zeroBranches, summary, sites, mrsd, exclusionSets, excludeCladeStems, siteList); + new GetNSCountsFromTrees(burnin, inputFileName, outputFileName, branchInfo, set, inclusionSets, cladeStem, zeroBranches, + summary, sites, mrsd, exclusionSets, excludeCladeStems, + siteList, prefix, heightAbove, heightBelow, discreteTrait, stateSet); System.exit(0); } From 3f3b99845f1fc8615b702a4353a063a83c22d7a7 Mon Sep 17 00:00:00 2001 From: Plemey Date: Mon, 19 Jun 2023 12:13:09 +0200 Subject: [PATCH 50/65] Revert "enhanced summary capacity for GetNSCountsFromTrees" This reverts commit 8d136c08849505f13e455b280ccdf24d0b63516f. --- src/dr/app/tools/GetNSCountsFromTrees.java | 317 ++++----------------- 1 file changed, 48 insertions(+), 269 deletions(-) diff --git a/src/dr/app/tools/GetNSCountsFromTrees.java b/src/dr/app/tools/GetNSCountsFromTrees.java index b745c60f6a..5486216bd0 100644 --- a/src/dr/app/tools/GetNSCountsFromTrees.java +++ b/src/dr/app/tools/GetNSCountsFromTrees.java @@ -8,6 +8,7 @@ import dr.evolution.io.TreeImporter; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; +import dr.evolution.tree.TreeUtils; import dr.inference.trace.TraceException; import dr.util.Version; @@ -19,6 +20,7 @@ */ public class GetNSCountsFromTrees { + //TODO: allow to add parent and descendent node discrete state private final static Version version = new BeastVersion(); public static final String BURNIN = "burnin"; public static final String totalcN = "N"; @@ -31,7 +33,7 @@ public class GetNSCountsFromTrees { public static final String BRANCHINFO = "branchInfo"; public static final String[] falseTrue = {"false", "true"}; public static final String BRANCHSET = "branchSet"; - // public static final String CLADETAXA = "cladeTaxa"; +// public static final String CLADETAXA = "cladeTaxa"; public static final String INCLUDECLADES = "includeClades"; public static final String CLADESTEM = "cladeStem"; public static final String EXCLUDECLADESTEM = "excludeCladeStem"; @@ -42,18 +44,11 @@ public class GetNSCountsFromTrees { public static final String CODONSITELIST = "codonSiteList"; public static final String MRSD = "mrsd"; public static final String EXCLUDECLADES = "excludeClades"; - public static final String PREFIX = "prefix"; - public static final String HEIGHT_ABOVE = "heightAbove"; - public static final String HEIGHT_BELOW = "heightBelow"; - public static final String TIME_BEFORE = "timeBefore"; - public static final String TIME_AFTER = "timeAfter"; - public static final String DISCRETE_TRAIT_STATE_SET = "discreteTraitStateSet"; - public static final String DISCRETE_TRAIT_ATTRIBUTE = "discreteTraitAttribute"; - - public GetNSCountsFromTrees (int burnin, String inputFileName, String outputFileName, boolean branchInfo, - BranchSet branchSet, List inclusionSets, boolean cladeStem, boolean zeroBranches, - boolean summary, int sites, double mrsd, List exclusionSets, boolean excludeCladeStems, - double[] siteList, String prefix, double heightAbove, double heightBelow, String discreteTrait, String[] stateSet) throws IOException { + + public GetNSCountsFromTrees(int burnin, String inputFileName, String outputFileName, boolean branchInfo, + BranchSet branchSet, List inclusionSets, boolean cladeStem, boolean zeroBranches, boolean summary, + int sites, double mrsd, List exclusionSets, boolean excludeCladeStems, + double[] siteList) throws IOException { File inputFile = new File(inputFileName); if (inputFile.isFile()){ @@ -64,42 +59,17 @@ public GetNSCountsFromTrees (int burnin, String inputFileName, String outputFile } this.branchInfo = branchInfo; - // in case you would like to ignore branches without N or S substitutions (can't remember why I thought this could be useful) this.zeroBranches = zeroBranches; this.summary = summary; this.mrsd = mrsd; this.cladeStem = cladeStem; this.excludeCladeStems = excludeCladeStems; - this.discreteTrait = discreteTrait; this.sites = sites; if (siteList!=null){ this.sites = (siteList.length)*3; - progressStream.println("Sites set based on site list provided (dN/dS may not be adequate as uN and uS does not account for site list)"); - } - - if (heightAbove > 0 || heightBelow < Double.MAX_VALUE) { - timeConstraints = true; - } else { - timeConstraints = false; - } - - if (prefix!=null){ - prefixTotalcN = prefix+"."+totalcN; - prefixTotalcS = prefix+"."+totalcS; - prefixTotaluN = prefix+"."+totaluN; - prefixTotaluS = prefix+"."+totaluS; - prefixHistoryN = prefix+"."+historyN; - prefixHistoryS = prefix+"."+historyS; - } else { - prefixTotalcN = totalcN; - prefixTotalcS = totalcS; - prefixTotaluN = totaluN; - prefixTotaluS = totaluS; - prefixHistoryN = historyN; - prefixHistoryS = historyS; + progressStream.println("number of sites set based on site list provided"); } - // resultsStream = System.out; if (outputFileName != null) { @@ -113,12 +83,11 @@ public GetNSCountsFromTrees (int burnin, String inputFileName, String outputFile resultsStream = new PrintStream(new File(inputFileName+".NSout.txt")); } - analyze(inputFile, burnin, branchSet, inclusionSets, exclusionSets, siteList, stateSet, heightAbove, heightBelow); + analyze(inputFile, burnin, branchSet, inclusionSets, exclusionSets, siteList); } - private void analyze(File inputFile, int burnin, BranchSet branchSet, List inclusionSets, List exclusionSets, - double[] siteList, String[] stateSet, double heightAbove, double heightBelow){ + private void analyze(File inputFile, int burnin, BranchSet branchSet, List inclusionSets, List exclusionSets, double[] siteList){ if (summary) { resultsStream.print("tree"+SEP+"cN"+SEP+"uN"+SEP+"cS"+SEP+"uS"+SEP+"cNrate"+SEP+"cSrate"+SEP+"dN/dS"+"\n"); @@ -152,7 +121,7 @@ private void analyze(File inputFile, int burnin, BranchSet branchSet, List Tree tree = importer.importNextTree(); if(count>=burnin) { - getNSCounts(tree, treeUsed, branchSet, inclusionSets, exclusionSets, siteList, stateSet, heightAbove, heightBelow); + getNSCounts(tree, treeUsed, branchSet, inclusionSets, exclusionSets, siteList); treeUsed ++; } count++; @@ -174,8 +143,7 @@ private void analyze(File inputFile, int burnin, BranchSet branchSet, List } } - private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List inclusionSets, List exclusionSets, - double[] siteList, String[] stateSet, double heightAbove, double heightBelow){ + private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List inclusionSets, List exclusionSets, double[] siteList){ int count = 0; double cN = 0; double uN = 0; @@ -187,91 +155,33 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List NodeRef node = tree.getNode(x); if (!tree.isRoot(node)){ count ++; - if (nodeToConsider(tree, node, branchSet, inclusionSets, exclusionSets, stateSet, heightAbove, heightBelow)){ + if (nodeToConsider(tree, node, branchSet, inclusionSets, exclusionSets)){ double branchLength = tree.getBranchLength(node); - if (timeConstraints){ - double nodeHeight = tree.getNodeHeight(node); - double parentNodeHeight = tree.getNodeHeight(tree.getParent(node)); - double upperHeight = parentNodeHeight; - double lowerHeight = nodeHeight; - if (heightBelow < Double.MAX_VALUE) { - if (parentNodeHeight > heightBelow) { - upperHeight = heightBelow; - } - } - if (heightAbove > 0){ - if(nodeHeight < heightAbove) { - lowerHeight = heightAbove; - } - } - if ((upperHeight - lowerHeight)>0){ - branchLength = upperHeight - lowerHeight; - } - } length += branchLength; - - Object totalNObject = tree.getNodeAttribute(node, prefixTotalcN); - Object totalSObject = tree.getNodeAttribute(node, prefixTotalcS); - //in case N and S would not be annotated + Object totalNObject = tree.getNodeAttribute(node, totalcN); + Object totalSObject = tree.getNodeAttribute(node, totalcS); if (totalNObject!=null && totalSObject!=null) { - //System.out.println("Hallo"); - double totalN = (Double) totalNObject; + double totalN = (Double) totalNObject; double totalS = (Double) totalSObject; - double totaluN = (Double) tree.getNodeAttribute(node, prefixTotaluN); - double totaluS = (Double) tree.getNodeAttribute(node, prefixTotaluS); - if(siteList==null && !timeConstraints){ + double totaluNObject = (Double) tree.getNodeAttribute(node, totaluN); + double totaluSObject = (Double) tree.getNodeAttribute(node, totaluS); + if(siteList==null){ cN += totalN; cS += totalS; } - uN += totaluN; - uS += totaluS; + uN += totaluNObject; + uS += totaluSObject; if (totalN > 0) { - Object[] allNObject = (Object[]) tree.getNodeAttribute(node, prefixHistoryN); + Object[] allNObject = (Object[]) tree.getNodeAttribute(node, historyN); for (int a = 0; a < allNObject.length; a++) { Object[] singleNObject = (Object[]) allNObject[a]; boolean proceedAgain = false; if(siteList==null) { - if (!timeConstraints){ - proceedAgain = true; - } else { - boolean timeCompatible = true; - if (heightBelow < Double.MAX_VALUE){ - if((Double) singleNObject[1] > heightBelow){ - timeCompatible = false; - } - } - if (heightAbove > 0){ - if((Double) singleNObject[1] < heightAbove) { - timeCompatible = false; - } - } - if (timeCompatible){ - proceedAgain = true; - cN ++; - } - } + proceedAgain = true; } else { if (inSiteList((Integer)singleNObject[0], siteList)) { - if (!timeConstraints){ - proceedAgain = true; - cN ++; - } else { - boolean timeCompatible = true; - if (heightBelow < Double.MAX_VALUE){ - if((Double) singleNObject[1] > heightBelow){ - timeCompatible = false; - } - } - if (heightAbove > 0){ - if((Double) singleNObject[1] < heightAbove) { - timeCompatible = false; - } - } - if (timeCompatible){ - proceedAgain = true; - cN ++; - } - } + proceedAgain = true; + cN ++; } } if(!summary && proceedAgain){ @@ -284,7 +194,7 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List } resultsStream.print(singleNObject[2] + SEP + singleNObject[3] + SEP); if (branchInfo) { - resultsStream.print(branchLength + SEP + totalN + SEP + totaluN + "\n"); + resultsStream.print(branchLength + SEP + totalN + SEP + totaluNObject + "\n"); } else { resultsStream.print("\n"); } @@ -293,52 +203,16 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List } } if (totalS > 0) { - Object[] allSObject = (Object[]) tree.getNodeAttribute(node, prefixHistoryS); + Object[] allSObject = (Object[]) tree.getNodeAttribute(node, historyS); for (int a = 0; a < allSObject.length; a++) { Object[] singleSObject = (Object[]) allSObject[a]; boolean proceedAgain = false; if(siteList==null) { - if (!timeConstraints){ - proceedAgain = true; - } else { - boolean timeCompatible = true; - if (heightBelow < Double.MAX_VALUE){ - if((Double) singleSObject[1] > heightBelow){ - timeCompatible = false; - } - } - if (heightAbove > 0){ - if((Double) singleSObject[1] < heightAbove) { - timeCompatible = false; - } - } - if (timeCompatible){ - proceedAgain = true; - cS ++; - } - } + proceedAgain = true; } else { if (inSiteList((Integer)singleSObject[0], siteList)) { - if (!timeConstraints){ - proceedAgain = true; - cS ++; - } else { - boolean timeCompatible = true; - if (heightBelow < Double.MAX_VALUE){ - if((Double) singleSObject[1] > heightBelow){ - timeCompatible = false; - } - } - if (heightAbove > 0){ - if((Double) singleSObject[1] < heightAbove) { - timeCompatible = false; - } - } - if (timeCompatible){ - proceedAgain = true; - cS ++; - } - } + proceedAgain = true; + cS ++; } } if(!summary && proceedAgain){ @@ -351,7 +225,7 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List } resultsStream.print(singleSObject[2] + SEP + singleSObject[3] + SEP); if (branchInfo) { - resultsStream.print(branchLength + SEP + totalS + SEP + totaluS + "\n"); + resultsStream.print(branchLength + SEP + totalS + SEP + totaluSObject + "\n"); } else { resultsStream.print("\n"); } @@ -359,7 +233,6 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List } } - //in case you do not wish to print out zero branch lengths if ((totalN+totalS)==0){ if (zeroBranches) { if (!summary){ @@ -376,9 +249,7 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List length -= branchLength; } } - } else { - System.err.println("No N or S annotations?"); - System.exit(-1); + //tree,branch,S/N,position,oriState,destState,time,branchLength,totalBranchcN,totalcS,totaluN,totaluS,summaryTine } } } @@ -389,8 +260,7 @@ private void getNSCounts(Tree tree, int treeUsed, BranchSet branchSet, List } } - private boolean nodeToConsider(Tree tree, NodeRef node, BranchSet branchSet, List inclusionSets, List exclusionSets, - String[] stateSet, double heightAbove, double heightBelow){ + private boolean nodeToConsider(Tree tree, NodeRef node, BranchSet branchSet, List inclusionSets, List exclusionSets){ boolean nodeToConsider = false; if (branchSet == BranchSet.ALL) { nodeToConsider = true; @@ -427,35 +297,6 @@ private boolean nodeToConsider(Tree tree, NodeRef node, BranchSet branchSet, Lis } } } - - // making sure it is a branch that maintains a specified state = node and parent state are a specified state. - if(nodeToConsider && stateSet!=null) { - nodeToConsider = false; - String nodeState = ((String)tree.getNodeAttribute(node, discreteTrait)).replaceAll("\"",""); - String parentNodeState = ((String)tree.getNodeAttribute(tree.getParent(node), discreteTrait)).replaceAll("\"","");; - if (nodeState.equals(parentNodeState)){ -// System.out.println(nodeState); - for (String state: stateSet){ - if (state.equalsIgnoreCase(nodeState)){ - nodeToConsider = true; - break; - } - } - } - } - - //partially accounting for specified time constraints - if (nodeToConsider && timeConstraints) { - double nodeHeight = tree.getNodeHeight(node); - double parentNodeHeight = tree.getNodeHeight(tree.getParent(node)); - if(parentNodeHeight < heightAbove){ - nodeToConsider = false; - } - if(nodeHeight > heightBelow){ - nodeToConsider = false; - } - } - return nodeToConsider; } @@ -496,7 +337,7 @@ private static boolean onBackbone(Tree tree, NodeRef node, Set targetSet) { if (tree.isExternal(node)) return false; - Set leafSet = Tree.Utils.getDescendantLeaves(tree, node); + Set leafSet = TreeUtils.getDescendantLeaves(tree, node); int size = leafSet.size(); leafSet.retainAll(targetSet); @@ -506,7 +347,7 @@ private static boolean onBackbone(Tree tree, NodeRef node, Set targetSet) { // if all leaves below are in target then check just above. if (leafSet.size() == size) { - Set superLeafSet = Tree.Utils.getDescendantLeaves(tree, tree.getParent(node)); + Set superLeafSet = TreeUtils.getDescendantLeaves(tree, tree.getParent(node)); superLeafSet.removeAll(targetSet); // the branch is on ancestral path if the super tree has some non-targets in it @@ -518,7 +359,7 @@ private static boolean onBackbone(Tree tree, NodeRef node, Set targetSet) { } private static boolean inClade(Tree tree, NodeRef node, Set targetSet, boolean includeStem) { - Set leafSet = Tree.Utils.getDescendantLeaves(tree, node); + Set leafSet = TreeUtils.getDescendantLeaves(tree, node); leafSet.removeAll(targetSet); @@ -527,7 +368,7 @@ private static boolean inClade(Tree tree, NodeRef node, Set targetSet, boolean i if (includeStem){ return true; } else { - Set parentLeafSet = Tree.Utils.getDescendantLeaves(tree, tree.getParent(node)); + Set parentLeafSet = TreeUtils.getDescendantLeaves(tree, tree.getParent(node)); parentLeafSet.removeAll(targetSet); if (parentLeafSet.size() == 0){ return true; @@ -541,12 +382,12 @@ private static boolean inClade(Tree tree, NodeRef node, Set targetSet, boolean i } private static boolean isMRCAnode(Tree tree, NodeRef node, Set targetSet) { - NodeRef mrca = Tree.Utils.getCommonAncestorNode(tree, targetSet); + NodeRef mrca = TreeUtils.getCommonAncestorNode(tree, targetSet); if (node.equals(mrca)){ return true; } else { return false; - } + } } private boolean branchInfo; @@ -556,18 +397,9 @@ private static boolean isMRCAnode(Tree tree, NodeRef node, Set targetSet) { private double mrsd; private boolean cladeStem; private boolean excludeCladeStems; - private String discreteTrait; - private boolean timeConstraints; private static PrintStream progressStream = System.err; private PrintStream resultsStream; - public String prefixTotalcN; - public String prefixTotalcS; - public String prefixTotaluN; - public String prefixTotaluS; - public String prefixHistoryN; - public String prefixHistoryS; - enum BranchSet { ALL, INT, @@ -669,15 +501,6 @@ public static void main(String[] args) throws IOException, TraceException { new Arguments.StringOption(EXCLUDECLADES, "clade exclusion files", "specifies files with taxa that define clades to be excluded"), new Arguments.IntegerOption(SITESUM, "the number of nucleotide sites to summarize rates in per site per time unit [default = 1]"), new Arguments.StringOption(CODONSITELIST, "list of sites", "sites for which the summary is restricted to"), - new Arguments.StringOption(PREFIX, "annotation prefix", "specifies a prefix that is used for the annotations (e.g. to distinguish partition-specific annotations"), - - new Arguments.RealOption(HEIGHT_ABOVE, "specifies a boundary above which the height must be to be included in the summary [default=none]"), - new Arguments.RealOption(HEIGHT_BELOW, "specifies a boundary below which the height must be to be included in the summary [default=none]"), - new Arguments.RealOption(TIME_BEFORE, "specifies a boundary before which the time must be to be included in the summary [default=none]"), - new Arguments.RealOption(TIME_AFTER, "specifies a boundary after which the time must be to be included in the summary [default=none]"), - - new Arguments.StringOption(DISCRETE_TRAIT_ATTRIBUTE, "discrete trait attribute", "specifies the string used as attribute for the discrete trait state"), - new Arguments.StringOption(DISCRETE_TRAIT_STATE_SET, "discrete trait state set", "specifies which discrete trait states that branches need to maintain to be considered for the summary"), new Arguments.Option("help", "option to print this message") }); @@ -716,9 +539,9 @@ public static void main(String[] args) throws IOException, TraceException { if (set == set.BACKBONE) { if (arguments.hasOption(BACKBONETAXA)) { String[] fileList = parseVariableLengthStringArray(arguments.getStringOption(BACKBONETAXA)); - for (String singleSet : fileList) { + for (String singleSet: fileList){ inclusionSets.add(getTargetSet(singleSet)); - progressStream.println("getting target set for backbone inclusion: " + singleSet); + progressStream.println("getting target set for backbone inclusion: "+singleSet); } } else { progressStream.println("you want to get summaries for (a) backbone(s), but no files with taxa to define it are provided??"); @@ -733,9 +556,9 @@ public static void main(String[] args) throws IOException, TraceException { if (arguments.hasOption(INCLUDECLADES)) { String[] fileList = parseVariableLengthStringArray(arguments.getStringOption(INCLUDECLADES)); - for (String singleSet : fileList) { + for (String singleSet: fileList){ inclusionSets.add(getTargetSet(singleSet)); - progressStream.println("getting target set for clade inclusion: " + singleSet); + progressStream.println("getting target set for clade inclusion: "+singleSet); } } else { progressStream.println("you want to get summaries for one or more clades, but no files with taxa to define it are provided??"); @@ -745,7 +568,7 @@ public static void main(String[] args) throws IOException, TraceException { if (set == set.SINGLEBRANCH) { if (arguments.hasOption(INCLUDECLADES)) { String[] fileList = parseVariableLengthStringArray(arguments.getStringOption(INCLUDECLADES)); - if (fileList.length > 1) { + if (fileList.length > 1){ progressStream.println("more than one clade set is specified for a summary of a single branch??"); System.exit(-1); } else { @@ -796,9 +619,9 @@ public static void main(String[] args) throws IOException, TraceException { List exclusionSets = new ArrayList(); if (arguments.hasOption(EXCLUDECLADES)) { String[] fileList = parseVariableLengthStringArray(arguments.getStringOption(EXCLUDECLADES)); - for (String singleSet : fileList) { + for (String singleSet: fileList){ exclusionSets.add(getTargetSet(singleSet)); - progressStream.println("getting target set for clade exclusion: " + singleSet); + progressStream.println("getting target set for clade exclusion: "+singleSet); } } @@ -808,48 +631,6 @@ public static void main(String[] args) throws IOException, TraceException { progressStream.println("site list provided: note that dN/dS will not be accurately estimated because the neutral expectation to get dN/dS (uN and uS) is for all sites along a branch."); } - String prefix = arguments.getStringOption(PREFIX); - - double heightAbove = 0; - if (arguments.hasOption(HEIGHT_ABOVE)) { - heightAbove = arguments.getRealOption(HEIGHT_ABOVE); - } else if (arguments.hasOption(TIME_BEFORE)) { - if (mrsd > 0){ - heightAbove = mrsd - arguments.getRealOption(TIME_BEFORE); - } else { - System.err.println(TIME_BEFORE +" is specified but no mrsd (>0)??"); - System.exit(-1); - } - } - - double heightBelow = Double.MAX_VALUE; - if (arguments.hasOption(HEIGHT_BELOW)) { - heightBelow = arguments.getRealOption(HEIGHT_BELOW); - } else if (arguments.hasOption(TIME_AFTER)) { - if (mrsd > 0){ - heightBelow = mrsd - arguments.getRealOption(TIME_AFTER); - } else { - System.err.println(TIME_AFTER +" is specified but no mrsd (>0)??"); - System.exit(-1); - } - } - - String discreteTrait = null; - if (arguments.hasOption(DISCRETE_TRAIT_ATTRIBUTE)) { - discreteTrait = arguments.getStringOption(DISCRETE_TRAIT_ATTRIBUTE); - progressStream.println("discrete trait attribute provided is "+discreteTrait); - } - - String[] stateSet = null; - if (arguments.hasOption(DISCRETE_TRAIT_STATE_SET)) { - stateSet = parseVariableLengthStringArray(arguments.getStringOption(DISCRETE_TRAIT_STATE_SET)); - if (discreteTrait==null){ - System.err.println("stateSet provided nut no discrete trait attribute provided??"); - System.exit(-1); - } - progressStream.println("discrete trait state set provided for summary; first state in set = "+stateSet[0]); - } - String inputFileName = null; String outputFileName = null; @@ -880,9 +661,7 @@ public static void main(String[] args) throws IOException, TraceException { burnin = Integer.parseInt(br.readLine()); } - new GetNSCountsFromTrees(burnin, inputFileName, outputFileName, branchInfo, set, inclusionSets, cladeStem, zeroBranches, - summary, sites, mrsd, exclusionSets, excludeCladeStems, - siteList, prefix, heightAbove, heightBelow, discreteTrait, stateSet); + new GetNSCountsFromTrees(burnin, inputFileName, outputFileName, branchInfo, set, inclusionSets, cladeStem, zeroBranches, summary, sites, mrsd, exclusionSets, excludeCladeStems, siteList); System.exit(0); } From 4859ecce1cfd9267900dc68827438d6a851245d8 Mon Sep 17 00:00:00 2001 From: rambaut Date: Tue, 20 Jun 2023 14:23:33 +0100 Subject: [PATCH 51/65] Allow stem transition point in absolute time units --- .../branchratemodel/LocalClockModel.java | 34 ++++++++++++++----- .../LocalClockModelParser.java | 27 ++++++++++++--- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/src/dr/evomodel/branchratemodel/LocalClockModel.java b/src/dr/evomodel/branchratemodel/LocalClockModel.java index 1cd1856621..7029ea6937 100644 --- a/src/dr/evomodel/branchratemodel/LocalClockModel.java +++ b/src/dr/evomodel/branchratemodel/LocalClockModel.java @@ -113,10 +113,10 @@ public void addExternalBranchClock(TaxonList taxonList, BranchRateModel branchRa addModel(branchRates); } - public void addCladeClock(TaxonList taxonList, Parameter rateParameter, boolean isRelativeRate, Parameter stemParameter, boolean excludeClade) throws TreeUtils.MissingTaxonException { + public void addCladeClock(TaxonList taxonList, Parameter rateParameter, boolean isRelativeRate, Parameter stemParameter, boolean stemAsTime, boolean excludeClade) throws TreeUtils.MissingTaxonException { Set tips = TreeUtils.getTipsForTaxa(treeModel, taxonList); BitSet tipBitSet = TreeUtils.getTipsBitSetForTaxa(treeModel, taxonList); - LocalClock clock = new LocalClock(rateParameter, isRelativeRate, tips, stemParameter, excludeClade); + LocalClock clock = new LocalClock(rateParameter, isRelativeRate, tips, stemParameter, stemAsTime, excludeClade); localCladeClocks.put(tipBitSet, clock); addVariable(rateParameter); if (stemParameter != null) { @@ -124,10 +124,10 @@ public void addCladeClock(TaxonList taxonList, Parameter rateParameter, boolean } } - public void addCladeClock(TaxonList taxonList, BranchRateModel branchRates, boolean isRelativeRate, Parameter stemParameter, boolean excludeClade) throws TreeUtils.MissingTaxonException { + public void addCladeClock(TaxonList taxonList, BranchRateModel branchRates, boolean isRelativeRate, Parameter stemParameter, boolean stemAsTime, boolean excludeClade) throws TreeUtils.MissingTaxonException { Set tips = TreeUtils.getTipsForTaxa(treeModel, taxonList); BitSet tipBitSet = TreeUtils.getTipsBitSetForTaxa(treeModel, taxonList); - LocalClock clock = new LocalClock(branchRates, isRelativeRate, tips, stemParameter, excludeClade); + LocalClock clock = new LocalClock(branchRates, isRelativeRate, tips, stemParameter, stemAsTime, excludeClade); localCladeClocks.put(tipBitSet, clock); addModel(branchRates); if (stemParameter != null) { @@ -265,7 +265,18 @@ public double getBranchRate(final Tree tree, final NodeRef node) { parentRate = localClock.getBranchRate(tree, tree.getParent(node)); } } - stemProportion = localClock.getStemProportion(); + if (localClock.stemAsTime) { + // this could be greater than 1 in which case bad things might happen + stemProportion = Math.min(localClock.getStemValue() / tree.getBranchLength(node), 1.0); + +// stemProportion = localClock.getStemValue() / tree.getBranchLength(node); +// if (stemProportion > 1.0) { +// // it should be ensured that this never happens. +// throw new IllegalArgumentException("A stem proportion for a local clock is > 1.0"); +// } + } else { + stemProportion = localClock.getStemValue(); + } } if (localClock.isRelativeRate()) { @@ -396,6 +407,7 @@ private class LocalClock { this.tipList = null; this.type = type; this.stemParameter = null; + this.stemAsTime = false; this.excludeClade = true; } @@ -408,10 +420,11 @@ private class LocalClock { this.tipList = null; this.type = type; this.stemParameter = null; + this.stemAsTime = false; this.excludeClade = true; } - LocalClock(Parameter rateParameter, boolean isRelativeRate, Set tips, Parameter stemParameter, boolean excludeClade) { + LocalClock(Parameter rateParameter, boolean isRelativeRate, Set tips, Parameter stemParameter, boolean stemAsTime, boolean excludeClade) { this.rateParameter = rateParameter; this.branchRates = null; this.indexParameter = null; @@ -420,10 +433,11 @@ private class LocalClock { this.tipList = null; this.type = ClockType.CLADE; this.stemParameter = stemParameter; + this.stemAsTime = stemAsTime; this.excludeClade = excludeClade; } - LocalClock(BranchRateModel branchRates, boolean isRelativeRate, Set tips, Parameter stemParameter, boolean excludeClade) { + LocalClock(BranchRateModel branchRates, boolean isRelativeRate, Set tips, Parameter stemParameter, boolean stemAsTime, boolean excludeClade) { this.rateParameter = null; this.branchRates = branchRates; this.indexParameter = null; @@ -432,6 +446,7 @@ private class LocalClock { this.tipList = null; this.type = ClockType.CLADE; this.stemParameter = stemParameter; + this.stemAsTime = stemAsTime; this.excludeClade = excludeClade; } @@ -444,6 +459,7 @@ private class LocalClock { this.tipList = tipList; this.type = type; this.stemParameter = null; + this.stemAsTime = false; this.excludeClade = true; } @@ -456,10 +472,11 @@ private class LocalClock { this.tipList = tipList; this.type = type; this.stemParameter = null; + this.stemAsTime = false; this.excludeClade = true; } - double getStemProportion() { + double getStemValue() { if (stemParameter != null) { return stemParameter.getParameterValue(0); } else { @@ -496,6 +513,7 @@ boolean isRelativeRate() { private final List tipList; private final ClockType type; private final Parameter stemParameter; + private final boolean stemAsTime; private final boolean excludeClade; } diff --git a/src/dr/evomodelxml/branchratemodel/LocalClockModelParser.java b/src/dr/evomodelxml/branchratemodel/LocalClockModelParser.java index 06bda961b3..78794827b9 100644 --- a/src/dr/evomodelxml/branchratemodel/LocalClockModelParser.java +++ b/src/dr/evomodelxml/branchratemodel/LocalClockModelParser.java @@ -46,6 +46,7 @@ public class LocalClockModelParser extends AbstractXMLObjectParser { public static final String CLADE = "clade"; public static final String INCLUDE_STEM = "includeStem"; public static final String STEM_PROPORTION = "stemProportion"; + public static final String STEM_TIME = "stemTime"; public static final String EXCLUDE_CLADE = "excludeClade"; public static final String EXTERNAL_BRANCHES = "externalBranches"; public static final String TRUNK = "trunk"; @@ -94,7 +95,8 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean excludeClade = false; Parameter stemParameter = null; - + boolean stemAsTime = false; + if (xoc.hasAttribute(INCLUDE_STEM)) { // if includeStem=true then assume it is the whole stem stemParameter = new Parameter.Default(xoc.getBooleanAttribute(INCLUDE_STEM) ? 1.0 : 0.0); @@ -107,6 +109,14 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } stemParameter = new Parameter.Default(stemValue); } + if (xoc.hasAttribute(STEM_TIME)) { + double stemValue = xoc.getDoubleAttribute(STEM_TIME); + if (stemValue < 0.0) { + throw new XMLParseException("A stem time should be >= 0"); + } + stemParameter = new Parameter.Default(stemValue); + stemAsTime = true; + } if (xoc.hasChildNamed(STEM_PROPORTION)) { stemParameter = (Parameter) xoc.getElementFirstChild(STEM_PROPORTION); @@ -114,6 +124,13 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { throw new XMLParseException("A stem proportion should be between 0, 1"); } } + if (xoc.hasChildNamed(STEM_TIME)) { + stemParameter = (Parameter) xoc.getElementFirstChild(STEM_TIME); + if (stemParameter.getParameterValue(0) < 0.0) { + throw new XMLParseException("A stem time should be >= 0"); + } + stemAsTime = true; + } if (xoc.hasAttribute(EXCLUDE_CLADE)) { excludeClade = xoc.getBooleanAttribute(EXCLUDE_CLADE); @@ -121,9 +138,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { try { if (branchRates != null) { - localClockModel.addCladeClock(taxonList, branchRates, relative, stemParameter, excludeClade); + localClockModel.addCladeClock(taxonList, branchRates, relative, stemParameter, stemAsTime, excludeClade); } else { - localClockModel.addCladeClock(taxonList, rateParameter, relative, stemParameter, excludeClade); + localClockModel.addCladeClock(taxonList, rateParameter, relative, stemParameter, stemAsTime, excludeClade); } } catch (TreeUtils.MissingTaxonException mte) { @@ -216,7 +233,9 @@ public XMLSyntaxRule[] getSyntaxRules() { new XMLSyntaxRule[]{ AttributeRule.newBooleanRule(INCLUDE_STEM, true, "determines whether or not the stem branch above this clade is included in the siteModel (default false)."), AttributeRule.newDoubleRule(STEM_PROPORTION, true, "proportion of stem to include in clade rate (default 0)."), - new ElementRule(STEM_PROPORTION, Parameter.class, "A parameter for the proportion of stem to include in clade rate (0 - 1)", false) + new ElementRule(STEM_PROPORTION, Parameter.class, "A parameter for the proportion of stem to include in clade rate (0 - 1)", false), + AttributeRule.newDoubleRule(STEM_TIME, true, "time within the stem to include in clade rate (default 0)."), + new ElementRule(STEM_TIME, Parameter.class, "A parameter for the time of stem to include in clade rate", false) }, true ), From 0d45ef43e04c28ae9d6d62aafe08c3ab2d301dc4 Mon Sep 17 00:00:00 2001 From: ghassler Date: Thu, 20 Jul 2023 10:54:20 -0700 Subject: [PATCH 52/65] save ellapsed time to log file --- .../app/beast/development_parsers.properties | 3 ++ src/dr/inference/loggers/TimeLogger.java | 32 +++++++++++++++++ .../loggers/TimeLoggerXMLParser.java | 34 +++++++++++++++++++ 3 files changed, 69 insertions(+) create mode 100644 src/dr/inference/loggers/TimeLogger.java create mode 100644 src/dr/inferencexml/loggers/TimeLoggerXMLParser.java diff --git a/src/dr/app/beast/development_parsers.properties b/src/dr/app/beast/development_parsers.properties index f0c17bfb5b..eeb4fe2bf3 100644 --- a/src/dr/app/beast/development_parsers.properties +++ b/src/dr/app/beast/development_parsers.properties @@ -323,3 +323,6 @@ dr.evomodelxml.coalescent.smooth.SmoothSkygridLikelihoodParser dr.evomodelxml.substmodel.AminoAcidMixtureParser dr.evomodelxml.treedatalikelihood.SequenceDistanceStatisticParser dr.evomodelxml.tree.TreeReportParser + +# Misc +dr.inferencexml.loggers.TimeLoggerXMLParser diff --git a/src/dr/inference/loggers/TimeLogger.java b/src/dr/inference/loggers/TimeLogger.java new file mode 100644 index 0000000000..0eaa3dc70f --- /dev/null +++ b/src/dr/inference/loggers/TimeLogger.java @@ -0,0 +1,32 @@ +package dr.inference.loggers; + +import dr.util.Timer; + +public class TimeLogger implements Loggable { + + private final Timer timer; + private Boolean hasStarted = false; + + public TimeLogger() { + this.timer = new Timer(); + } + + + @Override + public LogColumn[] getColumns() { + return new LogColumn[]{ + new LogColumn.Abstract("secondsElapsed") { + + @Override + protected String getFormattedValue() { + if (!hasStarted) { + hasStarted = true; + timer.start(); + } + + return Double.toString(timer.toSeconds()); + } + } + }; + } +} diff --git a/src/dr/inferencexml/loggers/TimeLoggerXMLParser.java b/src/dr/inferencexml/loggers/TimeLoggerXMLParser.java new file mode 100644 index 0000000000..43ca1d5564 --- /dev/null +++ b/src/dr/inferencexml/loggers/TimeLoggerXMLParser.java @@ -0,0 +1,34 @@ +package dr.inferencexml.loggers; + +import dr.inference.loggers.TimeLogger; +import dr.xml.AbstractXMLObjectParser; +import dr.xml.XMLObject; +import dr.xml.XMLParseException; +import dr.xml.XMLSyntaxRule; + +public class TimeLoggerXMLParser extends AbstractXMLObjectParser { + @Override + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + return new TimeLogger(); + } + + @Override + public XMLSyntaxRule[] getSyntaxRules() { + return new XMLSyntaxRule[0]; + } + + @Override + public String getParserDescription() { + return "saves elapsed seconds to log file"; + } + + @Override + public Class getReturnType() { + return TimeLogger.class; + } + + @Override + public String getParserName() { + return "timeLogger"; + } +} From 31bd16ff3548954d1c07934787b8e28909dac4d5 Mon Sep 17 00:00:00 2001 From: rambaut Date: Tue, 8 Aug 2023 17:27:12 +0100 Subject: [PATCH 53/65] Added a BranchLengthStatistic which extracts the length of the branch above an MRCA --- src/dr/app/beast/release_parsers.properties | 1 + .../evomodel/tree/BranchLengthStatistic.java | 83 ++++++++++++++ .../tree/BranchLengthStatisticParser.java | 102 ++++++++++++++++++ 3 files changed, 186 insertions(+) create mode 100644 src/dr/evomodel/tree/BranchLengthStatistic.java create mode 100644 src/dr/evomodelxml/tree/BranchLengthStatisticParser.java diff --git a/src/dr/app/beast/release_parsers.properties b/src/dr/app/beast/release_parsers.properties index ba5a7ceeb3..cef4cdd2f0 100644 --- a/src/dr/app/beast/release_parsers.properties +++ b/src/dr/app/beast/release_parsers.properties @@ -287,6 +287,7 @@ dr.evomodelxml.tree.TreeLengthStatisticParser dr.evomodelxml.tree.NodeHeightsStatisticParser dr.evomodelxml.tree.TreeShapeStatisticParser dr.evomodelxml.tree.TMRCAStatisticParser +dr.evomodelxml.tree.BranchLengthStatisticParser dr.evomodelxml.tree.AgeStatisticParser dr.evomodelxml.tree.MRCATraitStatisticParser dr.evomodelxml.tree.AncestralTraitParser diff --git a/src/dr/evomodel/tree/BranchLengthStatistic.java b/src/dr/evomodel/tree/BranchLengthStatistic.java new file mode 100644 index 0000000000..74c6fdfe14 --- /dev/null +++ b/src/dr/evomodel/tree/BranchLengthStatistic.java @@ -0,0 +1,83 @@ +/* + * TMRCAStatistic.java + * + * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodel.tree; + +import dr.evolution.tree.NodeRef; +import dr.evolution.tree.Tree; +import dr.evolution.tree.TreeUtils; +import dr.evolution.util.Taxon; +import dr.evolution.util.TaxonList; + +import java.util.Set; + +/** + * A statistic that extracts the length of the stem branch of a set of taxa + * + * @author Andrew Rambaut + * @version $Id:$ + */ +public class BranchLengthStatistic extends TreeStatistic { + + public BranchLengthStatistic(String name, Tree tree, TaxonList taxa) + throws TreeUtils.MissingTaxonException { + super(name); + this.tree = tree; + if (taxa != null) { + this.leafSet = TreeUtils.getLeavesForTaxa(tree, taxa); + } else { + throw new IllegalArgumentException("taxa cannot be null"); + } + } + + public void setTree(Tree tree) { + this.tree = tree; + } + + public Tree getTree() { + return tree; + } + + public int getDimension() { + return 1; + } + + /** + * @return the height of the MRCA node. + */ + public double getStatisticValue(int dim) { + NodeRef node = TreeUtils.getCommonAncestorNode(tree, leafSet); + + if (node == null) { + throw new RuntimeException("No node found that is MRCA of " + leafSet); + } + + return tree.getBranchLength(node); + } + + private Tree tree = null; + private Set leafSet = null; + +} diff --git a/src/dr/evomodelxml/tree/BranchLengthStatisticParser.java b/src/dr/evomodelxml/tree/BranchLengthStatisticParser.java new file mode 100644 index 0000000000..81032d5ab1 --- /dev/null +++ b/src/dr/evomodelxml/tree/BranchLengthStatisticParser.java @@ -0,0 +1,102 @@ +/* + * TMRCAStatisticParser.java + * + * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard + * + * This file is part of BEAST. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership and licensing. + * + * BEAST is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as + * published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * BEAST is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with BEAST; if not, write to the + * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, + * Boston, MA 02110-1301 USA + */ + +package dr.evomodelxml.tree; + +import dr.evolution.tree.Tree; +import dr.evolution.tree.TreeUtils; +import dr.evolution.util.Taxa; +import dr.evolution.util.TaxonList; +import dr.evomodel.tree.BranchLengthStatistic; +import dr.evomodel.tree.TMRCAStatistic; +import dr.inference.model.Statistic; +import dr.xml.*; + +/** + * + * To get the length of a stem branch of an MRCA: + * + * + * + * + * + * + * + * @author Alexei Drummond + * @author Andrew Rambaut + */ +public class BranchLengthStatisticParser extends AbstractXMLObjectParser { + + public static final String BRANCH_LENGTH_STATISTIC = "branchLengthStatistic"; + public static final String MRCA = "mrca"; + + public String getParserName() { + return BRANCH_LENGTH_STATISTIC; + } + + public Object parseXMLObject(XMLObject xo) throws XMLParseException { + + String name = xo.getAttribute(Statistic.NAME, xo.getId()); + Tree tree = (Tree) xo.getChild(Tree.class); + TaxonList taxa = null; + + if (xo.hasChildNamed(MRCA)) { + taxa = (TaxonList) xo.getElementFirstChild(MRCA); + } + + try { + return new BranchLengthStatistic(name, tree, taxa); + } catch (TreeUtils.MissingTaxonException mte) { + throw new XMLParseException( + "Taxon, " + mte + ", in " + getParserName() + "was not found in the tree."); + } + } + + //************************************************************************ + // AbstractXMLObjectParser implementation + //************************************************************************ + + public String getParserDescription() { + return "A statistic that has as its value the height of the most recent common ancestor " + + "of a set of taxa in a given tree. "; + } + + public Class getReturnType() { + return BranchLengthStatistic.class; + } + + public XMLSyntaxRule[] getSyntaxRules() { + return rules; + } + + private final XMLSyntaxRule[] rules = { + new ElementRule(Tree.class), + new StringAttributeRule("name", + "A name for this statistic primarily for the purposes of logging", true), + new ElementRule(MRCA, + new XMLSyntaxRule[]{new ElementRule(Taxa.class)}, true) + }; + +} From 2f72576c05c8ff0a94e16309362a0bb005cf59a0 Mon Sep 17 00:00:00 2001 From: rambaut Date: Tue, 8 Aug 2023 17:29:03 +0100 Subject: [PATCH 54/65] removing a hidden bound on the stem proportion --- .../evomodel/branchratemodel/LocalClockModel.java | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/dr/evomodel/branchratemodel/LocalClockModel.java b/src/dr/evomodel/branchratemodel/LocalClockModel.java index 7029ea6937..826b592e8a 100644 --- a/src/dr/evomodel/branchratemodel/LocalClockModel.java +++ b/src/dr/evomodel/branchratemodel/LocalClockModel.java @@ -255,6 +255,7 @@ public double getBranchRate(final Tree tree, final NodeRef node) { if (localClock != null) { double parentRate = rate; double stemProportion = 1.0; + double stemTime = 0.0; if (localClock != parentClock) { // this is the branch where the rate switch occurs @@ -267,15 +268,16 @@ public double getBranchRate(final Tree tree, final NodeRef node) { } if (localClock.stemAsTime) { // this could be greater than 1 in which case bad things might happen - stemProportion = Math.min(localClock.getStemValue() / tree.getBranchLength(node), 1.0); + stemTime = localClock.getStemValue(); + stemProportion = stemTime / tree.getBranchLength(node); -// stemProportion = localClock.getStemValue() / tree.getBranchLength(node); -// if (stemProportion > 1.0) { -// // it should be ensured that this never happens. -// throw new IllegalArgumentException("A stem proportion for a local clock is > 1.0"); -// } + if (stemProportion > 1.0) { + // it should be ensured that this never happens. + throw new IllegalArgumentException("A stem proportion for a local clock is > 1.0"); + } } else { stemProportion = localClock.getStemValue(); + stemTime = tree.getBranchLength(node) * stemProportion; } } From 88772fd0882e0c0a5306bbf17ac24f277826e798 Mon Sep 17 00:00:00 2001 From: rambaut Date: Thu, 10 Aug 2023 14:25:09 +0100 Subject: [PATCH 55/65] Alternative attributes for beta distribution prior - alpha and beta --- .../distribution/PriorParsers.java | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/dr/inferencexml/distribution/PriorParsers.java b/src/dr/inferencexml/distribution/PriorParsers.java index 822d7417fc..563d5ae76a 100644 --- a/src/dr/inferencexml/distribution/PriorParsers.java +++ b/src/dr/inferencexml/distribution/PriorParsers.java @@ -77,6 +77,7 @@ public class PriorParsers { public static final String HALF_T_PRIOR = "halfTPrior"; public static final String DIRICHLET_PRIOR = "dirichletPrior"; public static final String ALPHA = "alpha"; + public static final String BETA = "beta"; public static final String COUNTS = "counts"; public static final String SUMS_TO = "sumsTo"; @@ -847,8 +848,15 @@ public String getParserName() { public Object parseXMLObject(XMLObject xo) throws XMLParseException { - final double shape = xo.getDoubleAttribute(SHAPE); - final double shapeB = xo.getDoubleAttribute(SHAPEB); + double shape; + double shapeB; + if (xo.hasAttribute(ALPHA) && xo.hasAttribute(BETA)) { + shape = xo.getDoubleAttribute(ALPHA); + shapeB = xo.getDoubleAttribute(BETA); + } else { + shape = xo.getDoubleAttribute(SHAPE); + shapeB = xo.getDoubleAttribute(SHAPEB); + } final double offset = xo.getAttribute(OFFSET, 0.0); final double scale = xo.getAttribute(SCALE, 1.0); @@ -869,8 +877,16 @@ public XMLSyntaxRule[] getSyntaxRules() { } private final XMLSyntaxRule[] rules = { - AttributeRule.newDoubleRule(SHAPE), - AttributeRule.newDoubleRule(SHAPEB), + new XORRule( + new AndRule( + AttributeRule.newDoubleRule(SHAPE), + AttributeRule.newDoubleRule(SHAPEB) + ), + new AndRule( + AttributeRule.newDoubleRule(ALPHA), + AttributeRule.newDoubleRule(BETA) + ) + ), AttributeRule.newDoubleRule(OFFSET, true), new ElementRule(Statistic.class, 1, Integer.MAX_VALUE) }; From 4b4c39d965e9f0b60102518178576ebdcd2e4c68 Mon Sep 17 00:00:00 2001 From: rambaut Date: Fri, 11 Aug 2023 11:58:29 +0100 Subject: [PATCH 56/65] Adding a bit more logging of info to a few parsers (taxa and treemodel) --- src/dr/evomodelxml/tree/TreeModelParser.java | 21 ++++++++++++++++++-- src/dr/evoxml/SitePatternsParser.java | 2 +- src/dr/evoxml/TaxaParser.java | 6 ++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/dr/evomodelxml/tree/TreeModelParser.java b/src/dr/evomodelxml/tree/TreeModelParser.java index 8ede318402..99f05cc12e 100644 --- a/src/dr/evomodelxml/tree/TreeModelParser.java +++ b/src/dr/evomodelxml/tree/TreeModelParser.java @@ -264,9 +264,26 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } } + double minTaxonHeight = Double.MAX_VALUE; + double maxTaxonHeight = Double.MIN_VALUE; + for (int i = 0; i < treeModel.getTaxonCount(); i++) { + Taxon taxon = treeModel.getTaxon(i); + double h = Taxon.getHeightFromDate(taxon.getDate()); + if (h < minTaxonHeight) { + minTaxonHeight = h; + } + if (h > maxTaxonHeight) { + maxTaxonHeight = h; + } + } + // Logger.getLogger("dr.evomodel").info(" initial tree topology = " + TreeUtils.uniqueNewick(treeModel, treeModel.getRoot())); - Logger.getLogger("dr.evomodel").info(" taxon count = " + treeModel.getExternalNodeCount()); - Logger.getLogger("dr.evomodel").info(" tree height = " + treeModel.getNodeHeight(treeModel.getRoot())); + Logger.getLogger("dr.evomodel").info(" taxon count = " + treeModel.getExternalNodeCount()); + Logger.getLogger("dr.evomodel").info(" tree height = " + treeModel.getNodeHeight(treeModel.getRoot())); + Logger.getLogger("dr.evomodel").info(" min tip height = " + minTaxonHeight); + Logger.getLogger("dr.evomodel").info(" max tip height = " + maxTaxonHeight); + ; + return treeModel; } diff --git a/src/dr/evoxml/SitePatternsParser.java b/src/dr/evoxml/SitePatternsParser.java index f552a25942..afe58bf9d0 100644 --- a/src/dr/evoxml/SitePatternsParser.java +++ b/src/dr/evoxml/SitePatternsParser.java @@ -117,7 +117,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (xo.hasAttribute(XMLParser.ID)) { final Logger logger = Logger.getLogger("dr.evoxml"); - logger.info("Site patterns '" + xo.getId() + "' created from positions " + + logger.info("\nSite patterns '" + xo.getId() + "' created from positions " + Integer.toString(f) + "-" + Integer.toString(t) + " of alignment '" + alignment.getId() + "'"); diff --git a/src/dr/evoxml/TaxaParser.java b/src/dr/evoxml/TaxaParser.java index 4c77315c0b..5259ed8b41 100644 --- a/src/dr/evoxml/TaxaParser.java +++ b/src/dr/evoxml/TaxaParser.java @@ -30,6 +30,8 @@ import dr.evolution.util.TaxonList; import dr.xml.*; +import java.util.logging.Logger; + /** * @author Alexei Drummond * @author Andrew Rambaut @@ -81,6 +83,10 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { } } + final Logger logger = Logger.getLogger("dr.evoxml"); + logger.info("\nTaxon list '" + xo.getId() + "' created with " + taxonList.getTaxonCount() + " taxa."); + logger.info(" most recent taxon date = " + Taxon.getMostRecentDate()); + return taxonList; } From b11643b950bc323d7194c84acdba2b5378881f32 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 14 Aug 2023 13:46:02 +0100 Subject: [PATCH 57/65] Changing the default so BEAGLE CPU threading is on unless explicitly turned off. Default thread number needs thinking about. --- .../treedatalikelihood/BeagleDataLikelihoodDelegate.java | 8 ++++++-- src/dr/evomodel/treelikelihood/BeagleTreeLikelihood.java | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java index 8aaf8d452c..400b97b143 100644 --- a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java @@ -438,8 +438,12 @@ public BeagleDataLikelihoodDelegate(Tree tree, instanceFlags = instanceDetails.getFlags(); - if (IS_THREAD_COUNT_COMPATIBLE() && threadCount > 1) { - beagle.setCPUThreadCount(threadCount); + if (IS_THREAD_COUNT_COMPATIBLE()) { + if (threadCount > 0) { + beagle.setCPUThreadCount(threadCount); + } else { // if no thread_count is specified then this will be -1 so put no upper bound on threads + beagle.setCPUThreadCount(Integer.MAX_VALUE); + } } if (patternList instanceof UncertainSiteList) { // TODO Remove diff --git a/src/dr/evomodel/treelikelihood/BeagleTreeLikelihood.java b/src/dr/evomodel/treelikelihood/BeagleTreeLikelihood.java index 15e2afd514..f3cd2cd036 100644 --- a/src/dr/evomodel/treelikelihood/BeagleTreeLikelihood.java +++ b/src/dr/evomodel/treelikelihood/BeagleTreeLikelihood.java @@ -434,8 +434,12 @@ public BeagleTreeLikelihood(PatternList patternList, logger.info(" No external BEAGLE resources available, or resource list/requirements not met, using Java implementation"); } - if (IS_THREAD_COUNT_COMPATIBLE() && threadCount > 1) { - beagle.setCPUThreadCount(threadCount); + if (IS_THREAD_COUNT_COMPATIBLE()) { + if (threadCount > 0) { + beagle.setCPUThreadCount(threadCount); + } else { // if no thread_count is specified then this will be -1 so put no upper bound on threads + beagle.setCPUThreadCount(Integer.MAX_VALUE); + } } logger.info(" " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood."); From a3d1eb4cffcc97a6be91d0513d88cbd061d7287f Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 14 Aug 2023 13:58:33 +0100 Subject: [PATCH 58/65] Check that taxa have dates. --- src/dr/evomodelxml/tree/TreeModelParser.java | 26 ++++++++++++-------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/dr/evomodelxml/tree/TreeModelParser.java b/src/dr/evomodelxml/tree/TreeModelParser.java index 99f05cc12e..50fa40fb23 100644 --- a/src/dr/evomodelxml/tree/TreeModelParser.java +++ b/src/dr/evomodelxml/tree/TreeModelParser.java @@ -266,24 +266,30 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { double minTaxonHeight = Double.MAX_VALUE; double maxTaxonHeight = Double.MIN_VALUE; + boolean hasDates = false; for (int i = 0; i < treeModel.getTaxonCount(); i++) { Taxon taxon = treeModel.getTaxon(i); - double h = Taxon.getHeightFromDate(taxon.getDate()); - if (h < minTaxonHeight) { - minTaxonHeight = h; - } - if (h > maxTaxonHeight) { - maxTaxonHeight = h; + if (taxon.getDate() != null) { + hasDates = true; + double h = Taxon.getHeightFromDate(taxon.getDate()); + if (h < minTaxonHeight) { + minTaxonHeight = h; + } + if (h > maxTaxonHeight) { + maxTaxonHeight = h; + } } } // Logger.getLogger("dr.evomodel").info(" initial tree topology = " + TreeUtils.uniqueNewick(treeModel, treeModel.getRoot())); Logger.getLogger("dr.evomodel").info(" taxon count = " + treeModel.getExternalNodeCount()); Logger.getLogger("dr.evomodel").info(" tree height = " + treeModel.getNodeHeight(treeModel.getRoot())); - Logger.getLogger("dr.evomodel").info(" min tip height = " + minTaxonHeight); - Logger.getLogger("dr.evomodel").info(" max tip height = " + maxTaxonHeight); - ; - + if (hasDates) { + Logger.getLogger("dr.evomodel").info(" min tip height = " + minTaxonHeight); + Logger.getLogger("dr.evomodel").info(" max tip height = " + maxTaxonHeight); + } else { + Logger.getLogger("dr.evomodel").info(" tip heights = 0"); + } return treeModel; } From 492cb2da5e43e080e7d3ef3d41c87f3fbe03c2b4 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 14 Aug 2023 14:01:14 +0100 Subject: [PATCH 59/65] Check that taxa have dates. --- src/dr/evomodelxml/tree/TreeModelParser.java | 31 ++++++++------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/src/dr/evomodelxml/tree/TreeModelParser.java b/src/dr/evomodelxml/tree/TreeModelParser.java index 50fa40fb23..d35a3048df 100644 --- a/src/dr/evomodelxml/tree/TreeModelParser.java +++ b/src/dr/evomodelxml/tree/TreeModelParser.java @@ -263,33 +263,28 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { throw new XMLParseException("illegal child element in " + getParserName() + ": " + xo.getChildName(i) + " " + xo.getChild(i)); } } - + double minTaxonHeight = Double.MAX_VALUE; - double maxTaxonHeight = Double.MIN_VALUE; - boolean hasDates = false; + double maxTaxonHeight = -Double.MAX_VALUE; for (int i = 0; i < treeModel.getTaxonCount(); i++) { Taxon taxon = treeModel.getTaxon(i); + double h = 0; if (taxon.getDate() != null) { - hasDates = true; - double h = Taxon.getHeightFromDate(taxon.getDate()); - if (h < minTaxonHeight) { - minTaxonHeight = h; - } - if (h > maxTaxonHeight) { - maxTaxonHeight = h; - } + h = Taxon.getHeightFromDate(taxon.getDate()); + } + if (h < minTaxonHeight) { + minTaxonHeight = h; + } + if (h > maxTaxonHeight) { + maxTaxonHeight = h; } } - + // Logger.getLogger("dr.evomodel").info(" initial tree topology = " + TreeUtils.uniqueNewick(treeModel, treeModel.getRoot())); Logger.getLogger("dr.evomodel").info(" taxon count = " + treeModel.getExternalNodeCount()); Logger.getLogger("dr.evomodel").info(" tree height = " + treeModel.getNodeHeight(treeModel.getRoot())); - if (hasDates) { - Logger.getLogger("dr.evomodel").info(" min tip height = " + minTaxonHeight); - Logger.getLogger("dr.evomodel").info(" max tip height = " + maxTaxonHeight); - } else { - Logger.getLogger("dr.evomodel").info(" tip heights = 0"); - } + Logger.getLogger("dr.evomodel").info(" min tip height = " + minTaxonHeight); + Logger.getLogger("dr.evomodel").info(" max tip height = " + maxTaxonHeight); return treeModel; } From f4ae1b047af7fd5b3ec1562087c91d9463dfab27 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 14 Aug 2023 14:42:26 +0100 Subject: [PATCH 60/65] Re-implemented the "beagle_instances" option which divides patterns amongst independent BEAGLE instances --- .../TreeDataLikelihoodParser.java | 77 ++++++++++++------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 53f51ee387..0a0b7d2d5f 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -26,6 +26,7 @@ package dr.evomodelxml.treedatalikelihood; import dr.evolution.alignment.PatternList; +import dr.evolution.alignment.Patterns; import dr.evolution.tree.Tree; import dr.evolution.util.Taxon; import dr.evomodel.branchmodel.BranchModel; @@ -36,12 +37,7 @@ import dr.evomodel.substmodel.FrequencyModel; import dr.evomodel.substmodel.SubstitutionModel; import dr.evomodel.tipstatesmodel.TipStatesModel; -import dr.evomodel.tree.TreeModel; -import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate; -import dr.evomodel.treedatalikelihood.PreOrderSettings; -import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate; -import dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate; -import dr.evomodel.treedatalikelihood.TreeDataLikelihood; +import dr.evomodel.treedatalikelihood.*; import dr.evomodel.treelikelihood.PartialsRescalingScheme; import dr.inference.model.CompoundLikelihood; import dr.inference.model.Likelihood; @@ -110,7 +106,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, if (patternLists.size() > 1) { // will currently recommend true if using GPU, CUDA or OpenCL. useBeagle3MultiPartition = MultiPartitionDataLikelihoodDelegate.IS_MULTI_PARTITION_RECOMMENDED(); - + if (System.getProperty("USE_BEAGLE3_EXTENSIONS") != null) { useBeagle3MultiPartition = Boolean.parseBoolean(System.getProperty("USE_BEAGLE3_EXTENSIONS")); } @@ -139,7 +135,13 @@ protected Likelihood createTreeDataLikelihood(List patternLists, } } - if ( useBeagle3MultiPartition && !useJava) { + int instanceCount = 0; + String ic = System.getProperty(BEAGLE_INSTANCE_COUNT); + if (ic != null && ic.length() > 0) { + instanceCount = Integer.parseInt(ic); + } + + if ( useBeagle3MultiPartition && instanceCount == 0 && !useJava) { if (beagleThreadCount == -1 && threadCount >= 0) { System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount)); @@ -154,7 +156,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, useAmbiguities, scalingScheme, delayRescalingUntilUnderflow - ); + ); return new TreeDataLikelihood( dataLikelihoodDelegate, @@ -164,7 +166,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, useBeagle3MultiPartition = false; } - } + } // The multipartition data likelihood isn't available so make a set of single partition data likelihoods List treeDataLikelihoods = new ArrayList(); @@ -175,24 +177,45 @@ protected Likelihood createTreeDataLikelihood(List patternLists, } for (int i = 0; i < patternLists.size(); i++) { - - DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate( - treeModel, - patternLists.get(i), - branchModels.get(i), - siteRateModels.get(i), - useAmbiguities, - preferGPU, - scalingScheme, - delayRescalingUntilUnderflow, - settings); - - treeDataLikelihoods.add( - new TreeDataLikelihood( - dataLikelihoodDelegate, + if (instanceCount > 1) { + for (int j = 0; j < instanceCount; j++) { + PatternList patterns = new Patterns(patternLists.get(i), j, instanceCount); + DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate( treeModel, - branchRateModel)); + patterns, + branchModels.get(i), + siteRateModels.get(i), + useAmbiguities, + preferGPU, + scalingScheme, + delayRescalingUntilUnderflow, + settings); + + treeDataLikelihoods.add( + new TreeDataLikelihood( + dataLikelihoodDelegate, + treeModel, + branchRateModel)); + } + } else { + DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate( + treeModel, + patternLists.get(i), + branchModels.get(i), + siteRateModels.get(i), + useAmbiguities, + preferGPU, + scalingScheme, + delayRescalingUntilUnderflow, + settings); + treeDataLikelihoods.add( + new TreeDataLikelihood( + dataLikelihoodDelegate, + treeModel, + branchRateModel)); + + } } if (treeDataLikelihoods.size() == 1) { @@ -200,7 +223,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, } return new CompoundLikelihood(treeDataLikelihoods); - + } public Object parseXMLObject(XMLObject xo) throws XMLParseException { From 6d5dc32999fd32743763e042eafb425f732c2c61 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 14 Aug 2023 14:56:18 +0100 Subject: [PATCH 61/65] Improved reporting --- .../BeagleDataLikelihoodDelegate.java | 14 +++++++------- .../treedatalikelihood/TreeDataLikelihood.java | 4 ++-- .../TreeDataLikelihoodParser.java | 7 +++++++ 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java index 400b97b143..7e6571aea8 100644 --- a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java @@ -367,7 +367,7 @@ public BeagleDataLikelihoodDelegate(Tree tree, benchmarkFlags = BeagleBenchmarkFlag.SCALING_DYNAMIC.getMask(); } - logger.info("\nRunning benchmarks to automatically select fastest BEAGLE resource for analysis or partition... "); + logger.info("\t\tRunning benchmarks to automatically select fastest BEAGLE resource for analysis or partition... "); List benchmarkedResourceDetails = BeagleFactory.getBenchmarkedResourceDetails( @@ -455,12 +455,12 @@ public BeagleDataLikelihoodDelegate(Tree tree, // } //add in logger info for preOrder traversal - logger.info(" " + (settings.usePreOrder ? "Using" : "Ignoring") + " preOrder partials in tree likelihood."); - logger.info(" " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood."); - logger.info(" With " + patternList.getPatternCount() + " unique site patterns."); + logger.info(" " + (settings.usePreOrder ? "Using" : "Ignoring") + " preOrder partials in tree likelihood."); + logger.info(" " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood."); + logger.info(" With " + patternList.getPatternCount() + " unique site patterns."); if (patternList.areUncertain() && !useAmbiguities) { - logger.info(" WARNING: Uncertain site patterns will be ignored."); + logger.info(" WARNING: Uncertain site patterns will be ignored."); } for (int i = 0; i < tipCount; i++) { @@ -482,13 +482,13 @@ public BeagleDataLikelihoodDelegate(Tree tree, beagle.setPatternWeights(patternWeights); - String rescaleMessage = " Using rescaling scheme : " + this.rescalingScheme.getText(); + String rescaleMessage = " Using rescaling scheme : " + this.rescalingScheme.getText(); if (this.rescalingScheme == PartialsRescalingScheme.AUTO && resourceDetails != null && (resourceDetails.getFlags() & BeagleFlag.SCALING_AUTO.getMask()) == 0) { // If auto scaling in BEAGLE is not supported then do it here this.rescalingScheme = PartialsRescalingScheme.DYNAMIC; - rescaleMessage = " Auto rescaling not supported in BEAGLE, using : " + this.rescalingScheme.getText(); + rescaleMessage = " Auto rescaling not supported in BEAGLE, using : " + this.rescalingScheme.getText(); } boolean parenthesis = false; if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) { diff --git a/src/dr/evomodel/treedatalikelihood/TreeDataLikelihood.java b/src/dr/evomodel/treedatalikelihood/TreeDataLikelihood.java index 4482b586f3..29bd3d65bd 100644 --- a/src/dr/evomodel/treedatalikelihood/TreeDataLikelihood.java +++ b/src/dr/evomodel/treedatalikelihood/TreeDataLikelihood.java @@ -63,7 +63,7 @@ public TreeDataLikelihood(DataLikelihoodDelegate likelihoodDelegate, final Logger logger = Logger.getLogger("dr.evomodel"); - logger.info("\nUsing TreeDataLikelihood"); + logger.info("\nCreating TreeDataLikelihood"); this.likelihoodDelegate = likelihoodDelegate; addModel(likelihoodDelegate); @@ -79,7 +79,7 @@ public TreeDataLikelihood(DataLikelihoodDelegate likelihoodDelegate, this.branchRateModel = branchRateModel; if (!(branchRateModel instanceof DefaultBranchRateModel)) { - logger.info(" Branch rate model used: " + branchRateModel.getModelName()); + logger.info(" Branch rate model: " + branchRateModel.getModelName()); } addModel(this.branchRateModel); diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 0a0b7d2d5f..b87a347610 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -45,6 +45,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.logging.Logger; /** * @author Andrew Rambaut @@ -86,6 +87,9 @@ protected Likelihood createTreeDataLikelihood(List patternLists, boolean delayRescalingUntilUnderflow, PreOrderSettings settings) throws XMLParseException { + final Logger logger = Logger.getLogger("dr.evomodel"); + logger.info("\nCreating tree data likelihoods for " + patternLists.size() + " partitions"); + if (tipStatesModel != null) { throw new XMLParseException("Tip State Error models are not supported yet with TreeDataLikelihood"); } @@ -176,6 +180,9 @@ protected Likelihood createTreeDataLikelihood(List patternLists, System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount / patternLists.size())); } + if (instanceCount > 1) { + logger.info(" Dividing each partition amongst " + instanceCount + " BEAGLE instances:"); + } for (int i = 0; i < patternLists.size(); i++) { if (instanceCount > 1) { for (int j = 0; j < instanceCount; j++) { From 9a50939e5c34e1fe33bbb9d0c3aea5bc18468df0 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 14 Aug 2023 15:18:06 +0100 Subject: [PATCH 62/65] More reporting of the use of threads --- .../BeagleDataLikelihoodDelegate.java | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java index 7e6571aea8..cea0c55d72 100644 --- a/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java +++ b/src/dr/evomodel/treedatalikelihood/BeagleDataLikelihoodDelegate.java @@ -438,11 +438,19 @@ public BeagleDataLikelihoodDelegate(Tree tree, instanceFlags = instanceDetails.getFlags(); - if (IS_THREAD_COUNT_COMPATIBLE()) { - if (threadCount > 0) { - beagle.setCPUThreadCount(threadCount); - } else { // if no thread_count is specified then this will be -1 so put no upper bound on threads - beagle.setCPUThreadCount(Integer.MAX_VALUE); + if ((instanceFlags & BeagleFlag.THREADING_CPP.getMask()) != 0) { + if (IS_THREAD_COUNT_COMPATIBLE() || threadCount != 0) { + if (threadCount > 0) { + beagle.setCPUThreadCount(threadCount); + logger.info(" Using " + threadCount + " threads for CPU."); + } else { // if no thread_count is specified then this will be -1 so put no upper bound on threads + logger.info(" Using default thread count for CPU."); + // this is just intended to remove the cap on number of threads so BEAGLE will + // make its own decision (for better or worse). + beagle.setCPUThreadCount(1000); + } + } else { + logger.info(" BEAGLE threading turned off (or unavailable) for CPU."); } } From a1a4612da069a42abf40acc1fa978f1819c96746 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 14 Aug 2023 18:03:22 +0100 Subject: [PATCH 63/65] Revert "Re-implemented the "beagle_instances" option which divides patterns amongst independent BEAGLE instances" This reverts commit f4ae1b04 --- .../TreeDataLikelihoodParser.java | 77 +++++++------------ 1 file changed, 27 insertions(+), 50 deletions(-) diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index b87a347610..660d2ad3ab 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -26,7 +26,6 @@ package dr.evomodelxml.treedatalikelihood; import dr.evolution.alignment.PatternList; -import dr.evolution.alignment.Patterns; import dr.evolution.tree.Tree; import dr.evolution.util.Taxon; import dr.evomodel.branchmodel.BranchModel; @@ -37,7 +36,12 @@ import dr.evomodel.substmodel.FrequencyModel; import dr.evomodel.substmodel.SubstitutionModel; import dr.evomodel.tipstatesmodel.TipStatesModel; -import dr.evomodel.treedatalikelihood.*; +import dr.evomodel.tree.TreeModel; +import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate; +import dr.evomodel.treedatalikelihood.PreOrderSettings; +import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate; +import dr.evomodel.treedatalikelihood.MultiPartitionDataLikelihoodDelegate; +import dr.evomodel.treedatalikelihood.TreeDataLikelihood; import dr.evomodel.treelikelihood.PartialsRescalingScheme; import dr.inference.model.CompoundLikelihood; import dr.inference.model.Likelihood; @@ -110,7 +114,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, if (patternLists.size() > 1) { // will currently recommend true if using GPU, CUDA or OpenCL. useBeagle3MultiPartition = MultiPartitionDataLikelihoodDelegate.IS_MULTI_PARTITION_RECOMMENDED(); - + if (System.getProperty("USE_BEAGLE3_EXTENSIONS") != null) { useBeagle3MultiPartition = Boolean.parseBoolean(System.getProperty("USE_BEAGLE3_EXTENSIONS")); } @@ -139,13 +143,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, } } - int instanceCount = 0; - String ic = System.getProperty(BEAGLE_INSTANCE_COUNT); - if (ic != null && ic.length() > 0) { - instanceCount = Integer.parseInt(ic); - } - - if ( useBeagle3MultiPartition && instanceCount == 0 && !useJava) { + if ( useBeagle3MultiPartition && !useJava) { if (beagleThreadCount == -1 && threadCount >= 0) { System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount)); @@ -160,7 +158,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, useAmbiguities, scalingScheme, delayRescalingUntilUnderflow - ); + ); return new TreeDataLikelihood( dataLikelihoodDelegate, @@ -170,7 +168,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, useBeagle3MultiPartition = false; } - } + } // The multipartition data likelihood isn't available so make a set of single partition data likelihoods List treeDataLikelihoods = new ArrayList(); @@ -184,45 +182,24 @@ protected Likelihood createTreeDataLikelihood(List patternLists, logger.info(" Dividing each partition amongst " + instanceCount + " BEAGLE instances:"); } for (int i = 0; i < patternLists.size(); i++) { - if (instanceCount > 1) { - for (int j = 0; j < instanceCount; j++) { - PatternList patterns = new Patterns(patternLists.get(i), j, instanceCount); - DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate( - treeModel, - patterns, - branchModels.get(i), - siteRateModels.get(i), - useAmbiguities, - preferGPU, - scalingScheme, - delayRescalingUntilUnderflow, - settings); - - treeDataLikelihoods.add( - new TreeDataLikelihood( - dataLikelihoodDelegate, - treeModel, - branchRateModel)); - } - } else { - DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate( - treeModel, - patternLists.get(i), - branchModels.get(i), - siteRateModels.get(i), - useAmbiguities, - preferGPU, - scalingScheme, - delayRescalingUntilUnderflow, - settings); - treeDataLikelihoods.add( - new TreeDataLikelihood( - dataLikelihoodDelegate, - treeModel, - branchRateModel)); + DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate( + treeModel, + patternLists.get(i), + branchModels.get(i), + siteRateModels.get(i), + useAmbiguities, + preferGPU, + scalingScheme, + delayRescalingUntilUnderflow, + settings); + + treeDataLikelihoods.add( + new TreeDataLikelihood( + dataLikelihoodDelegate, + treeModel, + branchRateModel)); - } } if (treeDataLikelihoods.size() == 1) { @@ -230,7 +207,7 @@ protected Likelihood createTreeDataLikelihood(List patternLists, } return new CompoundLikelihood(treeDataLikelihoods); - + } public Object parseXMLObject(XMLObject xo) throws XMLParseException { From 4b76ba429c41709e85920b4f525cabd43e26b317 Mon Sep 17 00:00:00 2001 From: rambaut Date: Mon, 14 Aug 2023 14:56:18 +0100 Subject: [PATCH 64/65] Reverting beagle instance stuff --- .../treedatalikelihood/TreeDataLikelihoodParser.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java index 660d2ad3ab..ccd70eb9e5 100644 --- a/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java +++ b/src/dr/evomodelxml/treedatalikelihood/TreeDataLikelihoodParser.java @@ -178,9 +178,6 @@ protected Likelihood createTreeDataLikelihood(List patternLists, System.setProperty(BEAGLE_THREAD_COUNT, Integer.toString(threadCount / patternLists.size())); } - if (instanceCount > 1) { - logger.info(" Dividing each partition amongst " + instanceCount + " BEAGLE instances:"); - } for (int i = 0; i < patternLists.size(); i++) { DataLikelihoodDelegate dataLikelihoodDelegate = new BeagleDataLikelihoodDelegate( From 49fb15f28b61fc89e3f0188ceb55d335c7c69318 Mon Sep 17 00:00:00 2001 From: Marc Suchard Date: Mon, 14 Aug 2023 13:49:15 -0700 Subject: [PATCH 65/65] enable exact gradients for glm substitution models --- .../substmodel/OldGLMSubstitutionModel.java | 97 ++++++++++++++++++- .../hmc/CompoundGradientParser.java | 8 +- 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/src/dr/evomodel/substmodel/OldGLMSubstitutionModel.java b/src/dr/evomodel/substmodel/OldGLMSubstitutionModel.java index 6f5e5dfedb..e7c6e3ae9a 100644 --- a/src/dr/evomodel/substmodel/OldGLMSubstitutionModel.java +++ b/src/dr/evomodel/substmodel/OldGLMSubstitutionModel.java @@ -31,6 +31,8 @@ import dr.inference.loggers.LogColumn; import dr.inference.model.BayesianStochasticSearchVariableSelection; import dr.inference.model.Model; +import dr.inference.model.Parameter; +import dr.math.matrixAlgebra.WrappedMatrix; import dr.util.Citation; import dr.util.CommonCitations; @@ -41,7 +43,7 @@ * @author Marc A. Suchard */ @Deprecated -public class OldGLMSubstitutionModel extends ComplexSubstitutionModel { +public class OldGLMSubstitutionModel extends ComplexSubstitutionModel implements DifferentiableSubstitutionModel { public OldGLMSubstitutionModel(String name, DataType dataType, FrequencyModel rootFreqModel, LogLinearModel glm) { @@ -103,6 +105,95 @@ public List getCitations() { return Collections.singletonList(CommonCitations.LEMEY_2014_UNIFYING); } - private LogLinearModel glm; - private double[] testProbabilities; + final private LogLinearModel glm; + final private double[] testProbabilities; + + @Override + public WrappedMatrix getInfinitesimalDifferentialMatrix(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrt) { + // TODO all instantiations of this function currently do the same thing; remove duplication + return DifferentiableSubstitutionModelUtil.getInfinitesimalDifferentialMatrix(wrt, this); + } + + @Override + public DifferentialMassProvider.DifferentialWrapper.WrtParameter factory(Parameter parameter, int dim) { + for (int i = 0; i < glm.getNumberOfFixedEffects(); ++i) { + Parameter effect = glm.getFixedEffect(i); + if (parameter == effect) { + return new WrtGlmCoefficient(effect, i, dim); + } + } + throw new RuntimeException("Parameter not found"); + } + + @Override + public void setupDifferentialRates(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrt, + double[] differentialRates, + double normalizingConstant) { + double[] relativeRates = new double[rateCount]; + setupRelativeRates(relativeRates); // TODO These are large; should cache + wrt.setupDifferentialRates(differentialRates, relativeRates, normalizingConstant); + } + + @Override + public void setupDifferentialFrequency(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrt, + double[] differentialFrequency) { + double[] frequencies = freqModel.getFrequencies(); + System.arraycopy(frequencies, 0, differentialFrequency, 0, frequencies.length); + } + + @Override + public double getWeightedNormalizationGradient(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrt, + double[][] differentialMassMatrix, + double[] differentialFrequencies) { + double weight = 0.0; + for (int i = 0; i < stateCount; ++i) { + weight -= differentialMassMatrix[i][i] * getFrequencyModel().getFrequency(i); + } + return weight; + } + + class WrtGlmCoefficient implements DifferentialMassProvider.DifferentialWrapper.WrtParameter { + + final private Parameter parameter; + final int effect; + final int dim; + + public WrtGlmCoefficient(Parameter parameter, int effect, int dim) { + this.parameter = parameter; + this.effect = effect; + this.dim = dim; + } + + @Override + public void setupDifferentialRates(double[] differentialRates, + double[] relativeRates, + double normalizingConstant) { + + final double chainRule = getChainRule(); + double[][] design = glm.getX(effect); + + for (int i = 0; i < relativeRates.length; ++i) { + differentialRates[i] = design[i][dim] / normalizingConstant * chainRule; + } + } + + double getChainRule() { + return Math.exp(parameter.getParameterValue(dim)); + } + + @Override + public double getRate(int switchCase) { + throw new RuntimeException("Not yet implemented"); + } + + @Override + public double getNormalizationDifferential() { + throw new RuntimeException("Not yet implemented"); + } + + @Override + public void setupDifferentialFrequencies(double[] differentialFrequencies, double[] frequencies) { + throw new RuntimeException("Not yet implemented"); + } + } } diff --git a/src/dr/inferencexml/hmc/CompoundGradientParser.java b/src/dr/inferencexml/hmc/CompoundGradientParser.java index 44fffcd99c..6bb1d63466 100644 --- a/src/dr/inferencexml/hmc/CompoundGradientParser.java +++ b/src/dr/inferencexml/hmc/CompoundGradientParser.java @@ -28,7 +28,6 @@ import dr.inference.distribution.DistributionLikelihood; import dr.inference.distribution.MultivariateDistributionLikelihood; import dr.inference.hmc.CompoundDerivative; -import dr.inference.hmc.CompoundGradient; import dr.inference.hmc.GradientWrtParameterProvider; import dr.inference.model.*; import dr.xml.*; @@ -57,15 +56,13 @@ public String[] getParserNames() { @Override public Object parseXMLObject(XMLObject xo) throws XMLParseException { - List gradList = new ArrayList(); - List likelihoodList = new ArrayList(); // TODO Remove? + List gradList = new ArrayList<>(); for (int i = 0; i < xo.getChildCount(); ++i) { Object obj = xo.getChild(i); GradientWrtParameterProvider grad; - Likelihood likelihood; if (obj instanceof DistributionLikelihood) { DistributionLikelihood dl = (DistributionLikelihood) obj; @@ -83,19 +80,16 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { final GradientProvider provider = (GradientProvider) mdl.getDistribution(); final Parameter parameter = mdl.getDataParameter(); - likelihood = mdl; grad = new GradientWrtParameterProvider.ParameterWrapper(provider, parameter, mdl); } else if (obj instanceof GradientWrtParameterProvider) { grad = (GradientWrtParameterProvider) obj; - likelihood = grad.getLikelihood(); } else { throw new XMLParseException("Not a Gaussian process"); } gradList.add(grad); - likelihoodList.add(likelihood); } return new CompoundDerivative(gradList);