Skip to content

Commit

Permalink
no need for another tmpfilemanager impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Dec 17, 2024
1 parent 074016b commit 40a7005
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 125 deletions.
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ trait TempFileManager extends AutoCloseable {
def newTmpPath(tmpdir: String, prefix: String, extension: String = null): String
}

class OwningTempFileManager(fs: FS) extends TempFileManager {
class OwningTempFileManager(val fs: FS) extends TempFileManager {
private[this] val tmpPaths = mutable.ArrayBuffer[String]()

override def newTmpPath(tmpdir: String, prefix: String, extension: String): String = {
Expand Down
174 changes: 60 additions & 114 deletions hail/src/main/scala/is/hail/backend/api/Py4JBackendApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ import is.hail.backend._
import is.hail.backend.caching.BlockMatrixCache
import is.hail.backend.spark.SparkBackend
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
import is.hail.expr.ir.{
BaseIR, BlockMatrixIR, CodeCacheKey, CompiledFunction, EncodedLiteral, GetFieldByIdx, IRParser,
Interpret, MatrixIR, MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue,
}
import is.hail.expr.ir._
import is.hail.expr.ir.IRParser.parseType
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
import is.hail.expr.ir.functions.IRFunctionRegistry
Expand All @@ -31,77 +28,55 @@ import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConv
import java.io.Closeable
import java.net.InetSocketAddress
import java.util
import java.util.concurrent._
import java.util.concurrent.locks.ReentrantReadWriteLock

import com.google.api.client.http.HttpStatusCodes
import com.sun.net.httpserver.{HttpExchange, HttpServer}
import org.apache.hadoop
import org.apache.hadoop.conf.Configuration
import org.apache.spark.sql.DataFrame
import org.json4s
import org.json4s._
import org.json4s.jackson.{JsonMethods, Serialization}
import sourcecode.Enclosing

final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandling {

private[this] val rwlock = new ReentrantReadWriteLock()
private[this] def weak = rwlock.readLock()
private[this] def mutex = rwlock.writeLock()

private[this] val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
private[this] val hcl = new HailClassLoader(getClass.getClassLoader)
private[this] val references = mutable.Map(ReferenceGenome.builtinReferences().toSeq: _*)
private[this] val bmCache = new BlockMatrixCache()
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIr = mutable.Map[Int, BaseIR]()
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)

private[this] var irID: Int = 0

private[this] var tmpdir: String = _
private[this] var localTmpdir: String = _

private[this] object tmpFileManager extends TempFileManager {
private[this] var fs = newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
private[this] var manager = new OwningTempFileManager(fs)

def setFs(fs: FS): Unit = {
close()
this.fs = fs
manager = new OwningTempFileManager(fs)
}

def getFs: FS =
fs

override def newTmpPath(tmpdir: String, prefix: String, extension: String): String =
manager.newTmpPath(tmpdir, prefix, extension)

override def close(): Unit =
manager.close()
}
private[this] var tmpFileManager = new OwningTempFileManager(
newFs(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
)

def pyFs: FS =
using(weak.acquire())(_ => tmpFileManager.getFs)
synchronized(tmpFileManager.fs)

def pyGetFlag(name: String): String =
using(weak.acquire())(_ => flags.get(name))
synchronized(flags.get(name))

def pySetFlag(name: String, value: String): Unit =
using(weak.acquire())(_ => flags.set(name, value))
synchronized(flags.set(name, value))

def pyAvailableFlags: java.util.ArrayList[String] =
flags.available

def pySetRemoteTmp(tmp: String): Unit =
using(weak.acquire())(_ => tmpdir = tmp)
synchronized { tmpdir = tmp }

def pySetLocalTmp(tmp: String): Unit =
using(weak.acquire())(_ => localTmpdir = tmp)
synchronized { localTmpdir = tmp }

def pySetGcsRequesterPaysConfig(project: String, buckets: util.List[String]): Unit =
using(weak.acquire()) { _ =>
synchronized {
tmpFileManager.close()

val cloudfsConf = CloudStorageFSConfig.fromFlagsAndEnv(None, flags)

val rpConfig: Option[RequesterPaysConfig] =
Expand All @@ -126,20 +101,20 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
)
)

tmpFileManager.setFs(fs)
tmpFileManager = new OwningTempFileManager(fs)
}

def pyRemoveJavaIR(id: Int): Unit =
using(weak.acquire())(_ => persistedIr.remove(id))
synchronized(persistedIr.remove(id))

def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
using(weak.acquire()) { _ =>
val seq = IndexedFastaSequenceFile(tmpFileManager.getFs, fastaFile, indexFile)
synchronized {
val seq = IndexedFastaSequenceFile(tmpFileManager.fs, fastaFile, indexFile)
references(name).addSequence(seq)
}

def pyRemoveSequence(name: String): Unit =
using(weak.acquire())(_ => references(name).removeSequence())
synchronized(references(name).removeSequence())

def pyExportBlockMatrix(
pathIn: String,
Expand Down Expand Up @@ -240,46 +215,44 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin

def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] =
withExecuteContext() { ctx =>
implicit val fmts: Formats = DefaultFormats
log.info("pyReadMultipleMatrixTables: got query")
val kvs = JsonMethods.parse(jsonQuery) match {
case json4s.JObject(values) => values.toMap
}

val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map {
case json4s.JString(s) => s
}

val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s)
val kvs = JsonMethods.parse(jsonQuery).extract[Map[String, JValue]]
val paths = kvs("paths").extract[IndexedSeq[String]]
val intervalPointType = parseType(kvs("intervalPointType").extract[String])
val intervalObjects =
JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType)))
.asInstanceOf[IndexedSeq[Interval]]

val opts = NativeReaderOptions(intervalObjects, intervalPointType)
val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p =>
log.info(s"creating MatrixRead node for $p")
val mnr = MatrixNativeReader(ctx.fs, p, Some(opts))
MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR
}
val matrixReaders: util.List[MatrixIR] =
paths.map { p =>
log.info(s"creating MatrixRead node for $p")
val mnr = MatrixNativeReader(ctx.fs, p, Some(opts))
MatrixRead(mnr.fullMatrixTypeWithoutUIDs, false, false, mnr): MatrixIR
}.asJava

log.info("pyReadMultipleMatrixTables: returning N matrix tables")
matrixReaders.asJava
matrixReaders
}._1

def pyAddReference(jsonConfig: String): Unit =
using(weak.acquire())(_ => addReference(ReferenceGenome.fromJSON(jsonConfig)))
synchronized(addReference(ReferenceGenome.fromJSON(jsonConfig)))

def pyRemoveReference(name: String): Unit =
using(weak.acquire())(_ => removeReference(name))
synchronized(removeReference(name))

def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit =
using(weak.acquire()) { _ =>
synchronized {
references(name).addLiftover(
references(destRGName),
LiftOver(tmpFileManager.getFs, chainFile),
LiftOver(tmpFileManager.fs, chainFile),
)
}

def pyRemoveLiftover(name: String, destRGName: String): Unit =
using(weak.acquire())(_ => references(name).removeLiftover(destRGName))
synchronized(references(name).removeLiftover(destRGName))

def parse_blockmatrix_ir(s: String): BlockMatrixIR =
withExecuteContext(selfContainedExecution = false) { ctx =>
Expand All @@ -289,6 +262,15 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
private[this] def addReference(rg: ReferenceGenome): Unit =
ReferenceGenome.addFatalOnCollision(references, FastSeq(rg))

override def close(): Unit =
synchronized {
bmCache.close()
codeCache.clear()
persistedIr.clear()
coercerCache.clear()
backend.close()
}

private[this] def removeReference(name: String): Unit =
references -= name

Expand All @@ -298,13 +280,13 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
f: ExecuteContext => T
)(implicit E: Enclosing
): (T, Timings) =
using(mutex.acquire()) { _ =>
synchronized {
ExecutionTimer.time { timer =>
ExecuteContext.scoped(
tmpdir = tmpdir,
localTmpdir = localTmpdir,
backend = backend,
fs = tmpFileManager.getFs,
fs = tmpFileManager.fs,
timer = timer,
tempFileManager =
if (selfContainedExecution) null
Expand Down Expand Up @@ -355,23 +337,6 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
id
}

override def close(): Unit =
using(weak.acquire()) { _ =>
bmCache.close()
codeCache.clear()
persistedIr.clear()
coercerCache.clear()
backend.close()

if (backend.isInstanceOf[SparkBackend]) {
// Hadoop does not honor the hadoop configuration as a component of the cache key for file
// systems, so we blow away the cache so that a new configuration can successfully take
// effect.
// https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443
hadoop.fs.FileSystem.closeAll()
}
}

def pyHttpServer: HttpLikeBackendRpc[HttpExchange] with Closeable =
new HttpLikeBackendRpc[HttpExchange] with Closeable {
implicit object Handler extends Routing with Write[HttpExchange] with Context[HttpExchange] {
Expand Down Expand Up @@ -425,41 +390,22 @@ final class Py4JBackendApi(backend: Backend) extends Closeable with ErrorHandlin
// 0 => let the OS pick an available port
private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10)

private[this] val thread = {
// This HTTP server *must not* start non-daemon threads because such threads keep the JVM
// alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest
/* when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of
* the */
// JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel
// explicitly regardless of the JVM). It *does* manifest when submitting jobs with
//
// gcloud dataproc submit ...
//
// or
//
// spark-submit
//
// setExecutor(null) ensures the server creates no new threads:
//
/* > If this method is not called (before start()) or if it is called with a null Executor,
* then */
/* > a default implementation is used, which uses the thread which was created by the
* start() */
// > method.
//
// Source:
/* https://docs.oracle.com/en/java/javase/11/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html#setExecutor(java.util.concurrent.Executor) */
//
httpServer.createContext("/", runRpc(_: HttpExchange))
httpServer.setExecutor(null)
val t = Executors.defaultThreadFactory().newThread(() => httpServer.start())
t.setDaemon(true)
t
}

@nowarn def port: Int = httpServer.getAddress.getPort
override def close(): Unit = httpServer.stop(10)

thread.start()
// This HTTP server *must not* start non-daemon threads because such threads keep the JVM
// alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest
// when you run jobs in a local pyspark (because you'll Ctrl-C out of Python regardless of
// the JVM's state) nor does it manifest in a Notebook (again, you'll kill the Notebook kernel
// explicitly regardless of the JVM). It *does* manifest when submitting jobs with
//
// gcloud dataproc submit ...
//
// or
//
// spark-submit
httpServer.createContext("/", runRpc(_: HttpExchange))
httpServer.setExecutor(null) // ensures the server creates no new threads
httpServer.start()
}
}
19 changes: 15 additions & 4 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.control.NonFatal
import java.io.PrintWriter

import com.fasterxml.jackson.core.StreamReadConstraints
import org.apache.hadoop
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -275,8 +276,7 @@ object SparkBackend {

def stop(): Unit = synchronized {
if (theSparkBackend != null) {
if (theSparkBackend.sparkSession.isEvaluated) theSparkBackend.sparkSession.close()
theSparkBackend.sc.stop()
theSparkBackend.close()
theSparkBackend = null
}
}
Expand Down Expand Up @@ -356,8 +356,11 @@ class SparkBackend(val sc: SparkContext) extends Backend {
)
}
} catch {
case e: ExecutionException => failure = failure.orElse(Some(e.getCause))
case NonFatal(t) => failure = failure.orElse(Some(t))
case e: ExecutionException => failure = failure.orElse(Some(e.getCause))
case _: InterruptedException =>
sc.cancelAllJobs()
Thread.currentThread().interrupt()
}

(failure, buffer.sortBy(_._2))
Expand All @@ -368,7 +371,15 @@ class SparkBackend(val sc: SparkContext) extends Backend {
override def asSpark(implicit E: Enclosing): SparkBackend = this

def close(): Unit =
SparkBackend.stop()
synchronized {
if (sparkSession.isEvaluated) sparkSession.close()
sc.stop()
// Hadoop does not honor the hadoop configuration as a component of the cache key for file
// systems, so we blow away the cache so that a new configuration can successfully take
// effect.
// https://github.com/hail-is/hail/pull/12133#issuecomment-1241322443
hadoop.fs.FileSystem.closeAll()
}

def startProgressBar(): Unit =
ProgressBarBuilder.build(sc)
Expand Down
5 changes: 4 additions & 1 deletion hail/src/main/scala/is/hail/utils/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1052,14 +1052,17 @@ package object utils
val buffer = new mutable.ArrayBuffer[(A, Int)](tasks.length)
val completer = new ExecutorCompletionService[(A, Int)](executor)

tasks.foreach { case (t, k) => completer.submit(() => t() -> k) }
val futures = tasks.map { case (t, k) => completer.submit(() => t() -> k) }
tasks.foreach { _ =>
try buffer += completer.take().get()
catch {
case e: ExecutionException =>
err = accum(err, e.getCause)
case NonFatal(ex) =>
err = accum(err, ex)
case _: InterruptedException =>
futures.foreach(_.cancel(true))
Thread.currentThread().interrupt()
}
}

Expand Down
5 changes: 0 additions & 5 deletions hail/src/main/scala/is/hail/utils/richUtils/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ import scala.reflect.ClassTag
import scala.util.matching.Regex

import java.io.InputStream
import java.util.concurrent.locks.Lock

import breeze.linalg.DenseMatrix
import org.apache.hadoop.util.AutoCloseableLock
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -146,7 +144,4 @@ trait Implicits {

implicit def valueToRichCodeIterator[T](it: Value[Iterator[T]]): RichCodeIterator[T] =
new RichCodeIterator[T](it)

implicit def lockToAutoClosableLock(l: Lock): AutoCloseableLock =
new AutoCloseableLock(l)
}

0 comments on commit 40a7005

Please sign in to comment.