diff --git a/src/main/java/ldbc/finbench/datagen/generation/events/CompanyInvestEvent.java b/src/main/java/ldbc/finbench/datagen/generation/events/CompanyInvestEvent.java deleted file mode 100644 index 75e5fbdf..00000000 --- a/src/main/java/ldbc/finbench/datagen/generation/events/CompanyInvestEvent.java +++ /dev/null @@ -1,47 +0,0 @@ -package ldbc.finbench.datagen.generation.events; - -import java.io.Serializable; -import java.util.List; -import java.util.Random; -import ldbc.finbench.datagen.entities.edges.CompanyInvestCompany; -import ldbc.finbench.datagen.entities.nodes.Company; -import ldbc.finbench.datagen.generation.DatagenParams; -import ldbc.finbench.datagen.util.RandomGeneratorFarm; - -public class CompanyInvestEvent implements Serializable { - private final RandomGeneratorFarm randomFarm; - private final Random randIndex; - - public CompanyInvestEvent() { - randomFarm = new RandomGeneratorFarm(); - randIndex = new Random(DatagenParams.defaultSeed); - } - - public void resetState(int seed) { - randomFarm.resetRandomGenerators(seed); - randIndex.setSeed(seed); - } - - public List companyInvestPartition(List investors, List targets) { - Random numInvestorsRand = randomFarm.get(RandomGeneratorFarm.Aspect.NUMS_COMPANY_INVEST); - Random chooseInvestorRand = randomFarm.get(RandomGeneratorFarm.Aspect.CHOOSE_COMPANY_INVESTOR); - for (Company target : targets) { - int numInvestors = numInvestorsRand.nextInt( - DatagenParams.maxInvestors - DatagenParams.minInvestors + 1 - ) + DatagenParams.minInvestors; - for (int i = 0; i < numInvestors; i++) { - int index = chooseInvestorRand.nextInt(investors.size()); - Company investor = investors.get(index); - if (cannotInvest(investor, target)) { - continue; - } - CompanyInvestCompany.createCompanyInvestCompany(randomFarm, investor, target); - } - } - return targets; - } - - public boolean cannotInvest(Company investor, Company target) { - return (investor == target) || investor.hasInvestedBy(target) || target.hasInvestedBy(investor); - } -} diff --git a/src/main/java/ldbc/finbench/datagen/generation/events/InvestActivitesEvent.java b/src/main/java/ldbc/finbench/datagen/generation/events/InvestActivitesEvent.java new file mode 100644 index 00000000..1e26065b --- /dev/null +++ b/src/main/java/ldbc/finbench/datagen/generation/events/InvestActivitesEvent.java @@ -0,0 +1,67 @@ +package ldbc.finbench.datagen.generation.events; + +import java.io.Serializable; +import java.util.List; +import java.util.Random; +import ldbc.finbench.datagen.entities.edges.CompanyInvestCompany; +import ldbc.finbench.datagen.entities.edges.PersonInvestCompany; +import ldbc.finbench.datagen.entities.nodes.Company; +import ldbc.finbench.datagen.entities.nodes.Person; +import ldbc.finbench.datagen.generation.DatagenParams; +import ldbc.finbench.datagen.util.RandomGeneratorFarm; + +public class InvestActivitesEvent implements Serializable { + private final RandomGeneratorFarm randomFarm; + + public InvestActivitesEvent() { + randomFarm = new RandomGeneratorFarm(); + } + + public void resetState(int seed) { + randomFarm.resetRandomGenerators(seed); + } + + public List investPartition(List personinvestors, List companyInvestors, + List targets) { + Random numPersonInvestorsRand = randomFarm.get(RandomGeneratorFarm.Aspect.NUMS_PERSON_INVEST); + Random choosePersonInvestorRand = randomFarm.get(RandomGeneratorFarm.Aspect.CHOOSE_PERSON_INVESTOR); + Random numCompanyInvestorsRand = randomFarm.get(RandomGeneratorFarm.Aspect.NUMS_COMPANY_INVEST); + Random chooseCompanyInvestorRand = randomFarm.get(RandomGeneratorFarm.Aspect.CHOOSE_COMPANY_INVESTOR); + for (Company target : targets) { + // Person investors + int numPersonInvestors = numPersonInvestorsRand.nextInt( + DatagenParams.maxInvestors - DatagenParams.minInvestors + 1 + ) + DatagenParams.minInvestors; + for (int i = 0; i < numPersonInvestors; i++) { + int index = choosePersonInvestorRand.nextInt(personinvestors.size()); + Person investor = personinvestors.get(index); + if (cannotInvest(investor, target)) { + continue; + } + PersonInvestCompany.createPersonInvestCompany(randomFarm, investor, target); + } + + // Company investors + int numCompanyInvestors = numCompanyInvestorsRand.nextInt( + DatagenParams.maxInvestors - DatagenParams.minInvestors + 1 + ) + DatagenParams.minInvestors; + for (int i = 0; i < numCompanyInvestors; i++) { + int index = chooseCompanyInvestorRand.nextInt(companyInvestors.size()); + Company investor = companyInvestors.get(index); + if (cannotInvest(investor, target)) { + continue; + } + CompanyInvestCompany.createCompanyInvestCompany(randomFarm, investor, target); + } + } + return targets; + } + + public boolean cannotInvest(Person investor, Company target) { + return target.hasInvestedBy(investor); + } + + public boolean cannotInvest(Company investor, Company target) { + return (investor == target) || investor.hasInvestedBy(target) || target.hasInvestedBy(investor); + } +} diff --git a/src/main/java/ldbc/finbench/datagen/generation/events/PersonInvestEvent.java b/src/main/java/ldbc/finbench/datagen/generation/events/PersonInvestEvent.java deleted file mode 100644 index 297a70f5..00000000 --- a/src/main/java/ldbc/finbench/datagen/generation/events/PersonInvestEvent.java +++ /dev/null @@ -1,48 +0,0 @@ -package ldbc.finbench.datagen.generation.events; - -import java.io.Serializable; -import java.util.List; -import java.util.Random; -import ldbc.finbench.datagen.entities.edges.PersonInvestCompany; -import ldbc.finbench.datagen.entities.nodes.Company; -import ldbc.finbench.datagen.entities.nodes.Person; -import ldbc.finbench.datagen.generation.DatagenParams; -import ldbc.finbench.datagen.util.RandomGeneratorFarm; - -public class PersonInvestEvent implements Serializable { - private final RandomGeneratorFarm randomFarm; - private final Random randIndex; - - public PersonInvestEvent() { - randomFarm = new RandomGeneratorFarm(); - randIndex = new Random(DatagenParams.defaultSeed); - } - - public void resetState(int seed) { - randomFarm.resetRandomGenerators(seed); - randIndex.setSeed(seed); - } - - public List personInvestPartition(List investors, List targets) { - Random numInvestorsRand = randomFarm.get(RandomGeneratorFarm.Aspect.NUMS_PERSON_INVEST); - Random chooseInvestorRand = randomFarm.get(RandomGeneratorFarm.Aspect.CHOOSE_PERSON_INVESTOR); - for (Company target : targets) { - int numInvestors = numInvestorsRand.nextInt( - DatagenParams.maxInvestors - DatagenParams.minInvestors + 1 - ) + DatagenParams.minInvestors; - for (int i = 0; i < numInvestors; i++) { - int index = chooseInvestorRand.nextInt(investors.size()); - Person investor = investors.get(index); - if (cannotInvest(investor, target)) { - continue; - } - PersonInvestCompany.createPersonInvestCompany(randomFarm, investor, target); - } - } - return targets; - } - - public boolean cannotInvest(Person investor, Company target) { - return target.hasInvestedBy(investor); - } -} diff --git a/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala b/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala index 2e0810b0..d4e93ea6 100644 --- a/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala +++ b/src/main/scala/ldbc/finbench/datagen/generation/generators/ActivityGenerator.scala @@ -82,12 +82,9 @@ class ActivityGenerator()(implicit spark: SparkSession) personRDD: RDD[Person], companyRDD: RDD[Company] ): RDD[Company] = { - val persons = spark.sparkContext.broadcast(personRDD.collect().toList) - val companies = spark.sparkContext.broadcast(companyRDD.collect().toList) - - val personInvestEvent = new PersonInvestEvent() - val companyInvestEvent = new CompanyInvestEvent() + val investActivitesEvent = new InvestActivitesEvent + val numPartitions = companyRDD.getNumPartitions companyRDD .sample( withReplacement = false, @@ -95,17 +92,30 @@ class ActivityGenerator()(implicit spark: SparkSession) sampleRandom.nextLong() ) .mapPartitionsWithIndex { (index, targets) => - personInvestEvent.resetState(index) - personInvestEvent - .personInvestPartition(persons.value.asJava, targets.toList.asJava) - .iterator() - .asScala - } - .mapPartitionsWithIndex { (index, targets) => - companyInvestEvent.resetState(index) - companyInvestEvent - .companyInvestPartition( - companies.value.asJava, + val persons = personRDD + .sample( + withReplacement = false, + 1.0 / numPartitions, + sampleRandom.nextLong() + ) + .collect() + .toList + .asJava + val companies = companyRDD + .sample( + withReplacement = false, + 1.0 / numPartitions, + sampleRandom.nextLong() + ) + .collect() + .toList + .asJava + + investActivitesEvent.resetState(index) + investActivitesEvent + .investPartition( + persons, + companies, targets.toList.asJava ) .iterator() @@ -118,23 +128,22 @@ class ActivityGenerator()(implicit spark: SparkSession) mediumRDD: RDD[Medium], accountRDD: RDD[Account] ): RDD[Medium] = { - val accountSampleList = spark.sparkContext.broadcast( - accountRDD + val signInEvent = new SignInEvent + val numPartitions = mediumRDD.getNumPartitions + mediumRDD.mapPartitionsWithIndex((index, mediums) => { + val accountSampleList = accountRDD .sample( withReplacement = false, - DatagenParams.accountSignedInFraction, + DatagenParams.accountSignedInFraction / numPartitions, sampleRandom.nextLong() ) .collect() .toList - ) - - val signInEvent = new SignInEvent - mediumRDD.mapPartitionsWithIndex((index, mediums) => { + .asJava signInEvent .signIn( mediums.toList.asJava, - accountSampleList.value.asJava, + accountSampleList, index ) .iterator()