Skip to content

Commit

Permalink
support text format save
Browse files Browse the repository at this point in the history
  • Loading branch information
allwefantasy committed Jul 27, 2023
1 parent 13a0a9c commit 1759c5c
Showing 1 changed file with 39 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package streaming.core.datasource.impl

import org.apache.spark.ml.param.Param
import org.apache.spark.sql.{DataFrame, DataFrameReader, DataFrameWriter, Row, functions => F}
import org.apache.spark.sql.mlsql.session.MLSQLException
import org.apache.spark.sql.{DataFrame, DataFrameReader, DataFrameWriter, Row, SaveMode, functions => F}
import streaming.core.datasource._
import streaming.dsl.ScriptSQLExec
import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams}
import tech.mlsql.tool.HDFSOperatorV2

/**
* 2019-03-20 WilliamZhu(allwefantasy@gmail.com)
Expand All @@ -26,7 +28,39 @@ class MLSQLText(override val uid: String) extends MLSQLBaseFileSource with WowPa
}

override def save(writer: DataFrameWriter[Row], config: DataSinkConfig): Unit = {
throw new RuntimeException("text not support save")
val context = ScriptSQLExec.contextGetOrForTest()
val baseDir = resourceRealPath(context.execListener, Option(context.owner), config.path)

if (HDFSOperatorV2.fileExists(baseDir)) {
if (config.mode == SaveMode.Overwrite) {
HDFSOperatorV2.deleteDir(baseDir)
}
if (config.mode == SaveMode.ErrorIfExists) {
throw new MLSQLException(s"${baseDir} is exists")
}
}

config.config.get(contentColumn.name).map { m =>
set(contentColumn, m)
}.getOrElse {
throw new MLSQLException(s"${contentColumn.name} is required")
}

config.config.get(fileName.name).map { m =>
set(fileName, m)
}.getOrElse {
throw new MLSQLException(s"${fileName.name} is required")
}

val _fileName = $(fileName)
val _contentColumn = $(contentColumn)

val saveContent = (fileName: String, buffer: String) => {
HDFSOperatorV2.saveFile(baseDir, fileName, List(("",buffer)).toIterator)
baseDir + "/" + fileName
}

config.df.get.rdd.map(r => saveContent(r.getAs[String](_fileName), r.getAs[String](_contentColumn))).count()
}

override def sourceInfo(config: DataAuthConfig): SourceInfo = {
Expand All @@ -50,4 +84,7 @@ class MLSQLText(override val uid: String) extends MLSQLBaseFileSource with WowPa

final val wholetext: Param[Boolean] = new Param[Boolean](this, "wholetext", "`wholetext` (default `false`): If true, read a file as a single row and not split by \"\\n\".")
final val lineSep: Param[String] = new Param[String](this, "lineSep", "`(default covers all `\\r`, `\\r\\n` and `\\n`): defines the line separator\n * that should be used for parsing.")

final val contentColumn: Param[String] = new Param[String](this, "imageColumn", "for save mode")
final val fileName: Param[String] = new Param[String](this, "fileName", "for save mode")
}

0 comments on commit 1759c5c

Please sign in to comment.