Skip to main content

RSA


package encryption_utils

import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
)

// RSA struct to hold keys
type RSA struct {
privateKey *rsa.PrivateKey
publicKey *rsa.PublicKey
}

// Initialize RSA with given keys
func NewRSA(privateKey, publicKey string) *RSA {
rsa := &RSA{}
prKey, _ := loadPrivateKey(privateKey)
if prKey != nil {
rsa.privateKey = prKey
}

pbKey, _ := loadPublicKey(publicKey)
if pbKey != nil {
rsa.publicKey = pbKey
}

return rsa
}

func (r *RSA) EncryptUsingPublicKey(data string, usePkcsPadding bool) (encryptedData string, err error) {
var encryptedBytes []byte

if r.publicKey == nil {
return "", fmt.Errorf("public key is not set")
}

// Choose padding scheme
if usePkcsPadding {
encryptedBytes, err = rsa.EncryptPKCS1v15(rand.Reader, r.publicKey, []byte(data)) // PKCS#1 v1.5 padding
} else {
hash := sha256.New()
encryptedBytes, err = rsa.EncryptOAEP(hash, rand.Reader, r.publicKey, []byte(data), nil) // OAEP padding
}

encryptedData = base64.StdEncoding.EncodeToString(encryptedBytes) // Encode to base64 for easy transport

return encryptedData, err
}

func (r *RSA) DecryptUsingPrivateKey(encryptedData string, usePkcsPadding bool) (decryptedData string, err error) {
if r.privateKey == nil {
return "", fmt.Errorf("private key is not set")
}

encryptedBytes, err := base64.StdEncoding.DecodeString(encryptedData) // Decode from base64
if err != nil {
return "", fmt.Errorf("failed to decode base64 encoded data: %v", err)
}

var decryptedBytes []byte
if usePkcsPadding {
decryptedBytes, err = rsa.DecryptPKCS1v15(rand.Reader, r.privateKey, encryptedBytes) // PKCS#1 v1.5 padding
if err != nil {
return "", fmt.Errorf("failed to decrypt data using PKCS#1 v1.5: %v", err)
}
} else {
hash := sha256.New() // or any other hash function from the crypto/hash package
decryptedBytes, err = rsa.DecryptOAEP(hash, rand.Reader, r.privateKey, encryptedBytes, nil) // OAEP padding
if err != nil {
return "", fmt.Errorf("failed to decrypt data using OAEP: %v", err)
}
}

decryptedData = string(decryptedBytes)

return decryptedData, err
}

// Helper functions to load keys from PEM encoded strings
func loadPrivateKey(privateKey string) (*rsa.PrivateKey, error) {
pemPrivateKey := "-----BEGIN PRIVATE KEY-----\n" + privateKey + "\n-----END PRIVATE KEY-----"

block, _ := pem.Decode([]byte(pemPrivateKey))
if block == nil || block.Type != "PRIVATE KEY" {
return nil, fmt.Errorf("failed to decode PEM block containing private key")
}

key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("not an RSA private key")
}
return rsaKey, nil
}

// Helper function to load public key from PEM encoded string
func loadPublicKey(publicKey string) (*rsa.PublicKey, error) {
pemPublicKey := "-----BEGIN PUBLIC KEY-----\n" + publicKey + "\n-----END PUBLIC KEY-----"
block, _ := pem.Decode([]byte(pemPublicKey))
if block == nil || block.Type != "PUBLIC KEY" {
return nil, fmt.Errorf("failed to decode PEM block containing public key")
}

key, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, err
}
rsaKey, ok := key.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("not an RSA public key")
}
return rsaKey, nil
}