package encryption_utils
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"strings"
)
type AES struct {
key []byte
iv []byte
}
func NewAES() *AES {
key := make([]byte, 32)
iv := make([]byte, aes.BlockSize)
if _, err := rand.Read(key); err != nil {
panic(err)
}
if _, err := rand.Read(iv); err != nil {
panic(err)
}
return &AES{
key: key,
iv: iv,
}
}
type EncryptedResponse struct {
EncryptedData string `json:"encryptedData"`
AESProperties string `json:"aesProperties"`
AES bool `json:"aes"`
}
func pkcs7Pad(data []byte, blockSize int) []byte {
padding := blockSize - len(data)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(data, padtext...)
}
func (a *AES) EncryptWithAES(data string, r *RSA, usePkcsPadding bool) (*EncryptedResponse, error) {
block, err := aes.NewCipher(a.key)
if err != nil {
return nil, err
}
dataInBytes := []byte(data)
padded := pkcs7Pad(dataInBytes, aes.BlockSize)
cbc := cipher.NewCBCEncrypter(block, a.iv)
ciphertext := make([]byte, len(padded))
cbc.CryptBlocks(ciphertext, padded)
ivBase64 := base64.StdEncoding.EncodeToString(a.iv)
keyBase64 := base64.StdEncoding.EncodeToString(a.key)
properties := fmt.Sprintf("%s.%s", keyBase64, ivBase64)
encryptedData := base64.StdEncoding.EncodeToString(ciphertext)
aesProperties, err := r.EncryptUsingPublicKey(properties, usePkcsPadding)
if err != nil {
return nil, err
}
return &EncryptedResponse{
EncryptedData: encryptedData,
AESProperties: aesProperties,
AES: true,
}, nil
}
func (a *AES) DecryptWithAES(encryptedAESProperties, encryptedData string, r *RSA, usePkcsPadding bool) (string, error) {
decryptedProperties, err := r.DecryptUsingPrivateKey(encryptedAESProperties, usePkcsPadding)
if err != nil {
return "", err
}
parts := strings.Split(decryptedProperties, ".")
if len(parts) != 2 {
return "", fmt.Errorf("invalid AES properties format")
}
iv, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("failed to decode IV: %v", err)
}
key, err := base64.StdEncoding.DecodeString(parts[0])
if err != nil {
return "", fmt.Errorf("failed to decode key: %v", err)
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
cbc := cipher.NewCBCDecrypter(block, iv)
if err != nil {
return "", err
}
payload, err := base64.StdEncoding.DecodeString(encryptedData)
if err != nil {
return "", err
}
decryptedData := make([]byte, len(payload))
cbc.CryptBlocks(decryptedData, payload)
return string(decryptedData), nil
}