Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement saving and retrieval of session tasks #1728

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions client/command/tasks/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ func Commands(con *console.SliverClient) []*cobra.Command {
Run: func(cmd *cobra.Command, args []string) {
TasksCmd(cmd, con, args)
},
GroupID: consts.SliverCoreHelpGroup,
Annotations: flags.RestrictTargets(consts.BeaconCmdsFilter),
GroupID: consts.SliverCoreHelpGroup,
}
flags.Bind("tasks", true, tasksCmd, func(f *pflag.FlagSet) {
f.IntP("timeout", "t", flags.DefaultTimeout, "grpc timeout in seconds")
Expand Down
2 changes: 1 addition & 1 deletion client/command/tasks/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import (

// TasksFetchCmd - Manage beacon tasks.
func TasksFetchCmd(cmd *cobra.Command, con *console.SliverClient, args []string) {
beacon := con.ActiveTarget.GetBeaconInteractive()
beacon := con.ActiveTarget.GetSessionOrBeaconInteractive()
if beacon == nil {
return
}
Expand Down
16 changes: 13 additions & 3 deletions client/command/tasks/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,22 @@ func SelectBeaconTask(tasks []*clientpb.BeaconTask) (*clientpb.BeaconTask, error
// BeaconTaskIDCompleter returns a structured list of tasks completions, grouped by state.
func BeaconTaskIDCompleter(con *console.SliverClient) carapace.Action {
callback := func(ctx carapace.Context) carapace.Action {
id := ""
beacon := con.ActiveTarget.GetBeacon()
if beacon == nil {
return carapace.ActionMessage("no active beacon")
if beacon != nil {
id = beacon.ID
} else {
session := con.ActiveTarget.GetSession()
if session != nil {
id = session.ID
}
}

beaconTasks, err := con.Rpc.GetBeaconTasks(context.Background(), &clientpb.Beacon{ID: beacon.ID})
if id == "" {
return carapace.ActionMessage("no active beacon or session")
}

beaconTasks, err := con.Rpc.GetBeaconTasks(context.Background(), &clientpb.Beacon{ID: id})
if err != nil {
return carapace.ActionMessage("Failed to fetch tasks: %s", err.Error())
}
Expand Down
2 changes: 1 addition & 1 deletion client/command/tasks/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (

// TasksCmd - Manage beacon tasks.
func TasksCmd(cmd *cobra.Command, con *console.SliverClient, args []string) {
beacon := con.ActiveTarget.GetBeaconInteractive()
beacon := con.ActiveTarget.GetSessionOrBeaconInteractive()
if beacon == nil {
return
}
Expand Down
12 changes: 12 additions & 0 deletions client/console/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,18 @@ func (s *ActiveTarget) GetBeaconInteractive() *clientpb.Beacon {
return s.beacon
}

// GetSessionOrBeaconInteractive
func (s *ActiveTarget) GetSessionOrBeaconInteractive() *clientpb.Beacon {
if s.beacon != nil {
return s.beacon
} else if s.session != nil {
return &clientpb.Beacon{ID: s.session.ID}
} else {
fmt.Printf(Warn + "Please select a beacon or session via `use`\n")
return nil
}
}

// GetBeacon - Same as GetBeacon() but doesn't print a warning.
func (s *ActiveTarget) GetBeacon() *clientpb.Beacon {
return s.beacon
Expand Down
6 changes: 1 addition & 5 deletions server/rpc/rpc-beacons.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ func (rpc *Server) RmBeacon(ctx context.Context, req *clientpb.Beacon) (*commonp

// GetBeaconTasks - Get a list of tasks for a specific beacon
func (rpc *Server) GetBeaconTasks(ctx context.Context, req *clientpb.Beacon) (*clientpb.BeaconTasks, error) {
beacon, err := db.BeaconByID(req.ID)
if err != nil {
return nil, ErrInvalidBeaconID
}
tasks, err := db.BeaconTasksByBeaconID(beacon.ID.String())
tasks, err := db.BeaconTasksByBeaconID(req.ID)
return &clientpb.BeaconTasks{Tasks: tasks}, err
}

Expand Down
80 changes: 50 additions & 30 deletions server/rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ import (
"strings"
"time"

consts "github.com/bishopfox/sliver/client/constants"
"github.com/bishopfox/sliver/client/version"
"github.com/bishopfox/sliver/protobuf/clientpb"
"github.com/bishopfox/sliver/protobuf/commonpb"
"github.com/bishopfox/sliver/protobuf/rpcpb"
"github.com/bishopfox/sliver/protobuf/sliverpb"
"github.com/bishopfox/sliver/server/core"
"github.com/bishopfox/sliver/server/db"
"github.com/bishopfox/sliver/server/db/models"
"github.com/bishopfox/sliver/server/log"
"github.com/gofrs/uuid"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -104,22 +107,49 @@ func (rpc *Server) GenericHandler(req GenericRequest, resp GenericResponse) erro
if request == nil {
return ErrMissingRequestField
}
reqData, err := proto.Marshal(req)
if err != nil {
return err
}

taskResponse := resp.GetResponse()
taskResponse.Async = request.Async
beacon := models.Beacon{}
if request.BeaconID != "" {
beacon.ID, err = uuid.FromString(request.BeaconID)
} else if request.SessionID != "" {
beacon.ID, err = uuid.FromString(request.SessionID)
}
task, err := beacon.Task(&sliverpb.Envelope{
Type: sliverpb.MsgNumber(req),
Data: reqData,
})
if err != nil {
rpcLog.Errorf("Database error: %s", err)
return ErrDatabaseFailure
}
parts := strings.Split(string(req.ProtoReflect().Descriptor().FullName().Name()), ".")
name := parts[len(parts)-1]
task.Description = name
err = db.Session().Save(task).Error
if err != nil {
rpcLog.Errorf("Database error: %s", err)
return ErrDatabaseFailure
}
rpcLog.Warningf("Task: %#v", task)

if request.Async {
err = rpc.asyncGenericHandler(req, resp)
err = rpc.asyncGenericHandler(req, resp, task)
return err
}

task.SentAt = time.Now().Unix()
// Sync request
session := core.Sessions.Get(request.SessionID)
if session == nil {
return ErrInvalidSessionID
}

reqData, err := proto.Marshal(req)
if err != nil {
return err
}

data, err := session.Request(sliverpb.MsgNumber(req), rpc.getTimeout(req), reqData)
if err != nil {
return err
Expand All @@ -128,11 +158,24 @@ func (rpc *Server) GenericHandler(req GenericRequest, resp GenericResponse) erro
if err != nil {
return err
}
task.State = models.COMPLETED
task.CompletedAt = time.Now().Unix()
task.Response = data

err = db.Session().Updates(task).Error
if err != nil {
rpcLog.Errorf("Error updating db task: %s", err)
}
eventData, _ := proto.Marshal(task.ToProtobuf(false))
core.EventBroker.Publish(core.Event{
EventType: consts.BeaconTaskResultEvent,
Data: eventData,
})
return rpc.getError(resp)
}

// asyncGenericHandler - Generic handler for async request/response's for beacon tasks
func (rpc *Server) asyncGenericHandler(req GenericRequest, resp GenericResponse) error {
func (rpc *Server) asyncGenericHandler(req GenericRequest, resp GenericResponse, task *models.BeaconTask) error {
// VERY VERBOSE
// rpcLog.Debugf("Async Generic Handler: %#v", req)
request := req.GetRequest()
Expand All @@ -146,32 +189,9 @@ func (rpc *Server) asyncGenericHandler(req GenericRequest, resp GenericResponse)
return ErrInvalidBeaconID
}

// Overwrite unused implant fields before re-serializing
request.SessionID = ""
request.BeaconID = ""
reqData, err := proto.Marshal(req)
if err != nil {
return err
}
taskResponse := resp.GetResponse()
taskResponse.Async = true
taskResponse.BeaconID = beacon.ID.String()
task, err := beacon.Task(&sliverpb.Envelope{
Type: sliverpb.MsgNumber(req),
Data: reqData,
})
if err != nil {
rpcLog.Errorf("Database error: %s", err)
return ErrDatabaseFailure
}
parts := strings.Split(string(req.ProtoReflect().Descriptor().FullName().Name()), ".")
name := parts[len(parts)-1]
task.Description = name
err = db.Session().Save(task).Error
if err != nil {
rpcLog.Errorf("Database error: %s", err)
return ErrDatabaseFailure
}
taskResponse.TaskID = task.ID.String()
rpcLog.Debugf("Successfully tasked beacon: %#v", taskResponse)
return nil
Expand Down
Loading