From 9db30ac11b58c77564be19ce68347bdf56dc6bbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=80=97=E5=AD=90?= Date: Wed, 23 Oct 2024 22:51:18 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=8A=A0=E5=9B=9E=E6=97=A7=E7=89=88?= =?UTF-8?q?=E7=9A=84arg=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/shell/exec.go | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/pkg/shell/exec.go b/pkg/shell/exec.go index 8353e1f9..b4e9d910 100644 --- a/pkg/shell/exec.go +++ b/pkg/shell/exec.go @@ -8,15 +8,19 @@ import ( "io" "os" "os/exec" + "slices" "strings" "time" ) // Execf 执行 shell 命令 func Execf(shell string, args ...any) (string, error) { - var cmd *exec.Cmd + if !preCheckArg(args) { + return "", errors.New("command contains illegal characters") + } + _ = os.Setenv("LC_ALL", "C") - cmd = exec.Command("bash", "-c", fmt.Sprintf(shell, args...)) + cmd := exec.Command("bash", "-c", fmt.Sprintf(shell, args...)) var stdout, stderr bytes.Buffer cmd.Stdout = &stdout @@ -32,9 +36,12 @@ func Execf(shell string, args ...any) (string, error) { // ExecfAsync 异步执行 shell 命令 func ExecfAsync(shell string, args ...any) error { - var cmd *exec.Cmd + if !preCheckArg(args) { + return errors.New("command contains illegal characters") + } + _ = os.Setenv("LC_ALL", "C") - cmd = exec.Command("bash", "-c", fmt.Sprintf(shell, args...)) + cmd := exec.Command("bash", "-c", fmt.Sprintf(shell, args...)) err := cmd.Start() if err != nil { @@ -52,9 +59,12 @@ func ExecfAsync(shell string, args ...any) error { // ExecfWithTimeout 执行 shell 命令并设置超时时间 func ExecfWithTimeout(timeout time.Duration, shell string, args ...any) (string, error) { - var cmd *exec.Cmd + if !preCheckArg(args) { + return "", errors.New("command contains illegal characters") + } + _ = os.Setenv("LC_ALL", "C") - cmd = exec.Command("bash", "-c", fmt.Sprintf(shell, args...)) + cmd := exec.Command("bash", "-c", fmt.Sprintf(shell, args...)) var stdout, stderr bytes.Buffer cmd.Stdout = &stdout @@ -85,6 +95,10 @@ func ExecfWithTimeout(timeout time.Duration, shell string, args ...any) (string, // ExecfWithOutput 执行 shell 命令并输出到终端 func ExecfWithOutput(shell string, args ...any) error { + if !preCheckArg(args) { + return errors.New("command contains illegal characters") + } + _ = os.Setenv("LC_ALL", "C") cmd := exec.Command("bash", "-c", fmt.Sprintf(shell, args...)) cmd.Stdout = os.Stdout @@ -95,8 +109,11 @@ func ExecfWithOutput(shell string, args ...any) error { // ExecfWithPipe 执行 shell 命令并返回管道 func ExecfWithPipe(ctx context.Context, shell string, args ...any) (out io.ReadCloser, err error) { - _ = os.Setenv("LC_ALL", "C") + if !preCheckArg(args) { + return nil, errors.New("command contains illegal characters") + } + _ = os.Setenv("LC_ALL", "C") cmd := exec.CommandContext(ctx, "bash", "-c", fmt.Sprintf(shell, args...)) out, err = cmd.StdoutPipe() @@ -108,3 +125,14 @@ func ExecfWithPipe(ctx context.Context, shell string, args ...any) (out io.ReadC err = cmd.Start() return } + +func preCheckArg(args []any) bool { + illegals := []any{`&`, `|`, `;`, `$`, `'`, `"`, "`", `(`, `)`, "\n", "\r", `>`, `<`} + for arg := range slices.Values(args) { + if slices.Contains(illegals, arg) { + return false + } + } + + return true +}