#include "AudioOutputManager.h"

#include <mmdeviceapi.h>
#include <comip.h>
#include <stdexcept>
#include <iostream>
#include <string>
#include <comdef.h>
#include <atlstr.h>
#include "propvarutil.h"
#include <locale> 
#include <codecvt>


#include <Functiondiscoverykeys_devpkey.h>
#include <endpointvolume.h>
#include "mmreg.h"

//#include "PolicyConfig.h"



std::string wstringToString(std::wstring string_to_convert) // thanks zumalifeguard @ stackoverflow
{
  //setup converter
  using convert_type = std::codecvt_utf8<wchar_t>;
  std::wstring_convert<convert_type, wchar_t> converter;

  //use converter (.to_bytes: wstr->str, .from_bytes: str->wstr)
  return  converter.to_bytes(string_to_convert);
}


// audioEnpoint object, nothing fancy, it's just a container.

AudioEndpoint::AudioEndpoint(std::wstring  deviceid, std::wstring  friendlyName, std::wstring description, bool isDefault)
{ this->deviceid = deviceid;
  this->friendlyName = friendlyName;
  this->description = description;
  this->isDefault = isDefault;
}

// implementation for audioenpoint volume notifications

class AudioEndpointVolumeNotificationImplementation : public IAudioEndpointVolumeCallback
{
  AudioOutputManager* manager;
public:
  AudioEndpointVolumeNotificationImplementation(AudioOutputManager* manager);
  ULONG STDMETHODCALLTYPE AddRef();
  ULONG STDMETHODCALLTYPE Release();
  HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, VOID** ppvInterface);
  HRESULT STDMETHODCALLTYPE OnNotify(PAUDIO_VOLUME_NOTIFICATION_DATA pNotify);
};
 

AudioEndpointVolumeNotificationImplementation::AudioEndpointVolumeNotificationImplementation(AudioOutputManager* manager)
  {
    this->manager = manager;
  }
  
 
  // IUnknown methods -- AddRef, Release, and QueryInterface

  ULONG STDMETHODCALLTYPE AudioEndpointVolumeNotificationImplementation::AddRef()
  {
    return 0;
  }

  ULONG STDMETHODCALLTYPE AudioEndpointVolumeNotificationImplementation:: Release()
  {
    return 0;
  }

  HRESULT STDMETHODCALLTYPE AudioEndpointVolumeNotificationImplementation::QueryInterface(REFIID riid, VOID** ppvInterface)
  {
    
    return S_OK;
  }

  // Callback method for endpoint-volume-change notifications.

  HRESULT STDMETHODCALLTYPE AudioEndpointVolumeNotificationImplementation::OnNotify(PAUDIO_VOLUME_NOTIFICATION_DATA pNotify)
  {
    if (pNotify == NULL)
    {
      return E_INVALIDARG;
    }
    this->manager->refresh(VOLUMECHANGE);
    //std::cout << " volume change  " << pNotify->fMasterVolume << "  " << pNotify->bMuted << std::endl;
    return S_OK;
  }




/*
 *  Audio endpoint change callbacks
 */


class NotificationClientImplementation2 : public IMMNotificationClient
{
  AudioOutputManager* manager ;

public:

  NotificationClientImplementation2(AudioOutputManager* manager);
  HRESULT  STDMETHODCALLTYPE OnDefaultDeviceChanged(EDataFlow flow, ERole     role, LPCWSTR   pwstrDefaultDeviceId);
  HRESULT STDMETHODCALLTYPE OnDeviceAdded(LPCWSTR pwstrDeviceId);
  HRESULT STDMETHODCALLTYPE OnDeviceRemoved(LPCWSTR pwstrDeviceId);
  HRESULT STDMETHODCALLTYPE OnDeviceStateChanged(LPCWSTR pwstrDeviceId, DWORD   dwNewState);
  HRESULT STDMETHODCALLTYPE  OnPropertyValueChanged(LPCWSTR           pwstrDeviceId, const PROPERTYKEY key);
  HRESULT STDMETHODCALLTYPE QueryInterface(REFIID   riid, LPVOID* ppvObj);
  ULONG STDMETHODCALLTYPE AddRef();
  ULONG STDMETHODCALLTYPE Release();
};

NotificationClientImplementation2::NotificationClientImplementation2(AudioOutputManager* manager) 
   {
     this->manager = manager;
   }
HRESULT  STDMETHODCALLTYPE NotificationClientImplementation2::OnDefaultDeviceChanged(EDataFlow flow, ERole     role, LPCWSTR   pwstrDefaultDeviceId)
   {
    this->manager->refresh(ENDPOINTCHANGE);
    return NOERROR;
   }
HRESULT STDMETHODCALLTYPE NotificationClientImplementation2::OnDeviceAdded(LPCWSTR pwstrDeviceId) 
   {
    this->manager->refresh(ENDPOINTCHANGE);
    return NOERROR;
   }
HRESULT STDMETHODCALLTYPE NotificationClientImplementation2::OnDeviceRemoved(LPCWSTR pwstrDeviceId) 
   {
    this->manager->refresh(ENDPOINTCHANGE);
    return NOERROR;
   }
HRESULT STDMETHODCALLTYPE NotificationClientImplementation2::OnDeviceStateChanged(LPCWSTR pwstrDeviceId, DWORD   dwNewState) 
   {
    this->manager->refresh(ENDPOINTCHANGE);
    return NOERROR;
   }
HRESULT STDMETHODCALLTYPE  NotificationClientImplementation2::OnPropertyValueChanged(LPCWSTR           pwstrDeviceId, const PROPERTYKEY key) 
  { return NOERROR; }
HRESULT STDMETHODCALLTYPE NotificationClientImplementation2::QueryInterface(REFIID   riid, LPVOID* ppvObj)
  { return NOERROR; }
ULONG STDMETHODCALLTYPE NotificationClientImplementation2::AddRef() { return 0; }
ULONG STDMETHODCALLTYPE NotificationClientImplementation2::Release() { return 0; }



// set get/user data : allow to store arbitrary data
void    AudioOutputManager::set_userData(void* data)
{
  this->userData = data;
}
void* AudioOutputManager::get_userData()
{
  return this->userData;
}

// retun the AudioEndpoint count
int AudioOutputManager::get_count(void) { return (int)this->endpoints.size();}

// return a human readable name for AudioEndpoint #index
string AudioOutputManager::get_FriendlyName(int index)
{ if ((index<0)||(index>= this->endpoints.size())) throw std::runtime_error("invalid endpoint index");
  return wstringToString(this->endpoints[index]->friendlyName);
}

// set AudioEndpoint #index as default one
void     AudioOutputManager::set_defaultOutputAbs(int index)
{ if ((index < 0) || (index >= this->endpoints.size())) throw std::runtime_error("invalid endpoint index");
  this->SetDefaultEndpointOneRole(this->endpoints[index]->deviceid);
}
// set AudioEndpoint #current_index+index_ofset   as default one
void    AudioOutputManager::set_defaultOutputRel(int index_ofset)
{
  if (index_ofset == 0) return;
  int count = (int)this->endpoints.size();
  int current = this->get_defaultOutput();
  if (current < 0) return;
  while (index_ofset < 0) index_ofset += count;
  this->set_defaultOutputAbs((current + index_ofset) % count);
}
// return default  AudioEndpoint index 
int AudioOutputManager::get_defaultOutput(void)
{ for (unsigned i = 0; i < this->endpoints.size(); i++)
  {
    if (this->endpoints[i]->isDefault) return i;
  }
  return -1;
}

// called when AudioEndpoint volume has changed

void  AudioOutputManager::defaultEndPointVolumeChanged(float volume, bool muted)
 {
  this->DefaultEndPointVolume = volume;
  this->DefaultEndPointMuted = muted;
 }


inline string GetMessageForHresult(HRESULT hr) //  thanks   nietras @ stackoverflow 
 { _com_error error(hr); 
   std::string cs = std::string((char*)error.ErrorMessage());
   return cs; 
 }


// construtor

const CLSID CLSID_MMDeviceEnumerator = __uuidof(MMDeviceEnumerator);
const IID IID_IMMDeviceEnumerator = __uuidof(IMMDeviceEnumerator);

AudioOutputManager::AudioOutputManager(AudioEndPointManagerChangeCallbackPtr changeCallback,void* userdata)
{
  this->initdone = "Marsupilami";
  this->userData = userdata;
  this->changeCallback = changeCallback;
  this->notifications = new NotificationClientImplementation2(this);
  this->volNotification = new  AudioEndpointVolumeNotificationImplementation(this);
  this->m_device_state = DEVICE_STATE_ACTIVE;
  this->m_device_type  = ::eRender;
  this->pEndptVol = NULL;
  this->endpoints = {};


  HRESULT Result = CoInitializeEx(nullptr, COINIT_APARTMENTTHREADED);
 // Result = pDeviceEnumerator.CreateInstance(__uuidof(MMDeviceEnumerator));
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));

  Result = CoCreateInstance( CLSID_MMDeviceEnumerator, NULL, CLSCTX_ALL, IID_IMMDeviceEnumerator, (void**)&pDeviceEnumerator);
  if (FAILED(Result))  throw std::runtime_error (GetMessageForHresult(Result));

  pDeviceEnumerator->RegisterEndpointNotificationCallback(this->notifications);
  this->refresh(NONE);

  
}

// destructor

AudioOutputManager::~AudioOutputManager()
{
  pDeviceEnumerator->UnregisterEndpointNotificationCallback(this->notifications);

}

// internal, used to call a user callback when somethign about  an audioendpoint has changed
void AudioOutputManager::refresh(ChangeCause cause)
{
  this->EnumerateEndpoints(pDeviceEnumerator, this->m_device_state, this->m_device_type);
  if (this->changeCallback != NULL)  (*this->changeCallback)(this,cause);
}

// construct a list of all audioendpoint

void AudioOutputManager::EnumerateEndpoints( IMMDeviceEnumerator* pDeviceEnumerator, unsigned deviceState,EDataFlow deviceFlow )
{
  HRESULT Result = CoInitializeEx(nullptr, COINIT_APARTMENTTHREADED);

  IMMDeviceCollection* pDeviceCollection;
  Result = pDeviceEnumerator->EnumAudioEndpoints(deviceFlow, deviceState, &pDeviceCollection);
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
   
  // find out the defualt end point ID
  IMMDevice* pDefaultEndpoint;
  Result = pDeviceEnumerator->GetDefaultAudioEndpoint(deviceFlow, eMultimedia, &pDefaultEndpoint);
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
  WCHAR* wszDefaultDeviceId = nullptr;
  Result = pDefaultEndpoint->GetId(&wszDefaultDeviceId);
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
  std::wstring DEfaultEndpointId(wszDefaultDeviceId);
    

  if (this->pEndptVol != NULL)
  {
    Result =  this->pEndptVol->UnregisterControlChangeNotify((IAudioEndpointVolumeCallback*)this->volNotification);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
    this->pEndptVol = NULL;
  }

  Result = pDefaultEndpoint->Activate(__uuidof(IAudioEndpointVolume), CLSCTX_ALL, NULL, (void**)&this->pEndptVol);
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));

  Result = this->pEndptVol->RegisterControlChangeNotify( (IAudioEndpointVolumeCallback*)this->volNotification);
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
 
  IAudioEndpointVolume* endpointVolume = NULL;
  Result = pDefaultEndpoint->Activate(__uuidof(IAudioEndpointVolume), CLSCTX_INPROC_SERVER, NULL, (LPVOID*)&endpointVolume);
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));

  Result = endpointVolume->GetMasterVolumeLevelScalar(&this->DefaultEndPointVolume);
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));

  Result = endpointVolume->GetMute(&this->DefaultEndPointMuted);
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));

  UINT nCount = 0;
  Result = pDeviceCollection->GetCount(&nCount);
   if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result)); 
 
   std::vector<AudioEndpoint*> newendpoints = {};
   std::vector<AudioEndpoint*> oldendpoints = this->endpoints;
   newendpoints.resize(nCount);
 

   
  for (unsigned i = 0; i < nCount; ++i)
  {
    IMMDevice* pDevice;
    Result = pDeviceCollection->Item(i, &pDevice);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));

    DWORD dwState;
    Result = pDevice->GetState(&dwState);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));


    WCHAR* wszDeviceId = nullptr;
    Result = pDevice->GetId(&wszDeviceId);

    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
    std::wstring deviceid(wszDeviceId);
   

    IPropertyStore*  pPropertyStore;
    Result = pDevice->OpenPropertyStore(STGM_READ, &pPropertyStore);

    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));

  
    PROPVARIANT DeviceDescProp;
    Result = pPropertyStore->GetValue(PKEY_Device_DeviceDesc, &DeviceDescProp);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));   
    WCHAR strbuffer[128];
    Result = PropVariantToString(DeviceDescProp, strbuffer, ARRAYSIZE(strbuffer));
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
    std::wstring description(strbuffer);
   

    PROPVARIANT FriendlyNameProp;
    Result = pPropertyStore->GetValue(PKEY_Device_FriendlyName, &FriendlyNameProp);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
    Result = PropVariantToString(FriendlyNameProp, strbuffer, ARRAYSIZE(strbuffer));
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
    std::wstring friendlyName(strbuffer);
    pPropertyStore->Release();

    newendpoints[i] = new AudioEndpoint(deviceid, friendlyName, description, deviceid == DEfaultEndpointId);

  }
  this->endpoints = newendpoints;
  for (unsigned i = 0; i < oldendpoints.size(); i++)
    delete oldendpoints[i];
  oldendpoints.clear();

  pDefaultEndpoint->Release();
  pDeviceCollection->Release();




}

// retreive te volume of the default audioendpoint volume

IAudioEndpointVolume* AudioOutputManager::get_DefaultAudioEndpointVolume()
{
    IAudioEndpointVolume* endpointVolume = NULL;
    IMMDevice* pDefaultEndpoint;

    HRESULT Result = pDeviceEnumerator->GetDefaultAudioEndpoint(::eRender, eMultimedia, &pDefaultEndpoint);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));

    
    Result = pDefaultEndpoint->Activate(__uuidof(IAudioEndpointVolume), CLSCTX_INPROC_SERVER, NULL, (LPVOID*)&endpointVolume);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
    pDefaultEndpoint->Release();
    return  endpointVolume;
}

// returns the  default audioendpoint volume
float AudioOutputManager::get_volume()
{
    float currentVolume = 0;
    IAudioEndpointVolume* endpointVolume = this->get_DefaultAudioEndpointVolume();
    HRESULT Result = endpointVolume->GetMasterVolumeLevelScalar(&currentVolume);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
    return currentVolume;
}

// sets the  default audioendpoint volume
void AudioOutputManager::set_volume(float volumeOfset)
{

  IAudioEndpointVolume* endpointVolume = this->get_DefaultAudioEndpointVolume();
  

  float currentVolume = 0;

  HRESULT Result = endpointVolume->GetMasterVolumeLevelScalar(&currentVolume);
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));

  float newVolume = currentVolume + volumeOfset;
  if (newVolume < 0) newVolume = 0;
  if (newVolume > 1) newVolume = 1;

  Result = endpointVolume->SetMasterVolumeLevelScalar((float)newVolume, NULL);
  if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
}

// mutes/unmute the  default audioendpoint 
void     AudioOutputManager::set_mute(BOOL state)
{

    IAudioEndpointVolume* endpointVolume = this->get_DefaultAudioEndpointVolume();

    HRESULT Result = endpointVolume->SetMute(state, NULL);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
}
// return  the  default audioendpoint mutes/unmute state 
BOOL  AudioOutputManager::get_mute()
{
    BOOL res = false;
    IAudioEndpointVolume* endpointVolume = this->get_DefaultAudioEndpointVolume();

    HRESULT Result = endpointVolume->GetMute(&res);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
    return res;
}



//  The following code is used the change the default audioendpoint.
//  It was copied from DefSound library, original author appear to
//  be someone going by the name of  EreTIk
//  (see github.com/Belphemur/AudioEndPointLibrary)


interface DECLSPEC_UUID("f8679f50-850a-41cf-9c72-430f290290c8") IPolicyConfig;
class DECLSPEC_UUID("870af99c-171d-4f9e-af0d-e63df40c2bc9") CPolicyConfigClient;
// ----------------------------------------------------------------------------
// class CPolicyConfigClient
// {870af99c-171d-4f9e-af0d-e63df40c2bc9}
//  
// interface IPolicyConfig
// {f8679f50-850a-41cf-9c72-430f290290c8}
//
// Query interface:
// CComPtr<IPolicyConfig> PolicyConfig;
// PolicyConfig.CoCreateInstance(__uuidof(CPolicyConfigClient));
// 
// @compatible: Windows 7 and Later
// ----------------------------------------------------------------------------
interface IPolicyConfig : public IUnknown
{
public:
  
  virtual HRESULT GetMixFormat(
    PCWSTR,
    WAVEFORMATEX**
  );
  
  virtual HRESULT STDMETHODCALLTYPE GetDeviceFormat(
    PCWSTR,
    INT,
    WAVEFORMATEX**
  );

  virtual HRESULT STDMETHODCALLTYPE ResetDeviceFormat(
    PCWSTR
  );
  
  virtual HRESULT STDMETHODCALLTYPE SetDeviceFormat(
    PCWSTR,
    WAVEFORMATEX*,
    WAVEFORMATEX*
  );

  virtual HRESULT STDMETHODCALLTYPE GetProcessingPeriod(
    PCWSTR,
    INT,
    PINT64,
    PINT64
  );

  virtual HRESULT STDMETHODCALLTYPE SetProcessingPeriod(
    PCWSTR,
    PINT64
  );

  virtual HRESULT STDMETHODCALLTYPE GetShareMode(
    PCWSTR,
    struct DeviceShareMode*
  );

  virtual HRESULT STDMETHODCALLTYPE SetShareMode(
    PCWSTR,
    struct DeviceShareMode*
  );

  virtual HRESULT STDMETHODCALLTYPE GetPropertyValue(
    PCWSTR,
    const PROPERTYKEY&,
    PROPVARIANT*
  );

  virtual HRESULT STDMETHODCALLTYPE SetPropertyValue(
    PCWSTR,
    const PROPERTYKEY&,
    PROPVARIANT*
  );
  
  virtual HRESULT STDMETHODCALLTYPE SetDefaultEndpoint(
    __in PCWSTR wszDeviceId,
    __in ERole eRole
  );
  
  virtual HRESULT STDMETHODCALLTYPE SetEndpointVisibility(
    PCWSTR,
    INT
  );
};

typedef _com_ptr_t< _com_IIID<IPolicyConfig, &__uuidof(IPolicyConfig)> > CPolicyConfigPtr;

void AudioOutputManager::SetDefaultEndpointOneRole( std::wstring id )
  {
    CPolicyConfigPtr pPolicyConfig;
    HRESULT Result = pPolicyConfig.CreateInstance(__uuidof(CPolicyConfigClient));
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
      

    Result = pPolicyConfig->SetDefaultEndpoint(id.c_str(), ::eMultimedia);
    if (FAILED(Result))  throw std::runtime_error(GetMessageForHresult(Result));
      

    
  }
  

  
 

  

