Skip to content

Commit

Permalink
re-enable random rotations
Browse files Browse the repository at this point in the history
  • Loading branch information
rdk committed Jun 26, 2024
1 parent 8a2d483 commit 7766312
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 95 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ dependencies {
implementation 'org.tukaani:xz:1.9'
implementation 'com.github.luben:zstd-jni:1.5.6-3'
implementation 'com.github.dpaukov:combinatoricslib3:3.4.0'
//implementation 'us.ihmc:euclid:0.21.0'
implementation 'us.ihmc:euclid:0.21.0'

implementation 'org.slf4j:slf4j-api:1.7.36'
implementation 'org.apache.logging.log4j:log4j-slf4j-impl:2.23.1'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ abstract class GeometricTransformation {
transformAtom(atom)
}


void applyToAtoms(Iterable<Atom> atoms) {
for (Atom atom : atoms) {
transformAtom(atom)
Expand Down
24 changes: 8 additions & 16 deletions src/main/groovy/cz/siret/prank/geom/transform/Rotation.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,25 @@ package cz.siret.prank.geom.transform
import groovy.transform.CompileStatic
import org.biojava.nbio.structure.Atom
import org.biojava.nbio.structure.Calc
//import us.ihmc.euclid.matrix.RotationMatrix
import us.ihmc.euclid.matrix.RotationMatrix

/**
*
*/
@CompileStatic
class Rotation extends GeometricTransformation {

// private double[][] matrix
//
// Rotation(String name, RotationMatrix rotMatrix) {
// super(name)
//
// this.matrix = Rotations.rotationMatrixToArrays(rotMatrix)
// }
//
// @Override
// void transformAtom(Atom atom) {
// Calc.rotate(atom, matrix)
// }

Rotation(String name) {
private double[][] matrix

Rotation(String name, RotationMatrix rotMatrix) {
super(name)

this.matrix = Rotations.rotationMatrixToArrays(rotMatrix)
}

@Override
void transformAtom(Atom atom) {

Calc.rotate(atom, matrix)
}

}
68 changes: 35 additions & 33 deletions src/main/groovy/cz/siret/prank/geom/transform/Rotations.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import org.biojava.nbio.structure.Calc;
import org.biojava.nbio.structure.Structure;
import org.biojava.nbio.structure.StructureException;
//import us.ihmc.euclid.matrix.RotationMatrix;
//import us.ihmc.euclid.tools.EuclidCoreRandomTools;
import us.ihmc.euclid.matrix.RotationMatrix;
import us.ihmc.euclid.tools.EuclidCoreRandomTools;
import us.ihmc.euclid.matrix.RotationMatrix;
import us.ihmc.euclid.tools.EuclidCoreRandomTools;

import java.util.Random;

Expand All @@ -13,36 +15,36 @@
*/
public class Rotations {

// /**
// * see https://msl.cs.uiuc.edu/planning/node198.html
// */
// public static RotationMatrix generateRandomRotation(Random rand) {
// return EuclidCoreRandomTools.nextRotationMatrix(rand);
// }
//
// public static double[][] rotationMatrixToArrays(RotationMatrix mat) {
// double[] aux = new double[9];
// mat.get(aux); // fill in
//
// double[][] res = new double[3][];
//
// res[0] = new double[] { aux[0], aux[1], aux[2] };
// res[1] = new double[] { aux[3], aux[4], aux[5] };
// res[2] = new double[] { aux[6], aux[7], aux[8] };
//
// return res;
// }
//
// public static void rotateStructureInplace(Structure structure, double[][] rotationMatrix3D) {
// try {
// Calc.rotate(structure, rotationMatrix3D);
// } catch (StructureException e) {
// throw new RuntimeException("Failed to rotate the structure.", e);
// }
// }
//
// public static void rotateStructureInplace(Structure structure, RotationMatrix rotationMatrix3D) {
// rotateStructureInplace(structure, rotationMatrixToArrays(rotationMatrix3D));
// }
/**
* see https://msl.cs.uiuc.edu/planning/node198.html
*/
public static RotationMatrix generateRandomRotation(Random rand) {
return EuclidCoreRandomTools.nextRotationMatrix(rand);
}

public static double[][] rotationMatrixToArrays(RotationMatrix mat) {
double[] aux = new double[9];
mat.get(aux); // fill in

double[][] res = new double[3][];

res[0] = new double[] { aux[0], aux[1], aux[2] };
res[1] = new double[] { aux[3], aux[4], aux[5] };
res[2] = new double[] { aux[6], aux[7], aux[8] };

return res;
}

public static void rotateStructureInplace(Structure structure, double[][] rotationMatrix3D) {
try {
Calc.rotate(structure, rotationMatrix3D);
} catch (StructureException e) {
throw new RuntimeException("Failed to rotate the structure.", e);
}
}

public static void rotateStructureInplace(Structure structure, RotationMatrix rotationMatrix3D) {
rotateStructureInplace(structure, rotationMatrixToArrays(rotationMatrix3D));
}

}
10 changes: 5 additions & 5 deletions src/main/groovy/cz/siret/prank/program/params/Params.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -1290,11 +1290,11 @@ class Params {
int loaded_pockets_limit = 0


// /**
// * Add random rotations of each protein (from training dataset) to the training dataset
// */
// @RuntimeParam // training
// int train_random_rotated_copies = 0
/**
* Add random rotations of each protein (from training dataset) to the training dataset
*/
@RuntimeParam // training
int train_random_rotated_copies = 0

//===========================================================================================================//
// Derived parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import cz.siret.prank.utils.PerfUtils
import cz.siret.prank.utils.WekaUtils
import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
//import us.ihmc.euclid.matrix.RotationMatrix
import us.ihmc.euclid.matrix.RotationMatrix
import weka.core.Instance
import weka.core.Instances

Expand Down Expand Up @@ -49,13 +49,13 @@ class CollectVectorsRoutine extends Routine {
dataset = dataset.randomSubset(params.train_protein_limit, params.seed)
}

// // add random rotations
// // TODO move to TrainEvalRoutine to make use of dataset caching
// if (params.train_random_rotated_copies > 0) {
// dataset = expandDatasetWithRandomRotations(dataset, params.train_random_rotated_copies)
//
// // savePdbsToDir(dataset, outdir + "/train_pdbs") // TODO remove
// }
// add random rotations
// TODO move to TrainEvalRoutine to make use of dataset caching
if (params.train_random_rotated_copies > 0) {
dataset = expandDatasetWithRandomRotations(dataset, params.train_random_rotated_copies)

// savePdbsToDir(dataset, outdir + "/train_pdbs") // debug
}

return dataset
}
Expand All @@ -69,37 +69,34 @@ class CollectVectorsRoutine extends Routine {
}

private Dataset expandDatasetWithRandomRotations(Dataset dataset, int numRotations) {
throw new UnsupportedOperationException("Rotations are not supported in this version")
// log.info "Extending training dataset with {} random rotations of each protein", numRotations
//
// Random rand = new Random(params.seed)
//
// List<Dataset.Item> newItems = new ArrayList<>()
//
// newItems.addAll( dataset.items.collect { it.copy() } )
//
// for (int i=1; i<=numRotations; ++i) {
// String nameSuffix = "rotation." + i
//
// RotationMatrix matrix = Rotations.generateRandomRotation(rand)
// Rotation rotation = new Rotation(nameSuffix, matrix)
//
// //matrix.setIdentity() // TODO xxx temp
// matrix.normalize()
//
// log.info "Random rotation $i: " + matrix
//
// List<Dataset.Item> rotItems = dataset.items.collect { it.cleanCopy() }
// for (Dataset.Item item : rotItems) {
// item.label += nameSuffix
// item.transformation = rotation
// // TODO conditionally rotate predictions (pockets of other methods)
// }
//
// newItems.addAll(rotItems)
// }
//
// return dataset.copyWithNewItems(newItems, dataset.name + "-with-rotations")
log.info "Extending training dataset with {} random rotations of each protein", numRotations

Random rand = new Random(params.seed)

List<Dataset.Item> newItems = new ArrayList<>()

newItems.addAll( dataset.items.collect { it.copy() } )

for (int i=1; i<=numRotations; ++i) {
String nameSuffix = "rotation." + i

RotationMatrix matrix = Rotations.generateRandomRotation(rand)
Rotation rotation = new Rotation(nameSuffix, matrix)

matrix.normalize()

log.info "Random rotation $i: " + matrix

List<Dataset.Item> rotItems = dataset.items.collect { it.cleanCopy() }
for (Dataset.Item item : rotItems) {
item.label += nameSuffix
item.transformation = rotation
}

newItems.addAll(rotItems)
}

return dataset.copyWithNewItems(newItems, dataset.name + "-with-rotations")
}

/**
Expand Down

0 comments on commit 7766312

Please sign in to comment.