Skip to content

Commit

Permalink
feat filter memshell when godzilla
Browse files Browse the repository at this point in the history
  • Loading branch information
yoloyyh committed Dec 18, 2023
1 parent 0638779 commit d3c4d83
Show file tree
Hide file tree
Showing 4 changed files with 353 additions and 190 deletions.
219 changes: 41 additions & 178 deletions rasp/jvm/JVMProbe/src/main/java/com/security/smith/SmithProbe.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.lmax.disruptor.EventHandler;
import com.lmax.disruptor.InsufficientCapacityException;
import com.lmax.disruptor.RingBuffer;

import com.lmax.disruptor.dsl.Disruptor;
import com.lmax.disruptor.util.DaemonThreadFactory;
import com.security.smith.asm.SmithClassVisitor;
import com.security.smith.asm.SmithClassWriter;
import com.security.smith.client.message.*;
import com.security.smith.common.Reflection;

import com.security.smith.common.SmithHandler;
import com.security.smith.log.SmithLogger;
import com.security.smith.module.Patcher;
Expand Down Expand Up @@ -41,33 +40,32 @@
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerArray;

import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;


public class SmithProbe implements ClassFileTransformer, MessageHandler, EventHandler<Trace> {
private static final SmithProbe ourInstance = new SmithProbe();
private static final int TRACE_BUFFER_SIZE = 1024;
private static final int CLASS_MAX_ID = 30;
private static final int METHOD_MAX_ID = 20;
private static final int DEFAULT_QUOTA = 12000;

private Boolean disable;
private Boolean scanswitch;
private Instrumentation inst;
private final Client client;
private final Heartbeat heartbeat;
private final Disruptor<Trace> disruptor;

private final Map<String, SmithClass> smithClasses;
private final Map<String, Patcher> patchers;
private final Map<Pair<Integer, Integer>, Filter> filters;
private final Map<Pair<Integer, Integer>, Block> blocks;
private final Map<Pair<Integer, Integer>, Integer> limits;
private final AtomicIntegerArray[] quotas;
private final Disruptor<Trace> disruptor;

private final Rule_Mgr rulemgr;
private final Rule_Config ruleconfig;
private SmithProbeProxy smithProxy;

enum Action {
STOP,
Expand All @@ -90,9 +88,8 @@ public SmithProbe() {

heartbeat = new Heartbeat();
client = new Client(this);

disruptor = new Disruptor<>(Trace::new, TRACE_BUFFER_SIZE, DaemonThreadFactory.INSTANCE);
quotas = Stream.generate(() -> new AtomicIntegerArray(METHOD_MAX_ID)).limit(CLASS_MAX_ID).toArray(AtomicIntegerArray[]::new);

rulemgr = new Rule_Mgr();
ruleconfig = new Rule_Config(rulemgr);
}
Expand Down Expand Up @@ -121,13 +118,11 @@ public void start() {

Thread clientThread = new Thread(client::start);

clientThread.setDaemon(true);
clientThread.start();



disruptor.handleEventsWith(this);
disruptor.start();

clientThread.setDaemon(true);
clientThread.start();

new Timer(true).schedule(
new TimerTask() {
Expand All @@ -138,17 +133,19 @@ public void run() {
},
TimeUnit.MINUTES.toMillis(1)
);

smithProxy = SmithProbeProxy.getInstance();
new Timer(true).schedule(
new TimerTask() {
@Override
public void run() {
onTimer();
smithProxy.onTimer();
}
},
0,
TimeUnit.MINUTES.toMillis(1)
);
SmithProbeProxy.getInstance().setClient(client);
SmithProbeProxy.getInstance().setDisruptor(disruptor);
}

private void reloadClasses() {
Expand All @@ -168,147 +165,6 @@ private void reloadClasses(Collection<String> classes) {
}
}

public void detect(int classID, int methodID, Object[] args) {
Block block = blocks.get(new ImmutablePair<>(classID, methodID));

if (block == null)
return;

if (Arrays.stream(block.getRules()).anyMatch(rule -> {
if (rule.getIndex() >= args.length)
return false;

return Pattern.compile(rule.getRegex()).matcher(args[rule.getIndex()].toString()).find();
})) {
throw new SecurityException("API blocked by RASP");
}
}

public void checkAddServletPre(int classID, int methodID, Object[] args) {
SmithLogger.logger.info("checkAddServlet post_hook call success");
if (args.length < 3) {
return;
}
try {
Object context = args[0];
String name = (String)args[2];
if (context != null) {
Class<?>[] argTypes = new Class[]{String.class};

Object wrapper = Reflection.invokeMethod(context, "findChild", argTypes, name);

if(wrapper != null) {
Class<?>[] emptyArgTypes = new Class[]{};

Object servlet = Reflection.invokeMethod(wrapper, "getServlet", emptyArgTypes);
if(servlet != null) {
ClassFilter classFilter = new ClassFilter();
//classFilter.setClassName(name);
SmithHandler.queryClassFilter(servlet.getClass(), classFilter);
classFilter.setTransId();
classFilter.setRuleId(-1);
classFilter.setStackTrace(Thread.currentThread().getStackTrace());
client.write(Operate.SCANCLASS, classFilter);
SmithLogger.logger.info("send metadata: " + classFilter.toString());
sendClass(servlet.getClass(), classFilter.getTransId());
}
}
}

} catch (Exception e) {
SmithLogger.exception(e);
}
}

public void checkAddFilterPre(int classID, int methodID, Object[] args) {
SmithLogger.logger.info("checkAddFilter post_hook call success");
if (args.length < 2) {
return;
}
try {
Object filterDef = args[1];
Object filter = null;
if (filterDef != null) {
Class<?>[] emptyArgTypes = new Class[]{};
filter = Reflection.invokeMethod(filterDef, "getFilter", emptyArgTypes);
if (filter != null) {
ClassFilter classFilter = new ClassFilter();
SmithHandler.queryClassFilter(filter.getClass(), classFilter);
classFilter.setTransId();
classFilter.setRuleId(-1);
classFilter.setStackTrace(Thread.currentThread().getStackTrace());
client.write(Operate.SCANCLASS, classFilter);
SmithLogger.logger.info("send metadata: " + classFilter.toString());
sendClass(filter.getClass(), classFilter.getTransId());
}

}

} catch (Exception e) {
SmithLogger.exception(e);
}
}

public void checkAddValvePre(int classID, int methodID, Object[] args) {
if (args.length < 2) {
return;
}
try {
Object valve = args[1];
if (valve != null) {
ClassFilter classFilter = new ClassFilter();
SmithHandler.queryClassFilter(valve.getClass(), classFilter);
classFilter.setTransId();
classFilter.setRuleId(-1);
classFilter.setStackTrace(Thread.currentThread().getStackTrace());
client.write(Operate.SCANCLASS, classFilter);
SmithLogger.logger.info("send metadata: " + classFilter.toString());
sendClass(valve.getClass(), classFilter.getTransId());
}

} catch (Exception e) {
SmithLogger.exception(e);
}
}

public void checkAddListenerPre(int classID, int methodID, Object[] args) {
checkAddValvePre(classID, methodID, args);
}

public void trace(int classID, int methodID, Object[] args, Object ret, boolean blocked) {
if (classID >= CLASS_MAX_ID || methodID >= METHOD_MAX_ID)
return;

while (true) {
int quota = quotas[classID].get(methodID);

if (quota <= 0)
return;

if (quotas[classID].compareAndSet(methodID, quota, quota - 1))
break;
}

RingBuffer<Trace> ringBuffer = disruptor.getRingBuffer();

try {
long sequence = ringBuffer.tryNext();

Trace trace = ringBuffer.get(sequence);

trace.setClassID(classID);
trace.setMethodID(methodID);
trace.setBlocked(blocked);
trace.setRet(ret);
trace.setArgs(args);
trace.setStackTrace(Thread.currentThread().getStackTrace());

ringBuffer.publish(sequence);
} catch (InsufficientCapacityException ignored) {

}
}

@Override
public void onEvent(Trace trace, long sequence, boolean endOfBatch) {
Filter filter = filters.get(new ImmutablePair<>(trace.getClassID(), trace.getMethodID()));
Expand Down Expand Up @@ -339,23 +195,6 @@ public void onEvent(Trace trace, long sequence, boolean endOfBatch) {
client.write(Operate.TRACE, trace);
}

private void onTimer() {
client.write(Operate.HEARTBEAT, heartbeat);

for (int i = 0; i < CLASS_MAX_ID; i++) {
for (int j = 0; j < METHOD_MAX_ID; j++) {
Integer quota = limits.get(new ImmutablePair<>(i, j));

if (quota == null) {
quotas[i].set(j, DEFAULT_QUOTA);
continue;
}

quotas[i].set(j, quota);
}
}
}

public void printClassfilter(ClassFilter data) {
/*
SmithLogger.logger.info("className:" + data.getClassName());
Expand Down Expand Up @@ -774,7 +613,7 @@ public void onScanAllClass() {
/*
* send class file
*/
private void sendClass(Class<?> clazz, String transId) {
public void sendClass(Class<?> clazz, String transId) {
if (clazz == null || transId == null) {
return;
}
Expand Down Expand Up @@ -828,4 +667,28 @@ private void sendByte(byte[] data, String transId) {
//}
}

public Heartbeat getHeartbeat() {
return heartbeat;
}

public Map<Pair<Integer, Integer>, Integer> getLimits() {
return limits;
}

public Map<Pair<Integer, Integer>, Block> GetBlocks() {
return blocks;
}

public Map<Pair<Integer, Integer>, Filter> GetFiltes() {
return filters;
}

public Client getClient() {
return client;
}

public Disruptor<Trace> getDisruptor() {
return disruptor;
}

}
Loading

0 comments on commit d3c4d83

Please sign in to comment.