Skip to content

Commit

Permalink
avoid collect
Browse files Browse the repository at this point in the history
  • Loading branch information
qishipengqsp committed Sep 19, 2024
1 parent cfa327a commit 4a30958
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 119 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package ldbc.finbench.datagen.generation.events;

import java.io.Serializable;
import java.util.List;
import java.util.Random;
import ldbc.finbench.datagen.entities.edges.CompanyInvestCompany;
import ldbc.finbench.datagen.entities.edges.PersonInvestCompany;
import ldbc.finbench.datagen.entities.nodes.Company;
import ldbc.finbench.datagen.entities.nodes.Person;
import ldbc.finbench.datagen.generation.DatagenParams;
import ldbc.finbench.datagen.util.RandomGeneratorFarm;

public class InvestActivitesEvent implements Serializable {
private final RandomGeneratorFarm randomFarm;

public InvestActivitesEvent() {
randomFarm = new RandomGeneratorFarm();
}

public void resetState(int seed) {
randomFarm.resetRandomGenerators(seed);
}

public List<Company> investPartition(List<Person> personinvestors, List<Company> companyInvestors,
List<Company> targets) {
Random numPersonInvestorsRand = randomFarm.get(RandomGeneratorFarm.Aspect.NUMS_PERSON_INVEST);
Random choosePersonInvestorRand = randomFarm.get(RandomGeneratorFarm.Aspect.CHOOSE_PERSON_INVESTOR);
Random numCompanyInvestorsRand = randomFarm.get(RandomGeneratorFarm.Aspect.NUMS_COMPANY_INVEST);
Random chooseCompanyInvestorRand = randomFarm.get(RandomGeneratorFarm.Aspect.CHOOSE_COMPANY_INVESTOR);
for (Company target : targets) {
// Person investors
int numPersonInvestors = numPersonInvestorsRand.nextInt(
DatagenParams.maxInvestors - DatagenParams.minInvestors + 1
) + DatagenParams.minInvestors;
for (int i = 0; i < numPersonInvestors; i++) {
int index = choosePersonInvestorRand.nextInt(personinvestors.size());
Person investor = personinvestors.get(index);
if (cannotInvest(investor, target)) {
continue;
}
PersonInvestCompany.createPersonInvestCompany(randomFarm, investor, target);
}

// Company investors
int numCompanyInvestors = numCompanyInvestorsRand.nextInt(
DatagenParams.maxInvestors - DatagenParams.minInvestors + 1
) + DatagenParams.minInvestors;
for (int i = 0; i < numCompanyInvestors; i++) {
int index = chooseCompanyInvestorRand.nextInt(companyInvestors.size());
Company investor = companyInvestors.get(index);
if (cannotInvest(investor, target)) {
continue;
}
CompanyInvestCompany.createCompanyInvestCompany(randomFarm, investor, target);
}
}
return targets;
}

public boolean cannotInvest(Person investor, Company target) {
return target.hasInvestedBy(investor);
}

public boolean cannotInvest(Company investor, Company target) {
return (investor == target) || investor.hasInvestedBy(target) || target.hasInvestedBy(investor);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,40 @@ class ActivityGenerator()(implicit spark: SparkSession)
personRDD: RDD[Person],
companyRDD: RDD[Company]
): RDD[Company] = {
val persons = spark.sparkContext.broadcast(personRDD.collect().toList)
val companies = spark.sparkContext.broadcast(companyRDD.collect().toList)

val personInvestEvent = new PersonInvestEvent()
val companyInvestEvent = new CompanyInvestEvent()
val investActivitesEvent = new InvestActivitesEvent

val numPartitions = companyRDD.getNumPartitions
companyRDD
.sample(
withReplacement = false,
DatagenParams.companyInvestedFraction,
sampleRandom.nextLong()
)
.mapPartitionsWithIndex { (index, targets) =>
personInvestEvent.resetState(index)
personInvestEvent
.personInvestPartition(persons.value.asJava, targets.toList.asJava)
.iterator()
.asScala
}
.mapPartitionsWithIndex { (index, targets) =>
companyInvestEvent.resetState(index)
companyInvestEvent
.companyInvestPartition(
companies.value.asJava,
val persons = personRDD
.sample(
withReplacement = false,
1.0 / numPartitions,
sampleRandom.nextLong()
)
.collect()
.toList
.asJava
val companies = companyRDD
.sample(
withReplacement = false,
1.0 / numPartitions,
sampleRandom.nextLong()
)
.collect()
.toList
.asJava

investActivitesEvent.resetState(index)
investActivitesEvent
.investPartition(
persons,
companies,
targets.toList.asJava
)
.iterator()
Expand All @@ -118,23 +128,22 @@ class ActivityGenerator()(implicit spark: SparkSession)
mediumRDD: RDD[Medium],
accountRDD: RDD[Account]
): RDD[Medium] = {
val accountSampleList = spark.sparkContext.broadcast(
accountRDD
val signInEvent = new SignInEvent
val numPartitions = mediumRDD.getNumPartitions
mediumRDD.mapPartitionsWithIndex((index, mediums) => {
val accountSampleList = accountRDD
.sample(
withReplacement = false,
DatagenParams.accountSignedInFraction,
DatagenParams.accountSignedInFraction / numPartitions,
sampleRandom.nextLong()
)
.collect()
.toList
)

val signInEvent = new SignInEvent
mediumRDD.mapPartitionsWithIndex((index, mediums) => {
.asJava
signInEvent
.signIn(
mediums.toList.asJava,
accountSampleList.value.asJava,
accountSampleList,
index
)
.iterator()
Expand Down

0 comments on commit 4a30958

Please sign in to comment.